Browse Source

add frameworks for Vocabulary, Document

tags/v0.4.10
FengZiYjun 7 years ago
parent
commit
a09ddb918c
4 changed files with 89 additions and 3 deletions
  1. +1
    -0
      action/action.py
  2. +12
    -0
      action/tester.py
  3. +49
    -3
      action/trainer.py
  4. +27
    -0
      model/base_model.py

+ 1
- 0
action/action.py View File

@@ -1,3 +1,4 @@

class Action(object):
"""
base class for Trainer and Tester


+ 12
- 0
action/tester.py View File

@@ -14,6 +14,7 @@ class Tester(Action):
self.test_args = test_args
self.args_dict = {name: value for name, value in self.test_args.__dict__.iteritems()}
self.mean_loss = None
self.output = None

def test(self, network, data):
# transform into network input and label
@@ -22,6 +23,7 @@ class Tester(Action):
# split into batches by self.batch_size
iterations, test_batch_generator = self.batchify(X, Y)

batch_output = list()
loss_history = list()
# turn on the testing mode of the network
network.mode(test=True)
@@ -31,6 +33,7 @@ class Tester(Action):

# forward pass from tests input to predicted output
prediction = network.data_forward(batch_x)
batch_output.append(prediction)

# get the loss
loss = network.loss(batch_y, prediction)
@@ -39,7 +42,16 @@ class Tester(Action):
self.log(self.make_log(step, loss))

self.mean_loss = np.mean(np.array(loss_history))
self.output = self.make_output(batch_output)

@property
def loss(self):
return self.mean_loss

@property
def result(self):
return self.output

def make_output(self, batch_output):
# construct full prediction with batch outputs
raise NotImplementedError

+ 49
- 3
action/trainer.py View File

@@ -1,4 +1,5 @@
from action.action import Action
from action.tester import Tester


class Trainer(Action):
@@ -6,9 +7,54 @@ class Trainer(Action):
Trainer for common training logic of all models
"""

def __init__(self, arg):
def __init__(self, train_args):
"""
:param train_args: namedtuple
"""
super(Trainer, self).__init__()
self.arg = arg
self.train_args = train_args
self.args_dict = {name: value for name, value in self.train_args.__dict__.iteritems()}
self.n_epochs = self.train_args.epochs
self.validate = True
self.save_when_better = True

def train(self, args):
def train(self, network, data, dev_data):
X, Y = network.prepare_input(data)

iterations, train_batch_generator = self.batchify(X, Y)
loss_history = list()
network.mode(test=False)

test_args = "..."
evaluator = Tester(test_args)
best_loss = 1e10

for epoch in range(self.n_epochs):

for step in range(iterations):
batch_x, batch_y = train_batch_generator.__next__()
prediction = network.data_forward(batch_x)
loss = network.loss(batch_y, prediction)
network.grad_backward()
loss_history.append(loss)
self.log(self.make_log(epoch, step, loss))

# evaluate over dev set
if self.validate:
evaluator.test(network, dev_data)
self.log(self.make_valid_log(epoch, evaluator.loss))
if evaluator.loss < best_loss:
best_loss = evaluator.loss
if self.save_when_better:
self.save_model(network)

# finish training

def make_log(self, *args):
raise NotImplementedError

def make_valid_log(self, *args):
raise NotImplementedError

def save_model(self, model):
raise NotImplementedError

+ 27
- 0
model/base_model.py View File

@@ -18,3 +18,30 @@ class BaseModel(object):

def loss(self, pred, truth):
raise NotImplementedError


class Vocabulary(object):
"""
A collection of lookup tables.
"""

def __init__(self):
self.word_set = None
self.word2idx = None
self.emb_matrix = None

def lookup(self, word):
if word in self.word_set:
return self.emb_matrix[self.word2idx[word]]
return LookupError("The key " + word + " does not exist.")


class Document(object):
"""
contains a sequence of tokens
each token is a character with linguistic attributes
"""

def __init__(self):
# wrap pandas.dataframe
self.dataframe = None

Loading…
Cancel
Save