You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

action.py 1.4 kB

7 years ago
7 years ago
7 years ago
7 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546
  1. from saver.logger import Logger
  2. class Action(object):
  3. """
  4. base class for Trainer and Tester
  5. """
  6. def __init__(self):
  7. super(Action, self).__init__()
  8. self.logger = Logger("logger_output.txt")
  9. def load_config(self, args):
  10. raise NotImplementedError
  11. def load_dataset(self, args):
  12. raise NotImplementedError
  13. def log(self, string):
  14. self.logger.log(string)
  15. def batchify(self, batch_size, X, Y=None):
  16. """
  17. :param batch_size: int
  18. :param X: feature matrix of size [n_sample, m_feature]
  19. :param Y: label vector of size [n_sample, 1] (optional)
  20. :return iteration:int, the number of step in each epoch
  21. generator:generator, to generate batch inputs
  22. """
  23. n_samples = X.size()[0]
  24. num_iter = n_samples // batch_size
  25. if Y is None:
  26. generator = self._batch_generate(batch_size, num_iter, X)
  27. else:
  28. generator = self._batch_generate(batch_size, num_iter, X, Y)
  29. return num_iter, generator
  30. @staticmethod
  31. def _batch_generate(batch_size, num_iter, *data):
  32. for step in range(num_iter):
  33. start = batch_size * step
  34. end = batch_size * (step + 1)
  35. yield tuple([x[start:end] for x in data])
  36. def make_log(self, *args):
  37. return "log"

一款轻量级的自然语言处理(NLP)工具包,目标是减少用户项目中的工程型代码,例如数据处理循环、训练循环、多卡运行等