You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

oneflow_model.py 1.4 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243
  1. from fastNLP.envs.imports import _NEED_IMPORT_ONEFLOW
  2. if _NEED_IMPORT_ONEFLOW:
  3. import oneflow
  4. from oneflow.nn import Module
  5. import oneflow.nn as nn
  6. else:
  7. from fastNLP.core.utils.dummy_class import DummyClass as Module
  8. # 1. 最为基础的分类模型
  9. class OneflowNormalModel_Classification_1(Module):
  10. """
  11. 单独实现 train_step 和 evaluate_step;
  12. """
  13. def __init__(self, num_labels, feature_dimension):
  14. super(OneflowNormalModel_Classification_1, self).__init__()
  15. self.num_labels = num_labels
  16. self.linear1 = nn.Linear(in_features=feature_dimension, out_features=10)
  17. self.ac1 = nn.ReLU()
  18. self.linear2 = nn.Linear(in_features=10, out_features=10)
  19. self.ac2 = nn.ReLU()
  20. self.output = nn.Linear(in_features=10, out_features=num_labels)
  21. self.loss_fn = nn.CrossEntropyLoss()
  22. def forward(self, x):
  23. x = self.ac1(self.linear1(x))
  24. x = self.ac2(self.linear2(x))
  25. x = self.output(x)
  26. return x
  27. def train_step(self, x, y):
  28. x = self(x)
  29. return {"loss": self.loss_fn(x, y)}
  30. def evaluate_step(self, x, y):
  31. """
  32. 如果不加参数 y,那么应该在 trainer 中设置 output_mapping = {"y": "target"};
  33. """
  34. x = self(x)
  35. x = oneflow.max(x, dim=-1)[1]
  36. return {"pred": x, "target": y}