You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

base_model.py 1.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. import numpy as np
  2. class BaseModel(object):
  3. """base model for all models"""
  4. def __init__(self):
  5. pass
  6. def prepare_input(self, data):
  7. """
  8. :param data: str, raw input vector(?)
  9. :return (X, Y): tuple, input features and labels
  10. """
  11. raise NotImplementedError
  12. def mode(self, test=False):
  13. raise NotImplementedError
  14. def data_forward(self, x):
  15. raise NotImplementedError
  16. def grad_backward(self):
  17. raise NotImplementedError
  18. def loss(self, pred, truth):
  19. raise NotImplementedError
  20. class ToyModel(BaseModel):
  21. """This is for code testing."""
  22. def __init__(self):
  23. super(ToyModel, self).__init__()
  24. self.test_mode = False
  25. self.weight = np.random.rand(5, 1)
  26. self.bias = np.random.rand()
  27. self._loss = 0
  28. def prepare_input(self, data):
  29. return data[:, :-1], data[:, -1]
  30. def mode(self, test=False):
  31. self.test_mode = test
  32. def data_forward(self, x):
  33. return np.matmul(x, self.weight) + self.bias
  34. def grad_backward(self):
  35. print("loss gradient backward")
  36. def loss(self, pred, truth):
  37. self._loss = np.mean(np.square(pred - truth))
  38. return self._loss
  39. class Vocabulary(object):
  40. """
  41. A collection of lookup tables.
  42. """
  43. def __init__(self):
  44. self.word_set = None
  45. self.word2idx = None
  46. self.emb_matrix = None
  47. def lookup(self, word):
  48. if word in self.word_set:
  49. return self.emb_matrix[self.word2idx[word]]
  50. return LookupError("The key " + word + " does not exist.")
  51. class Document(object):
  52. """
  53. contains a sequence of tokens
  54. each token is a character with linguistic attributes
  55. """
  56. def __init__(self):
  57. # wrap pandas.dataframe
  58. self.dataframe = None

一款轻量级的自然语言处理(NLP)工具包,目标是减少用户项目中的工程型代码,例如数据处理循环、训练循环、多卡运行等