You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dct.cpp 16 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337
  1. #include "megdnn/oprs/nn.h"
  2. #include "test/common/benchmarker.h"
  3. #include "test/common/checker.h"
  4. #include "test/common/dct_ref.h"
  5. #include "test/common/rng.h"
  6. #include "test/cuda/fixture.h"
  7. namespace megdnn {
  8. namespace test {
  9. TEST_F(CUDA, DCT) {
  10. DctChannelSelectForward::Param param;
  11. Checker<DctChannelSelectForward> checker(handle_cuda());
  12. for (size_t n : {1, 3}) {
  13. for (size_t ic : {1, 3}) {
  14. for (size_t ih : {8, 16, 32, 512, 1024}) {
  15. for (size_t iw : {8, 16, 32, 64, 128, 256, 512, 1024}) {
  16. checker.set_param(param)
  17. .set_dtype(0, dtype::Uint8())
  18. .set_dtype(1, dtype::Int32())
  19. .set_dtype(2, dtype::Int32())
  20. .execs({TensorShape{n, ic, ih, iw}, {}, {}, {}});
  21. }
  22. }
  23. }
  24. }
  25. }
  26. TEST_F(CUDA, DCT_QINT8) {
  27. DctChannelSelectForward::Param param;
  28. Checker<DctChannelSelectForward> checker(handle_cuda());
  29. param.format = Param::Format::NCHW4;
  30. for (size_t n : {1, 3}) {
  31. for (size_t ic : {1, 3}) {
  32. for (size_t ih : {8, 16, 32, 512, 1024}) {
  33. for (size_t iw : {8, 16, 32, 64, 128, 256, 512, 1024}) {
  34. checker.set_param(param)
  35. .set_dtype(0, dtype::Uint8())
  36. .set_dtype(1, dtype::Int32())
  37. .set_dtype(2, dtype::Int32())
  38. .set_dtype(3, dtype::QuantizedS8(10.f))
  39. .set_epsilon(1)
  40. .execs({TensorShape{n, ic, ih, iw}, {}, {}, {}});
  41. }
  42. }
  43. }
  44. }
  45. }
  46. TEST_F(CUDA, DCT_WITH_FIX_32_MASK) {
  47. using Param = DctChannelSelectForward::Param;
  48. Param param;
  49. Checker<DctChannelSelectForward> checker(handle_cuda(), false);
  50. param.fastImpl = Param::FastImpl::FIX_32_MASK;
  51. auto test_case = gen_dct_case(3, 3, 1024, 768, 32, param);
  52. checker.set_param(param).exect(test_case->testcase_in, test_case->testcase_out);
  53. }
  54. TEST_F(CUDA, DCT_WITH_FIX_32_MASK_QINT8) {
  55. using Param = DctChannelSelectForward::Param;
  56. Param param;
  57. Checker<DctChannelSelectForward> checker(handle_cuda(), false);
  58. param.fastImpl = Param::FastImpl::FIX_32_MASK;
  59. param.format = Param::Format::NCHW4;
  60. auto test_case = gen_dct_case(3, 3, 1024, 768, 32, param, dtype::QuantizedS8(10.f));
  61. checker.set_param(param).set_epsilon(1).exect(
  62. test_case->testcase_in, test_case->testcase_out);
  63. }
  64. TEST_F(CUDA, DCT_WITH_MASK) {
  65. Checker<DctChannelSelectForward> checker(handle_cuda(), false);
  66. DctChannelSelectForward::Param param;
  67. checker.set_param(param).exect(
  68. Testcase{
  69. TensorValue(
  70. {1, 3, 8, 16}, dtype::Uint8(),
  71. {109, 39, 30, 115, 71, 15, 206, 139, 221, 5, 18, 16,
  72. 93, 185, 99, 102, 205, 172, 191, 29, 185, 6, 47, 84,
  73. 0, 47, 105, 203, 251, 73, 196, 83, 3, 211, 32, 181,
  74. 49, 111, 114, 83, 148, 232, 77, 17, 35, 2, 154, 100,
  75. 41, 135, 141, 206, 56, 91, 137, 199, 104, 192, 75, 122,
  76. 78, 65, 184, 69, 91, 82, 2, 172, 194, 240, 49, 145,
  77. 87, 210, 97, 190, 179, 93, 125, 105, 181, 207, 148, 178,
  78. 133, 53, 25, 198, 238, 151, 14, 120, 213, 195, 145, 20,
  79. 122, 107, 217, 185, 65, 5, 115, 110, 82, 206, 163, 86,
  80. 2, 2, 44, 125, 50, 38, 41, 106, 30, 5, 151, 243,
  81. 238, 181, 232, 191, 161, 57, 23, 204,
  82. 109, 39, 30, 115, 71, 15, 206, 139, 221, 5, 18, 16,
  83. 93, 185, 99, 102, 205, 172, 191, 29, 185, 6, 47, 84,
  84. 0, 47, 105, 203, 251, 73, 196, 83, 3, 211, 32, 181,
  85. 49, 111, 114, 83, 148, 232, 77, 17, 35, 2, 154, 100,
  86. 41, 135, 141, 206, 56, 91, 137, 199, 104, 192, 75, 122,
  87. 78, 65, 184, 69, 91, 82, 2, 172, 194, 240, 49, 145,
  88. 87, 210, 97, 190, 179, 93, 125, 105, 181, 207, 148, 178,
  89. 133, 53, 25, 198, 238, 151, 14, 120, 213, 195, 145, 20,
  90. 122, 107, 217, 185, 65, 5, 115, 110, 82, 206, 163, 86,
  91. 2, 2, 44, 125, 50, 38, 41, 106, 30, 5, 151, 243,
  92. 238, 181, 232, 191, 161, 57, 23, 204,
  93. 109, 39, 30, 115, 71, 15, 206, 139, 221, 5, 18, 16,
  94. 93, 185, 99, 102, 205, 172, 191, 29, 185, 6, 47, 84,
  95. 0, 47, 105, 203, 251, 73, 196, 83, 3, 211, 32, 181,
  96. 49, 111, 114, 83, 148, 232, 77, 17, 35, 2, 154, 100,
  97. 41, 135, 141, 206, 56, 91, 137, 199, 104, 192, 75, 122,
  98. 78, 65, 184, 69, 91, 82, 2, 172, 194, 240, 49, 145,
  99. 87, 210, 97, 190, 179, 93, 125, 105, 181, 207, 148, 178,
  100. 133, 53, 25, 198, 238, 151, 14, 120, 213, 195, 145, 20,
  101. 122, 107, 217, 185, 65, 5, 115, 110, 82, 206, 163, 86,
  102. 2, 2, 44, 125, 50, 38, 41, 106, 30, 5, 151, 243,
  103. 238, 181, 232, 191, 161, 57, 23, 204}),
  104. TensorValue({4}, dtype::Int32(), {0, 14, 22, 30}),
  105. TensorValue(
  106. {30}, dtype::Int32(),
  107. {8, 16, 9, 2, 3, 10, 17, 24, 32, 25, 18, 11, 4, 5, 0,
  108. 1, 8, 16, 9, 2, 3, 10, 0, 1, 8, 16, 9, 2, 3, 10}),
  109. {}},
  110. Testcase{
  111. {},
  112. {},
  113. {},
  114. TensorValue(
  115. {1, 30, 1, 2}, dtype::Float32(),
  116. {-22.850792, -97.862236, -101.043236, -4.727012,
  117. 28.275675, -157.96654, 42.1377, 45.06531,
  118. -149.77373, 24.487143, -8.054966, -13.990831,
  119. -6.9395194, -3.9211385, 64.79172, -12.363858,
  120. -47.875, 59., 56.271786, -62.725567,
  121. 120.522675, 16.559765, 85.74334, 112.904495,
  122. 99.375, 29.499973, 2.0220923, -19.681704,
  123. 890.12494, 941.25, -7.0498576, 99.47632,
  124. -22.850792, -97.862236, -101.043236, -4.727012,
  125. 28.275675, -157.96654, 42.1377, 45.06531,
  126. -149.77373, 24.487143, -8.054966, -13.990831,
  127. 890.12494, 941.25, -7.0498576, 99.47632,
  128. -22.850792, -97.862236, -101.043236, -4.727012,
  129. 28.275675, -157.96654, 42.1377, 45.06531,
  130. -149.77373, 24.487143, -8.054966, -13.990831})});
  131. }
  132. TEST_F(CUDA, DCT_WITH_MASK2) {
  133. Checker<DctChannelSelectForward> checker(handle_cuda(), false);
  134. DctChannelSelectForward::Param param;
  135. UniformIntRNG rng_oc(0, 3 * 64);
  136. for (size_t n : {1, 3}) {
  137. for (size_t ic : {1, 3}) {
  138. for (size_t ih : {8, 16, 32, 512, 1024}) {
  139. for (size_t iw : {8, 16, 32, 64, 128, 256, 512, 1024}) {
  140. int random_oc = static_cast<int>(rng_oc.gen_single_val());
  141. int max_oc = ic * 64;
  142. int mask_oc = (random_oc % max_oc) + 1;
  143. auto test_case = gen_dct_case(n, ic, ih, iw, mask_oc, param);
  144. checker.set_param(param).exect(
  145. test_case->testcase_in, test_case->testcase_out);
  146. }
  147. }
  148. }
  149. }
  150. }
  151. TEST_F(CUDA, DCT_WITH_MASK2_QINT8) {
  152. Checker<DctChannelSelectForward> checker(handle_cuda(), false);
  153. DctChannelSelectForward::Param param;
  154. param.format = DctChannelSelectForward::Param::Format::NCHW4;
  155. UniformIntRNG rng_oc(0, 3 * 64);
  156. for (size_t n : {1, 3}) {
  157. for (size_t ic : {1, 3}) {
  158. for (size_t ih : {8, 16, 32, 512, 1024}) {
  159. for (size_t iw : {8, 16, 32, 64, 128, 256, 512, 1024}) {
  160. int random_oc = static_cast<int>(rng_oc.gen_single_val());
  161. int max_oc = ic * 64;
  162. int mask_oc = (random_oc % max_oc) + 1;
  163. mask_oc = (mask_oc + 3) / 4 * 4;
  164. auto test_case = gen_dct_case(
  165. n, ic, ih, iw, mask_oc, param, dtype::QuantizedS8(10.f));
  166. checker.set_param(param).set_epsilon(1).exect(
  167. test_case->testcase_in, test_case->testcase_out);
  168. }
  169. }
  170. }
  171. }
  172. }
  173. TEST_F(CUDA, DCT_WITH_MASK2_QINT8_CONSTRAINT) {
  174. DctChannelSelectForward::Param param;
  175. param.format = DctChannelSelectForward::Param::Format::NCHW4;
  176. Checker<DctChannelSelectForward> checker(handle_cuda(), false);
  177. checker.set_param(param)
  178. .set_dtype(0, dtype::Uint8())
  179. .set_dtype(1, dtype::Int32())
  180. .set_dtype(2, dtype::Int32())
  181. .set_dtype(3, dtype::QuantizedS8(10.f))
  182. .set_epsilon(1);
  183. UniformIntRNG rng_oc(0, 3 * 64);
  184. for (size_t n : {1, 3}) {
  185. for (size_t ic : {1, 3}) {
  186. for (size_t ih : {8, 16, 32, 512, 1024}) {
  187. for (size_t iw : {8, 16, 32, 64, 128, 256, 512, 1024}) {
  188. int random_oc = static_cast<int>(rng_oc.gen_single_val());
  189. int max_oc = ic * 64;
  190. int mask_oc = (random_oc % max_oc) + 1;
  191. mask_oc = (mask_oc + 3) / 4 * 4;
  192. if (mask_oc < max_oc) {
  193. checker
  194. .set_tensors_constraint(gen_dct_constriant(
  195. n, ic, ih, iw, mask_oc, param))
  196. .exec({TensorShape{n, ic, ih, iw},
  197. TensorShape{ic + 1},
  198. TensorShape{(size_t)mask_oc},
  199. {}});
  200. } else {
  201. checker.set_tensors_constraint({}).exec(
  202. {TensorShape{n, ic, ih, iw}, {}, {}, {}});
  203. }
  204. }
  205. }
  206. }
  207. }
  208. }
  209. #if MEGDNN_WITH_BENCHMARK
  210. TEST_F(CUDA, BENCHMARK_DCT) {
  211. using Param = DctChannelSelectForward::Param;
  212. auto run = [&](const TensorShapeArray& shapes, Param param) {
  213. Benchmarker<DctChannelSelectForward> benchmarker(handle_cuda());
  214. benchmarker.set_param(param);
  215. benchmarker.set_dtype(0, dtype::Uint8())
  216. .set_dtype(1, dtype::Int32())
  217. .set_dtype(2, dtype::Int32());
  218. for (auto&& shape : shapes) {
  219. double computation =
  220. double(shape[0]) * shape[1] * shape[2] * shape[3] * 32.0 * 1e-6;
  221. auto time_ms = benchmarker.execs({shape, {}, {}, {}});
  222. printf("execute %s, %.4f Gops\n", shape.to_string().c_str(),
  223. computation / time_ms);
  224. }
  225. };
  226. auto run_case = [&](const DctTestcase& testcase, Param param,
  227. std::string comment = "") {
  228. Benchmarker<DctChannelSelectForward> benchmarker(handle_cuda());
  229. benchmarker.set_param(param);
  230. benchmarker.set_dtype(0, dtype::Uint8())
  231. .set_dtype(1, dtype::Int32())
  232. .set_dtype(2, dtype::Int32())
  233. .set_dtype(3, testcase.testcase_out[3].layout.dtype);
  234. auto src_shape = testcase.testcase_in[0].layout;
  235. double computation = double(src_shape[0]) * src_shape[1] * src_shape[2] *
  236. src_shape[3] * 32.0 * 1e-6;
  237. auto time_ms = benchmarker.exect(testcase.testcase_in);
  238. printf("[%s] execute %s, %.4f Gops\n", comment.c_str(),
  239. src_shape.to_string().c_str(), computation / time_ms);
  240. };
  241. auto run_case_constraint =
  242. [&](const Benchmarker<DctChannelSelectForward>::TensorsConstriant&
  243. constraint,
  244. Param param, const TensorShapeArray& shapes, std::string comment = "",
  245. DType output_dtype) {
  246. Benchmarker<DctChannelSelectForward> benchmarker(handle_cuda());
  247. benchmarker.set_param(param)
  248. .set_dtype(0, dtype::Uint8())
  249. .set_dtype(1, dtype::Int32())
  250. .set_dtype(2, dtype::Int32())
  251. .set_dtype(3, output_dtype)
  252. .set_tensors_constraint(constraint);
  253. auto src_shape = shapes[0];
  254. double computation = double(src_shape[0]) * src_shape[1] *
  255. src_shape[2] * src_shape[3] * 32.0 * 1e-6;
  256. auto time_ms = benchmarker.exec(shapes);
  257. printf("[%s] execute %s, %.4f Gops\n", comment.c_str(),
  258. src_shape.to_string().c_str(), computation / time_ms);
  259. };
  260. TensorShapeArray shapes = {
  261. {1, 3, 512, 512},
  262. {8, 3, 2176, 3840},
  263. };
  264. {
  265. Param param;
  266. run(shapes, param);
  267. }
  268. Param fix_32_param;
  269. fix_32_param.fastImpl = Param::FastImpl::FIX_32_MASK;
  270. {
  271. auto test_case = gen_dct_case(8, 3, 2176, 3840, 32, fix_32_param);
  272. run_case(*test_case, fix_32_param, "FIX_32_MASK");
  273. }
  274. {
  275. Param param;
  276. auto test_case = gen_dct_case(8, 3, 2176, 3840, 32, fix_32_param);
  277. run_case(*test_case, param, "MASK 32");
  278. }
  279. {
  280. Param fix_32_nchw4_param;
  281. fix_32_nchw4_param.fastImpl = Param::FastImpl::FIX_32_MASK;
  282. fix_32_nchw4_param.format = Param::Format::NCHW4;
  283. auto test_case = gen_dct_case(
  284. 8, 3, 2176, 3840, 32, fix_32_nchw4_param, dtype::QuantizedS8(10.f));
  285. run_case(*test_case, fix_32_nchw4_param, "FIX_32_MASK QINT8");
  286. }
  287. {
  288. Param fix_32_nchw4_param;
  289. fix_32_nchw4_param.fastImpl = Param::FastImpl::FIX_32_MASK;
  290. fix_32_nchw4_param.format = Param::Format::NCHW4;
  291. auto test_case = gen_dct_case(
  292. 8, 3, 2176, 3840, 32, fix_32_nchw4_param, dtype::QuantizedS8(10.f));
  293. fix_32_nchw4_param.fastImpl = Param::FastImpl::NONE;
  294. run_case(*test_case, fix_32_nchw4_param, "MASK 32 QINT8");
  295. }
  296. {
  297. Param fix_32_nchw4_param;
  298. fix_32_nchw4_param.fastImpl = Param::FastImpl::FIX_32_MASK;
  299. fix_32_nchw4_param.format = Param::Format::NCHW4;
  300. TensorShapeArray shapes = {{8, 3, 2176, 3840}, {4}, {32}, {}};
  301. auto constraint = gen_dct_constriant(8, 3, 2176, 3840, 32, fix_32_nchw4_param);
  302. run_case_constraint(
  303. constraint, fix_32_nchw4_param, shapes, "FIX_32_MASK QINT8 Constraint",
  304. dtype::QuantizedS8(10.f));
  305. }
  306. }
  307. #endif
  308. } // namespace test
  309. } // namespace megdnn
  310. // vim: syntax=cpp.doxygen