You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

pooling.cpp 8.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199
  1. #include "hcc_detail/hcc_defs_prologue.h"
  2. #include "test/rocm/fixture.h"
  3. #include "megdnn/tensor_iter.h"
  4. #include "test/common/checker.h"
  5. #include "test/common/pooling.h"
  6. #include "test/rocm/benchmarker.h"
  7. #include "src/common/utils.h"
  8. #include "src/rocm/utils.h"
  9. namespace megdnn {
  10. namespace test {
  11. TEST_F(ROCM, POOLING_FORWARD) {
  12. auto args = pooling::get_args();
  13. using Format = param::Pooling::Format;
  14. std::vector<DType> dtypes{DNN_INC_FLOAT16(dtype::Float16() MEGDNN_COMMA)
  15. dtype::Float32()};
  16. for (auto dtype : dtypes)
  17. for (auto format : {Format::NCHW})
  18. for (auto&& arg : args) {
  19. auto param = arg.param;
  20. auto src = arg.ishape;
  21. param.format = format;
  22. Checker<Pooling> checker(handle_rocm());
  23. checker.set_epsilon(1e-2);
  24. checker.set_param(param).set_dtype(0, dtype).set_dtype(1, dtype).exec(
  25. TensorShapeArray{src, {}});
  26. }
  27. }
  28. TEST_F(ROCM, POOLING_BACKWARD) {
  29. auto args = pooling::get_args();
  30. for (auto&& arg : args) {
  31. Checker<PoolingBackward> checker(handle_rocm());
  32. TensorLayout ilayout = TensorLayout(arg.ishape, dtype::Float32());
  33. TensorLayout olayout;
  34. auto constraint = [this, arg](CheckerHelper::TensorValueArray& tensors_orig) {
  35. megdnn_assert(tensors_orig.size() == 4);
  36. auto opr = handle_rocm()->create_operator<PoolingForward>();
  37. opr->param() = arg.param;
  38. auto tensors_rocm_storage = CheckerHelper::alloc_tensors(
  39. handle_rocm(), {tensors_orig[0].layout, tensors_orig[1].layout}, 0);
  40. auto&& tensors_rocm = *tensors_rocm_storage;
  41. auto span = tensors_rocm[0].layout.span();
  42. auto dst = static_cast<dt_byte*>(tensors_rocm[0].raw_ptr()) + span.low_byte;
  43. auto src = static_cast<const dt_byte*>(tensors_orig[0].raw_ptr()) +
  44. span.low_byte;
  45. megdnn_memcpy_H2D(handle_rocm(), dst, src, span.dist_byte());
  46. auto workspace_size = opr->get_workspace_in_bytes(
  47. tensors_rocm[0].layout, tensors_rocm[1].layout);
  48. auto workspace_rocm = megdnn_malloc(handle_rocm(), workspace_size);
  49. Workspace workspace{static_cast<dt_byte*>(workspace_rocm), workspace_size};
  50. opr->exec(tensors_rocm[0], tensors_rocm[1], workspace);
  51. megdnn_free(handle_rocm(), workspace_rocm);
  52. span = tensors_rocm[1].layout.span();
  53. dst = static_cast<dt_byte*>(tensors_orig[1].raw_ptr()) + span.low_byte;
  54. src = static_cast<const dt_byte*>(tensors_rocm[1].raw_ptr()) +
  55. span.low_byte;
  56. megdnn_memcpy_D2H(handle_rocm(), dst, src, span.dist_byte());
  57. };
  58. {
  59. auto opr = handle_rocm()->create_operator<PoolingForward>();
  60. opr->param() = arg.param;
  61. opr->deduce_layout(ilayout, olayout);
  62. }
  63. auto set_dtype = [&checker](DType dtype) {
  64. checker.set_dtype(0, dtype)
  65. .set_dtype(1, dtype)
  66. .set_dtype(2, dtype)
  67. .set_dtype(3, dtype);
  68. };
  69. checker.set_tensors_constraint(constraint);
  70. set_dtype(dtype::Float32());
  71. checker.set_param(arg.param).exec(
  72. TensorShapeArray{ilayout, olayout, olayout, ilayout});
  73. #if !MEGDNN_DISABLE_FLOAT16
  74. //! FIXME: MIOpen pooling backward for fp16 with bug
  75. #if 0
  76. Float16PeriodicalRNG rng;
  77. set_dtype(dtype::Float16());
  78. checker.set_param(arg.param).set_rng(0, &rng).set_epsilon(1e-2).exec(
  79. TensorShapeArray{ilayout, olayout, olayout, ilayout});
  80. #endif
  81. #endif
  82. }
  83. }
  84. #if MEGDNN_WITH_BENCHMARK
  85. TEST_F(ROCM, POOLING_FWD_BENCHMARK) {
  86. megdnn::rocm::enable_miopen_algo_search(handle_rocm(), true);
  87. auto benchmarker =
  88. ROCMBenchmarker<PoolingForward>(handle_rocm(), handle_naive(false));
  89. auto run = [&](size_t N, size_t IC, size_t IH, size_t IW, size_t SH = 1,
  90. size_t SW = 1, size_t FH = 1, size_t FW = 1, size_t PH = 0,
  91. size_t PW = 0, DType dtype = dtype::Float32()) {
  92. benchmarker.set_dtype(0, dtype).set_dtype(1, dtype);
  93. benchmarker.set_display(true);
  94. PoolingForward::Param param;
  95. param.mode = param::Pooling::Mode::AVERAGE_COUNT_EXCLUDE_PADDING;
  96. param.stride_h = SH;
  97. param.stride_w = SW;
  98. param.pad_h = PH;
  99. param.pad_w = PW;
  100. param.window_h = FH;
  101. param.window_w = FW;
  102. benchmarker.set_param(param);
  103. size_t OH = infer_conv_shape(IH, FH, SH, PH);
  104. size_t OW = infer_conv_shape(IW, FW, SW, PW);
  105. // warm up
  106. benchmarker.execs({{N, IC, IH, IW}, {N, IC, OH, OW}});
  107. // do actual benchmark
  108. auto time_ms = benchmarker.execs({{N, IC, IH, IW}, {N, IC, OH, OW}});
  109. time_ms = benchmarker.execs({{N, IC, IH, IW}, {N, IC, OH, OW}});
  110. auto io = (double)N * IC * OH * OW * (1 + FH * FW) * dtype.size();
  111. auto gbps = io / (time_ms * 1e6);
  112. printf("io %.2fGB, flops %.3fGB/s\n", io / 1e9, gbps);
  113. };
  114. run(32, 128, 80, 64, 2, 2, 2, 2, 0, 0);
  115. run(32, 128, 40, 128, 2, 2, 2, 2, 0, 0);
  116. run(32, 224, 40, 32, 2, 2, 2, 2, 0, 0);
  117. run(32, 24, 160, 128, 2, 2, 4, 4, 0, 0);
  118. run(32, 24, 160, 128, 2, 2, 4, 4, 1, 1);
  119. }
  120. TEST_F(ROCM, POOLING_BWD_BENCHMARK) {
  121. using Mode = param::Pooling::Mode;
  122. megdnn::rocm::enable_miopen_algo_search(handle_rocm(), true);
  123. auto benchmarker =
  124. ROCMBenchmarker<PoolingBackward>(handle_rocm(), handle_naive(false));
  125. auto run = [&](size_t N, size_t IC, size_t IH, size_t IW, size_t SH = 1,
  126. size_t SW = 1, size_t FH = 1, size_t FW = 1, size_t PH = 0,
  127. size_t PW = 0, Mode mode = Mode::AVERAGE_COUNT_EXCLUDE_PADDING,
  128. DType dtype = dtype::Float32()) {
  129. benchmarker.set_dtype(0, dtype).set_dtype(1, dtype);
  130. benchmarker.set_display(true);
  131. PoolingForward::Param param;
  132. param.mode = mode;
  133. param.stride_h = SH;
  134. param.stride_w = SW;
  135. param.pad_h = PH;
  136. param.pad_w = PW;
  137. param.window_h = FH;
  138. param.window_w = FW;
  139. benchmarker.set_param(param);
  140. size_t OH = infer_conv_shape(IH, FH, SH, PH);
  141. size_t OW = infer_conv_shape(IW, FW, SW, PW);
  142. // warm up
  143. benchmarker.execs(
  144. {{N, IC, IH, IW}, {N, IC, OH, OW}, {N, IC, OH, OW}, {N, IC, IH, IW}});
  145. // do actual benchmark
  146. auto time_ms = benchmarker.execs(
  147. {{N, IC, IH, IW}, {N, IC, OH, OW}, {N, IC, OH, OW}, {N, IC, IH, IW}});
  148. time_ms = benchmarker.execs(
  149. {{N, IC, IH, IW}, {N, IC, OH, OW}, {N, IC, OH, OW}, {N, IC, IH, IW}});
  150. double io = 0.;
  151. double gbps = 0.;
  152. if (mode == Mode::AVERAGE_COUNT_EXCLUDE_PADDING) {
  153. io = (double)N * IC * OH * OW * FH * FW * 2 * dtype.size();
  154. gbps = io / (time_ms * 1e6);
  155. } else {
  156. io = (double)N * IC * OH * OW * 2 * dtype.size();
  157. gbps = io / (time_ms * 1e6);
  158. }
  159. printf("Mode = %s, io %.2fGB, flops %.3fGB/s\n",
  160. mode == Mode::AVERAGE_COUNT_EXCLUDE_PADDING ? "AVERAGE" : "MAX",
  161. io / 1e9, gbps);
  162. };
  163. Mode mode = Mode::AVERAGE_COUNT_EXCLUDE_PADDING;
  164. run(32, 128, 80, 64, 2, 2, 2, 2, 0, 0, mode);
  165. run(32, 128, 40, 128, 2, 2, 2, 2, 0, 0, mode);
  166. run(32, 224, 40, 32, 2, 2, 2, 2, 0, 0, mode);
  167. run(32, 24, 160, 128, 2, 2, 4, 4, 0, 0, mode);
  168. run(32, 24, 160, 128, 2, 2, 4, 4, 1, 1, mode);
  169. mode = Mode::MAX;
  170. run(32, 128, 80, 64, 2, 2, 2, 2, 0, 0, mode);
  171. run(32, 128, 40, 128, 2, 2, 2, 2, 0, 0, mode);
  172. run(32, 224, 40, 32, 2, 2, 2, 2, 0, 0, mode);
  173. run(32, 24, 160, 128, 2, 2, 4, 4, 0, 0, mode);
  174. run(32, 24, 160, 128, 2, 2, 4, 4, 1, 1, mode);
  175. }
  176. #endif
  177. } // namespace test
  178. } // namespace megdnn
  179. // vim: syntax=cpp.doxygen