You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

torch_model.py 2.1 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. import torch
  2. import torch.nn as nn
  3. # 1. 最为基础的分类模型
  4. class TorchNormalModel_Classification_1(nn.Module):
  5. """
  6. 单独实现 train_step 和 evaluate_step;
  7. """
  8. def __init__(self, num_labels, feature_dimension):
  9. super(TorchNormalModel_Classification_1, self).__init__()
  10. self.num_labels = num_labels
  11. self.linear1 = nn.Linear(in_features=feature_dimension, out_features=10)
  12. self.ac1 = nn.ReLU()
  13. self.linear2 = nn.Linear(in_features=10, out_features=10)
  14. self.ac2 = nn.ReLU()
  15. self.output = nn.Linear(in_features=10, out_features=num_labels)
  16. self.loss_fn = nn.CrossEntropyLoss()
  17. def forward(self, x):
  18. x = self.ac1(self.linear1(x))
  19. x = self.ac2(self.linear2(x))
  20. x = self.output(x)
  21. return x
  22. def train_step(self, x, y):
  23. x = self(x)
  24. return {"loss": self.loss_fn(x, y)}
  25. def validate_step(self, x, y):
  26. """
  27. 如果不加参数 y,那么应该在 trainer 中设置 output_mapping = {"y": "target"};
  28. """
  29. x = self(x)
  30. x = torch.max(x, dim=-1)[1]
  31. return {"preds": x, "target": y}
  32. class TorchNormalModel_Classification_2(nn.Module):
  33. """
  34. 只实现一个 forward 函数,来测试用户自己在外面初始化 DDP 的场景;
  35. """
  36. def __init__(self, num_labels, feature_dimension):
  37. super(TorchNormalModel_Classification_2, self).__init__()
  38. self.num_labels = num_labels
  39. self.linear1 = nn.Linear(in_features=feature_dimension, out_features=10)
  40. self.ac1 = nn.ReLU()
  41. self.linear2 = nn.Linear(in_features=10, out_features=10)
  42. self.ac2 = nn.ReLU()
  43. self.output = nn.Linear(in_features=10, out_features=num_labels)
  44. self.loss_fn = nn.CrossEntropyLoss()
  45. def forward(self, x, y):
  46. x = self.ac1(self.linear1(x))
  47. x = self.ac2(self.linear2(x))
  48. x = self.output(x)
  49. loss = self.loss_fn(x, y)
  50. x = torch.max(x, dim=-1)[1]
  51. return {"loss": loss, "preds": x, "target": y}