You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

paddle_model.py 980 B

1234567891011121314151617181920212223242526272829303132
  1. import paddle
  2. import paddle.nn as nn
  3. class PaddleNormalModel_Classification(paddle.nn.Layer):
  4. """
  5. 基础的paddle分类模型
  6. """
  7. def __init__(self, num_labels, feature_dimension):
  8. super(PaddleNormalModel_Classification, self).__init__()
  9. self.num_labels = num_labels
  10. self.linear1 = nn.Linear(in_features=feature_dimension, out_features=64)
  11. self.ac1 = nn.ReLU()
  12. self.linear2 = nn.Linear(in_features=64, out_features=32)
  13. self.ac2 = nn.ReLU()
  14. self.output = nn.Linear(in_features=32, out_features=num_labels)
  15. self.loss_fn = nn.CrossEntropyLoss()
  16. def forward(self, x):
  17. x = self.ac1(self.linear1(x))
  18. x = self.ac2(self.linear2(x))
  19. x = self.output(x)
  20. return x
  21. def train_step(self, x, y):
  22. x = self(x)
  23. return {"loss": self.loss_fn(x, y)}
  24. def validate_step(self, x, y):
  25. x = self(x)
  26. return {"pred": x, "target": y.reshape((-1,))}