@@ -0,0 +1,4 @@ | |||||
<?xml version="1.0" encoding="UTF-8"?> | |||||
<project version="4"> | |||||
<component name="PublishConfigData" persistUploadOnCheckin="false" /> | |||||
</project> |
@@ -0,0 +1,11 @@ | |||||
<?xml version="1.0" encoding="UTF-8"?> | |||||
<module type="PYTHON_MODULE" version="4"> | |||||
<component name="NewModuleRootManager"> | |||||
<content url="file://$MODULE_DIR$" /> | |||||
<orderEntry type="inheritedJdk" /> | |||||
<orderEntry type="sourceFolder" forTests="false" /> | |||||
</component> | |||||
<component name="TestRunnerService"> | |||||
<option name="PROJECT_TEST_RUNNER" value="Unittests" /> | |||||
</component> | |||||
</module> |
@@ -0,0 +1,4 @@ | |||||
<?xml version="1.0" encoding="UTF-8"?> | |||||
<project version="4"> | |||||
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.5 (PCA_emb)" project-jdk-type="Python SDK" /> | |||||
</project> |
@@ -0,0 +1,8 @@ | |||||
<?xml version="1.0" encoding="UTF-8"?> | |||||
<project version="4"> | |||||
<component name="ProjectModuleManager"> | |||||
<modules> | |||||
<module fileurl="file://$PROJECT_DIR$/.idea/fastNLP.iml" filepath="$PROJECT_DIR$/.idea/fastNLP.iml" /> | |||||
</modules> | |||||
</component> | |||||
</project> |
@@ -0,0 +1,6 @@ | |||||
<?xml version="1.0" encoding="UTF-8"?> | |||||
<project version="4"> | |||||
<component name="VcsDirectoryMappings"> | |||||
<mapping directory="$PROJECT_DIR$" vcs="Git" /> | |||||
</component> | |||||
</project> |
@@ -1,2 +1,2 @@ | |||||
# FastNLP | # FastNLP | ||||
FastNLP | |||||
FastNLP |
@@ -1,4 +1,3 @@ | |||||
Some useful reference: | |||||
SpaCy "Doc" | SpaCy "Doc" | ||||
https://github.com/explosion/spaCy/blob/75d2a05c2938f412f0fae44748374e4de19cc2be/spacy/tokens/doc.pyx#L80 | https://github.com/explosion/spaCy/blob/75d2a05c2938f412f0fae44748374e4de19cc2be/spacy/tokens/doc.pyx#L80 | ||||
@@ -8,10 +8,10 @@ class Action(object): | |||||
self.logger = None | self.logger = None | ||||
def load_config(self, args): | def load_config(self, args): | ||||
pass | |||||
raise NotImplementedError | |||||
def load_dataset(self, args): | def load_dataset(self, args): | ||||
pass | |||||
raise NotImplementedError | |||||
def log(self, args): | def log(self, args): | ||||
self.logger.log(args) | self.logger.log(args) | ||||
@@ -22,7 +22,7 @@ class Action(object): | |||||
def batchify(self, X, Y=None): | def batchify(self, X, Y=None): | ||||
# a generator | # a generator | ||||
pass | |||||
raise NotImplementedError | |||||
def make_log(self, *args): | def make_log(self, *args): | ||||
pass | |||||
raise NotImplementedError |
@@ -29,7 +29,7 @@ class Tester(Action): | |||||
for step in range(iterations): | for step in range(iterations): | ||||
batch_x, batch_y = test_batch_generator.__next__() | batch_x, batch_y = test_batch_generator.__next__() | ||||
# forward pass from test input to predicted output | |||||
# forward pass from tests input to predicted output | |||||
prediction = network.data_forward(batch_x) | prediction = network.data_forward(batch_x) | ||||
# get the loss | # get the loss | ||||
@@ -11,4 +11,4 @@ class Trainer(Action): | |||||
self.arg = arg | self.arg = arg | ||||
def train(self, args): | def train(self, args): | ||||
pass | |||||
raise NotImplementedError |
@@ -10,5 +10,4 @@ class ConfigLoader(BaseLoader): | |||||
@staticmethod | @staticmethod | ||||
def parse(string): | def parse(string): | ||||
# To do | |||||
return string | |||||
raise NotImplementedError |
@@ -0,0 +1,20 @@ | |||||
class BaseModel(object): | |||||
"""base model for all models""" | |||||
def __init__(self): | |||||
pass | |||||
def prepare_input(self, data): | |||||
raise NotImplementedError | |||||
def mode(self, test=False): | |||||
raise NotImplementedError | |||||
def data_forward(self, x): | |||||
raise NotImplementedError | |||||
def grad_backward(self): | |||||
raise NotImplementedError | |||||
def loss(self, pred, truth): | |||||
raise NotImplementedError |
@@ -1,17 +1,12 @@ | |||||
import os | import os | ||||
import torch | |||||
import | |||||
import torch | |||||
import torch.nn as nn | import torch.nn as nn | ||||
import torchvision.datasets as dsets | |||||
import torchvision.transforms as transforms | |||||
import dataset as dst | |||||
from model import CNN_text | |||||
.dataset as dst | |||||
from .model import CNN_text | |||||
from torch.autograd import Variable | from torch.autograd import Variable | ||||
from sklearn import cross_validation | |||||
from sklearn import datasets | |||||
# Hyper Parameters | # Hyper Parameters | ||||
batch_size = 50 | batch_size = 50 | ||||
learning_rate = 0.0001 | learning_rate = 0.0001 | ||||
@@ -51,8 +46,7 @@ if cuda: | |||||
criterion = nn.CrossEntropyLoss() | criterion = nn.CrossEntropyLoss() | ||||
optimizer = torch.optim.Adam(cnn.parameters(), lr=learning_rate) | optimizer = torch.optim.Adam(cnn.parameters(), lr=learning_rate) | ||||
#train and test | |||||
# train and tests | |||||
best_acc = None | best_acc = None | ||||
for epoch in range(num_epochs): | for epoch in range(num_epochs): | ||||
@@ -1,12 +1,12 @@ | |||||
import os | import os | ||||
from collections import namedtuple | |||||
import numpy as np | |||||
import torch | import torch | ||||
from torch.autograd import Variable | |||||
import torch.nn as nn | import torch.nn as nn | ||||
import torch.nn.functional as F | |||||
import numpy as np | |||||
from model import charLM | |||||
from torch.autograd import Variable | |||||
from utilities import * | from utilities import * | ||||
from collections import namedtuple | |||||
def to_var(x): | def to_var(x): | ||||
if torch.cuda.is_available(): | if torch.cuda.is_available(): | ||||
@@ -76,18 +76,18 @@ if __name__ == "__main__": | |||||
if os.path.exists("cache/data_sets.pt") is False: | if os.path.exists("cache/data_sets.pt") is False: | ||||
test_text = read_data("./test.txt") | |||||
test_text = read_data("./tests.txt") | |||||
test_set = np.array(text2vec(test_text, char_dict, max_word_len)) | test_set = np.array(text2vec(test_text, char_dict, max_word_len)) | ||||
# Labels are next-word index in word_dict with the same length as inputs | # Labels are next-word index in word_dict with the same length as inputs | ||||
test_label = np.array([word_dict[w] for w in test_text[1:]] + [word_dict[test_text[-1]]]) | test_label = np.array([word_dict[w] for w in test_text[1:]] + [word_dict[test_text[-1]]]) | ||||
category = {"test": test_set, "tlabel":test_label} | |||||
category = {"tests": test_set, "tlabel": test_label} | |||||
torch.save(category, "cache/data_sets.pt") | torch.save(category, "cache/data_sets.pt") | ||||
else: | else: | ||||
data_sets = torch.load("cache/data_sets.pt") | data_sets = torch.load("cache/data_sets.pt") | ||||
test_set = data_sets["test"] | |||||
test_set = data_sets["tests"] | |||||
test_label = data_sets["tlabel"] | test_label = data_sets["tlabel"] | ||||
train_set = data_sets["tdata"] | train_set = data_sets["tdata"] | ||||
train_label = data_sets["trlabel"] | train_label = data_sets["trlabel"] | ||||
@@ -13,8 +13,7 @@ from .utilities import * | |||||
def preprocess(): | def preprocess(): | ||||
word_dict, char_dict = create_word_char_dict("valid.txt", "train.txt", "test.txt") | |||||
word_dict, char_dict = create_word_char_dict("valid.txt", "train.txt", "tests.txt") | |||||
num_words = len(word_dict) | num_words = len(word_dict) | ||||
num_char = len(char_dict) | num_char = len(char_dict) | ||||
char_dict["BOW"] = num_char+1 | char_dict["BOW"] = num_char+1 | ||||
@@ -195,7 +194,7 @@ if __name__=="__main__": | |||||
if os.path.exists("cache/data_sets.pt") is False: | if os.path.exists("cache/data_sets.pt") is False: | ||||
train_text = read_data("./train.txt") | train_text = read_data("./train.txt") | ||||
valid_text = read_data("./valid.txt") | valid_text = read_data("./valid.txt") | ||||
test_text = read_data("./test.txt") | |||||
test_text = read_data("./tests.txt") | |||||
train_set = np.array(text2vec(train_text, char_dict, max_word_len)) | train_set = np.array(text2vec(train_text, char_dict, max_word_len)) | ||||
valid_set = np.array(text2vec(valid_text, char_dict, max_word_len)) | valid_set = np.array(text2vec(valid_text, char_dict, max_word_len)) | ||||
@@ -206,14 +205,14 @@ if __name__=="__main__": | |||||
valid_label = np.array([word_dict[w] for w in valid_text[1:]] + [word_dict[valid_text[-1]]]) | valid_label = np.array([word_dict[w] for w in valid_text[1:]] + [word_dict[valid_text[-1]]]) | ||||
test_label = np.array([word_dict[w] for w in test_text[1:]] + [word_dict[test_text[-1]]]) | test_label = np.array([word_dict[w] for w in test_text[1:]] + [word_dict[test_text[-1]]]) | ||||
category = {"tdata":train_set, "vdata":valid_set, "test": test_set, | |||||
category = {"tdata": train_set, "vdata": valid_set, "tests": test_set, | |||||
"trlabel":train_label, "vlabel":valid_label, "tlabel":test_label} | "trlabel":train_label, "vlabel":valid_label, "tlabel":test_label} | ||||
torch.save(category, "cache/data_sets.pt") | torch.save(category, "cache/data_sets.pt") | ||||
else: | else: | ||||
data_sets = torch.load("cache/data_sets.pt") | data_sets = torch.load("cache/data_sets.pt") | ||||
train_set = data_sets["tdata"] | train_set = data_sets["tdata"] | ||||
valid_set = data_sets["vdata"] | valid_set = data_sets["vdata"] | ||||
test_set = data_sets["test"] | |||||
test_set = data_sets["tests"] | |||||
train_label = data_sets["trlabel"] | train_label = data_sets["trlabel"] | ||||
valid_label = data_sets["vlabel"] | valid_label = data_sets["vlabel"] | ||||
test_label = data_sets["tlabel"] | test_label = data_sets["tlabel"] | ||||
@@ -5,10 +5,10 @@ class BaseSaver(object): | |||||
self.save_path = save_path | self.save_path = save_path | ||||
def save_bytes(self): | def save_bytes(self): | ||||
pass | |||||
raise NotImplementedError | |||||
def save_str(self): | def save_str(self): | ||||
pass | |||||
raise NotImplementedError | |||||
def compress(self): | def compress(self): | ||||
pass | |||||
raise NotImplementedError |
@@ -8,4 +8,4 @@ class Logger(BaseSaver): | |||||
super(Logger, self).__init__(save_path) | super(Logger, self).__init__(save_path) | ||||
def log(self, string): | def log(self, string): | ||||
pass | |||||
raise NotImplementedError |