You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

softmax.cpp 1.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445
  1. #include "test/naive/fixture.h"
  2. #include "megdnn/oprs/nn.h"
  3. #include "test/common/checker.h"
  4. using namespace megdnn;
  5. using namespace test;
  6. TEST_F(NAIVE, SOFTMAX_FORWARD) {
  7. Checker<Softmax> checker(handle(), /* check_dispatch */ false);
  8. Softmax::Param param{0};
  9. TensorND input = TensorValue(
  10. {2, 2, 2, 2}, dtype::Float32(),
  11. {0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.});
  12. TensorND output = TensorValue(
  13. {2, 2, 2, 2}, dtype::Float32(),
  14. {0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.9997,
  15. 0.9997, 0.9997, 0.9997, 0.9997, 0.9997, 0.9997, 0.9997});
  16. checker.set_param(param).exect(Testcase{input, {}}, Testcase{{}, output});
  17. }
  18. TEST_F(NAIVE, SOFTMAX_BACKWARD) {
  19. Checker<SoftmaxBackward> checker(handle(), /* check_dispatch */ false);
  20. Softmax::Param param{0};
  21. TensorND input = TensorValue(
  22. {2, 2, 2, 2}, dtype::Float32(),
  23. {0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.9997,
  24. 0.9997, 0.9997, 0.9997, 0.9997, 0.9997, 0.9997, 0.9997});
  25. TensorND diff = TensorValue(
  26. {2, 2, 2, 2}, dtype::Float32(),
  27. {1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.});
  28. TensorND output = TensorValue(
  29. {2, 2, 2, 2}, dtype::Float32(),
  30. {0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.});
  31. checker.set_param(param).exect(Testcase{input, diff, {}}, Testcase{{}, {}, output});
  32. }