You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

matrix_mul.cpp 18 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506
  1. /**
  2. * \file dnn/test/cuda/matrix_mul.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "test/cuda/fixture.h"
  12. #include "test/common/checker.h"
  13. #include "test/common/matrix_mul.h"
  14. #include "test/common/benchmarker.h"
  15. #include "src/cuda/utils.h"
  16. #if defined(cuda_check)
  17. #undef cuda_check
  18. #endif
  19. #include "test/cuda/utils.h"
  20. #include <cuda.h>
  21. namespace megdnn {
  22. namespace test {
  23. #if CUDA_VERSION >= 10000
  24. TEST_F(CUDA, MATRIX_MUL_QUANTIZED4x4x32_EXCEPTION) {
  25. if (cuda::current_device_prop().major > 7 ||
  26. (cuda::current_device_prop().major == 7 &&
  27. cuda::current_device_prop().minor >= 5)) {
  28. printf("Skip CUDA.MATRIX_MUL_QUANTIZED4x4x32_EXCEPTION test as current "
  29. "device support wmma intrinsics\n");
  30. return;
  31. }
  32. Checker<MatrixMul> checker(handle_cuda(), false);
  33. using Param = MatrixMul::Param;
  34. Param param;
  35. param.transposeB = true;
  36. checker.set_param(param);
  37. checker.set_dtype(0, dtype::Quantized4Asymm(1.3f, (uint8_t)3));
  38. checker.set_dtype(1, dtype::Quantized4Asymm(1.3f, (uint8_t)3));
  39. checker.set_dtype(2, dtype::QuantizedS32(1.3f * 1.3f));
  40. ASSERT_THROW(checker.exec({{256, 256}, {256, 256}, {256, 256}}),
  41. MegDNNError);
  42. }
  43. TEST_F(CUDA, MATRIX_MUL_QUANTIZED4x4x32) {
  44. if (cuda::current_device_prop().major < 7 ||
  45. (cuda::current_device_prop().major == 7 &&
  46. cuda::current_device_prop().minor < 5)) {
  47. printf("Skip CUDA.MATRIX_MUL_QUANTIZED4x4x32 test as current device "
  48. "doesn't support\n");
  49. return;
  50. }
  51. Checker<MatrixMul> checker(handle_cuda(), false);
  52. using Param = MatrixMul::Param;
  53. Param param;
  54. param.transposeB = true;
  55. checker.set_param(param);
  56. checker.set_dtype(0, dtype::Quantized4Asymm(1.3f, (uint8_t)3));
  57. checker.set_dtype(1, dtype::Quantized4Asymm(1.3f, (uint8_t)3));
  58. checker.set_dtype(2, dtype::QuantizedS32(1.3f*1.3f));
  59. checker.exec({{256, 256}, {256, 256}, {256, 256}});
  60. auto args = matrix_mul::get_matmul_args();
  61. for (auto arg : args) {
  62. size_t m = DIVUP(arg.m, 8) * 8, n = DIVUP(arg.n, 8) * 8,
  63. k = DIVUP(arg.k, 32) * 32;
  64. checker.exec({{m, k}, {n, k}, {m, n}});
  65. }
  66. }
  67. #if MEGDNN_WITH_BENCHMARK
  68. TEST_F(CUDA, BENCHMARK_MATRIX_MUL_QUANTIZED4x4x32) {
  69. if (cuda::current_device_prop().major < 7 ||
  70. (cuda::current_device_prop().major == 7 &&
  71. cuda::current_device_prop().minor < 5)) {
  72. printf("Skip CUDA.BENCHMARK_MATRIX_MUL_QUANTIZED4x4x32 test as current "
  73. "device doesn't support\n");
  74. return;
  75. }
  76. Benchmarker<MatrixMul> bencher(handle_cuda());
  77. using Param = MatrixMul::Param;
  78. Param param;
  79. param.transposeB = true;
  80. bencher.set_param(param);
  81. bencher.set_dtype(0, dtype::Quantized4Asymm(1.0f, (uint8_t)3));
  82. bencher.set_dtype(1, dtype::Quantized4Asymm(1.0f, (uint8_t)3));
  83. bencher.set_dtype(2, dtype::QuantizedS32(1.0f));
  84. for (size_t m : {256, 1024, 4096, 10240, 40960}) {
  85. for (size_t n : {256, 1024, 4096}) {
  86. for (size_t k :{512, 1024, 2048}) {
  87. bencher.set_times(400);
  88. auto time_in_ms = bencher.exec({{m, k}, {n, k}, {m, n}}) / 400;
  89. auto gflps = 2.0 * m * k * n / (time_in_ms * 1e-3) * 1e-12;
  90. printf("m=%zu, k=%zu, n=%zu, time: %fms, perf: %f TFlops\n",
  91. m, k, n, time_in_ms, gflps);
  92. }
  93. }
  94. }
  95. }
  96. TEST_F(CUDA, PEAK_BENCHMARK_MATRIX_MUL_QUANTIZED4x4x32) {
  97. if (cuda::current_device_prop().major < 7 ||
  98. (cuda::current_device_prop().major == 7 &&
  99. cuda::current_device_prop().minor < 5)) {
  100. printf("Skip CUDA.PEAK_BENCHMARK_MATRIX_MUL_QUANTIZED4x4x32 test as "
  101. "current "
  102. "device doesn't support\n");
  103. return;
  104. }
  105. Benchmarker<MatrixMul> bencher(handle_cuda());
  106. using Param = MatrixMul::Param;
  107. Param param;
  108. param.transposeB = true;
  109. bencher.set_param(param);
  110. bencher.set_dtype(0, dtype::Quantized4Asymm(1.0f, (uint8_t)3));
  111. bencher.set_dtype(1, dtype::Quantized4Asymm(1.0f, (uint8_t)3));
  112. bencher.set_dtype(2, dtype::QuantizedS32(1.0f));
  113. bencher.set_times(400);
  114. size_t m = 4096, n = 4096, k = 81920;
  115. auto time_in_ms = bencher.exec({{m, k}, {n, k}, {m, n}}) / 400;
  116. auto tflps = 2.0 * m * k * n / (time_in_ms * 1e-3) * 1e-12;
  117. printf("m=%zu, k=%zu, n=%zu, time: %fms, perf: %f TFlops\n", m, k, n,
  118. time_in_ms, tflps);
  119. }
  120. #endif
  121. #endif
  122. TEST_F(CUDA, MATRIX_MUL_INT8x8x32_WITH_SPETIAL_STRIDES) {
  123. if (!cuda::is_compute_capability_required(6, 1)) {
  124. printf("Skip CUDA.MATRIX_MUL test as current device doesn't support\n");
  125. return;
  126. }
  127. Checker<MatrixMul> checker(handle_cuda());
  128. using Param = MatrixMul::Param;
  129. Param param;
  130. DType stype = dtype::Int8();
  131. checker.set_param(param)
  132. .set_dtype(0, stype)
  133. .set_dtype(1, stype)
  134. .set_dtype(2, dtype::Int32())
  135. .set_epsilon(5e-3);
  136. size_t m = 1024, n = 1024, k = 1024;
  137. {
  138. TensorLayout A{{m, k}, {2048, 1}, dtype::Int8()},
  139. B{{k, n}, {2048, 1}, dtype::Int8()}, C{{m, n}, dtype::Int32()};
  140. checker.execl({A, B, {}});
  141. }
  142. }
  143. TEST_F(CUDA, MATRIX_MUL_INT8x8x32_NAIVE) {
  144. if (!cuda::is_compute_capability_required(6, 1)) {
  145. printf("Skip CUDA.MATRIX_MUL test as current device doesn't support\n");
  146. return;
  147. }
  148. using Param = MatrixMul::Param;
  149. UniformIntRNG rng{-128, 127};
  150. Checker<MatrixMul> checker(handle_cuda());
  151. checker.set_rng(0, &rng).set_rng(1, &rng);
  152. size_t m = 1007, n = 1003, k = 129;
  153. for (unsigned mask = 0; mask < 4; ++mask) {
  154. Param param;
  155. param.transposeA = mask & 1;
  156. param.transposeB = mask & 2;
  157. TensorShape A, B;
  158. if (param.transposeA)
  159. A = TensorShape{k, m};
  160. else
  161. A = TensorShape{m, k};
  162. if (param.transposeB)
  163. B = TensorShape{n, k};
  164. else
  165. B = TensorShape{k, n};
  166. checker.set_param(param)
  167. .set_dtype(0, dtype::Int8())
  168. .set_dtype(1, dtype::Int8())
  169. .set_dtype(2, dtype::Int32())
  170. .set_epsilon(0)
  171. .execs({A, B, {}});
  172. }
  173. }
  174. TEST_F(CUDA, MATRIX_MUL_FLOAT_NAIVE) {
  175. Checker<MatrixMul> checker(handle_cuda());
  176. checker.set_before_exec_callback(AlgoChecker<MatrixMulForward>("NAIVE"));
  177. using Param = MatrixMul::Param;
  178. size_t m = 12, n = 16, k = 20;
  179. std::vector<DType> dtype_array;
  180. dtype_array.push_back(dtype::Float32());
  181. dtype_array.push_back(dtype::Float16());
  182. for (DType dtype : dtype_array) {
  183. for (unsigned mask = 0; mask < 4; ++mask) {
  184. Param param;
  185. param.transposeA = mask & 1;
  186. param.transposeB = mask & 2;
  187. DType stype = dtype;
  188. TensorShape A, B;
  189. if (param.transposeA)
  190. A = TensorShape{k, m};
  191. else
  192. A = TensorShape{m, k};
  193. if (param.transposeB)
  194. B = TensorShape{n, k};
  195. else
  196. B = TensorShape{k, n};
  197. if (dtype == dtype::Float16()) {
  198. param.compute_mode = param::MatrixMul::ComputeMode::FLOAT32;
  199. }
  200. checker.set_param(param)
  201. .set_dtype(0, stype)
  202. .set_dtype(1, stype)
  203. .set_dtype(2, dtype)
  204. .set_epsilon(dtype == dtype::Float16()
  205. ? 5e-2
  206. : 5e-3)
  207. .execs({A, B, {}});
  208. }
  209. }
  210. }
  211. TEST_F(CUDA, MATRIX_MUL) {
  212. if (cuda::current_device_prop().major < 6) {
  213. printf("Skip CUDA.MATRIX_MUL test as current device doesn't support\n");
  214. return;
  215. }
  216. Checker<MatrixMul> checker(handle_cuda());
  217. using Param = MatrixMul::Param;
  218. size_t m = 12, n = 16, k = 20;
  219. bool is_int_available = cuda::is_compute_capability_required(6, 1);
  220. std::vector<DType> dtype_array;
  221. dtype_array.push_back(dtype::Float32());
  222. dtype_array.push_back(dtype::Float16());
  223. dtype_array.push_back(dtype::BFloat16());
  224. if (is_int_available)
  225. dtype_array.push_back(dtype::Int32());
  226. for (DType dtype : dtype_array) {
  227. for (unsigned mask = 0; mask < 4; ++mask) {
  228. Param param;
  229. param.transposeA = mask & 1;
  230. param.transposeB = mask & 2;
  231. DType stype = dtype == dtype::Int32() ? dtype::Int8() : dtype;
  232. TensorShape A, B;
  233. if (param.transposeA)
  234. A = TensorShape{k, m};
  235. else
  236. A = TensorShape{m, k};
  237. if (param.transposeB)
  238. B = TensorShape{n, k};
  239. else
  240. B = TensorShape{k, n};
  241. if (dtype == dtype::BFloat16()) {
  242. param.compute_mode = param::MatrixMul::ComputeMode::FLOAT32;
  243. checker.set_before_exec_callback(
  244. AlgoChecker<MatrixMulForward>(ExecutionPolicyAlgoName{
  245. "MATMUL_BFLOAT16", {{"CUBLAS", {}}}}));
  246. }
  247. checker.set_param(param)
  248. .set_dtype(0, stype)
  249. .set_dtype(1, stype)
  250. .set_dtype(2, dtype)
  251. .set_epsilon(dtype == dtype::Float16() ||
  252. dtype == dtype::BFloat16()
  253. ? 5e-2
  254. : 5e-3)
  255. .execs({A, B, {}});
  256. if (dtype == dtype::BFloat16()) {
  257. checker.reset_before_exec_callback();
  258. checker.opr()->execution_policy() = {};
  259. }
  260. }
  261. }
  262. // general tests
  263. auto args = matrix_mul::get_matmul_args();
  264. for (auto arg: args) {
  265. auto m = arg.m, n = arg.n, k = arg.k;
  266. auto mask = arg.mask;
  267. Param param;
  268. param.transposeA = mask & 1;
  269. param.transposeB = mask & 2;
  270. TensorShape AS, BS, CS;
  271. if (param.transposeA)
  272. AS = TensorShape{k, m};
  273. else
  274. AS = TensorShape{m, k};
  275. if (param.transposeB)
  276. BS = TensorShape{n, k};
  277. else
  278. BS = TensorShape{k, n};
  279. CS = TensorShape{m, n};
  280. TensorLayout AL, BL, CL;
  281. if (arg.A_stride == matrix_mul::TestArg::UNSET_STRIDE_VAL) {
  282. AL = TensorLayout(AS, dtype::Float32());
  283. } else {
  284. AL = TensorLayout(AS, {ptrdiff_t(arg.A_stride), 1},
  285. dtype::Float32());
  286. }
  287. if (arg.B_stride == matrix_mul::TestArg::UNSET_STRIDE_VAL) {
  288. BL = TensorLayout(BS, dtype::Float32());
  289. } else {
  290. BL = TensorLayout(BS, {ptrdiff_t(arg.B_stride), 1},
  291. dtype::Float32());
  292. }
  293. if (arg.C_stride == matrix_mul::TestArg::UNSET_STRIDE_VAL) {
  294. CL = TensorLayout(CS, dtype::Float32());
  295. } else {
  296. CL = TensorLayout(CS, {ptrdiff_t(arg.C_stride), 1},
  297. dtype::Float32());
  298. }
  299. checker.set_param(param).execl({AL, BL, CL});
  300. }
  301. }
  302. TEST_F(CUDA, MATRIX_MUL_CUBLASLT)
  303. {
  304. require_compute_capability(7, 5);
  305. NormalRNG normal_rng;
  306. Checker<MatrixMul> checker(handle_cuda());
  307. checker.set_rng(0, &normal_rng)
  308. .set_rng(1, &normal_rng)
  309. .set_before_exec_callback(
  310. AlgoChecker<MatrixMulForward>("CUBLAS_LT"));
  311. using Param = MatrixMul::Param;
  312. size_t m = 32, n = 32, k = 32;
  313. // test Int8 matmul
  314. {
  315. DType dtype=dtype::Int32();
  316. Param param;
  317. param.transposeA = false;
  318. param.transposeB = false;
  319. DType stype = dtype == dtype::Int32() ? dtype::Int8() : dtype;
  320. TensorShape A, B;
  321. A = TensorShape{m, k};
  322. B = TensorShape{k, n};
  323. checker.set_param(param).
  324. set_dtype(0, stype).
  325. set_dtype(1, stype).
  326. set_dtype(2, dtype).
  327. set_epsilon(dtype == dtype::Float16() ? 5e-2 : 5e-3).
  328. execs({A, B, {}});
  329. }
  330. // test float-point matmul
  331. for (DType dtype: std::array<DType, 2>{
  332. {dtype::Float32(), dtype::Float16()}}) {
  333. for (unsigned mask = 0; mask < 4; ++mask) {
  334. Param param;
  335. param.transposeA = mask & 1;
  336. param.transposeB = mask & 2;
  337. DType stype = dtype == dtype::Int32() ? dtype::Int8() : dtype;
  338. TensorShape A, B;
  339. if (param.transposeA)
  340. A = TensorShape{k, m};
  341. else
  342. A = TensorShape{m, k};
  343. if (param.transposeB)
  344. B = TensorShape{n, k};
  345. else
  346. B = TensorShape{k, n};
  347. checker.set_param(param).
  348. set_dtype(0, stype).
  349. set_dtype(1, stype).
  350. set_dtype(2, dtype).
  351. set_epsilon(dtype == dtype::Float16() ? 5e-2 : 8e-3).
  352. execs({A, B, {}});
  353. }
  354. }
  355. // general tests
  356. auto args = matrix_mul::get_matmul_args();
  357. for (auto arg: args) {
  358. auto m = arg.m, n = arg.n, k = arg.k;
  359. auto mask = arg.mask;
  360. Param param;
  361. param.transposeA = mask & 1;
  362. param.transposeB = mask & 2;
  363. TensorShape AS, BS, CS;
  364. if (param.transposeA)
  365. AS = TensorShape{k, m};
  366. else
  367. AS = TensorShape{m, k};
  368. if (param.transposeB)
  369. BS = TensorShape{n, k};
  370. else
  371. BS = TensorShape{k, n};
  372. CS = TensorShape{m, n};
  373. TensorLayout AL, BL, CL;
  374. if (arg.A_stride == matrix_mul::TestArg::UNSET_STRIDE_VAL) {
  375. AL = TensorLayout(AS, dtype::Float32());
  376. } else {
  377. AL = TensorLayout(AS, {ptrdiff_t(arg.A_stride), 1},
  378. dtype::Float32());
  379. }
  380. if (arg.B_stride == matrix_mul::TestArg::UNSET_STRIDE_VAL) {
  381. BL = TensorLayout(BS, dtype::Float32());
  382. } else {
  383. BL = TensorLayout(BS, {ptrdiff_t(arg.B_stride), 1},
  384. dtype::Float32());
  385. }
  386. if (arg.C_stride == matrix_mul::TestArg::UNSET_STRIDE_VAL) {
  387. CL = TensorLayout(CS, dtype::Float32());
  388. } else {
  389. CL = TensorLayout(CS, {ptrdiff_t(arg.C_stride), 1},
  390. dtype::Float32());
  391. }
  392. checker.set_param(param).execl({AL, BL, CL});
  393. }
  394. }
  395. TEST_F(CUDA, MATRIX_MUL_CUBLASLT_SPECIAL_CASE) {
  396. require_compute_capability(7, 5);
  397. size_t m = 12, n = 16, k = 20;
  398. Checker<MatrixMul> checker(handle_cuda());
  399. checker.set_before_exec_callback(
  400. AlgoChecker<MatrixMulForward>("CUBLAS_LT"));
  401. using Param = MatrixMul::Param;
  402. Param param;
  403. DType stype = dtype::Float32();
  404. DType dtype = dtype::Float32();
  405. TensorShape A, B;
  406. param.transposeA=param.transposeB=1;
  407. if (param.transposeA)
  408. A = TensorShape{k, m};
  409. else
  410. A = TensorShape{m, k};
  411. if (param.transposeB)
  412. B = TensorShape{n, k};
  413. else
  414. B = TensorShape{k, n};
  415. checker.set_param(param).
  416. set_dtype(0, stype).
  417. set_dtype(1, stype).
  418. set_dtype(2, dtype).
  419. set_epsilon(dtype == dtype::Float16() ? 5e-1 : 5e-2).
  420. execs({A, B, {}});
  421. }
  422. TEST_F(CUDA, MATRIX_MUL_CUBLASLT_INT8) {
  423. require_compute_capability(7, 5);
  424. NormalRNG normal_rng;
  425. Checker<MatrixMul> checker(handle_cuda());
  426. checker.set_rng(0, &normal_rng)
  427. .set_rng(1, &normal_rng)
  428. .set_before_exec_callback(AlgoChecker<MatrixMulForward>("CUBLAS_LT"));
  429. using Param = MatrixMul::Param;
  430. //size_t m = 32, n = 32, k = 32;
  431. // test Int8 matmul
  432. for (size_t m=8; m<=64; m+=4)
  433. for (size_t n=8; n<=64; n+=4)
  434. for (size_t k=8; k<=64; k+=4)
  435. {
  436. DType dtype=dtype::Int32();
  437. Param param;
  438. param.transposeA = false;
  439. param.transposeB = false;
  440. DType stype = dtype == dtype::Int32() ? dtype::Int8() : dtype;
  441. TensorShape A, B;
  442. A = TensorShape{m, k};
  443. B = TensorShape{k, n};
  444. checker.set_param(param).
  445. set_dtype(0, stype).
  446. set_dtype(1, stype).
  447. set_dtype(2, dtype).
  448. set_epsilon(dtype == dtype::Float16() ? 5e-2 : 5e-3).
  449. execs({A, B, {}});
  450. }
  451. }
  452. TEST_F(CUDA, MATRIX_MUL_CUBLASLT_F32) {
  453. require_compute_capability(7, 5);
  454. size_t m = 128, n = 1024, k = 18432;
  455. Checker<MatrixMul> checker(handle_cuda());
  456. checker.set_before_exec_callback(
  457. AlgoChecker<MatrixMulForward>("CUBLAS_LT"));
  458. using Param = MatrixMul::Param;
  459. Param param;
  460. DType stype = dtype::Float32();
  461. DType dtype = dtype::Float32();
  462. TensorShape A, B;
  463. param.transposeA = param.transposeB = 0;
  464. if (param.transposeA)
  465. A = TensorShape{k, m};
  466. else
  467. A = TensorShape{m, k};
  468. if (param.transposeB)
  469. B = TensorShape{n, k};
  470. else
  471. B = TensorShape{k, n};
  472. checker.set_param(param)
  473. .set_dtype(0, stype)
  474. .set_dtype(1, stype)
  475. .set_dtype(2, dtype)
  476. .execs({A, B, {}});
  477. }
  478. } // namespace test
  479. } // namespace megdnn
  480. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台