You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

softmax.cpp 1.8 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. #include "test/cuda/fixture.h"
  2. #include "megdnn/tensor_iter.h"
  3. #include "test/common/checker.h"
  4. #include "test/common/softmax.h"
  5. #include "src/common/utils.h"
  6. #include "test/cuda/utils.h"
  7. // to check cudnn version
  8. #include <cudnn.h>
  9. #include "test/cuda/benchmark.h"
  10. namespace megdnn {
  11. namespace test {
  12. TEST_F(CUDA, SOFTMAX_FORWARD) {
  13. auto args = softmax::get_args();
  14. std::vector<DType> dtypes{dtype::Float16(), dtype::Float32()};
  15. for (auto dtype : dtypes)
  16. for (auto&& arg : args) {
  17. auto param = arg.param;
  18. auto src = arg.ishape;
  19. Checker<Softmax> checker(handle_cuda());
  20. if (dtype == dtype::BFloat16()) {
  21. checker.set_epsilon(2e-2);
  22. } else {
  23. checker.set_epsilon(1e-2);
  24. }
  25. checker.set_param(param).set_dtype(0, dtype).set_dtype(1, dtype).exec(
  26. TensorShapeArray{src, {}});
  27. }
  28. }
  29. TEST_F(CUDA, SOFTMAX_BACKWARD) {
  30. auto args = softmax::get_args();
  31. for (auto&& arg : args) {
  32. Checker<SoftmaxBackward> checker(handle_cuda());
  33. TensorLayout ilayout = TensorLayout(arg.ishape, dtype::Float32());
  34. TensorLayout olayout;
  35. {
  36. auto opr = handle_cuda()->create_operator<SoftmaxForward>();
  37. opr->param() = arg.param;
  38. opr->deduce_layout(ilayout, olayout);
  39. }
  40. auto set_dtype = [&checker](DType dtype) {
  41. checker.set_dtype(0, dtype).set_dtype(1, dtype).set_dtype(2, dtype);
  42. };
  43. set_dtype(dtype::Float32());
  44. checker.set_epsilon(1e-3).set_param(arg.param).exec(
  45. TensorShapeArray{ilayout, olayout, ilayout});
  46. }
  47. }
  48. } // namespace test
  49. } // namespace megdnn
  50. // vim: syntax=cpp.doxygen