From cd7b6b1e8d5c132dcd29f0840eff2a36d8f04df3 Mon Sep 17 00:00:00 2001 From: FengZiYjun Date: Tue, 22 May 2018 16:28:33 +0800 Subject: [PATCH] add toy model to check data flow all right --- action/action.py | 27 +++++++++++++++++++-------- action/tester.py | 4 ++-- action/trainer.py | 16 ++++++++-------- model/base_model.py | 34 ++++++++++++++++++++++++++++++++++ tests/test_trainer.py | 21 +++++++++++++++++++++ 5 files changed, 84 insertions(+), 18 deletions(-) create mode 100644 tests/test_trainer.py diff --git a/action/action.py b/action/action.py index cea9fa65..557c7ef2 100644 --- a/action/action.py +++ b/action/action.py @@ -15,15 +15,26 @@ class Action(object): raise NotImplementedError def log(self, args): - self.logger.log(args) - - """ - Basic operations shared between Trainer and Tester. - """ + print("call logger.log") def batchify(self, X, Y=None): - # a generator - raise NotImplementedError + """ + :param X: + :param Y: + :return iteration:int, the number of step in each epoch + generator:generator, to generate batch inputs + """ + data = X + if Y is not None: + data = [X, Y] + return 2, self._batch_generate(data) + + def _batch_generate(self, data): + step = 10 + for i in range(2): + start = i * step + end = (i + 1) * step + yield data[0][start:end], data[1][start:end] def make_log(self, *args): - raise NotImplementedError + return "log" diff --git a/action/tester.py b/action/tester.py index 0bce1c5d..96c8a0ae 100644 --- a/action/tester.py +++ b/action/tester.py @@ -12,7 +12,7 @@ class Tester(Action): """ super(Tester, self).__init__() self.test_args = test_args - self.args_dict = {name: value for name, value in self.test_args.__dict__.iteritems()} + # self.args_dict = {name: value for name, value in self.test_args.__dict__.iteritems()} self.mean_loss = None self.output = None @@ -54,4 +54,4 @@ class Tester(Action): def make_output(self, batch_output): # construct full prediction with batch outputs - raise NotImplementedError + return np.concatenate((batch_output[0], batch_output[1]), axis=0) diff --git a/action/trainer.py b/action/trainer.py index e00aff08..cd82e544 100644 --- a/action/trainer.py +++ b/action/trainer.py @@ -1,5 +1,5 @@ -from action.action import Action -from action.tester import Tester +from .action import Action +from .tester import Tester class Trainer(Action): @@ -13,10 +13,10 @@ class Trainer(Action): """ super(Trainer, self).__init__() self.train_args = train_args - self.args_dict = {name: value for name, value in self.train_args.__dict__.iteritems()} + # self.args_dict = {name: value for name, value in self.train_args.__dict__.iteritems()} self.n_epochs = self.train_args.epochs - self.validate = True - self.save_when_better = True + self.validate = self.train_args.validate + self.save_when_better = self.train_args.save_when_better def train(self, network, data, dev_data): X, Y = network.prepare_input(data) @@ -51,10 +51,10 @@ class Trainer(Action): # finish training def make_log(self, *args): - raise NotImplementedError + print("logged") def make_valid_log(self, *args): - raise NotImplementedError + print("logged") def save_model(self, model): - raise NotImplementedError + print("model saved") diff --git a/model/base_model.py b/model/base_model.py index f4fdf120..28d1fe1e 100644 --- a/model/base_model.py +++ b/model/base_model.py @@ -1,3 +1,6 @@ +import numpy as np + + class BaseModel(object): """base model for all models""" @@ -5,6 +8,10 @@ class BaseModel(object): pass def prepare_input(self, data): + """ + :param data: str, raw input vector(?) + :return (X, Y): tuple, input features and labels + """ raise NotImplementedError def mode(self, test=False): @@ -20,6 +27,33 @@ class BaseModel(object): raise NotImplementedError +class ToyModel(BaseModel): + """This is for code testing.""" + + def __init__(self): + super(ToyModel, self).__init__() + self.test_mode = False + self.weight = np.random.rand(5, 1) + self.bias = np.random.rand() + self._loss = 0 + + def prepare_input(self, data): + return data[:, :-1], data[:, -1] + + def mode(self, test=False): + self.test_mode = test + + def data_forward(self, x): + return np.matmul(x, self.weight) + self.bias + + def grad_backward(self): + print("loss gradient backward") + + def loss(self, pred, truth): + self._loss = np.mean(np.square(pred - truth)) + return self._loss + + class Vocabulary(object): """ A collection of lookup tables. diff --git a/tests/test_trainer.py b/tests/test_trainer.py new file mode 100644 index 00000000..0b0d4553 --- /dev/null +++ b/tests/test_trainer.py @@ -0,0 +1,21 @@ +from collections import namedtuple + +import numpy as np + +from action.trainer import Trainer +from model.base_model import ToyModel + + +def test_trainer(): + Config = namedtuple("config", ["epochs", "validate", "save_when_better"]) + train_config = Config(epochs=5, validate=True, save_when_better=True) + trainer = Trainer(train_config) + + net = ToyModel() + data = np.random.rand(20, 6) + dev_data = np.random.rand(20, 6) + trainer.train(net, data, dev_data) + + +if __name__ == "__main__": + test_trainer()