You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

softmax.cpp 4.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. #include "test/naive/fixture.h"
  2. #include "megdnn/oprs/nn.h"
  3. #include "test/common/checker.h"
  4. using namespace megdnn;
  5. using namespace test;
  6. TEST_F(NAIVE, SOFTMAX_FORWARD) {
  7. Checker<Softmax> checker(handle(), /* check_dispatch */ false);
  8. Softmax::Param param{0};
  9. TensorND input = TensorValue(
  10. {2, 2, 2, 2}, dtype::Float32(),
  11. {0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.});
  12. TensorND output = TensorValue(
  13. {2, 2, 2, 2}, dtype::Float32(),
  14. {0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.9997,
  15. 0.9997, 0.9997, 0.9997, 0.9997, 0.9997, 0.9997, 0.9997});
  16. checker.set_param(param).exect(Testcase{input, {}}, Testcase{{}, output});
  17. }
  18. TEST_F(NAIVE, SOFTMAX_BACKWARD) {
  19. Checker<SoftmaxBackward> checker(handle(), /* check_dispatch */ false);
  20. Softmax::Param param{0};
  21. TensorND input = TensorValue(
  22. {2, 2, 2, 2}, dtype::Float32(),
  23. {0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.9997,
  24. 0.9997, 0.9997, 0.9997, 0.9997, 0.9997, 0.9997, 0.9997});
  25. TensorND diff = TensorValue(
  26. {2, 2, 2, 2}, dtype::Float32(),
  27. {1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.});
  28. TensorND output = TensorValue(
  29. {2, 2, 2, 2}, dtype::Float32(),
  30. {0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.});
  31. checker.set_param(param).exect(Testcase{input, diff, {}}, Testcase{{}, {}, output});
  32. }
  33. TEST_F(NAIVE, SOFTMAX_FORWARD_NHWCD4) {
  34. Checker<Softmax> checker(handle(), false);
  35. Softmax::Param param{0};
  36. TensorND input1 = TensorValue(
  37. {1, 2, 1, 2, 4}, dtype::Float32(),
  38. {0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15});
  39. TensorND output1 = TensorValue(
  40. {1, 2, 1, 2, 4}, dtype::Float32(),
  41. {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1});
  42. checker.set_param(param).exect(Testcase{input1, {}}, Testcase{{}, output1});
  43. TensorND input2 = TensorValue(
  44. {2, 2, 1, 2, 4}, dtype::Float32(),
  45. {0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15,
  46. 16, 20, 24, 28, 17, 21, 25, 29, 18, 22, 26, 30, 19, 23, 27, 31});
  47. TensorND output2 = TensorValue(
  48. {2, 2, 1, 2, 4}, dtype::Float32(),
  49. {1.12535162e-07, 1.12535162e-07, 1.12535162e-07, 1.12535162e-07,
  50. 1.12535162e-07, 1.12535162e-07, 1.12535162e-07, 1.12535162e-07,
  51. 1.12535162e-07, 1.12535162e-07, 1.12535162e-07, 1.12535162e-07,
  52. 1.12535162e-07, 1.12535162e-07, 1.12535162e-07, 1.12535162e-07,
  53. 9.99999887e-01, 9.99999887e-01, 9.99999887e-01, 9.99999887e-01,
  54. 9.99999887e-01, 9.99999887e-01, 9.99999887e-01, 9.99999887e-01,
  55. 9.99999887e-01, 9.99999887e-01, 9.99999887e-01, 9.99999887e-01,
  56. 9.99999887e-01, 9.99999887e-01, 9.99999887e-01, 9.99999887e-01});
  57. checker.set_param(param).exect(Testcase{input2, {}}, Testcase{{}, output2});
  58. }
  59. TEST_F(NAIVE, SOFTMAX_BACKWARD_NHWCD4) {
  60. Checker<SoftmaxBackward> checker(handle(), false);
  61. Softmax::Param param{0};
  62. TensorND input = TensorValue(
  63. {2, 2, 1, 2, 4}, dtype::Float32(),
  64. {1.12535162e-07, 1.12535162e-07, 1.12535162e-07, 1.12535162e-07,
  65. 1.12535162e-07, 1.12535162e-07, 1.12535162e-07, 1.12535162e-07,
  66. 1.12535162e-07, 1.12535162e-07, 1.12535162e-07, 1.12535162e-07,
  67. 1.12535162e-07, 1.12535162e-07, 1.12535162e-07, 1.12535162e-07,
  68. 9.99999887e-01, 9.99999887e-01, 9.99999887e-01, 9.99999887e-01,
  69. 9.99999887e-01, 9.99999887e-01, 9.99999887e-01, 9.99999887e-01,
  70. 9.99999887e-01, 9.99999887e-01, 9.99999887e-01, 9.99999887e-01,
  71. 9.99999887e-01, 9.99999887e-01, 9.99999887e-01, 9.99999887e-01});
  72. TensorND diff = TensorValue(
  73. {2, 2, 1, 2, 4}, dtype::Float32(),
  74. {1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
  75. 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.});
  76. TensorND output = TensorValue(
  77. {2, 2, 1, 2, 4}, dtype::Float32(),
  78. {0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
  79. 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.});
  80. checker.set_param(param).exect(Testcase{input, diff, {}}, Testcase{{}, {}, output});
  81. }