|
|
@@ -13,8 +13,7 @@ from .utilities import * |
|
|
|
|
|
|
|
|
|
|
|
def preprocess(): |
|
|
|
|
|
|
|
word_dict, char_dict = create_word_char_dict("valid.txt", "train.txt", "test.txt") |
|
|
|
word_dict, char_dict = create_word_char_dict("valid.txt", "train.txt", "tests.txt") |
|
|
|
num_words = len(word_dict) |
|
|
|
num_char = len(char_dict) |
|
|
|
char_dict["BOW"] = num_char+1 |
|
|
@@ -195,7 +194,7 @@ if __name__=="__main__": |
|
|
|
if os.path.exists("cache/data_sets.pt") is False: |
|
|
|
train_text = read_data("./train.txt") |
|
|
|
valid_text = read_data("./valid.txt") |
|
|
|
test_text = read_data("./test.txt") |
|
|
|
test_text = read_data("./tests.txt") |
|
|
|
|
|
|
|
train_set = np.array(text2vec(train_text, char_dict, max_word_len)) |
|
|
|
valid_set = np.array(text2vec(valid_text, char_dict, max_word_len)) |
|
|
@@ -206,14 +205,14 @@ if __name__=="__main__": |
|
|
|
valid_label = np.array([word_dict[w] for w in valid_text[1:]] + [word_dict[valid_text[-1]]]) |
|
|
|
test_label = np.array([word_dict[w] for w in test_text[1:]] + [word_dict[test_text[-1]]]) |
|
|
|
|
|
|
|
category = {"tdata":train_set, "vdata":valid_set, "test": test_set, |
|
|
|
category = {"tdata": train_set, "vdata": valid_set, "tests": test_set, |
|
|
|
"trlabel":train_label, "vlabel":valid_label, "tlabel":test_label} |
|
|
|
torch.save(category, "cache/data_sets.pt") |
|
|
|
else: |
|
|
|
data_sets = torch.load("cache/data_sets.pt") |
|
|
|
train_set = data_sets["tdata"] |
|
|
|
valid_set = data_sets["vdata"] |
|
|
|
test_set = data_sets["test"] |
|
|
|
test_set = data_sets["tests"] |
|
|
|
train_label = data_sets["trlabel"] |
|
|
|
valid_label = data_sets["vlabel"] |
|
|
|
test_label = data_sets["tlabel"] |
|
|
|