You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tester.py 2.7 kB

7 years ago
7 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. from collections import namedtuple
  2. import numpy as np
  3. from action.action import Action
  4. class Tester(Action):
  5. """docstring for Tester"""
  6. TestConfig = namedtuple("config", ["validate_in_training", "save_dev_input", "save_output",
  7. "save_loss", "batch_size"])
  8. def __init__(self, test_args):
  9. """
  10. :param test_args: named tuple
  11. """
  12. super(Tester, self).__init__()
  13. self.validate_in_training = test_args.validate_in_training
  14. self.save_dev_input = test_args.save_dev_input
  15. self.valid_x = None
  16. self.valid_y = None
  17. self.save_output = test_args.save_output
  18. self.output = None
  19. self.save_loss = test_args.save_loss
  20. self.mean_loss = None
  21. self.batch_size = test_args.batch_size
  22. def test(self, network, data):
  23. print("testing")
  24. network.mode(test=True) # turn on the testing mode
  25. if self.save_dev_input:
  26. if self.valid_x is None:
  27. valid_x, valid_y = network.prepare_input(data)
  28. self.valid_x = valid_x
  29. self.valid_y = valid_y
  30. else:
  31. valid_x = self.valid_x
  32. valid_y = self.valid_y
  33. else:
  34. valid_x, valid_y = network.prepare_input(data)
  35. # split into batches by self.batch_size
  36. iterations, test_batch_generator = self.batchify(self.batch_size, valid_x, valid_y)
  37. batch_output = list()
  38. loss_history = list()
  39. # turn on the testing mode of the network
  40. network.mode(test=True)
  41. for step in range(iterations):
  42. batch_x, batch_y = test_batch_generator.__next__()
  43. # forward pass from tests input to predicted output
  44. prediction = network.data_forward(batch_x)
  45. loss = network.get_loss(prediction, batch_y)
  46. if self.save_output:
  47. batch_output.append(prediction.data)
  48. if self.save_loss:
  49. loss_history.append(loss)
  50. self.log(self.make_log(step, loss))
  51. if self.save_loss:
  52. self.mean_loss = np.mean(np.array(loss_history))
  53. if self.save_output:
  54. self.output = self.make_output(batch_output)
  55. @property
  56. def loss(self):
  57. return self.mean_loss
  58. @property
  59. def result(self):
  60. return self.output
  61. @staticmethod
  62. def make_output(batch_outputs):
  63. # construct full prediction with batch outputs
  64. return np.concatenate(batch_outputs, axis=0)
  65. def load_config(self, args):
  66. raise NotImplementedError
  67. def load_dataset(self, args):
  68. raise NotImplementedError

一款轻量级的自然语言处理(NLP)工具包,目标是减少用户项目中的工程型代码,例如数据处理循环、训练循环、多卡运行等