Browse Source

add toy model to check data flow all right

tags/v0.4.10
FengZiYjun 7 years ago
parent
commit
cd7b6b1e8d
5 changed files with 84 additions and 18 deletions
  1. +19
    -8
      action/action.py
  2. +2
    -2
      action/tester.py
  3. +8
    -8
      action/trainer.py
  4. +34
    -0
      model/base_model.py
  5. +21
    -0
      tests/test_trainer.py

+ 19
- 8
action/action.py View File

@@ -15,15 +15,26 @@ class Action(object):
raise NotImplementedError

def log(self, args):
self.logger.log(args)

"""
Basic operations shared between Trainer and Tester.
"""
print("call logger.log")

def batchify(self, X, Y=None):
# a generator
raise NotImplementedError
"""
:param X:
:param Y:
:return iteration:int, the number of step in each epoch
generator:generator, to generate batch inputs
"""
data = X
if Y is not None:
data = [X, Y]
return 2, self._batch_generate(data)

def _batch_generate(self, data):
step = 10
for i in range(2):
start = i * step
end = (i + 1) * step
yield data[0][start:end], data[1][start:end]

def make_log(self, *args):
raise NotImplementedError
return "log"

+ 2
- 2
action/tester.py View File

@@ -12,7 +12,7 @@ class Tester(Action):
"""
super(Tester, self).__init__()
self.test_args = test_args
self.args_dict = {name: value for name, value in self.test_args.__dict__.iteritems()}
# self.args_dict = {name: value for name, value in self.test_args.__dict__.iteritems()}
self.mean_loss = None
self.output = None

@@ -54,4 +54,4 @@ class Tester(Action):

def make_output(self, batch_output):
# construct full prediction with batch outputs
raise NotImplementedError
return np.concatenate((batch_output[0], batch_output[1]), axis=0)

+ 8
- 8
action/trainer.py View File

@@ -1,5 +1,5 @@
from action.action import Action
from action.tester import Tester
from .action import Action
from .tester import Tester


class Trainer(Action):
@@ -13,10 +13,10 @@ class Trainer(Action):
"""
super(Trainer, self).__init__()
self.train_args = train_args
self.args_dict = {name: value for name, value in self.train_args.__dict__.iteritems()}
# self.args_dict = {name: value for name, value in self.train_args.__dict__.iteritems()}
self.n_epochs = self.train_args.epochs
self.validate = True
self.save_when_better = True
self.validate = self.train_args.validate
self.save_when_better = self.train_args.save_when_better

def train(self, network, data, dev_data):
X, Y = network.prepare_input(data)
@@ -51,10 +51,10 @@ class Trainer(Action):
# finish training

def make_log(self, *args):
raise NotImplementedError
print("logged")

def make_valid_log(self, *args):
raise NotImplementedError
print("logged")

def save_model(self, model):
raise NotImplementedError
print("model saved")

+ 34
- 0
model/base_model.py View File

@@ -1,3 +1,6 @@
import numpy as np


class BaseModel(object):
"""base model for all models"""

@@ -5,6 +8,10 @@ class BaseModel(object):
pass

def prepare_input(self, data):
"""
:param data: str, raw input vector(?)
:return (X, Y): tuple, input features and labels
"""
raise NotImplementedError

def mode(self, test=False):
@@ -20,6 +27,33 @@ class BaseModel(object):
raise NotImplementedError


class ToyModel(BaseModel):
"""This is for code testing."""

def __init__(self):
super(ToyModel, self).__init__()
self.test_mode = False
self.weight = np.random.rand(5, 1)
self.bias = np.random.rand()
self._loss = 0

def prepare_input(self, data):
return data[:, :-1], data[:, -1]

def mode(self, test=False):
self.test_mode = test

def data_forward(self, x):
return np.matmul(x, self.weight) + self.bias

def grad_backward(self):
print("loss gradient backward")

def loss(self, pred, truth):
self._loss = np.mean(np.square(pred - truth))
return self._loss


class Vocabulary(object):
"""
A collection of lookup tables.


+ 21
- 0
tests/test_trainer.py View File

@@ -0,0 +1,21 @@
from collections import namedtuple

import numpy as np

from action.trainer import Trainer
from model.base_model import ToyModel


def test_trainer():
Config = namedtuple("config", ["epochs", "validate", "save_when_better"])
train_config = Config(epochs=5, validate=True, save_when_better=True)
trainer = Trainer(train_config)

net = ToyModel()
data = np.random.rand(20, 6)
dev_data = np.random.rand(20, 6)
trainer.train(net, data, dev_data)


if __name__ == "__main__":
test_trainer()

Loading…
Cancel
Save