@@ -1,3 +1,4 @@ | |||
class Action(object): | |||
""" | |||
base class for Trainer and Tester | |||
@@ -14,6 +14,7 @@ class Tester(Action): | |||
self.test_args = test_args | |||
self.args_dict = {name: value for name, value in self.test_args.__dict__.iteritems()} | |||
self.mean_loss = None | |||
self.output = None | |||
def test(self, network, data): | |||
# transform into network input and label | |||
@@ -22,6 +23,7 @@ class Tester(Action): | |||
# split into batches by self.batch_size | |||
iterations, test_batch_generator = self.batchify(X, Y) | |||
batch_output = list() | |||
loss_history = list() | |||
# turn on the testing mode of the network | |||
network.mode(test=True) | |||
@@ -31,6 +33,7 @@ class Tester(Action): | |||
# forward pass from tests input to predicted output | |||
prediction = network.data_forward(batch_x) | |||
batch_output.append(prediction) | |||
# get the loss | |||
loss = network.loss(batch_y, prediction) | |||
@@ -39,7 +42,16 @@ class Tester(Action): | |||
self.log(self.make_log(step, loss)) | |||
self.mean_loss = np.mean(np.array(loss_history)) | |||
self.output = self.make_output(batch_output) | |||
@property | |||
def loss(self): | |||
return self.mean_loss | |||
@property | |||
def result(self): | |||
return self.output | |||
def make_output(self, batch_output): | |||
# construct full prediction with batch outputs | |||
raise NotImplementedError |
@@ -1,4 +1,5 @@ | |||
from action.action import Action | |||
from action.tester import Tester | |||
class Trainer(Action): | |||
@@ -6,9 +7,54 @@ class Trainer(Action): | |||
Trainer for common training logic of all models | |||
""" | |||
def __init__(self, arg): | |||
def __init__(self, train_args): | |||
""" | |||
:param train_args: namedtuple | |||
""" | |||
super(Trainer, self).__init__() | |||
self.arg = arg | |||
self.train_args = train_args | |||
self.args_dict = {name: value for name, value in self.train_args.__dict__.iteritems()} | |||
self.n_epochs = self.train_args.epochs | |||
self.validate = True | |||
self.save_when_better = True | |||
def train(self, args): | |||
def train(self, network, data, dev_data): | |||
X, Y = network.prepare_input(data) | |||
iterations, train_batch_generator = self.batchify(X, Y) | |||
loss_history = list() | |||
network.mode(test=False) | |||
test_args = "..." | |||
evaluator = Tester(test_args) | |||
best_loss = 1e10 | |||
for epoch in range(self.n_epochs): | |||
for step in range(iterations): | |||
batch_x, batch_y = train_batch_generator.__next__() | |||
prediction = network.data_forward(batch_x) | |||
loss = network.loss(batch_y, prediction) | |||
network.grad_backward() | |||
loss_history.append(loss) | |||
self.log(self.make_log(epoch, step, loss)) | |||
# evaluate over dev set | |||
if self.validate: | |||
evaluator.test(network, dev_data) | |||
self.log(self.make_valid_log(epoch, evaluator.loss)) | |||
if evaluator.loss < best_loss: | |||
best_loss = evaluator.loss | |||
if self.save_when_better: | |||
self.save_model(network) | |||
# finish training | |||
def make_log(self, *args): | |||
raise NotImplementedError | |||
def make_valid_log(self, *args): | |||
raise NotImplementedError | |||
def save_model(self, model): | |||
raise NotImplementedError |
@@ -18,3 +18,30 @@ class BaseModel(object): | |||
def loss(self, pred, truth): | |||
raise NotImplementedError | |||
class Vocabulary(object): | |||
""" | |||
A collection of lookup tables. | |||
""" | |||
def __init__(self): | |||
self.word_set = None | |||
self.word2idx = None | |||
self.emb_matrix = None | |||
def lookup(self, word): | |||
if word in self.word_set: | |||
return self.emb_matrix[self.word2idx[word]] | |||
return LookupError("The key " + word + " does not exist.") | |||
class Document(object): | |||
""" | |||
contains a sequence of tokens | |||
each token is a character with linguistic attributes | |||
""" | |||
def __init__(self): | |||
# wrap pandas.dataframe | |||
self.dataframe = None |