You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360
  1. /**
  2. * \file dnn/test/cuda/dct.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "megdnn/oprs/nn.h"
  13. #include "test/common/benchmarker.h"
  14. #include "test/common/checker.h"
  15. #include "test/common/dct_ref.h"
  16. #include "test/common/rng.h"
  17. #include "test/cuda/fixture.h"
  18. namespace megdnn {
  19. namespace test {
  20. TEST_F(CUDA, DCT) {
  21. DctChannelSelectForward::Param param;
  22. Checker<DctChannelSelectForward> checker(handle_cuda());
  23. for (size_t n : {1, 3}) {
  24. for (size_t ic : {1, 3}) {
  25. for (size_t ih : {8, 16, 32, 512, 1024}) {
  26. for (size_t iw : {8, 16, 32, 64, 128, 256, 512, 1024}) {
  27. checker.set_param(param)
  28. .set_dtype(0, dtype::Uint8())
  29. .set_dtype(1, dtype::Int32())
  30. .set_dtype(2, dtype::Int32())
  31. .execs({TensorShape{n, ic, ih, iw}, {}, {}, {}});
  32. }
  33. }
  34. }
  35. }
  36. }
  37. TEST_F(CUDA, DCT_QINT8) {
  38. DctChannelSelectForward::Param param;
  39. Checker<DctChannelSelectForward> checker(handle_cuda());
  40. param.format = Param::Format::NCHW4;
  41. for (size_t n : {1, 3}) {
  42. for (size_t ic : {1, 3}) {
  43. for (size_t ih : {8, 16, 32, 512, 1024}) {
  44. for (size_t iw : {8, 16, 32, 64, 128, 256, 512, 1024}) {
  45. checker.set_param(param)
  46. .set_dtype(0, dtype::Uint8())
  47. .set_dtype(1, dtype::Int32())
  48. .set_dtype(2, dtype::Int32())
  49. .set_dtype(3, dtype::QuantizedS8(10.f))
  50. .set_epsilon(1)
  51. .execs({TensorShape{n, ic, ih, iw}, {}, {}, {}});
  52. }
  53. }
  54. }
  55. }
  56. }
  57. TEST_F(CUDA, DCT_WITH_FIX_32_MASK) {
  58. using Param = DctChannelSelectForward::Param;
  59. Param param;
  60. Checker<DctChannelSelectForward> checker(handle_cuda(), false);
  61. param.fastImpl = Param::FastImpl::FIX_32_MASK;
  62. auto test_case = gen_dct_case(3, 3, 1024, 768, 32, param);
  63. checker.set_param(param).exect(test_case->testcase_in,
  64. test_case->testcase_out);
  65. }
  66. TEST_F(CUDA, DCT_WITH_FIX_32_MASK_QINT8) {
  67. using Param = DctChannelSelectForward::Param;
  68. Param param;
  69. Checker<DctChannelSelectForward> checker(handle_cuda(), false);
  70. param.fastImpl = Param::FastImpl::FIX_32_MASK;
  71. param.format = Param::Format::NCHW4;
  72. auto test_case =
  73. gen_dct_case(3, 3, 1024, 768, 32, param, dtype::QuantizedS8(10.f));
  74. checker.set_param(param).set_epsilon(1).exect(test_case->testcase_in,
  75. test_case->testcase_out);
  76. }
  77. TEST_F(CUDA, DCT_WITH_MASK) {
  78. Checker<DctChannelSelectForward> checker(handle_cuda(), false);
  79. DctChannelSelectForward::Param param;
  80. checker.set_param(param).exect(
  81. Testcase{TensorValue(
  82. {1, 3, 8, 16}, dtype::Uint8(),
  83. {109, 39, 30, 115, 71, 15, 206, 139, 221, 5,
  84. 18, 16, 93, 185, 99, 102, 205, 172, 191, 29,
  85. 185, 6, 47, 84, 0, 47, 105, 203, 251, 73,
  86. 196, 83, 3, 211, 32, 181, 49, 111, 114, 83,
  87. 148, 232, 77, 17, 35, 2, 154, 100, 41, 135,
  88. 141, 206, 56, 91, 137, 199, 104, 192, 75, 122,
  89. 78, 65, 184, 69, 91, 82, 2, 172, 194, 240,
  90. 49, 145, 87, 210, 97, 190, 179, 93, 125, 105,
  91. 181, 207, 148, 178, 133, 53, 25, 198, 238, 151,
  92. 14, 120, 213, 195, 145, 20, 122, 107, 217, 185,
  93. 65, 5, 115, 110, 82, 206, 163, 86, 2, 2,
  94. 44, 125, 50, 38, 41, 106, 30, 5, 151, 243,
  95. 238, 181, 232, 191, 161, 57, 23, 204,
  96. 109, 39, 30, 115, 71, 15, 206, 139, 221, 5,
  97. 18, 16, 93, 185, 99, 102, 205, 172, 191, 29,
  98. 185, 6, 47, 84, 0, 47, 105, 203, 251, 73,
  99. 196, 83, 3, 211, 32, 181, 49, 111, 114, 83,
  100. 148, 232, 77, 17, 35, 2, 154, 100, 41, 135,
  101. 141, 206, 56, 91, 137, 199, 104, 192, 75, 122,
  102. 78, 65, 184, 69, 91, 82, 2, 172, 194, 240,
  103. 49, 145, 87, 210, 97, 190, 179, 93, 125, 105,
  104. 181, 207, 148, 178, 133, 53, 25, 198, 238, 151,
  105. 14, 120, 213, 195, 145, 20, 122, 107, 217, 185,
  106. 65, 5, 115, 110, 82, 206, 163, 86, 2, 2,
  107. 44, 125, 50, 38, 41, 106, 30, 5, 151, 243,
  108. 238, 181, 232, 191, 161, 57, 23, 204,
  109. 109, 39, 30, 115, 71, 15, 206, 139, 221, 5,
  110. 18, 16, 93, 185, 99, 102, 205, 172, 191, 29,
  111. 185, 6, 47, 84, 0, 47, 105, 203, 251, 73,
  112. 196, 83, 3, 211, 32, 181, 49, 111, 114, 83,
  113. 148, 232, 77, 17, 35, 2, 154, 100, 41, 135,
  114. 141, 206, 56, 91, 137, 199, 104, 192, 75, 122,
  115. 78, 65, 184, 69, 91, 82, 2, 172, 194, 240,
  116. 49, 145, 87, 210, 97, 190, 179, 93, 125, 105,
  117. 181, 207, 148, 178, 133, 53, 25, 198, 238, 151,
  118. 14, 120, 213, 195, 145, 20, 122, 107, 217, 185,
  119. 65, 5, 115, 110, 82, 206, 163, 86, 2, 2,
  120. 44, 125, 50, 38, 41, 106, 30, 5, 151, 243,
  121. 238, 181, 232, 191, 161, 57, 23, 204}),
  122. TensorValue({4}, dtype::Int32(), {0, 14, 22, 30}),
  123. TensorValue({30}, dtype::Int32(),
  124. {8, 16, 9, 2, 3, 10, 17, 24, 32, 25,
  125. 18, 11, 4, 5, 0, 1, 8, 16, 9, 2,
  126. 3, 10, 0, 1, 8, 16, 9, 2, 3, 10}),
  127. {}},
  128. Testcase{{},
  129. {},
  130. {},
  131. TensorValue({1, 30, 1, 2}, dtype::Float32(),
  132. {-22.850792, -97.862236, -101.043236,
  133. -4.727012, 28.275675, -157.96654,
  134. 42.1377, 45.06531, -149.77373,
  135. 24.487143, -8.054966, -13.990831,
  136. -6.9395194, -3.9211385, 64.79172,
  137. -12.363858, -47.875, 59.,
  138. 56.271786, -62.725567, 120.522675,
  139. 16.559765, 85.74334, 112.904495,
  140. 99.375, 29.499973, 2.0220923,
  141. -19.681704, 890.12494, 941.25,
  142. -7.0498576, 99.47632, -22.850792,
  143. -97.862236, -101.043236, -4.727012,
  144. 28.275675, -157.96654, 42.1377,
  145. 45.06531, -149.77373, 24.487143,
  146. -8.054966, -13.990831, 890.12494,
  147. 941.25, -7.0498576, 99.47632,
  148. -22.850792, -97.862236, -101.043236,
  149. -4.727012, 28.275675, -157.96654,
  150. 42.1377, 45.06531, -149.77373,
  151. 24.487143, -8.054966, -13.990831})});
  152. }
  153. TEST_F(CUDA, DCT_WITH_MASK2) {
  154. Checker<DctChannelSelectForward> checker(handle_cuda(), false);
  155. DctChannelSelectForward::Param param;
  156. UniformIntRNG rng_oc(0, 3 * 64);
  157. for (size_t n : {1, 3}) {
  158. for (size_t ic : {1, 3}) {
  159. for (size_t ih : {8, 16, 32, 512, 1024}) {
  160. for (size_t iw : {8, 16, 32, 64, 128, 256, 512, 1024}) {
  161. int random_oc = static_cast<int>(rng_oc.gen_single_val());
  162. int max_oc = ic * 64;
  163. int mask_oc = (random_oc % max_oc) + 1;
  164. auto test_case =
  165. gen_dct_case(n, ic, ih, iw, mask_oc, param);
  166. checker.set_param(param).exect(test_case->testcase_in,
  167. test_case->testcase_out);
  168. }
  169. }
  170. }
  171. }
  172. }
  173. TEST_F(CUDA, DCT_WITH_MASK2_QINT8) {
  174. Checker<DctChannelSelectForward> checker(handle_cuda(), false);
  175. DctChannelSelectForward::Param param;
  176. param.format = DctChannelSelectForward::Param::Format::NCHW4;
  177. UniformIntRNG rng_oc(0, 3 * 64);
  178. for (size_t n : {1, 3}) {
  179. for (size_t ic : {1, 3}) {
  180. for (size_t ih : {8, 16, 32, 512, 1024}) {
  181. for (size_t iw : {8, 16, 32, 64, 128, 256, 512, 1024}) {
  182. int random_oc = static_cast<int>(rng_oc.gen_single_val());
  183. int max_oc = ic * 64;
  184. int mask_oc = (random_oc % max_oc) + 1;
  185. mask_oc = (mask_oc + 3) / 4 * 4;
  186. auto test_case = gen_dct_case(n, ic, ih, iw, mask_oc, param,
  187. dtype::QuantizedS8(10.f));
  188. checker.set_param(param).set_epsilon(1).exect(
  189. test_case->testcase_in, test_case->testcase_out);
  190. }
  191. }
  192. }
  193. }
  194. }
  195. TEST_F(CUDA, DCT_WITH_MASK2_QINT8_CONSTRAINT) {
  196. DctChannelSelectForward::Param param;
  197. param.format = DctChannelSelectForward::Param::Format::NCHW4;
  198. Checker<DctChannelSelectForward> checker(handle_cuda(), false);
  199. checker.set_param(param)
  200. .set_dtype(0, dtype::Uint8())
  201. .set_dtype(1, dtype::Int32())
  202. .set_dtype(2, dtype::Int32())
  203. .set_dtype(3, dtype::QuantizedS8(10.f))
  204. .set_epsilon(1);
  205. UniformIntRNG rng_oc(0, 3 * 64);
  206. for (size_t n : {1, 3}) {
  207. for (size_t ic : {1, 3}) {
  208. for (size_t ih : {8, 16, 32, 512, 1024}) {
  209. for (size_t iw : {8, 16, 32, 64, 128, 256, 512, 1024}) {
  210. int random_oc = static_cast<int>(rng_oc.gen_single_val());
  211. int max_oc = ic * 64;
  212. int mask_oc = (random_oc % max_oc) + 1;
  213. mask_oc = (mask_oc + 3) / 4 * 4;
  214. if (mask_oc < max_oc) {
  215. checker
  216. .set_tensors_constraint(gen_dct_constriant(
  217. n, ic, ih, iw, mask_oc, param))
  218. .exec({TensorShape{n, ic, ih, iw},
  219. TensorShape{ic + 1},
  220. TensorShape{(size_t)mask_oc},
  221. {}});
  222. } else {
  223. checker.set_tensors_constraint({}).exec(
  224. {TensorShape{n, ic, ih, iw}, {}, {}, {}});
  225. }
  226. }
  227. }
  228. }
  229. }
  230. }
  231. #if MEGDNN_WITH_BENCHMARK
  232. TEST_F(CUDA, BENCHMARK_DCT) {
  233. using Param = DctChannelSelectForward::Param;
  234. auto run = [&](const TensorShapeArray& shapes, Param param) {
  235. Benchmarker<DctChannelSelectForward> benchmarker(handle_cuda());
  236. benchmarker.set_param(param);
  237. benchmarker.set_dtype(0, dtype::Uint8())
  238. .set_dtype(1, dtype::Int32())
  239. .set_dtype(2, dtype::Int32());
  240. for (auto&& shape : shapes) {
  241. double computation = double(shape[0]) * shape[1] * shape[2] *
  242. shape[3] * 32.0 * 1e-6;
  243. auto time_ms = benchmarker.execs({shape, {}, {}, {}});
  244. printf("execute %s, %.4f Gops\n", shape.to_string().c_str(),
  245. computation / time_ms);
  246. }
  247. };
  248. auto run_case = [&](const DctTestcase& testcase, Param param,
  249. std::string comment = "") {
  250. Benchmarker<DctChannelSelectForward> benchmarker(handle_cuda());
  251. benchmarker.set_param(param);
  252. benchmarker.set_dtype(0, dtype::Uint8())
  253. .set_dtype(1, dtype::Int32())
  254. .set_dtype(2, dtype::Int32())
  255. .set_dtype(3, testcase.testcase_out[3].layout.dtype);
  256. auto src_shape = testcase.testcase_in[0].layout;
  257. double computation = double(src_shape[0]) * src_shape[1] *
  258. src_shape[2] * src_shape[3] * 32.0 * 1e-6;
  259. auto time_ms = benchmarker.exect(testcase.testcase_in);
  260. printf("[%s] execute %s, %.4f Gops\n", comment.c_str(),
  261. src_shape.to_string().c_str(), computation / time_ms);
  262. };
  263. auto run_case_constraint =
  264. [&](const Benchmarker<DctChannelSelectForward>::TensorsConstriant&
  265. constraint,
  266. Param param, const TensorShapeArray& shapes,
  267. std::string comment = "", DType output_dtype) {
  268. Benchmarker<DctChannelSelectForward> benchmarker(handle_cuda());
  269. benchmarker.set_param(param)
  270. .set_dtype(0, dtype::Uint8())
  271. .set_dtype(1, dtype::Int32())
  272. .set_dtype(2, dtype::Int32())
  273. .set_dtype(3, output_dtype)
  274. .set_tensors_constraint(constraint);
  275. auto src_shape = shapes[0];
  276. double computation = double(src_shape[0]) * src_shape[1] *
  277. src_shape[2] * src_shape[3] * 32.0 * 1e-6;
  278. auto time_ms = benchmarker.exec(shapes);
  279. printf("[%s] execute %s, %.4f Gops\n", comment.c_str(),
  280. src_shape.to_string().c_str(), computation / time_ms);
  281. };
  282. TensorShapeArray shapes = {
  283. {1, 3, 512, 512},
  284. {8, 3, 2176, 3840},
  285. };
  286. {
  287. Param param;
  288. run(shapes, param);
  289. }
  290. Param fix_32_param;
  291. fix_32_param.fastImpl = Param::FastImpl::FIX_32_MASK;
  292. {
  293. auto test_case = gen_dct_case(8, 3, 2176, 3840, 32, fix_32_param);
  294. run_case(*test_case, fix_32_param, "FIX_32_MASK");
  295. }
  296. {
  297. Param param;
  298. auto test_case = gen_dct_case(8, 3, 2176, 3840, 32, fix_32_param);
  299. run_case(*test_case, param, "MASK 32");
  300. }
  301. {
  302. Param fix_32_nchw4_param;
  303. fix_32_nchw4_param.fastImpl = Param::FastImpl::FIX_32_MASK;
  304. fix_32_nchw4_param.format = Param::Format::NCHW4;
  305. auto test_case = gen_dct_case(8, 3, 2176, 3840, 32, fix_32_nchw4_param,
  306. dtype::QuantizedS8(10.f));
  307. run_case(*test_case, fix_32_nchw4_param, "FIX_32_MASK QINT8");
  308. }
  309. {
  310. Param fix_32_nchw4_param;
  311. fix_32_nchw4_param.fastImpl = Param::FastImpl::FIX_32_MASK;
  312. fix_32_nchw4_param.format = Param::Format::NCHW4;
  313. auto test_case = gen_dct_case(8, 3, 2176, 3840, 32, fix_32_nchw4_param,
  314. dtype::QuantizedS8(10.f));
  315. fix_32_nchw4_param.fastImpl = Param::FastImpl::NONE;
  316. run_case(*test_case, fix_32_nchw4_param, "MASK 32 QINT8");
  317. }
  318. {
  319. Param fix_32_nchw4_param;
  320. fix_32_nchw4_param.fastImpl = Param::FastImpl::FIX_32_MASK;
  321. fix_32_nchw4_param.format = Param::Format::NCHW4;
  322. TensorShapeArray shapes = {{8, 3, 2176, 3840}, {4}, {32}, {}};
  323. auto constraint =
  324. gen_dct_constriant(8, 3, 2176, 3840, 32, fix_32_nchw4_param);
  325. run_case_constraint(constraint, fix_32_nchw4_param, shapes,
  326. "FIX_32_MASK QINT8 Constraint",
  327. dtype::QuantizedS8(10.f));
  328. }
  329. }
  330. #endif
  331. } // namespace test
  332. } // namespace megdnn
  333. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台