@@ -1,3 +1,4 @@ | |||||
class Action(object): | class Action(object): | ||||
""" | """ | ||||
base class for Trainer and Tester | base class for Trainer and Tester | ||||
@@ -14,6 +14,7 @@ class Tester(Action): | |||||
self.test_args = test_args | self.test_args = test_args | ||||
self.args_dict = {name: value for name, value in self.test_args.__dict__.iteritems()} | self.args_dict = {name: value for name, value in self.test_args.__dict__.iteritems()} | ||||
self.mean_loss = None | self.mean_loss = None | ||||
self.output = None | |||||
def test(self, network, data): | def test(self, network, data): | ||||
# transform into network input and label | # transform into network input and label | ||||
@@ -22,6 +23,7 @@ class Tester(Action): | |||||
# split into batches by self.batch_size | # split into batches by self.batch_size | ||||
iterations, test_batch_generator = self.batchify(X, Y) | iterations, test_batch_generator = self.batchify(X, Y) | ||||
batch_output = list() | |||||
loss_history = list() | loss_history = list() | ||||
# turn on the testing mode of the network | # turn on the testing mode of the network | ||||
network.mode(test=True) | network.mode(test=True) | ||||
@@ -31,6 +33,7 @@ class Tester(Action): | |||||
# forward pass from tests input to predicted output | # forward pass from tests input to predicted output | ||||
prediction = network.data_forward(batch_x) | prediction = network.data_forward(batch_x) | ||||
batch_output.append(prediction) | |||||
# get the loss | # get the loss | ||||
loss = network.loss(batch_y, prediction) | loss = network.loss(batch_y, prediction) | ||||
@@ -39,7 +42,16 @@ class Tester(Action): | |||||
self.log(self.make_log(step, loss)) | self.log(self.make_log(step, loss)) | ||||
self.mean_loss = np.mean(np.array(loss_history)) | self.mean_loss = np.mean(np.array(loss_history)) | ||||
self.output = self.make_output(batch_output) | |||||
@property | @property | ||||
def loss(self): | def loss(self): | ||||
return self.mean_loss | return self.mean_loss | ||||
@property | |||||
def result(self): | |||||
return self.output | |||||
def make_output(self, batch_output): | |||||
# construct full prediction with batch outputs | |||||
raise NotImplementedError |
@@ -1,4 +1,5 @@ | |||||
from action.action import Action | from action.action import Action | ||||
from action.tester import Tester | |||||
class Trainer(Action): | class Trainer(Action): | ||||
@@ -6,9 +7,54 @@ class Trainer(Action): | |||||
Trainer for common training logic of all models | Trainer for common training logic of all models | ||||
""" | """ | ||||
def __init__(self, arg): | |||||
def __init__(self, train_args): | |||||
""" | |||||
:param train_args: namedtuple | |||||
""" | |||||
super(Trainer, self).__init__() | super(Trainer, self).__init__() | ||||
self.arg = arg | |||||
self.train_args = train_args | |||||
self.args_dict = {name: value for name, value in self.train_args.__dict__.iteritems()} | |||||
self.n_epochs = self.train_args.epochs | |||||
self.validate = True | |||||
self.save_when_better = True | |||||
def train(self, args): | |||||
def train(self, network, data, dev_data): | |||||
X, Y = network.prepare_input(data) | |||||
iterations, train_batch_generator = self.batchify(X, Y) | |||||
loss_history = list() | |||||
network.mode(test=False) | |||||
test_args = "..." | |||||
evaluator = Tester(test_args) | |||||
best_loss = 1e10 | |||||
for epoch in range(self.n_epochs): | |||||
for step in range(iterations): | |||||
batch_x, batch_y = train_batch_generator.__next__() | |||||
prediction = network.data_forward(batch_x) | |||||
loss = network.loss(batch_y, prediction) | |||||
network.grad_backward() | |||||
loss_history.append(loss) | |||||
self.log(self.make_log(epoch, step, loss)) | |||||
# evaluate over dev set | |||||
if self.validate: | |||||
evaluator.test(network, dev_data) | |||||
self.log(self.make_valid_log(epoch, evaluator.loss)) | |||||
if evaluator.loss < best_loss: | |||||
best_loss = evaluator.loss | |||||
if self.save_when_better: | |||||
self.save_model(network) | |||||
# finish training | |||||
def make_log(self, *args): | |||||
raise NotImplementedError | |||||
def make_valid_log(self, *args): | |||||
raise NotImplementedError | |||||
def save_model(self, model): | |||||
raise NotImplementedError | raise NotImplementedError |
@@ -18,3 +18,30 @@ class BaseModel(object): | |||||
def loss(self, pred, truth): | def loss(self, pred, truth): | ||||
raise NotImplementedError | raise NotImplementedError | ||||
class Vocabulary(object): | |||||
""" | |||||
A collection of lookup tables. | |||||
""" | |||||
def __init__(self): | |||||
self.word_set = None | |||||
self.word2idx = None | |||||
self.emb_matrix = None | |||||
def lookup(self, word): | |||||
if word in self.word_set: | |||||
return self.emb_matrix[self.word2idx[word]] | |||||
return LookupError("The key " + word + " does not exist.") | |||||
class Document(object): | |||||
""" | |||||
contains a sequence of tokens | |||||
each token is a character with linguistic attributes | |||||
""" | |||||
def __init__(self): | |||||
# wrap pandas.dataframe | |||||
self.dataframe = None |