diff --git a/action/action.py b/action/action.py index 8473a1a2..cea9fa65 100644 --- a/action/action.py +++ b/action/action.py @@ -1,3 +1,4 @@ + class Action(object): """ base class for Trainer and Tester diff --git a/action/tester.py b/action/tester.py index 591d75ce..0bce1c5d 100644 --- a/action/tester.py +++ b/action/tester.py @@ -14,6 +14,7 @@ class Tester(Action): self.test_args = test_args self.args_dict = {name: value for name, value in self.test_args.__dict__.iteritems()} self.mean_loss = None + self.output = None def test(self, network, data): # transform into network input and label @@ -22,6 +23,7 @@ class Tester(Action): # split into batches by self.batch_size iterations, test_batch_generator = self.batchify(X, Y) + batch_output = list() loss_history = list() # turn on the testing mode of the network network.mode(test=True) @@ -31,6 +33,7 @@ class Tester(Action): # forward pass from tests input to predicted output prediction = network.data_forward(batch_x) + batch_output.append(prediction) # get the loss loss = network.loss(batch_y, prediction) @@ -39,7 +42,16 @@ class Tester(Action): self.log(self.make_log(step, loss)) self.mean_loss = np.mean(np.array(loss_history)) + self.output = self.make_output(batch_output) @property def loss(self): return self.mean_loss + + @property + def result(self): + return self.output + + def make_output(self, batch_output): + # construct full prediction with batch outputs + raise NotImplementedError diff --git a/action/trainer.py b/action/trainer.py index f4429a98..e00aff08 100644 --- a/action/trainer.py +++ b/action/trainer.py @@ -1,4 +1,5 @@ from action.action import Action +from action.tester import Tester class Trainer(Action): @@ -6,9 +7,54 @@ class Trainer(Action): Trainer for common training logic of all models """ - def __init__(self, arg): + def __init__(self, train_args): + """ + :param train_args: namedtuple + """ super(Trainer, self).__init__() - self.arg = arg + self.train_args = train_args + self.args_dict = {name: value for name, value in self.train_args.__dict__.iteritems()} + self.n_epochs = self.train_args.epochs + self.validate = True + self.save_when_better = True - def train(self, args): + def train(self, network, data, dev_data): + X, Y = network.prepare_input(data) + + iterations, train_batch_generator = self.batchify(X, Y) + loss_history = list() + network.mode(test=False) + + test_args = "..." + evaluator = Tester(test_args) + best_loss = 1e10 + + for epoch in range(self.n_epochs): + + for step in range(iterations): + batch_x, batch_y = train_batch_generator.__next__() + prediction = network.data_forward(batch_x) + loss = network.loss(batch_y, prediction) + network.grad_backward() + loss_history.append(loss) + self.log(self.make_log(epoch, step, loss)) + + # evaluate over dev set + if self.validate: + evaluator.test(network, dev_data) + self.log(self.make_valid_log(epoch, evaluator.loss)) + if evaluator.loss < best_loss: + best_loss = evaluator.loss + if self.save_when_better: + self.save_model(network) + + # finish training + + def make_log(self, *args): + raise NotImplementedError + + def make_valid_log(self, *args): + raise NotImplementedError + + def save_model(self, model): raise NotImplementedError diff --git a/model/base_model.py b/model/base_model.py index 3298c3d6..f4fdf120 100644 --- a/model/base_model.py +++ b/model/base_model.py @@ -18,3 +18,30 @@ class BaseModel(object): def loss(self, pred, truth): raise NotImplementedError + + +class Vocabulary(object): + """ + A collection of lookup tables. + """ + + def __init__(self): + self.word_set = None + self.word2idx = None + self.emb_matrix = None + + def lookup(self, word): + if word in self.word_set: + return self.emb_matrix[self.word2idx[word]] + return LookupError("The key " + word + " does not exist.") + + +class Document(object): + """ + contains a sequence of tokens + each token is a character with linguistic attributes + """ + + def __init__(self): + # wrap pandas.dataframe + self.dataframe = None