diff --git a/dnn/src/naive/softmax/opr_impl.h b/dnn/src/naive/softmax/opr_impl.h index 39cefbe2..e0a78277 100644 --- a/dnn/src/naive/softmax/opr_impl.h +++ b/dnn/src/naive/softmax/opr_impl.h @@ -4,7 +4,7 @@ namespace megdnn { namespace naive { -class SoftmaxForwardImpl final : public SoftmaxForward { +class SoftmaxForwardImpl : public SoftmaxForward { public: using SoftmaxForward::SoftmaxForward; void exec( @@ -16,7 +16,7 @@ public: } }; -class SoftmaxBackwardImpl final : public SoftmaxBackward { +class SoftmaxBackwardImpl : public SoftmaxBackward { public: using SoftmaxBackward::SoftmaxBackward; void exec( @@ -32,4 +32,4 @@ public: } // namespace naive } // namespace megdnn -// vim: syntax=cpp.doxygen \ No newline at end of file +// vim: syntax=cpp.doxygen diff --git a/dnn/test/naive/softmax.cpp b/dnn/test/naive/softmax.cpp index 94e278c8..4ac616bc 100644 --- a/dnn/test/naive/softmax.cpp +++ b/dnn/test/naive/softmax.cpp @@ -42,4 +42,61 @@ TEST_F(NAIVE, SOFTMAX_BACKWARD) { {0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.}); checker.set_param(param).exect(Testcase{input, diff, {}}, Testcase{{}, {}, output}); -} \ No newline at end of file +} + +TEST_F(NAIVE, SOFTMAX_FORWARD_NHWCD4) { + Checker checker(handle(), false); + Softmax::Param param{0}; + + TensorND input1 = TensorValue( + {1, 2, 1, 2, 4}, dtype::Float32(), + {0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15}); + TensorND output1 = TensorValue( + {1, 2, 1, 2, 4}, dtype::Float32(), + {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}); + checker.set_param(param).exect(Testcase{input1, {}}, Testcase{{}, output1}); + + TensorND input2 = TensorValue( + {2, 2, 1, 2, 4}, dtype::Float32(), + {0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15, + 16, 20, 24, 28, 17, 21, 25, 29, 18, 22, 26, 30, 19, 23, 27, 31}); + TensorND output2 = TensorValue( + {2, 2, 1, 2, 4}, dtype::Float32(), + {1.12535162e-07, 1.12535162e-07, 1.12535162e-07, 1.12535162e-07, + 1.12535162e-07, 1.12535162e-07, 1.12535162e-07, 1.12535162e-07, + 1.12535162e-07, 1.12535162e-07, 1.12535162e-07, 1.12535162e-07, + 1.12535162e-07, 1.12535162e-07, 1.12535162e-07, 1.12535162e-07, + 9.99999887e-01, 9.99999887e-01, 9.99999887e-01, 9.99999887e-01, + 9.99999887e-01, 9.99999887e-01, 9.99999887e-01, 9.99999887e-01, + 9.99999887e-01, 9.99999887e-01, 9.99999887e-01, 9.99999887e-01, + 9.99999887e-01, 9.99999887e-01, 9.99999887e-01, 9.99999887e-01}); + checker.set_param(param).exect(Testcase{input2, {}}, Testcase{{}, output2}); +} + +TEST_F(NAIVE, SOFTMAX_BACKWARD_NHWCD4) { + Checker checker(handle(), false); + Softmax::Param param{0}; + + TensorND input = TensorValue( + {2, 2, 1, 2, 4}, dtype::Float32(), + {1.12535162e-07, 1.12535162e-07, 1.12535162e-07, 1.12535162e-07, + 1.12535162e-07, 1.12535162e-07, 1.12535162e-07, 1.12535162e-07, + 1.12535162e-07, 1.12535162e-07, 1.12535162e-07, 1.12535162e-07, + 1.12535162e-07, 1.12535162e-07, 1.12535162e-07, 1.12535162e-07, + 9.99999887e-01, 9.99999887e-01, 9.99999887e-01, 9.99999887e-01, + 9.99999887e-01, 9.99999887e-01, 9.99999887e-01, 9.99999887e-01, + 9.99999887e-01, 9.99999887e-01, 9.99999887e-01, 9.99999887e-01, + 9.99999887e-01, 9.99999887e-01, 9.99999887e-01, 9.99999887e-01}); + + TensorND diff = TensorValue( + {2, 2, 1, 2, 4}, dtype::Float32(), + {1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., + 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.}); + + TensorND output = TensorValue( + {2, 2, 1, 2, 4}, dtype::Float32(), + {0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.}); + + checker.set_param(param).exect(Testcase{input, diff, {}}, Testcase{{}, {}, output}); +}