@@ -15,15 +15,26 @@ class Action(object): | |||
raise NotImplementedError | |||
def log(self, args): | |||
self.logger.log(args) | |||
""" | |||
Basic operations shared between Trainer and Tester. | |||
""" | |||
print("call logger.log") | |||
def batchify(self, X, Y=None): | |||
# a generator | |||
raise NotImplementedError | |||
""" | |||
:param X: | |||
:param Y: | |||
:return iteration:int, the number of step in each epoch | |||
generator:generator, to generate batch inputs | |||
""" | |||
data = X | |||
if Y is not None: | |||
data = [X, Y] | |||
return 2, self._batch_generate(data) | |||
def _batch_generate(self, data): | |||
step = 10 | |||
for i in range(2): | |||
start = i * step | |||
end = (i + 1) * step | |||
yield data[0][start:end], data[1][start:end] | |||
def make_log(self, *args): | |||
raise NotImplementedError | |||
return "log" |
@@ -12,7 +12,7 @@ class Tester(Action): | |||
""" | |||
super(Tester, self).__init__() | |||
self.test_args = test_args | |||
self.args_dict = {name: value for name, value in self.test_args.__dict__.iteritems()} | |||
# self.args_dict = {name: value for name, value in self.test_args.__dict__.iteritems()} | |||
self.mean_loss = None | |||
self.output = None | |||
@@ -54,4 +54,4 @@ class Tester(Action): | |||
def make_output(self, batch_output): | |||
# construct full prediction with batch outputs | |||
raise NotImplementedError | |||
return np.concatenate((batch_output[0], batch_output[1]), axis=0) |
@@ -1,5 +1,5 @@ | |||
from action.action import Action | |||
from action.tester import Tester | |||
from .action import Action | |||
from .tester import Tester | |||
class Trainer(Action): | |||
@@ -13,10 +13,10 @@ class Trainer(Action): | |||
""" | |||
super(Trainer, self).__init__() | |||
self.train_args = train_args | |||
self.args_dict = {name: value for name, value in self.train_args.__dict__.iteritems()} | |||
# self.args_dict = {name: value for name, value in self.train_args.__dict__.iteritems()} | |||
self.n_epochs = self.train_args.epochs | |||
self.validate = True | |||
self.save_when_better = True | |||
self.validate = self.train_args.validate | |||
self.save_when_better = self.train_args.save_when_better | |||
def train(self, network, data, dev_data): | |||
X, Y = network.prepare_input(data) | |||
@@ -51,10 +51,10 @@ class Trainer(Action): | |||
# finish training | |||
def make_log(self, *args): | |||
raise NotImplementedError | |||
print("logged") | |||
def make_valid_log(self, *args): | |||
raise NotImplementedError | |||
print("logged") | |||
def save_model(self, model): | |||
raise NotImplementedError | |||
print("model saved") |
@@ -1,3 +1,6 @@ | |||
import numpy as np | |||
class BaseModel(object): | |||
"""base model for all models""" | |||
@@ -5,6 +8,10 @@ class BaseModel(object): | |||
pass | |||
def prepare_input(self, data): | |||
""" | |||
:param data: str, raw input vector(?) | |||
:return (X, Y): tuple, input features and labels | |||
""" | |||
raise NotImplementedError | |||
def mode(self, test=False): | |||
@@ -20,6 +27,33 @@ class BaseModel(object): | |||
raise NotImplementedError | |||
class ToyModel(BaseModel): | |||
"""This is for code testing.""" | |||
def __init__(self): | |||
super(ToyModel, self).__init__() | |||
self.test_mode = False | |||
self.weight = np.random.rand(5, 1) | |||
self.bias = np.random.rand() | |||
self._loss = 0 | |||
def prepare_input(self, data): | |||
return data[:, :-1], data[:, -1] | |||
def mode(self, test=False): | |||
self.test_mode = test | |||
def data_forward(self, x): | |||
return np.matmul(x, self.weight) + self.bias | |||
def grad_backward(self): | |||
print("loss gradient backward") | |||
def loss(self, pred, truth): | |||
self._loss = np.mean(np.square(pred - truth)) | |||
return self._loss | |||
class Vocabulary(object): | |||
""" | |||
A collection of lookup tables. | |||
@@ -0,0 +1,21 @@ | |||
from collections import namedtuple | |||
import numpy as np | |||
from action.trainer import Trainer | |||
from model.base_model import ToyModel | |||
def test_trainer(): | |||
Config = namedtuple("config", ["epochs", "validate", "save_when_better"]) | |||
train_config = Config(epochs=5, validate=True, save_when_better=True) | |||
trainer = Trainer(train_config) | |||
net = ToyModel() | |||
data = np.random.rand(20, 6) | |||
dev_data = np.random.rand(20, 6) | |||
trainer.train(net, data, dev_data) | |||
if __name__ == "__main__": | |||
test_trainer() |