You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

softmax.cpp 1.9 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. /**
  2. * \file dnn/test/naive/softmax.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "test/naive/fixture.h"
  13. #include "megdnn/oprs/nn.h"
  14. #include "test/common/checker.h"
  15. using namespace megdnn;
  16. using namespace test;
  17. TEST_F(NAIVE, SOFTMAX_FORWARD) {
  18. Checker<Softmax> checker(handle(), /* check_dispatch */ false);
  19. Softmax::Param param{0};
  20. TensorND input = TensorValue(
  21. {2, 2, 2, 2}, dtype::Float32(),
  22. {0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.});
  23. TensorND output = TensorValue(
  24. {2, 2, 2, 2}, dtype::Float32(),
  25. {0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.9997,
  26. 0.9997, 0.9997, 0.9997, 0.9997, 0.9997, 0.9997, 0.9997});
  27. checker.set_param(param).exect(Testcase{input, {}}, Testcase{{}, output});
  28. }
  29. TEST_F(NAIVE, SOFTMAX_BACKWARD) {
  30. Checker<SoftmaxBackward> checker(handle(), /* check_dispatch */ false);
  31. Softmax::Param param{0};
  32. TensorND input = TensorValue(
  33. {2, 2, 2, 2}, dtype::Float32(),
  34. {0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.9997,
  35. 0.9997, 0.9997, 0.9997, 0.9997, 0.9997, 0.9997, 0.9997});
  36. TensorND diff = TensorValue(
  37. {2, 2, 2, 2}, dtype::Float32(),
  38. {1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.});
  39. TensorND output = TensorValue(
  40. {2, 2, 2, 2}, dtype::Float32(),
  41. {0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.});
  42. checker.set_param(param).exect(Testcase{input, diff, {}}, Testcase{{}, {}, output});
  43. }