You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

pooling.cpp 8.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209
  1. /**
  2. * \file dnn/test/rocm/pooling.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "hcc_detail/hcc_defs_prologue.h"
  12. #include "test/rocm/fixture.h"
  13. #include "megdnn/tensor_iter.h"
  14. #include "test/common/checker.h"
  15. #include "test/common/pooling.h"
  16. #include "test/rocm/benchmarker.h"
  17. #include "src/common/utils.h"
  18. #include "src/rocm/utils.h"
  19. namespace megdnn {
  20. namespace test {
  21. TEST_F(ROCM, POOLING_FORWARD) {
  22. auto args = pooling::get_args();
  23. using Format = param::Pooling::Format;
  24. std::vector<DType> dtypes{DNN_INC_FLOAT16(dtype::Float16() MEGDNN_COMMA)
  25. dtype::Float32()};
  26. for (auto dtype : dtypes)
  27. for (auto format : {Format::NCHW})
  28. for (auto&& arg : args) {
  29. auto param = arg.param;
  30. auto src = arg.ishape;
  31. param.format = format;
  32. Checker<Pooling> checker(handle_rocm());
  33. checker.set_epsilon(1e-2);
  34. checker.set_param(param).set_dtype(0, dtype).set_dtype(1, dtype).exec(
  35. TensorShapeArray{src, {}});
  36. }
  37. }
  38. TEST_F(ROCM, POOLING_BACKWARD) {
  39. auto args = pooling::get_args();
  40. for (auto&& arg : args) {
  41. Checker<PoolingBackward> checker(handle_rocm());
  42. TensorLayout ilayout = TensorLayout(arg.ishape, dtype::Float32());
  43. TensorLayout olayout;
  44. auto constraint = [this, arg](CheckerHelper::TensorValueArray& tensors_orig) {
  45. megdnn_assert(tensors_orig.size() == 4);
  46. auto opr = handle_rocm()->create_operator<PoolingForward>();
  47. opr->param() = arg.param;
  48. auto tensors_rocm_storage = CheckerHelper::alloc_tensors(
  49. handle_rocm(), {tensors_orig[0].layout, tensors_orig[1].layout}, 0);
  50. auto&& tensors_rocm = *tensors_rocm_storage;
  51. auto span = tensors_rocm[0].layout.span();
  52. auto dst = static_cast<dt_byte*>(tensors_rocm[0].raw_ptr()) + span.low_byte;
  53. auto src = static_cast<const dt_byte*>(tensors_orig[0].raw_ptr()) +
  54. span.low_byte;
  55. megdnn_memcpy_H2D(handle_rocm(), dst, src, span.dist_byte());
  56. auto workspace_size = opr->get_workspace_in_bytes(
  57. tensors_rocm[0].layout, tensors_rocm[1].layout);
  58. auto workspace_rocm = megdnn_malloc(handle_rocm(), workspace_size);
  59. Workspace workspace{static_cast<dt_byte*>(workspace_rocm), workspace_size};
  60. opr->exec(tensors_rocm[0], tensors_rocm[1], workspace);
  61. megdnn_free(handle_rocm(), workspace_rocm);
  62. span = tensors_rocm[1].layout.span();
  63. dst = static_cast<dt_byte*>(tensors_orig[1].raw_ptr()) + span.low_byte;
  64. src = static_cast<const dt_byte*>(tensors_rocm[1].raw_ptr()) +
  65. span.low_byte;
  66. megdnn_memcpy_D2H(handle_rocm(), dst, src, span.dist_byte());
  67. };
  68. {
  69. auto opr = handle_rocm()->create_operator<PoolingForward>();
  70. opr->param() = arg.param;
  71. opr->deduce_layout(ilayout, olayout);
  72. }
  73. auto set_dtype = [&checker](DType dtype) {
  74. checker.set_dtype(0, dtype)
  75. .set_dtype(1, dtype)
  76. .set_dtype(2, dtype)
  77. .set_dtype(3, dtype);
  78. };
  79. checker.set_tensors_constraint(constraint);
  80. set_dtype(dtype::Float32());
  81. checker.set_param(arg.param).exec(
  82. TensorShapeArray{ilayout, olayout, olayout, ilayout});
  83. #if !MEGDNN_DISABLE_FLOAT16
  84. //! FIXME: MIOpen pooling backward for fp16 with bug
  85. #if 0
  86. Float16PeriodicalRNG rng;
  87. set_dtype(dtype::Float16());
  88. checker.set_param(arg.param).set_rng(0, &rng).set_epsilon(1e-2).exec(
  89. TensorShapeArray{ilayout, olayout, olayout, ilayout});
  90. #endif
  91. #endif
  92. }
  93. }
  94. #if MEGDNN_WITH_BENCHMARK
  95. TEST_F(ROCM, POOLING_FWD_BENCHMARK) {
  96. megdnn::rocm::enable_miopen_algo_search(handle_rocm(), true);
  97. auto benchmarker =
  98. ROCMBenchmarker<PoolingForward>(handle_rocm(), handle_naive(false));
  99. auto run = [&](size_t N, size_t IC, size_t IH, size_t IW, size_t SH = 1,
  100. size_t SW = 1, size_t FH = 1, size_t FW = 1, size_t PH = 0,
  101. size_t PW = 0, DType dtype = dtype::Float32()) {
  102. benchmarker.set_dtype(0, dtype).set_dtype(1, dtype);
  103. benchmarker.set_display(true);
  104. PoolingForward::Param param;
  105. param.mode = param::Pooling::Mode::AVERAGE_COUNT_EXCLUDE_PADDING;
  106. param.stride_h = SH;
  107. param.stride_w = SW;
  108. param.pad_h = PH;
  109. param.pad_w = PW;
  110. param.window_h = FH;
  111. param.window_w = FW;
  112. benchmarker.set_param(param);
  113. size_t OH = infer_conv_shape(IH, FH, SH, PH);
  114. size_t OW = infer_conv_shape(IW, FW, SW, PW);
  115. // warm up
  116. benchmarker.execs({{N, IC, IH, IW}, {N, IC, OH, OW}});
  117. // do actual benchmark
  118. auto time_ms = benchmarker.execs({{N, IC, IH, IW}, {N, IC, OH, OW}});
  119. time_ms = benchmarker.execs({{N, IC, IH, IW}, {N, IC, OH, OW}});
  120. auto io = (double)N * IC * OH * OW * (1 + FH * FW) * dtype.size();
  121. auto gbps = io / (time_ms * 1e6);
  122. printf("io %.2fGB, flops %.3fGB/s\n", io / 1e9, gbps);
  123. };
  124. run(32, 128, 80, 64, 2, 2, 2, 2, 0, 0);
  125. run(32, 128, 40, 128, 2, 2, 2, 2, 0, 0);
  126. run(32, 224, 40, 32, 2, 2, 2, 2, 0, 0);
  127. run(32, 24, 160, 128, 2, 2, 4, 4, 0, 0);
  128. run(32, 24, 160, 128, 2, 2, 4, 4, 1, 1);
  129. }
  130. TEST_F(ROCM, POOLING_BWD_BENCHMARK) {
  131. using Mode = param::Pooling::Mode;
  132. megdnn::rocm::enable_miopen_algo_search(handle_rocm(), true);
  133. auto benchmarker =
  134. ROCMBenchmarker<PoolingBackward>(handle_rocm(), handle_naive(false));
  135. auto run = [&](size_t N, size_t IC, size_t IH, size_t IW, size_t SH = 1,
  136. size_t SW = 1, size_t FH = 1, size_t FW = 1, size_t PH = 0,
  137. size_t PW = 0, Mode mode = Mode::AVERAGE_COUNT_EXCLUDE_PADDING,
  138. DType dtype = dtype::Float32()) {
  139. benchmarker.set_dtype(0, dtype).set_dtype(1, dtype);
  140. benchmarker.set_display(true);
  141. PoolingForward::Param param;
  142. param.mode = mode;
  143. param.stride_h = SH;
  144. param.stride_w = SW;
  145. param.pad_h = PH;
  146. param.pad_w = PW;
  147. param.window_h = FH;
  148. param.window_w = FW;
  149. benchmarker.set_param(param);
  150. size_t OH = infer_conv_shape(IH, FH, SH, PH);
  151. size_t OW = infer_conv_shape(IW, FW, SW, PW);
  152. // warm up
  153. benchmarker.execs(
  154. {{N, IC, IH, IW}, {N, IC, OH, OW}, {N, IC, OH, OW}, {N, IC, IH, IW}});
  155. // do actual benchmark
  156. auto time_ms = benchmarker.execs(
  157. {{N, IC, IH, IW}, {N, IC, OH, OW}, {N, IC, OH, OW}, {N, IC, IH, IW}});
  158. time_ms = benchmarker.execs(
  159. {{N, IC, IH, IW}, {N, IC, OH, OW}, {N, IC, OH, OW}, {N, IC, IH, IW}});
  160. double io = 0.;
  161. double gbps = 0.;
  162. if (mode == Mode::AVERAGE_COUNT_EXCLUDE_PADDING) {
  163. io = (double)N * IC * OH * OW * FH * FW * 2 * dtype.size();
  164. gbps = io / (time_ms * 1e6);
  165. } else {
  166. io = (double)N * IC * OH * OW * 2 * dtype.size();
  167. gbps = io / (time_ms * 1e6);
  168. }
  169. printf("Mode = %s, io %.2fGB, flops %.3fGB/s\n",
  170. mode == Mode::AVERAGE_COUNT_EXCLUDE_PADDING ? "AVERAGE" : "MAX",
  171. io / 1e9, gbps);
  172. };
  173. Mode mode = Mode::AVERAGE_COUNT_EXCLUDE_PADDING;
  174. run(32, 128, 80, 64, 2, 2, 2, 2, 0, 0, mode);
  175. run(32, 128, 40, 128, 2, 2, 2, 2, 0, 0, mode);
  176. run(32, 224, 40, 32, 2, 2, 2, 2, 0, 0, mode);
  177. run(32, 24, 160, 128, 2, 2, 4, 4, 0, 0, mode);
  178. run(32, 24, 160, 128, 2, 2, 4, 4, 1, 1, mode);
  179. mode = Mode::MAX;
  180. run(32, 128, 80, 64, 2, 2, 2, 2, 0, 0, mode);
  181. run(32, 128, 40, 128, 2, 2, 2, 2, 0, 0, mode);
  182. run(32, 224, 40, 32, 2, 2, 2, 2, 0, 0, mode);
  183. run(32, 24, 160, 128, 2, 2, 4, 4, 0, 0, mode);
  184. run(32, 24, 160, 128, 2, 2, 4, 4, 1, 1, mode);
  185. }
  186. #endif
  187. } // namespace test
  188. } // namespace megdnn
  189. // vim: syntax=cpp.doxygen