You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

pooling.cpp 14 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355
  1. /**
  2. * \file dnn/test/cuda/pooling.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "test/cuda/fixture.h"
  12. #include "megdnn/tensor_iter.h"
  13. #include "test/common/checker.h"
  14. #include "test/common/pooling.h"
  15. #include "src/common/utils.h"
  16. #include "test/cuda/utils.h"
  17. // to check cudnn version
  18. #include <cudnn.h>
  19. #include "test/cuda/benchmark.h"
  20. namespace megdnn {
  21. namespace test {
  22. TEST_F(CUDA, POOLING_FORWARD)
  23. {
  24. auto args = pooling::get_args();
  25. using Format = param::Pooling::Format;
  26. std::vector<DType> dtypes{dtype::Float16(), dtype::BFloat16(), dtype::Float32()};
  27. if (check_compute_capability(6, 0)) {
  28. // int pooling is supported only for Pascal or higher
  29. dtypes.push_back(dtype::Int8());
  30. }
  31. for (auto dtype: dtypes)
  32. for (auto format: {Format::NCHW, Format::NHWC})
  33. for (auto &&arg: args) {
  34. auto param = arg.param;
  35. auto src = arg.ishape;
  36. param.format = format;
  37. if (param.format == Format::NHWC) {
  38. src = cvt_src_or_dst_nchw2nhwc(src);
  39. }
  40. Checker<Pooling> checker(handle_cuda());
  41. if (dtype == dtype::Int8()) {
  42. // different versions of cuDNN differs in rounding behavior;
  43. // setting eps to 1 to allow for rounding errors.
  44. checker.set_epsilon(1 + 1e-3);
  45. } else if (dtype == dtype::BFloat16()) {
  46. checker.set_epsilon(2e-2);
  47. } else {
  48. checker.set_epsilon(1e-2);
  49. }
  50. checker.set_param(param)
  51. .set_dtype(0, dtype)
  52. .set_dtype(1, dtype)
  53. .exec(TensorShapeArray{
  54. src, {}});
  55. }
  56. /* add test for new Mode temporarily */
  57. for (auto dtype: dtypes)
  58. for (auto format: {Format::NCHW, Format::NHWC})
  59. for(auto &&arg : args) {
  60. auto param = arg.param;
  61. if(param.mode == Pooling::Mode::AVERAGE)
  62. param.mode = Pooling::Mode::AVERAGE_COUNT_EXCLUDE_PADDING;
  63. else continue;
  64. auto src = arg.ishape;
  65. param.format = format;
  66. if (param.format == Format::NHWC) {
  67. src = cvt_src_or_dst_nchw2nhwc(src);
  68. }
  69. Checker<Pooling> checker(handle_cuda());
  70. if (dtype == dtype::Int8()) {
  71. // different versions of cuDNN differs in rounding behavior;
  72. // setting eps to 1 to allow for rounding errors.
  73. checker.set_epsilon(1 + 1e-3);
  74. } else if (dtype == dtype::BFloat16()) {
  75. checker.set_epsilon(2e-2);
  76. }
  77. else {
  78. checker.set_epsilon(1e-2);
  79. }
  80. checker.set_param(param)
  81. .set_dtype(0, dtype)
  82. .set_dtype(1, dtype)
  83. .exec(TensorShapeArray{
  84. src, {}});
  85. }
  86. }
  87. TEST_F(CUDA, POOLING_BACKWARD)
  88. {
  89. auto args = pooling::get_args();
  90. for (auto &&arg: args) {
  91. Checker<PoolingBackward> checker(handle_cuda());
  92. TensorLayout ilayout = TensorLayout(arg.ishape, dtype::Float32());
  93. TensorLayout olayout;
  94. auto constraint = [this,
  95. arg](CheckerHelper::TensorValueArray& tensors_orig) {
  96. megdnn_assert(tensors_orig.size() == 4);
  97. auto opr = handle_cuda()->create_operator<PoolingForward>();
  98. opr->param() = arg.param;
  99. auto tensors_cuda_storage = CheckerHelper::alloc_tensors(
  100. handle_cuda(),
  101. {tensors_orig[0].layout, tensors_orig[1].layout}, 0);
  102. auto&& tensors_cuda = *tensors_cuda_storage;
  103. auto span = tensors_cuda[0].layout.span();
  104. auto dst = static_cast<dt_byte*>(tensors_cuda[0].raw_ptr) +
  105. span.low_byte;
  106. auto src = static_cast<const dt_byte*>(tensors_orig[0].raw_ptr) +
  107. span.low_byte;
  108. megdnn_memcpy_H2D(handle_cuda(), dst, src, span.dist_byte());
  109. auto workspace_size = opr->get_workspace_in_bytes(
  110. tensors_cuda[0].layout, tensors_cuda[1].layout);
  111. auto workspace_cuda = megdnn_malloc(handle_cuda(), workspace_size);
  112. Workspace workspace{static_cast<dt_byte*>(workspace_cuda),
  113. workspace_size};
  114. opr->exec(tensors_cuda[0], tensors_cuda[1], workspace);
  115. megdnn_free(handle_cuda(), workspace_cuda);
  116. span = tensors_cuda[1].layout.span();
  117. dst = static_cast<dt_byte*>(tensors_orig[1].raw_ptr) +
  118. span.low_byte;
  119. src = static_cast<const dt_byte*>(tensors_cuda[1].raw_ptr) +
  120. span.low_byte;
  121. megdnn_memcpy_D2H(handle_cuda(), dst, src, span.dist_byte());
  122. };
  123. {
  124. auto opr = handle_cuda()->create_operator<PoolingForward>();
  125. opr->param() = arg.param;
  126. opr->deduce_layout(ilayout, olayout);
  127. }
  128. auto set_dtype = [&checker](DType dtype)
  129. {
  130. checker.set_dtype(0, dtype).
  131. set_dtype(1, dtype).
  132. set_dtype(2, dtype).
  133. set_dtype(3, dtype);
  134. };
  135. checker.set_tensors_constraint(constraint);
  136. set_dtype(dtype::Float32());
  137. checker.set_param(arg.param).exec(TensorShapeArray{
  138. ilayout, olayout, olayout, ilayout});
  139. Float16PeriodicalRNG rng;
  140. set_dtype(dtype::Float16());
  141. checker
  142. .set_param(arg.param)
  143. .set_rng(0, &rng)
  144. .set_epsilon(1e-2)
  145. .exec(TensorShapeArray{
  146. ilayout, olayout, olayout, ilayout});
  147. BFloat16PeriodicalRNG bf16_rng;
  148. set_dtype(dtype::BFloat16());
  149. checker.set_param(arg.param)
  150. .set_rng(0, &bf16_rng)
  151. .set_epsilon(1e-2)
  152. .exec(TensorShapeArray{ilayout, olayout, olayout, ilayout});
  153. }
  154. /* add test for new Mode temporarily */
  155. for(auto &&arg : args) {
  156. if(arg.param.mode == Pooling::Mode::AVERAGE)
  157. arg.param.mode = Pooling::Mode::AVERAGE_COUNT_EXCLUDE_PADDING;
  158. else continue;
  159. Checker<PoolingBackward> checker(handle_cuda());
  160. TensorLayout ilayout = TensorLayout(arg.ishape, dtype::Float32());
  161. TensorLayout olayout;
  162. auto constraint = [this,
  163. arg](CheckerHelper::TensorValueArray& tensors_orig) {
  164. megdnn_assert(tensors_orig.size() == 4);
  165. auto opr = handle_cuda()->create_operator<PoolingForward>();
  166. opr->param() = arg.param;
  167. auto tensors_cuda_storage = CheckerHelper::alloc_tensors(
  168. handle_cuda(),
  169. {tensors_orig[0].layout, tensors_orig[1].layout}, 0);
  170. auto&& tensors_cuda = *tensors_cuda_storage;
  171. auto span = tensors_cuda[0].layout.span();
  172. auto dst = static_cast<dt_byte*>(tensors_cuda[0].raw_ptr) +
  173. span.low_byte;
  174. auto src = static_cast<const dt_byte*>(tensors_orig[0].raw_ptr) +
  175. span.low_byte;
  176. megdnn_memcpy_H2D(handle_cuda(), dst, src, span.dist_byte());
  177. auto workspace_size = opr->get_workspace_in_bytes(
  178. tensors_cuda[0].layout, tensors_cuda[1].layout);
  179. auto workspace_cuda = megdnn_malloc(handle_cuda(), workspace_size);
  180. Workspace workspace{static_cast<dt_byte*>(workspace_cuda),
  181. workspace_size};
  182. opr->exec(tensors_cuda[0], tensors_cuda[1], workspace);
  183. megdnn_free(handle_cuda(), workspace_cuda);
  184. span = tensors_cuda[1].layout.span();
  185. dst = static_cast<dt_byte*>(tensors_orig[1].raw_ptr) +
  186. span.low_byte;
  187. src = static_cast<const dt_byte*>(tensors_cuda[1].raw_ptr) +
  188. span.low_byte;
  189. megdnn_memcpy_D2H(handle_cuda(), dst, src, span.dist_byte());
  190. };
  191. {
  192. auto opr = handle_cuda()->create_operator<PoolingForward>();
  193. opr->param() = arg.param;
  194. opr->deduce_layout(ilayout, olayout);
  195. }
  196. auto set_dtype = [&checker](DType dtype)
  197. {
  198. checker.set_dtype(0, dtype).
  199. set_dtype(1, dtype).
  200. set_dtype(2, dtype).
  201. set_dtype(3, dtype);
  202. };
  203. checker.set_tensors_constraint(constraint);
  204. set_dtype(dtype::Float32());
  205. checker.set_param(arg.param).exec(TensorShapeArray{
  206. ilayout, olayout, olayout, ilayout});
  207. Float16PeriodicalRNG rng;
  208. set_dtype(dtype::Float16());
  209. checker
  210. .set_param(arg.param)
  211. .set_rng(0, &rng)
  212. .set_epsilon(1e-2)
  213. .exec(TensorShapeArray{
  214. ilayout, olayout, olayout, ilayout});
  215. BFloat16PeriodicalRNG bf16_rng;
  216. set_dtype(dtype::BFloat16());
  217. checker.set_param(arg.param)
  218. .set_rng(0, &bf16_rng)
  219. .set_epsilon(1e-2)
  220. .exec(TensorShapeArray{ilayout, olayout, olayout, ilayout});
  221. }
  222. }
  223. TEST_F(CUDA, POOLING_FORWARD_NCHW4) {
  224. require_compute_capability(7, 5);
  225. using Param = param::Pooling;
  226. Checker<Pooling> checker(handle_cuda());
  227. Param param;
  228. checker.set_dtype(0, dtype::QuantizedS8(0.1f));
  229. param.format = Param::Format::NCHW4;
  230. checker.set_epsilon(1 + 1e-3);
  231. checker.set_param(param).exec({{20, 3, 50, 50, 4}, {}});
  232. }
  233. #if CUDNN_VERSION >= 7500
  234. TEST_F(CUDA, POOLING_FORWARD_NCHW32) {
  235. require_compute_capability(7, 5);
  236. using Param = param::Pooling;
  237. Checker<Pooling> checker(handle_cuda());
  238. Param param;
  239. auto i8_min = std::numeric_limits<int8_t>().min();
  240. auto i8_max = std::numeric_limits<int8_t>().max();
  241. UniformIntRNG int_rng{i8_min, i8_max};
  242. checker.set_dtype(0, dtype::QuantizedS8(0.1f));
  243. param.format = Param::Format::NCHW32;
  244. checker.set_epsilon(1e-3).set_rng(0, &int_rng);
  245. checker.set_param(param).exec({{64, 8, 28, 28, 32}, {}});
  246. }
  247. #endif
  248. TEST_F(CUDA, POOLING_FORWARD_CHWN4) {
  249. require_compute_capability(6, 1);
  250. using Param = param::Pooling;
  251. Checker<Pooling> checker(handle_cuda());
  252. Param param;
  253. auto i8_min = std::numeric_limits<int8_t>().min();
  254. auto i8_max = std::numeric_limits<int8_t>().max();
  255. UniformIntRNG int_rng{i8_min, i8_max};
  256. checker.set_dtype(0, dtype::QuantizedS8(0.1f));
  257. param.format = Param::Format::CHWN4;
  258. for (auto mode : {Param::Mode::MAX, Param::Mode::AVERAGE,
  259. Param::Mode::AVERAGE_COUNT_EXCLUDE_PADDING}) {
  260. param.mode = mode;
  261. checker.set_epsilon(1e-3).set_rng(0, &int_rng);
  262. checker.set_param(param).exec({{8, 28, 28, 64, 4}, {}});
  263. checker.set_param(param).exec({{8, 28, 28, 15, 4}, {}});
  264. checker.set_param(param).exec({{8, 28, 28, 30, 4}, {}});
  265. }
  266. }
  267. TEST_F(CUDA, POOLING_FORWARD_INT8_NCHW4) {
  268. require_compute_capability(6, 1);
  269. using Param = param::Pooling;
  270. Checker<Pooling> checker(handle_cuda());
  271. Param param;
  272. auto i8_min = std::numeric_limits<int8_t>().min();
  273. auto i8_max = std::numeric_limits<int8_t>().max();
  274. UniformIntRNG int_rng{i8_min, i8_max};
  275. checker.set_dtype(0, dtype::QuantizedS8(0.1f));
  276. param.format = Param::Format::NCHW4;
  277. for (auto mode : {Param::Mode::MAX, Param::Mode::AVERAGE,
  278. Param::Mode::AVERAGE_COUNT_EXCLUDE_PADDING}) {
  279. param.mode = mode;
  280. checker.set_epsilon(1e-3).set_rng(0, &int_rng);
  281. checker.set_param(param).exec({{64, 8, 28, 28, 4}, {}});
  282. checker.set_param(param).exec({{15, 8, 28, 28, 4}, {}});
  283. checker.set_param(param).exec({{30, 8, 28, 28, 4}, {}});
  284. }
  285. }
  286. #if MEGDNN_WITH_BENCHMARK
  287. TEST_F(CUDA, BENCHMARK_POOLING_CHWN4) {
  288. CUBenchmarker<Pooling> bencher(handle_cuda());
  289. size_t nr_times = 1000;
  290. bencher.set_times(nr_times);
  291. using Param = param::Pooling;
  292. Param param;
  293. auto run_bench = [&](size_t N, size_t C, size_t H, size_t W, size_t stride,
  294. size_t padding, size_t window,
  295. Param::Mode mode = Param::Mode::MAX) {
  296. param.mode = mode;
  297. param.pad_h = param.pad_w = padding;
  298. param.window_h = param.window_w = window;
  299. param.stride_h = param.stride_w = stride;
  300. param.format = Param::Format::NCHW4;
  301. bencher.set_dtype(0, dtype::QuantizedS8{0.1f});
  302. bencher.set_param(param);
  303. auto time_cudnn = bencher.execs({{N, C / 4, H, W, 4}, {}}) / nr_times;
  304. param.format = Param::Format::CHWN4;
  305. bencher.set_param(param);
  306. auto time_chwn4 = bencher.execs({{C / 4, H, W, N, 4}, {}}) / nr_times;
  307. size_t oh = infer_conv_shape(H, window, stride, padding),
  308. ow = infer_conv_shape(W, window, stride, padding);
  309. float io = (N * C * H * W + N * C * oh * ow) * sizeof(int8_t);
  310. printf("time(cudnn)=%.2f ms, time(chwn4)=%.2f ms, "
  311. "bandwidth(cudnn)=%.2f Gb/s, bandwidth(chwn4)=%.2f Gb/s\n",
  312. time_cudnn, time_chwn4, io / (1e6 * time_cudnn),
  313. io / (1e6 * time_chwn4));
  314. };
  315. run_bench(64, 64, 112, 112, 2, 1, 2);
  316. run_bench(256, 64, 112, 112, 2, 1, 2);
  317. run_bench(64, 64, 112, 112, 2, 1, 2, Param::Mode::AVERAGE);
  318. run_bench(256, 64, 112, 112, 2, 1, 2, Param::Mode::AVERAGE);
  319. run_bench(64, 64, 112, 112, 2, 1, 2,
  320. Param::Mode::AVERAGE_COUNT_EXCLUDE_PADDING);
  321. run_bench(256, 64, 112, 112, 2, 1, 2,
  322. Param::Mode::AVERAGE_COUNT_EXCLUDE_PADDING);
  323. }
  324. #endif
  325. } // namespace test
  326. } // namespace megdnn
  327. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台