You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

trainer.py 3.0 kB

7 years ago
7 years ago
7 years ago
7 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. from collections import namedtuple
  2. from .action import Action
  3. from .tester import Tester
  4. class Trainer(Action):
  5. """
  6. Trainer for common training logic of all models
  7. """
  8. TrainConfig = namedtuple("config", ["epochs", "validate", "save_when_better",
  9. "log_per_step", "log_validation", "batch_size"])
  10. def __init__(self, train_args):
  11. """
  12. :param train_args: namedtuple
  13. """
  14. super(Trainer, self).__init__()
  15. self.n_epochs = train_args.epochs
  16. self.validate = train_args.validate
  17. self.save_when_better = train_args.save_when_better
  18. self.log_per_step = train_args.log_per_step
  19. self.log_validation = train_args.log_validation
  20. self.batch_size = train_args.batch_size
  21. def train(self, network, train_data, dev_data):
  22. """
  23. :param network: the model controller
  24. :param train_data: raw data for training
  25. :param dev_data: raw data for validation
  26. :return:
  27. """
  28. train_x, train_y = network.prepare_input(train_data)
  29. iterations, train_batch_generator = self.batchify(self.batch_size, train_x, train_y)
  30. test_args = Tester.TestConfig(save_output=True, validate_in_training=True,
  31. save_dev_input=True, save_loss=True, batch_size=self.batch_size)
  32. evaluator = Tester(test_args)
  33. best_loss = 1e10
  34. loss_history = list()
  35. for epoch in range(self.n_epochs):
  36. network.mode(test=False) # turn on the train mode
  37. network.define_optimizer()
  38. for step in range(iterations):
  39. batch_x, batch_y = train_batch_generator.__next__()
  40. prediction = network.data_forward(batch_x)
  41. loss = network.get_loss(prediction, batch_y)
  42. network.grad_backward()
  43. if step % self.log_per_step == 0:
  44. print("step ", step)
  45. loss_history.append(loss)
  46. self.log(self.make_log(epoch, step, loss))
  47. #################### evaluate over dev set ###################
  48. if self.validate:
  49. # give all controls to tester
  50. evaluator.test(network, dev_data)
  51. if self.log_validation:
  52. self.log(self.make_valid_log(epoch, evaluator.loss))
  53. if evaluator.loss < best_loss:
  54. best_loss = evaluator.loss
  55. if self.save_when_better:
  56. self.save_model(network)
  57. # finish training
  58. def make_log(self, *args):
  59. return "make a log"
  60. def make_valid_log(self, *args):
  61. return "make a valid log"
  62. def save_model(self, model):
  63. model.save()
  64. def load_data(self, data_name):
  65. print("load data")
  66. def load_config(self, args):
  67. raise NotImplementedError
  68. def load_dataset(self, args):
  69. raise NotImplementedError

一款轻量级的自然语言处理(NLP)工具包,目标是减少用户项目中的工程型代码,例如数据处理循环、训练循环、多卡运行等