You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

matrix_mul.cpp 20 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566
  1. /**
  2. * \file dnn/test/cuda/matrix_mul.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "test/cuda/fixture.h"
  13. #include "test/common/benchmarker.h"
  14. #include "test/common/checker.h"
  15. #include "test/common/matrix_mul.h"
  16. #if defined(cuda_check)
  17. #undef cuda_check
  18. #endif
  19. #include "test/cuda/utils.h"
  20. #include <cuda.h>
  21. namespace megdnn {
  22. namespace test {
  23. #if CUDA_VERSION >= 10000
  24. TEST_F(CUDA, MATRIX_MUL_QUANTIZED4x4x32_EXCEPTION) {
  25. if (cuda::current_device_prop().major > 7 ||
  26. (cuda::current_device_prop().major == 7 &&
  27. cuda::current_device_prop().minor >= 5)) {
  28. printf("Skip CUDA.MATRIX_MUL_QUANTIZED4x4x32_EXCEPTION test as current "
  29. "device support wmma intrinsics\n");
  30. return;
  31. }
  32. Checker<MatrixMul> checker(handle_cuda(), false);
  33. using Param = MatrixMul::Param;
  34. Param param;
  35. param.transposeB = true;
  36. checker.set_param(param);
  37. checker.set_dtype(0, dtype::Quantized4Asymm(1.3f, (uint8_t)3));
  38. checker.set_dtype(1, dtype::Quantized4Asymm(1.3f, (uint8_t)3));
  39. checker.set_dtype(2, dtype::QuantizedS32(1.3f * 1.3f));
  40. ASSERT_THROW(checker.exec({{256, 256}, {256, 256}, {256, 256}}), MegDNNError);
  41. }
  42. TEST_F(CUDA, MATRIX_MUL_QUANTIZED4x4x32) {
  43. if (cuda::current_device_prop().major < 7 ||
  44. (cuda::current_device_prop().major == 7 &&
  45. cuda::current_device_prop().minor < 5)) {
  46. printf("Skip CUDA.MATRIX_MUL_QUANTIZED4x4x32 test as current device "
  47. "doesn't support\n");
  48. return;
  49. }
  50. Checker<MatrixMul> checker(handle_cuda(), false);
  51. using Param = MatrixMul::Param;
  52. Param param;
  53. param.transposeB = true;
  54. checker.set_param(param);
  55. checker.set_dtype(0, dtype::Quantized4Asymm(1.3f, (uint8_t)3));
  56. checker.set_dtype(1, dtype::Quantized4Asymm(1.3f, (uint8_t)3));
  57. checker.set_dtype(2, dtype::QuantizedS32(1.3f * 1.3f));
  58. checker.exec({{256, 256}, {256, 256}, {256, 256}});
  59. auto args = matrix_mul::get_matmul_args();
  60. for (auto arg : args) {
  61. size_t m = DIVUP(arg.m, 8) * 8, n = DIVUP(arg.n, 8) * 8,
  62. k = DIVUP(arg.k, 32) * 32;
  63. checker.exec({{m, k}, {n, k}, {m, n}});
  64. }
  65. }
  66. #if MEGDNN_WITH_BENCHMARK
  67. TEST_F(CUDA, BENCHMARK_MATRIX_MUL_QUANTIZED4x4x32) {
  68. if (cuda::current_device_prop().major < 7 ||
  69. (cuda::current_device_prop().major == 7 &&
  70. cuda::current_device_prop().minor < 5)) {
  71. printf("Skip CUDA.BENCHMARK_MATRIX_MUL_QUANTIZED4x4x32 test as current "
  72. "device doesn't support\n");
  73. return;
  74. }
  75. Benchmarker<MatrixMul> bencher(handle_cuda());
  76. using Param = MatrixMul::Param;
  77. Param param;
  78. param.transposeB = true;
  79. bencher.set_param(param);
  80. bencher.set_dtype(0, dtype::Quantized4Asymm(1.0f, (uint8_t)3));
  81. bencher.set_dtype(1, dtype::Quantized4Asymm(1.0f, (uint8_t)3));
  82. bencher.set_dtype(2, dtype::QuantizedS32(1.0f));
  83. for (size_t m : {256, 1024, 4096, 10240, 40960}) {
  84. for (size_t n : {256, 1024, 4096}) {
  85. for (size_t k : {512, 1024, 2048}) {
  86. bencher.set_times(400);
  87. auto time_in_ms = bencher.exec({{m, k}, {n, k}, {m, n}}) / 400;
  88. auto gflps = 2.0 * m * k * n / (time_in_ms * 1e-3) * 1e-12;
  89. printf("m=%zu, k=%zu, n=%zu, time: %fms, perf: %f TFlops\n", m, k, n,
  90. time_in_ms, gflps);
  91. }
  92. }
  93. }
  94. }
  95. TEST_F(CUDA, PEAK_BENCHMARK_MATRIX_MUL_QUANTIZED4x4x32) {
  96. if (cuda::current_device_prop().major < 7 ||
  97. (cuda::current_device_prop().major == 7 &&
  98. cuda::current_device_prop().minor < 5)) {
  99. printf("Skip CUDA.PEAK_BENCHMARK_MATRIX_MUL_QUANTIZED4x4x32 test as "
  100. "current "
  101. "device doesn't support\n");
  102. return;
  103. }
  104. Benchmarker<MatrixMul> bencher(handle_cuda());
  105. using Param = MatrixMul::Param;
  106. Param param;
  107. param.transposeB = true;
  108. bencher.set_param(param);
  109. bencher.set_dtype(0, dtype::Quantized4Asymm(1.0f, (uint8_t)3));
  110. bencher.set_dtype(1, dtype::Quantized4Asymm(1.0f, (uint8_t)3));
  111. bencher.set_dtype(2, dtype::QuantizedS32(1.0f));
  112. bencher.set_times(400);
  113. size_t m = 4096, n = 4096, k = 81920;
  114. auto time_in_ms = bencher.exec({{m, k}, {n, k}, {m, n}}) / 400;
  115. auto tflps = 2.0 * m * k * n / (time_in_ms * 1e-3) * 1e-12;
  116. printf("m=%zu, k=%zu, n=%zu, time: %fms, perf: %f TFlops\n", m, k, n, time_in_ms,
  117. tflps);
  118. }
  119. #endif
  120. #endif
  121. TEST_F(CUDA, MATRIX_MUL_INT8x8x32_WITH_SPETIAL_STRIDES) {
  122. require_compute_capability(6, 1);
  123. Checker<MatrixMul> checker(handle_cuda());
  124. using Param = MatrixMul::Param;
  125. Param param;
  126. DType stype = dtype::Int8();
  127. checker.set_param(param)
  128. .set_dtype(0, stype)
  129. .set_dtype(1, stype)
  130. .set_dtype(2, dtype::Int32())
  131. .set_epsilon(5e-3);
  132. size_t m = 1024, n = 1024, k = 1024;
  133. {
  134. TensorLayout A{{m, k}, {2048, 1}, dtype::Int8()},
  135. B{{k, n}, {2048, 1}, dtype::Int8()}, C{{m, n}, dtype::Int32()};
  136. checker.execl({A, B, {}});
  137. }
  138. }
  139. TEST_F(CUDA, MATRIX_MUL_INT8x8x32_NAIVE) {
  140. require_compute_capability(6, 1);
  141. using Param = MatrixMul::Param;
  142. UniformIntRNG rng{-128, 127};
  143. Checker<MatrixMul> checker(handle_cuda());
  144. checker.set_rng(0, &rng).set_rng(1, &rng);
  145. size_t m = 1007, n = 1003, k = 129;
  146. for (unsigned mask = 0; mask < 4; ++mask) {
  147. Param param;
  148. param.transposeA = mask & 1;
  149. param.transposeB = mask & 2;
  150. TensorShape A, B;
  151. if (param.transposeA)
  152. A = TensorShape{k, m};
  153. else
  154. A = TensorShape{m, k};
  155. if (param.transposeB)
  156. B = TensorShape{n, k};
  157. else
  158. B = TensorShape{k, n};
  159. checker.set_param(param)
  160. .set_dtype(0, dtype::Int8())
  161. .set_dtype(1, dtype::Int8())
  162. .set_dtype(2, dtype::Int32())
  163. .set_epsilon(0)
  164. .execs({A, B, {}});
  165. }
  166. }
  167. TEST_F(CUDA, MATRIX_MUL_FLOAT_NAIVE) {
  168. Checker<MatrixMul> checker(handle_cuda());
  169. checker.set_before_exec_callback(AlgoChecker<MatrixMulForward>("NAIVE"));
  170. using Param = MatrixMul::Param;
  171. size_t m = 12, n = 16, k = 20;
  172. std::vector<DType> dtype_array;
  173. dtype_array.push_back(dtype::Float32());
  174. dtype_array.push_back(dtype::Float16());
  175. for (DType dtype : dtype_array) {
  176. for (unsigned mask = 0; mask < 4; ++mask) {
  177. Param param;
  178. param.transposeA = mask & 1;
  179. param.transposeB = mask & 2;
  180. DType stype = dtype;
  181. TensorShape A, B;
  182. if (param.transposeA)
  183. A = TensorShape{k, m};
  184. else
  185. A = TensorShape{m, k};
  186. if (param.transposeB)
  187. B = TensorShape{n, k};
  188. else
  189. B = TensorShape{k, n};
  190. if (dtype == dtype::Float16()) {
  191. param.compute_mode = param::MatrixMul::ComputeMode::FLOAT32;
  192. }
  193. checker.set_param(param)
  194. .set_dtype(0, stype)
  195. .set_dtype(1, stype)
  196. .set_dtype(2, dtype)
  197. .set_epsilon(dtype == dtype::Float16() ? 5e-2 : 5e-3)
  198. .execs({A, B, {}});
  199. }
  200. }
  201. }
  202. TEST_F(CUDA, MATRIX_MUL) {
  203. Checker<MatrixMul> checker(handle_cuda());
  204. using Param = MatrixMul::Param;
  205. size_t m = 12, n = 16, k = 20;
  206. bool is_int_available = check_compute_capability(6, 1);
  207. std::vector<DType> dtype_array;
  208. dtype_array.push_back(dtype::Float32());
  209. dtype_array.push_back(dtype::Float16());
  210. dtype_array.push_back(dtype::BFloat16());
  211. if (is_int_available)
  212. dtype_array.push_back(dtype::Int32());
  213. for (DType dtype : dtype_array) {
  214. for (unsigned mask = 0; mask < 4; ++mask) {
  215. Param param;
  216. param.transposeA = mask & 1;
  217. param.transposeB = mask & 2;
  218. DType stype = dtype == dtype::Int32() ? dtype::Int8() : dtype;
  219. TensorShape A, B;
  220. if (param.transposeA)
  221. A = TensorShape{k, m};
  222. else
  223. A = TensorShape{m, k};
  224. if (param.transposeB)
  225. B = TensorShape{n, k};
  226. else
  227. B = TensorShape{k, n};
  228. if (dtype == dtype::BFloat16()) {
  229. param.compute_mode = param::MatrixMul::ComputeMode::FLOAT32;
  230. checker.set_before_exec_callback(AlgoChecker<MatrixMulForward>(
  231. ExecutionPolicyAlgoName{"MATMUL_BFLOAT16", {{"CUBLAS", {}}}}));
  232. }
  233. checker.set_param(param)
  234. .set_dtype(0, stype)
  235. .set_dtype(1, stype)
  236. .set_dtype(2, dtype)
  237. .set_epsilon(
  238. dtype == dtype::Float16() || dtype == dtype::BFloat16()
  239. ? 5e-2
  240. : 5e-3)
  241. .execs({A, B, {}});
  242. if (dtype == dtype::BFloat16()) {
  243. checker.reset_before_exec_callback();
  244. checker.opr()->execution_policy() = {};
  245. }
  246. }
  247. }
  248. // general tests
  249. auto args = matrix_mul::get_matmul_args();
  250. for (auto arg : args) {
  251. auto m = arg.m, n = arg.n, k = arg.k;
  252. auto mask = arg.mask;
  253. Param param;
  254. param.transposeA = mask & 1;
  255. param.transposeB = mask & 2;
  256. TensorShape AS, BS, CS;
  257. if (param.transposeA)
  258. AS = TensorShape{k, m};
  259. else
  260. AS = TensorShape{m, k};
  261. if (param.transposeB)
  262. BS = TensorShape{n, k};
  263. else
  264. BS = TensorShape{k, n};
  265. CS = TensorShape{m, n};
  266. TensorLayout AL, BL, CL;
  267. if (arg.A_stride == matrix_mul::TestArg::UNSET_STRIDE_VAL) {
  268. AL = TensorLayout(AS, dtype::Float32());
  269. } else {
  270. AL = TensorLayout(AS, {ptrdiff_t(arg.A_stride), 1}, dtype::Float32());
  271. }
  272. if (arg.B_stride == matrix_mul::TestArg::UNSET_STRIDE_VAL) {
  273. BL = TensorLayout(BS, dtype::Float32());
  274. } else {
  275. BL = TensorLayout(BS, {ptrdiff_t(arg.B_stride), 1}, dtype::Float32());
  276. }
  277. if (arg.C_stride == matrix_mul::TestArg::UNSET_STRIDE_VAL) {
  278. CL = TensorLayout(CS, dtype::Float32());
  279. } else {
  280. CL = TensorLayout(CS, {ptrdiff_t(arg.C_stride), 1}, dtype::Float32());
  281. }
  282. checker.set_param(param).execl({AL, BL, CL});
  283. }
  284. }
  285. TEST_F(CUDA, MATRIX_MUL_CUBLASLT) {
  286. require_compute_capability(7, 5);
  287. NormalRNG normal_rng;
  288. Checker<MatrixMul> checker(handle_cuda());
  289. checker.set_rng(0, &normal_rng)
  290. .set_rng(1, &normal_rng)
  291. .set_before_exec_callback(AlgoChecker<MatrixMulForward>("CUBLAS_LT"));
  292. using Param = MatrixMul::Param;
  293. size_t m = 32, n = 32, k = 32;
  294. // test Int8 matmul
  295. {
  296. DType dtype = dtype::Int32();
  297. Param param;
  298. param.transposeA = false;
  299. param.transposeB = false;
  300. DType stype = dtype == dtype::Int32() ? dtype::Int8() : dtype;
  301. TensorShape A, B;
  302. A = TensorShape{m, k};
  303. B = TensorShape{k, n};
  304. checker.set_param(param)
  305. .set_dtype(0, stype)
  306. .set_dtype(1, stype)
  307. .set_dtype(2, dtype)
  308. .set_epsilon(dtype == dtype::Float16() ? 5e-2 : 5e-3)
  309. .execs({A, B, {}});
  310. }
  311. // test float-point matmul
  312. for (DType dtype : std::array<DType, 2>{{dtype::Float32(), dtype::Float16()}}) {
  313. for (unsigned mask = 0; mask < 4; ++mask) {
  314. Param param;
  315. param.transposeA = mask & 1;
  316. param.transposeB = mask & 2;
  317. DType stype = dtype == dtype::Int32() ? dtype::Int8() : dtype;
  318. TensorShape A, B;
  319. if (param.transposeA)
  320. A = TensorShape{k, m};
  321. else
  322. A = TensorShape{m, k};
  323. if (param.transposeB)
  324. B = TensorShape{n, k};
  325. else
  326. B = TensorShape{k, n};
  327. checker.set_param(param)
  328. .set_dtype(0, stype)
  329. .set_dtype(1, stype)
  330. .set_dtype(2, dtype)
  331. .set_epsilon(dtype == dtype::Float16() ? 5e-2 : 8e-3)
  332. .execs({A, B, {}});
  333. }
  334. }
  335. // general tests
  336. auto args = matrix_mul::get_matmul_args();
  337. for (auto arg : args) {
  338. auto m = arg.m, n = arg.n, k = arg.k;
  339. auto mask = arg.mask;
  340. Param param;
  341. param.transposeA = mask & 1;
  342. param.transposeB = mask & 2;
  343. TensorShape AS, BS, CS;
  344. if (param.transposeA)
  345. AS = TensorShape{k, m};
  346. else
  347. AS = TensorShape{m, k};
  348. if (param.transposeB)
  349. BS = TensorShape{n, k};
  350. else
  351. BS = TensorShape{k, n};
  352. CS = TensorShape{m, n};
  353. TensorLayout AL, BL, CL;
  354. if (arg.A_stride == matrix_mul::TestArg::UNSET_STRIDE_VAL) {
  355. AL = TensorLayout(AS, dtype::Float32());
  356. } else {
  357. AL = TensorLayout(AS, {ptrdiff_t(arg.A_stride), 1}, dtype::Float32());
  358. }
  359. if (arg.B_stride == matrix_mul::TestArg::UNSET_STRIDE_VAL) {
  360. BL = TensorLayout(BS, dtype::Float32());
  361. } else {
  362. BL = TensorLayout(BS, {ptrdiff_t(arg.B_stride), 1}, dtype::Float32());
  363. }
  364. if (arg.C_stride == matrix_mul::TestArg::UNSET_STRIDE_VAL) {
  365. CL = TensorLayout(CS, dtype::Float32());
  366. } else {
  367. CL = TensorLayout(CS, {ptrdiff_t(arg.C_stride), 1}, dtype::Float32());
  368. }
  369. checker.set_param(param).execl({AL, BL, CL});
  370. }
  371. }
  372. TEST_F(CUDA, MATRIX_MUL_CUBLASLT_SPECIAL_CASE) {
  373. require_compute_capability(7, 5);
  374. size_t m = 12, n = 16, k = 20;
  375. Checker<MatrixMul> checker(handle_cuda());
  376. checker.set_before_exec_callback(AlgoChecker<MatrixMulForward>("CUBLAS_LT"));
  377. using Param = MatrixMul::Param;
  378. Param param;
  379. DType stype = dtype::Float32();
  380. DType dtype = dtype::Float32();
  381. TensorShape A, B;
  382. param.transposeA = param.transposeB = 1;
  383. if (param.transposeA)
  384. A = TensorShape{k, m};
  385. else
  386. A = TensorShape{m, k};
  387. if (param.transposeB)
  388. B = TensorShape{n, k};
  389. else
  390. B = TensorShape{k, n};
  391. checker.set_param(param)
  392. .set_dtype(0, stype)
  393. .set_dtype(1, stype)
  394. .set_dtype(2, dtype)
  395. .set_epsilon(dtype == dtype::Float16() ? 5e-1 : 5e-2)
  396. .execs({A, B, {}});
  397. }
  398. TEST_F(CUDA, MATRIX_MUL_CUBLASLT_INT8) {
  399. require_compute_capability(7, 5);
  400. NormalRNG normal_rng;
  401. Checker<MatrixMul> checker(handle_cuda());
  402. checker.set_rng(0, &normal_rng)
  403. .set_rng(1, &normal_rng)
  404. .set_before_exec_callback(AlgoChecker<MatrixMulForward>("CUBLAS_LT"));
  405. using Param = MatrixMul::Param;
  406. // size_t m = 32, n = 32, k = 32;
  407. // test Int8 matmul
  408. for (size_t m = 8; m <= 64; m += 4)
  409. for (size_t n = 8; n <= 64; n += 4)
  410. for (size_t k = 8; k <= 64; k += 4) {
  411. DType dtype = dtype::Int32();
  412. Param param;
  413. param.transposeA = false;
  414. param.transposeB = false;
  415. DType stype = dtype == dtype::Int32() ? dtype::Int8() : dtype;
  416. TensorShape A, B;
  417. A = TensorShape{m, k};
  418. B = TensorShape{k, n};
  419. checker.set_param(param)
  420. .set_dtype(0, stype)
  421. .set_dtype(1, stype)
  422. .set_dtype(2, dtype)
  423. .set_epsilon(dtype == dtype::Float16() ? 5e-2 : 5e-3)
  424. .execs({A, B, {}});
  425. }
  426. }
  427. TEST_F(CUDA, MATRIX_MUL_CUBLASLT_F32) {
  428. require_compute_capability(7, 5);
  429. size_t m = 128, n = 1024, k = 18432;
  430. Checker<MatrixMul> checker(handle_cuda());
  431. checker.set_before_exec_callback(AlgoChecker<MatrixMulForward>("CUBLAS_LT"));
  432. using Param = MatrixMul::Param;
  433. Param param;
  434. DType stype = dtype::Float32();
  435. DType dtype = dtype::Float32();
  436. TensorShape A, B;
  437. param.transposeA = param.transposeB = 0;
  438. if (param.transposeA)
  439. A = TensorShape{k, m};
  440. else
  441. A = TensorShape{m, k};
  442. if (param.transposeB)
  443. B = TensorShape{n, k};
  444. else
  445. B = TensorShape{k, n};
  446. checker.set_param(param)
  447. .set_dtype(0, stype)
  448. .set_dtype(1, stype)
  449. .set_dtype(2, dtype)
  450. .execs({A, B, {}});
  451. }
  452. TEST_F(CUDA, MATRIX_MUL_CUDNN_F32_uncont) {
  453. Checker<MatrixMul> checker(handle_cuda());
  454. checker.set_before_exec_callback(AlgoChecker<MatrixMulForward>("MATMUL_CONV1X1"));
  455. using Param = MatrixMul::Param;
  456. size_t m = 100, n = 100, k = 100;
  457. Param param;
  458. param.transposeA = 1;
  459. param.transposeB = 1;
  460. TensorLayout A{{m, k}, {128, 1}, dtype::Float32()},
  461. B{{k, n}, {128, 1}, dtype::Float32()}, C{{m, n}, dtype::Float32()};
  462. DType stype = dtype::Float32();
  463. DType dtype = dtype::Float32();
  464. checker.set_param(param)
  465. .set_dtype(0, stype)
  466. .set_dtype(1, stype)
  467. .set_dtype(2, dtype)
  468. .execl({A, B, {}});
  469. }
  470. TEST_F(CUDA, MATRIX_MUL_CUDNN_F32) {
  471. Checker<MatrixMul> checker(handle_cuda());
  472. checker.set_before_exec_callback(AlgoChecker<MatrixMulForward>("MATMUL_CONV1X1"));
  473. using Param = MatrixMul::Param;
  474. for (size_t m = 8; m <= 64; m += 4) {
  475. for (size_t n = 8; n <= 64; n += 4) {
  476. for (size_t k = 8; k <= 64; k += 4) {
  477. for (unsigned mask = 0; mask < 4; ++mask) {
  478. Param param;
  479. param.transposeA = mask & 1;
  480. param.transposeB = mask & 2;
  481. DType stype = dtype::Float32();
  482. DType dtype = dtype::Float32();
  483. TensorShape A, B;
  484. if (param.transposeA)
  485. A = TensorShape{k, m};
  486. else
  487. A = TensorShape{m, k};
  488. if (param.transposeB)
  489. B = TensorShape{n, k};
  490. else
  491. B = TensorShape{k, n};
  492. checker.set_param(param)
  493. .set_dtype(0, stype)
  494. .set_dtype(1, stype)
  495. .set_dtype(2, dtype)
  496. .execs({A, B, {}});
  497. }
  498. }
  499. }
  500. }
  501. }
  502. TEST_F(CUDA, MATRIX_MUL_CUDNN_F16) {
  503. Checker<MatrixMul> checker(handle_cuda());
  504. checker.set_before_exec_callback(AlgoChecker<MatrixMulForward>("MATMUL_CONV1X1"));
  505. using Param = MatrixMul::Param;
  506. for (size_t m = 8; m <= 64; m += 4) {
  507. for (size_t n = 8; n <= 64; n += 4) {
  508. for (size_t k = 8; k <= 64; k += 4) {
  509. for (unsigned mask = 0; mask < 4; ++mask) {
  510. Param param;
  511. param.transposeA = mask & 1;
  512. param.transposeB = mask & 2;
  513. DType stype = dtype::Float16();
  514. DType dtype = dtype::Float16();
  515. TensorShape A, B;
  516. if (param.transposeA)
  517. A = TensorShape{k, m};
  518. else
  519. A = TensorShape{m, k};
  520. if (param.transposeB)
  521. B = TensorShape{n, k};
  522. else
  523. B = TensorShape{k, n};
  524. checker.set_param(param)
  525. .set_dtype(0, stype)
  526. .set_dtype(1, stype)
  527. .set_dtype(2, dtype)
  528. .set_epsilon(6e-2)
  529. .execs({A, B, {}});
  530. }
  531. }
  532. }
  533. }
  534. }
  535. } // namespace test
  536. } // namespace megdnn
  537. // vim: syntax=cpp.doxygen