You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

sliding_window_transpose.h 2.5 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. #pragma once
  2. #include <cstddef>
  3. #include "megdnn/basic_types.h"
  4. #include "megdnn/opr_param_defs.h"
  5. namespace megdnn {
  6. namespace test {
  7. namespace sliding_window_transpose {
  8. struct TestArg {
  9. param::SlidingWindowTranspose param;
  10. TensorShape ishape;
  11. TestArg(param::SlidingWindowTranspose param, TensorShape ishape)
  12. : param(param), ishape(ishape) {}
  13. };
  14. inline std::vector<TestArg> get_args() {
  15. std::vector<TestArg> args;
  16. // clang-format off
  17. for (uint32_t ih : {25, 96})
  18. for (uint32_t iw : {26, 128})
  19. for (uint32_t ph : {0, 1})
  20. for (uint32_t pw : {0, 1})
  21. for (uint32_t sh : {1, 2})
  22. for (uint32_t sw : {1, 2})
  23. for (uint32_t dh : {1, 2})
  24. for (uint32_t dw : {1, 2})
  25. for (uint32_t wh : {3, 4})
  26. for (uint32_t ww : {3, 4}) {
  27. unsigned long int oh = (ih + 2 * ph - dh * (wh-1)-1) / sh + 1;
  28. unsigned long int ow = (iw + 2 * pw - dw * (ww-1)-1) / sw + 1;
  29. args.emplace_back(param::SlidingWindowTranspose{ih, iw, ph, pw, sh, sw, dh, dw, wh, ww},
  30. TensorShape{2, 3, oh, ow, wh, ww});
  31. }
  32. // clang-format on
  33. // large window case
  34. args.emplace_back(
  35. param::SlidingWindowTranspose{96, 128, 0, 0, 1, 1, 1, 1, 32, 64},
  36. TensorShape{2, 3, 65, 65, 32, 64});
  37. // // large size
  38. args.emplace_back(
  39. param::SlidingWindowTranspose{28, 24, 0, 0, 1, 1, 1, 1, 1, 1},
  40. TensorShape{128, 128, 28, 24, 1, 1});
  41. return args;
  42. }
  43. inline std::vector<TestArg> get_benchmark_args() {
  44. std::vector<TestArg> args;
  45. // clang-format off
  46. for (uint32_t ph : {0, 1})
  47. for (uint32_t pw : {0, 1})
  48. for (uint32_t sh : {1, 2})
  49. for (uint32_t sw : {1, 2})
  50. for (uint32_t dh : {1, 2})
  51. for (uint32_t dw : {1, 2})
  52. for (uint32_t wh : {3, 4})
  53. for (uint32_t ww : {3, 4})
  54. for (uint32_t b : {1, 64})
  55. for (uint32_t c : {64, 128})
  56. for (uint32_t hw : {64, 128}) {
  57. unsigned long int o_hw = (hw + 2 * ph - dh * (wh-1)-1) / sh + 1;
  58. args.emplace_back(param::SlidingWindowTranspose{hw, hw, ph, pw, sh, sw, dh, dw, wh, ww},
  59. TensorShape{b, c, o_hw, o_hw, wh, ww});
  60. }
  61. // clang-format on
  62. // large size
  63. args.emplace_back(
  64. param::SlidingWindowTranspose{28, 24, 0, 0, 1, 1, 1, 1, 1, 1},
  65. TensorShape{1024, 128, 28, 24, 1, 1});
  66. return args;
  67. }
  68. } // namespace sliding_window_transpose
  69. } // namespace test
  70. } // namespace megdnn
  71. // vim: syntax=cpp.doxygen