You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

base_model.py 1.9 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. import numpy as np
  2. class BaseModel(object):
  3. """PyTorch base model for all models"""
  4. def __init__(self):
  5. pass
  6. def prepare_input(self, data):
  7. """
  8. :param data: str, raw input vector(?)
  9. :return (X, Y): tuple, input features and labels
  10. """
  11. raise NotImplementedError
  12. def mode(self, test=False):
  13. raise NotImplementedError
  14. def data_forward(self, *x):
  15. # required by PyTorch nn
  16. raise NotImplementedError
  17. def grad_backward(self):
  18. raise NotImplementedError
  19. def loss(self, pred, truth):
  20. raise NotImplementedError
  21. class ToyModel(BaseModel):
  22. """This is for code testing."""
  23. def __init__(self):
  24. super(ToyModel, self).__init__()
  25. self.test_mode = False
  26. self.weight = np.random.rand(5, 1)
  27. self.bias = np.random.rand()
  28. self._loss = 0
  29. def prepare_input(self, data):
  30. return data[:, :-1], data[:, -1]
  31. def mode(self, test=False):
  32. self.test_mode = test
  33. def data_forward(self, x):
  34. return np.matmul(x, self.weight) + self.bias
  35. def grad_backward(self):
  36. print("loss gradient backward")
  37. def loss(self, pred, truth):
  38. self._loss = np.mean(np.square(pred - truth))
  39. return self._loss
  40. class Vocabulary(object):
  41. """
  42. A collection of lookup tables.
  43. """
  44. def __init__(self):
  45. self.word_set = None
  46. self.word2idx = None
  47. self.emb_matrix = None
  48. def lookup(self, word):
  49. if word in self.word_set:
  50. return self.emb_matrix[self.word2idx[word]]
  51. return LookupError("The key " + word + " does not exist.")
  52. class Document(object):
  53. """
  54. contains a sequence of tokens
  55. each token is a character with linguistic attributes
  56. """
  57. def __init__(self):
  58. # wrap pandas.dataframe
  59. self.dataframe = None

一款轻量级的自然语言处理(NLP)工具包,目标是减少用户项目中的工程型代码,例如数据处理循环、训练循环、多卡运行等