You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dct.cpp 16 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348
  1. /**
  2. * \file dnn/test/cuda/dct.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "megdnn/oprs/nn.h"
  13. #include "test/common/benchmarker.h"
  14. #include "test/common/checker.h"
  15. #include "test/common/dct_ref.h"
  16. #include "test/common/rng.h"
  17. #include "test/cuda/fixture.h"
  18. namespace megdnn {
  19. namespace test {
  20. TEST_F(CUDA, DCT) {
  21. DctChannelSelectForward::Param param;
  22. Checker<DctChannelSelectForward> checker(handle_cuda());
  23. for (size_t n : {1, 3}) {
  24. for (size_t ic : {1, 3}) {
  25. for (size_t ih : {8, 16, 32, 512, 1024}) {
  26. for (size_t iw : {8, 16, 32, 64, 128, 256, 512, 1024}) {
  27. checker.set_param(param)
  28. .set_dtype(0, dtype::Uint8())
  29. .set_dtype(1, dtype::Int32())
  30. .set_dtype(2, dtype::Int32())
  31. .execs({TensorShape{n, ic, ih, iw}, {}, {}, {}});
  32. }
  33. }
  34. }
  35. }
  36. }
  37. TEST_F(CUDA, DCT_QINT8) {
  38. DctChannelSelectForward::Param param;
  39. Checker<DctChannelSelectForward> checker(handle_cuda());
  40. param.format = Param::Format::NCHW4;
  41. for (size_t n : {1, 3}) {
  42. for (size_t ic : {1, 3}) {
  43. for (size_t ih : {8, 16, 32, 512, 1024}) {
  44. for (size_t iw : {8, 16, 32, 64, 128, 256, 512, 1024}) {
  45. checker.set_param(param)
  46. .set_dtype(0, dtype::Uint8())
  47. .set_dtype(1, dtype::Int32())
  48. .set_dtype(2, dtype::Int32())
  49. .set_dtype(3, dtype::QuantizedS8(10.f))
  50. .set_epsilon(1)
  51. .execs({TensorShape{n, ic, ih, iw}, {}, {}, {}});
  52. }
  53. }
  54. }
  55. }
  56. }
  57. TEST_F(CUDA, DCT_WITH_FIX_32_MASK) {
  58. using Param = DctChannelSelectForward::Param;
  59. Param param;
  60. Checker<DctChannelSelectForward> checker(handle_cuda(), false);
  61. param.fastImpl = Param::FastImpl::FIX_32_MASK;
  62. auto test_case = gen_dct_case(3, 3, 1024, 768, 32, param);
  63. checker.set_param(param).exect(test_case->testcase_in, test_case->testcase_out);
  64. }
  65. TEST_F(CUDA, DCT_WITH_FIX_32_MASK_QINT8) {
  66. using Param = DctChannelSelectForward::Param;
  67. Param param;
  68. Checker<DctChannelSelectForward> checker(handle_cuda(), false);
  69. param.fastImpl = Param::FastImpl::FIX_32_MASK;
  70. param.format = Param::Format::NCHW4;
  71. auto test_case = gen_dct_case(3, 3, 1024, 768, 32, param, dtype::QuantizedS8(10.f));
  72. checker.set_param(param).set_epsilon(1).exect(
  73. test_case->testcase_in, test_case->testcase_out);
  74. }
  75. TEST_F(CUDA, DCT_WITH_MASK) {
  76. Checker<DctChannelSelectForward> checker(handle_cuda(), false);
  77. DctChannelSelectForward::Param param;
  78. checker.set_param(param).exect(
  79. Testcase{
  80. TensorValue(
  81. {1, 3, 8, 16}, dtype::Uint8(),
  82. {109, 39, 30, 115, 71, 15, 206, 139, 221, 5, 18, 16,
  83. 93, 185, 99, 102, 205, 172, 191, 29, 185, 6, 47, 84,
  84. 0, 47, 105, 203, 251, 73, 196, 83, 3, 211, 32, 181,
  85. 49, 111, 114, 83, 148, 232, 77, 17, 35, 2, 154, 100,
  86. 41, 135, 141, 206, 56, 91, 137, 199, 104, 192, 75, 122,
  87. 78, 65, 184, 69, 91, 82, 2, 172, 194, 240, 49, 145,
  88. 87, 210, 97, 190, 179, 93, 125, 105, 181, 207, 148, 178,
  89. 133, 53, 25, 198, 238, 151, 14, 120, 213, 195, 145, 20,
  90. 122, 107, 217, 185, 65, 5, 115, 110, 82, 206, 163, 86,
  91. 2, 2, 44, 125, 50, 38, 41, 106, 30, 5, 151, 243,
  92. 238, 181, 232, 191, 161, 57, 23, 204,
  93. 109, 39, 30, 115, 71, 15, 206, 139, 221, 5, 18, 16,
  94. 93, 185, 99, 102, 205, 172, 191, 29, 185, 6, 47, 84,
  95. 0, 47, 105, 203, 251, 73, 196, 83, 3, 211, 32, 181,
  96. 49, 111, 114, 83, 148, 232, 77, 17, 35, 2, 154, 100,
  97. 41, 135, 141, 206, 56, 91, 137, 199, 104, 192, 75, 122,
  98. 78, 65, 184, 69, 91, 82, 2, 172, 194, 240, 49, 145,
  99. 87, 210, 97, 190, 179, 93, 125, 105, 181, 207, 148, 178,
  100. 133, 53, 25, 198, 238, 151, 14, 120, 213, 195, 145, 20,
  101. 122, 107, 217, 185, 65, 5, 115, 110, 82, 206, 163, 86,
  102. 2, 2, 44, 125, 50, 38, 41, 106, 30, 5, 151, 243,
  103. 238, 181, 232, 191, 161, 57, 23, 204,
  104. 109, 39, 30, 115, 71, 15, 206, 139, 221, 5, 18, 16,
  105. 93, 185, 99, 102, 205, 172, 191, 29, 185, 6, 47, 84,
  106. 0, 47, 105, 203, 251, 73, 196, 83, 3, 211, 32, 181,
  107. 49, 111, 114, 83, 148, 232, 77, 17, 35, 2, 154, 100,
  108. 41, 135, 141, 206, 56, 91, 137, 199, 104, 192, 75, 122,
  109. 78, 65, 184, 69, 91, 82, 2, 172, 194, 240, 49, 145,
  110. 87, 210, 97, 190, 179, 93, 125, 105, 181, 207, 148, 178,
  111. 133, 53, 25, 198, 238, 151, 14, 120, 213, 195, 145, 20,
  112. 122, 107, 217, 185, 65, 5, 115, 110, 82, 206, 163, 86,
  113. 2, 2, 44, 125, 50, 38, 41, 106, 30, 5, 151, 243,
  114. 238, 181, 232, 191, 161, 57, 23, 204}),
  115. TensorValue({4}, dtype::Int32(), {0, 14, 22, 30}),
  116. TensorValue(
  117. {30}, dtype::Int32(),
  118. {8, 16, 9, 2, 3, 10, 17, 24, 32, 25, 18, 11, 4, 5, 0,
  119. 1, 8, 16, 9, 2, 3, 10, 0, 1, 8, 16, 9, 2, 3, 10}),
  120. {}},
  121. Testcase{
  122. {},
  123. {},
  124. {},
  125. TensorValue(
  126. {1, 30, 1, 2}, dtype::Float32(),
  127. {-22.850792, -97.862236, -101.043236, -4.727012,
  128. 28.275675, -157.96654, 42.1377, 45.06531,
  129. -149.77373, 24.487143, -8.054966, -13.990831,
  130. -6.9395194, -3.9211385, 64.79172, -12.363858,
  131. -47.875, 59., 56.271786, -62.725567,
  132. 120.522675, 16.559765, 85.74334, 112.904495,
  133. 99.375, 29.499973, 2.0220923, -19.681704,
  134. 890.12494, 941.25, -7.0498576, 99.47632,
  135. -22.850792, -97.862236, -101.043236, -4.727012,
  136. 28.275675, -157.96654, 42.1377, 45.06531,
  137. -149.77373, 24.487143, -8.054966, -13.990831,
  138. 890.12494, 941.25, -7.0498576, 99.47632,
  139. -22.850792, -97.862236, -101.043236, -4.727012,
  140. 28.275675, -157.96654, 42.1377, 45.06531,
  141. -149.77373, 24.487143, -8.054966, -13.990831})});
  142. }
  143. TEST_F(CUDA, DCT_WITH_MASK2) {
  144. Checker<DctChannelSelectForward> checker(handle_cuda(), false);
  145. DctChannelSelectForward::Param param;
  146. UniformIntRNG rng_oc(0, 3 * 64);
  147. for (size_t n : {1, 3}) {
  148. for (size_t ic : {1, 3}) {
  149. for (size_t ih : {8, 16, 32, 512, 1024}) {
  150. for (size_t iw : {8, 16, 32, 64, 128, 256, 512, 1024}) {
  151. int random_oc = static_cast<int>(rng_oc.gen_single_val());
  152. int max_oc = ic * 64;
  153. int mask_oc = (random_oc % max_oc) + 1;
  154. auto test_case = gen_dct_case(n, ic, ih, iw, mask_oc, param);
  155. checker.set_param(param).exect(
  156. test_case->testcase_in, test_case->testcase_out);
  157. }
  158. }
  159. }
  160. }
  161. }
  162. TEST_F(CUDA, DCT_WITH_MASK2_QINT8) {
  163. Checker<DctChannelSelectForward> checker(handle_cuda(), false);
  164. DctChannelSelectForward::Param param;
  165. param.format = DctChannelSelectForward::Param::Format::NCHW4;
  166. UniformIntRNG rng_oc(0, 3 * 64);
  167. for (size_t n : {1, 3}) {
  168. for (size_t ic : {1, 3}) {
  169. for (size_t ih : {8, 16, 32, 512, 1024}) {
  170. for (size_t iw : {8, 16, 32, 64, 128, 256, 512, 1024}) {
  171. int random_oc = static_cast<int>(rng_oc.gen_single_val());
  172. int max_oc = ic * 64;
  173. int mask_oc = (random_oc % max_oc) + 1;
  174. mask_oc = (mask_oc + 3) / 4 * 4;
  175. auto test_case = gen_dct_case(
  176. n, ic, ih, iw, mask_oc, param, dtype::QuantizedS8(10.f));
  177. checker.set_param(param).set_epsilon(1).exect(
  178. test_case->testcase_in, test_case->testcase_out);
  179. }
  180. }
  181. }
  182. }
  183. }
  184. TEST_F(CUDA, DCT_WITH_MASK2_QINT8_CONSTRAINT) {
  185. DctChannelSelectForward::Param param;
  186. param.format = DctChannelSelectForward::Param::Format::NCHW4;
  187. Checker<DctChannelSelectForward> checker(handle_cuda(), false);
  188. checker.set_param(param)
  189. .set_dtype(0, dtype::Uint8())
  190. .set_dtype(1, dtype::Int32())
  191. .set_dtype(2, dtype::Int32())
  192. .set_dtype(3, dtype::QuantizedS8(10.f))
  193. .set_epsilon(1);
  194. UniformIntRNG rng_oc(0, 3 * 64);
  195. for (size_t n : {1, 3}) {
  196. for (size_t ic : {1, 3}) {
  197. for (size_t ih : {8, 16, 32, 512, 1024}) {
  198. for (size_t iw : {8, 16, 32, 64, 128, 256, 512, 1024}) {
  199. int random_oc = static_cast<int>(rng_oc.gen_single_val());
  200. int max_oc = ic * 64;
  201. int mask_oc = (random_oc % max_oc) + 1;
  202. mask_oc = (mask_oc + 3) / 4 * 4;
  203. if (mask_oc < max_oc) {
  204. checker
  205. .set_tensors_constraint(gen_dct_constriant(
  206. n, ic, ih, iw, mask_oc, param))
  207. .exec({TensorShape{n, ic, ih, iw},
  208. TensorShape{ic + 1},
  209. TensorShape{(size_t)mask_oc},
  210. {}});
  211. } else {
  212. checker.set_tensors_constraint({}).exec(
  213. {TensorShape{n, ic, ih, iw}, {}, {}, {}});
  214. }
  215. }
  216. }
  217. }
  218. }
  219. }
  220. #if MEGDNN_WITH_BENCHMARK
  221. TEST_F(CUDA, BENCHMARK_DCT) {
  222. using Param = DctChannelSelectForward::Param;
  223. auto run = [&](const TensorShapeArray& shapes, Param param) {
  224. Benchmarker<DctChannelSelectForward> benchmarker(handle_cuda());
  225. benchmarker.set_param(param);
  226. benchmarker.set_dtype(0, dtype::Uint8())
  227. .set_dtype(1, dtype::Int32())
  228. .set_dtype(2, dtype::Int32());
  229. for (auto&& shape : shapes) {
  230. double computation =
  231. double(shape[0]) * shape[1] * shape[2] * shape[3] * 32.0 * 1e-6;
  232. auto time_ms = benchmarker.execs({shape, {}, {}, {}});
  233. printf("execute %s, %.4f Gops\n", shape.to_string().c_str(),
  234. computation / time_ms);
  235. }
  236. };
  237. auto run_case = [&](const DctTestcase& testcase, Param param,
  238. std::string comment = "") {
  239. Benchmarker<DctChannelSelectForward> benchmarker(handle_cuda());
  240. benchmarker.set_param(param);
  241. benchmarker.set_dtype(0, dtype::Uint8())
  242. .set_dtype(1, dtype::Int32())
  243. .set_dtype(2, dtype::Int32())
  244. .set_dtype(3, testcase.testcase_out[3].layout.dtype);
  245. auto src_shape = testcase.testcase_in[0].layout;
  246. double computation = double(src_shape[0]) * src_shape[1] * src_shape[2] *
  247. src_shape[3] * 32.0 * 1e-6;
  248. auto time_ms = benchmarker.exect(testcase.testcase_in);
  249. printf("[%s] execute %s, %.4f Gops\n", comment.c_str(),
  250. src_shape.to_string().c_str(), computation / time_ms);
  251. };
  252. auto run_case_constraint =
  253. [&](const Benchmarker<DctChannelSelectForward>::TensorsConstriant&
  254. constraint,
  255. Param param, const TensorShapeArray& shapes, std::string comment = "",
  256. DType output_dtype) {
  257. Benchmarker<DctChannelSelectForward> benchmarker(handle_cuda());
  258. benchmarker.set_param(param)
  259. .set_dtype(0, dtype::Uint8())
  260. .set_dtype(1, dtype::Int32())
  261. .set_dtype(2, dtype::Int32())
  262. .set_dtype(3, output_dtype)
  263. .set_tensors_constraint(constraint);
  264. auto src_shape = shapes[0];
  265. double computation = double(src_shape[0]) * src_shape[1] *
  266. src_shape[2] * src_shape[3] * 32.0 * 1e-6;
  267. auto time_ms = benchmarker.exec(shapes);
  268. printf("[%s] execute %s, %.4f Gops\n", comment.c_str(),
  269. src_shape.to_string().c_str(), computation / time_ms);
  270. };
  271. TensorShapeArray shapes = {
  272. {1, 3, 512, 512},
  273. {8, 3, 2176, 3840},
  274. };
  275. {
  276. Param param;
  277. run(shapes, param);
  278. }
  279. Param fix_32_param;
  280. fix_32_param.fastImpl = Param::FastImpl::FIX_32_MASK;
  281. {
  282. auto test_case = gen_dct_case(8, 3, 2176, 3840, 32, fix_32_param);
  283. run_case(*test_case, fix_32_param, "FIX_32_MASK");
  284. }
  285. {
  286. Param param;
  287. auto test_case = gen_dct_case(8, 3, 2176, 3840, 32, fix_32_param);
  288. run_case(*test_case, param, "MASK 32");
  289. }
  290. {
  291. Param fix_32_nchw4_param;
  292. fix_32_nchw4_param.fastImpl = Param::FastImpl::FIX_32_MASK;
  293. fix_32_nchw4_param.format = Param::Format::NCHW4;
  294. auto test_case = gen_dct_case(
  295. 8, 3, 2176, 3840, 32, fix_32_nchw4_param, dtype::QuantizedS8(10.f));
  296. run_case(*test_case, fix_32_nchw4_param, "FIX_32_MASK QINT8");
  297. }
  298. {
  299. Param fix_32_nchw4_param;
  300. fix_32_nchw4_param.fastImpl = Param::FastImpl::FIX_32_MASK;
  301. fix_32_nchw4_param.format = Param::Format::NCHW4;
  302. auto test_case = gen_dct_case(
  303. 8, 3, 2176, 3840, 32, fix_32_nchw4_param, dtype::QuantizedS8(10.f));
  304. fix_32_nchw4_param.fastImpl = Param::FastImpl::NONE;
  305. run_case(*test_case, fix_32_nchw4_param, "MASK 32 QINT8");
  306. }
  307. {
  308. Param fix_32_nchw4_param;
  309. fix_32_nchw4_param.fastImpl = Param::FastImpl::FIX_32_MASK;
  310. fix_32_nchw4_param.format = Param::Format::NCHW4;
  311. TensorShapeArray shapes = {{8, 3, 2176, 3840}, {4}, {32}, {}};
  312. auto constraint = gen_dct_constriant(8, 3, 2176, 3840, 32, fix_32_nchw4_param);
  313. run_case_constraint(
  314. constraint, fix_32_nchw4_param, shapes, "FIX_32_MASK QINT8 Constraint",
  315. dtype::QuantizedS8(10.f));
  316. }
  317. }
  318. #endif
  319. } // namespace test
  320. } // namespace megdnn
  321. // vim: syntax=cpp.doxygen