You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

sliding_window_transpose.cpp 2.6 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. #include "test/cuda/fixture.h"
  2. #include "test/common/checker.h"
  3. #include "test/common/rng.h"
  4. #include "test/common/sliding_window_transpose.h"
  5. #include "test/cuda/benchmark.h"
  6. namespace megdnn {
  7. namespace test {
  8. TEST_F(CUDA, SLIDINGWINDOWTRANSPOSE_FORWARD) {
  9. UniformFloatRNG rng(0, 1);
  10. auto args = sliding_window_transpose::get_args();
  11. for (auto&& arg : args) {
  12. Checker<SlidingWindowTransposeForward> checker(handle_cuda());
  13. checker.set_rng(0, &rng);
  14. checker.set_epsilon(1e-2);
  15. TensorLayout ilayout = TensorLayout(arg.ishape, dtype::Float32());
  16. TensorLayout olayout;
  17. {
  18. auto opr = handle_cuda()->create_operator<SlidingWindowTransposeForward>();
  19. opr->param() = arg.param;
  20. opr->deduce_layout(ilayout, olayout);
  21. }
  22. auto set_dtype = [&checker](DType dtype) {
  23. checker.set_dtype(0, dtype).set_dtype(1, dtype);
  24. };
  25. set_dtype(dtype::Float32());
  26. checker.set_param(arg.param).exec(TensorShapeArray{ilayout, olayout});
  27. set_dtype(dtype::Float16());
  28. checker.set_param(arg.param).exec(TensorShapeArray{ilayout, olayout});
  29. }
  30. }
  31. TEST_F(CUDA, SLIDINGWINDOWTRANSPOSE_BACKWARD) {
  32. UniformFloatRNG rng(0, 1);
  33. auto args = sliding_window_transpose::get_args();
  34. for (auto&& arg : args) {
  35. Checker<SlidingWindowTransposeBackward> checker(handle_cuda());
  36. // checker.set_epsilon(1e-2);
  37. checker.set_rng(0, &rng);
  38. TensorLayout ilayout = TensorLayout(arg.ishape, dtype::Float32());
  39. TensorLayout olayout;
  40. {
  41. auto opr = handle_cuda()->create_operator<SlidingWindowTranspose>();
  42. opr->param() = arg.param;
  43. opr->deduce_layout(ilayout, olayout);
  44. }
  45. auto set_dtype = [&checker](DType dtype) {
  46. checker.set_dtype(0, dtype).set_dtype(1, dtype);
  47. };
  48. set_dtype(dtype::Float32());
  49. checker.set_param(arg.param).exec(TensorShapeArray{olayout, ilayout});
  50. set_dtype(dtype::Float16());
  51. checker.set_param(arg.param).exec(TensorShapeArray{olayout, ilayout});
  52. }
  53. }
  54. #if MEGDNN_WITH_BENCHMARK
  55. TEST_F(CUDA, BENCHMARK_SLIDINGWINDOWTRANSPOSE_FORWARD) {
  56. auto args = sliding_window_transpose::get_benchmark_args();
  57. for (auto&& arg : args) {
  58. CUBenchmarker<SlidingWindowTransposeForward> bencher(handle_cuda());
  59. bencher.set_param(arg.param)
  60. .set_dtype(0, dtype::Float32())
  61. .exec(TensorShapeArray{arg.ishape, {}});
  62. }
  63. }
  64. #endif
  65. } // namespace test
  66. } // namespace megdnn
  67. // vim: syntax=cpp.doxygen