You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

matrix_mul.cpp 15 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437
  1. /**
  2. * \file dnn/test/cuda/matrix_mul.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "test/cuda/fixture.h"
  12. #include "test/common/checker.h"
  13. #include "test/common/matrix_mul.h"
  14. #include "test/common/benchmarker.h"
  15. #include "src/cuda/utils.h"
  16. #if defined(cuda_check)
  17. #undef cuda_check
  18. #endif
  19. #include "test/cuda/utils.h"
  20. #include <cuda.h>
  21. namespace megdnn {
  22. namespace test {
  23. #if CUDA_VERSION >= 10000
  24. TEST_F(CUDA, MATRIX_MUL_QUANTIZED4x4x32_EXCEPTION) {
  25. if (cuda::current_device_prop().major > 7 ||
  26. (cuda::current_device_prop().major == 7 &&
  27. cuda::current_device_prop().minor >= 5)) {
  28. printf("Skip CUDA.MATRIX_MUL_QUANTIZED4x4x32_EXCEPTION test as current "
  29. "device support wmma intrinsics\n");
  30. return;
  31. }
  32. Checker<MatrixMul> checker(handle_cuda(), false);
  33. using Param = MatrixMul::Param;
  34. Param param;
  35. param.transposeB = true;
  36. checker.set_param(param);
  37. checker.set_dtype(0, dtype::Quantized4Asymm(1.3f, (uint8_t)3));
  38. checker.set_dtype(1, dtype::Quantized4Asymm(1.3f, (uint8_t)3));
  39. checker.set_dtype(2, dtype::QuantizedS32(1.3f * 1.3f));
  40. ASSERT_THROW(checker.exec({{256, 256}, {256, 256}, {256, 256}}),
  41. MegDNNError);
  42. }
  43. TEST_F(CUDA, MATRIX_MUL_QUANTIZED4x4x32) {
  44. if (cuda::current_device_prop().major < 7 ||
  45. (cuda::current_device_prop().major == 7 &&
  46. cuda::current_device_prop().minor < 5)) {
  47. printf("Skip CUDA.MATRIX_MUL_QUANTIZED4x4x32 test as current device doesn't support\n");
  48. return;
  49. }
  50. Checker<MatrixMul> checker(handle_cuda(), false);
  51. using Param = MatrixMul::Param;
  52. Param param;
  53. param.transposeB = true;
  54. checker.set_param(param);
  55. checker.set_dtype(0, dtype::Quantized4Asymm(1.3f, (uint8_t)3));
  56. checker.set_dtype(1, dtype::Quantized4Asymm(1.3f, (uint8_t)3));
  57. checker.set_dtype(2, dtype::QuantizedS32(1.3f*1.3f));
  58. checker.exec({{256, 256}, {256, 256}, {256, 256}});
  59. auto args = matrix_mul::get_matmul_args();
  60. for (auto arg : args) {
  61. size_t m = DIVUP(arg.m, 8) * 8, n = DIVUP(arg.n, 8) * 8,
  62. k = DIVUP(arg.k, 32) * 32;
  63. checker.exec({{m, k}, {n, k}, {m, n}});
  64. }
  65. }
  66. #if MEGDNN_WITH_BENCHMARK
  67. TEST_F(CUDA, BENCHMARK_MATRIX_MUL_QUANTIZED4x4x32) {
  68. if (cuda::current_device_prop().major < 7 ||
  69. (cuda::current_device_prop().major == 7 &&
  70. cuda::current_device_prop().minor < 5)) {
  71. printf("Skip CUDA.BENCHMARK_MATRIX_MUL_QUANTIZED4x4x32 test as current "
  72. "device doesn't support\n");
  73. return;
  74. }
  75. Benchmarker<MatrixMul> bencher(handle_cuda());
  76. using Param = MatrixMul::Param;
  77. Param param;
  78. param.transposeB = true;
  79. bencher.set_param(param);
  80. bencher.set_dtype(0, dtype::Quantized4Asymm(1.0f, (uint8_t)3));
  81. bencher.set_dtype(1, dtype::Quantized4Asymm(1.0f, (uint8_t)3));
  82. bencher.set_dtype(2, dtype::QuantizedS32(1.0f));
  83. for (size_t m : {256, 1024, 4096, 10240, 40960}) {
  84. for (size_t n : {256, 1024, 4096}) {
  85. for (size_t k :{512, 1024, 2048}) {
  86. bencher.set_times(400);
  87. auto time_in_ms = bencher.exec({{m, k}, {n, k}, {m, n}}) / 400;
  88. auto gflps = 2.0 * m * k * n / (time_in_ms * 1e-3) * 1e-12;
  89. printf("m=%zu, k=%zu, n=%zu, time: %fms, perf: %f TFlops\n",
  90. m, k, n, time_in_ms, gflps);
  91. }
  92. }
  93. }
  94. }
  95. TEST_F(CUDA, PEAK_BENCHMARK_MATRIX_MUL_QUANTIZED4x4x32) {
  96. if (cuda::current_device_prop().major < 7 ||
  97. (cuda::current_device_prop().major == 7 &&
  98. cuda::current_device_prop().minor < 5)) {
  99. printf("Skip CUDA.PEAK_BENCHMARK_MATRIX_MUL_QUANTIZED4x4x32 test as "
  100. "current "
  101. "device doesn't support\n");
  102. return;
  103. }
  104. Benchmarker<MatrixMul> bencher(handle_cuda());
  105. using Param = MatrixMul::Param;
  106. Param param;
  107. param.transposeB = true;
  108. bencher.set_param(param);
  109. bencher.set_dtype(0, dtype::Quantized4Asymm(1.0f, (uint8_t)3));
  110. bencher.set_dtype(1, dtype::Quantized4Asymm(1.0f, (uint8_t)3));
  111. bencher.set_dtype(2, dtype::QuantizedS32(1.0f));
  112. bencher.set_times(400);
  113. size_t m = 4096, n = 4096, k = 81920;
  114. auto time_in_ms = bencher.exec({{m, k}, {n, k}, {m, n}}) / 400;
  115. auto tflps = 2.0 * m * k * n / (time_in_ms * 1e-3) * 1e-12;
  116. printf("m=%zu, k=%zu, n=%zu, time: %fms, perf: %f TFlops\n", m, k, n,
  117. time_in_ms, tflps);
  118. }
  119. #endif
  120. #endif
  121. TEST_F(CUDA, MATRIX_MUL_INT8x8x32_WITH_SPETIAL_STRIDES) {
  122. if (!cuda::is_compute_capability_required(6, 1)) {
  123. printf("Skip CUDA.MATRIX_MUL test as current device doesn't support\n");
  124. return;
  125. }
  126. Checker<MatrixMul> checker(handle_cuda());
  127. using Param = MatrixMul::Param;
  128. Param param;
  129. DType stype = dtype::Int8();
  130. checker.set_param(param)
  131. .set_dtype(0, stype)
  132. .set_dtype(1, stype)
  133. .set_dtype(2, dtype::Int32())
  134. .set_epsilon(5e-3);
  135. size_t m = 1024, n = 1024, k = 1024;
  136. {
  137. TensorLayout A{{m, k}, {2048, 1}, dtype::Int8()},
  138. B{{k, n}, {2048, 1}, dtype::Int8()}, C{{m, n}, dtype::Int32()};
  139. checker.execl({A, B, {}});
  140. }
  141. }
  142. TEST_F(CUDA, MATRIX_MUL_INT8x8x32_NAIVE) {
  143. if (!cuda::is_compute_capability_required(6, 1)) {
  144. printf("Skip CUDA.MATRIX_MUL test as current device doesn't support\n");
  145. return;
  146. }
  147. using Param = MatrixMul::Param;
  148. UniformIntRNG rng{-128, 127};
  149. Checker<MatrixMul> checker(handle_cuda());
  150. checker.set_rng(0, &rng).set_rng(1, &rng);
  151. size_t m = 1007, n = 1003, k = 129;
  152. for (unsigned mask = 0; mask < 4; ++mask) {
  153. Param param;
  154. param.transposeA = mask & 1;
  155. param.transposeB = mask & 2;
  156. TensorShape A, B;
  157. if (param.transposeA)
  158. A = TensorShape{k, m};
  159. else
  160. A = TensorShape{m, k};
  161. if (param.transposeB)
  162. B = TensorShape{n, k};
  163. else
  164. B = TensorShape{k, n};
  165. checker.set_param(param)
  166. .set_dtype(0, dtype::Int8())
  167. .set_dtype(1, dtype::Int8())
  168. .set_dtype(2, dtype::Int32())
  169. .set_epsilon(0)
  170. .execs({A, B, {}});
  171. }
  172. }
  173. TEST_F(CUDA, MATRIX_MUL) {
  174. if (cuda::current_device_prop().major < 6) {
  175. printf("Skip CUDA.MATRIX_MUL test as current device doesn't support\n");
  176. return;
  177. }
  178. Checker<MatrixMul> checker(handle_cuda());
  179. using Param = MatrixMul::Param;
  180. size_t m = 12, n = 16, k = 20;
  181. bool is_int_available = cuda::is_compute_capability_required(6, 1);
  182. std::vector<DType> dtype_array;
  183. dtype_array.push_back(dtype::Float32());
  184. dtype_array.push_back(dtype::Float16());
  185. dtype_array.push_back(dtype::BFloat16());
  186. if (is_int_available)
  187. dtype_array.push_back(dtype::Int32());
  188. for (DType dtype : dtype_array) {
  189. for (unsigned mask = 0; mask < 4; ++mask) {
  190. Param param;
  191. param.transposeA = mask & 1;
  192. param.transposeB = mask & 2;
  193. DType stype = dtype == dtype::Int32() ? dtype::Int8() : dtype;
  194. TensorShape A, B;
  195. if (param.transposeA)
  196. A = TensorShape{k, m};
  197. else
  198. A = TensorShape{m, k};
  199. if (param.transposeB)
  200. B = TensorShape{n, k};
  201. else
  202. B = TensorShape{k, n};
  203. if (dtype == dtype::BFloat16()) {
  204. param.compute_mode = param::MatrixMul::ComputeMode::FLOAT32;
  205. checker.set_before_exec_callback(
  206. AlgoChecker<MatrixMulForward>(ExecutionPolicyAlgoName{
  207. "MATMUL_BFLOAT16", {{"CUBLAS", {}}}}));
  208. }
  209. checker.set_param(param)
  210. .set_dtype(0, stype)
  211. .set_dtype(1, stype)
  212. .set_dtype(2, dtype)
  213. .set_epsilon(dtype == dtype::Float16() ||
  214. dtype == dtype::BFloat16()
  215. ? 5e-2
  216. : 5e-3)
  217. .execs({A, B, {}});
  218. if (dtype == dtype::BFloat16()) {
  219. checker.reset_before_exec_callback();
  220. checker.opr()->execution_policy() = {};
  221. }
  222. }
  223. }
  224. // general tests
  225. auto args = matrix_mul::get_matmul_args();
  226. for (auto arg: args) {
  227. auto m = arg.m, n = arg.n, k = arg.k;
  228. auto mask = arg.mask;
  229. Param param;
  230. param.transposeA = mask & 1;
  231. param.transposeB = mask & 2;
  232. TensorShape AS, BS, CS;
  233. if (param.transposeA)
  234. AS = TensorShape{k, m};
  235. else
  236. AS = TensorShape{m, k};
  237. if (param.transposeB)
  238. BS = TensorShape{n, k};
  239. else
  240. BS = TensorShape{k, n};
  241. CS = TensorShape{m, n};
  242. TensorLayout AL, BL, CL;
  243. if (arg.A_stride == 0) {
  244. AL = TensorLayout(AS, dtype::Float32());
  245. } else {
  246. AL = TensorLayout(AS, {ptrdiff_t(arg.A_stride), 1},
  247. dtype::Float32());
  248. }
  249. if (arg.B_stride == 0) {
  250. BL = TensorLayout(BS, dtype::Float32());
  251. } else {
  252. BL = TensorLayout(BS, {ptrdiff_t(arg.B_stride), 1},
  253. dtype::Float32());
  254. }
  255. if (arg.C_stride == 0) {
  256. CL = TensorLayout(CS, dtype::Float32());
  257. } else {
  258. CL = TensorLayout(CS, {ptrdiff_t(arg.C_stride), 1},
  259. dtype::Float32());
  260. }
  261. checker.set_param(param).execl({AL, BL, CL});
  262. }
  263. }
  264. TEST_F(CUDA, MATRIX_MUL_CUBLASLT)
  265. {
  266. require_compute_capability(7, 5);
  267. NormalRNG normal_rng;
  268. Checker<MatrixMul> checker(handle_cuda());
  269. checker.set_rng(0, &normal_rng)
  270. .set_rng(1, &normal_rng)
  271. .set_before_exec_callback(AlgoChecker<MatrixMulForward>("CUBLAS_LT"));
  272. using Param = MatrixMul::Param;
  273. size_t m = 32, n = 32, k = 32;
  274. // test Int8 matmul
  275. {
  276. DType dtype=dtype::Int32();
  277. Param param;
  278. param.transposeA = false;
  279. param.transposeB = false;
  280. DType stype = dtype == dtype::Int32() ? dtype::Int8() : dtype;
  281. TensorShape A, B;
  282. A = TensorShape{m, k};
  283. B = TensorShape{k, n};
  284. checker.set_param(param).
  285. set_dtype(0, stype).
  286. set_dtype(1, stype).
  287. set_dtype(2, dtype).
  288. set_epsilon(dtype == dtype::Float16() ? 5e-2 : 5e-3).
  289. execs({A, B, {}});
  290. }
  291. // test float-point matmul
  292. for (DType dtype: std::array<DType, 2>{
  293. {dtype::Float32(), dtype::Float16()}}) {
  294. for (unsigned mask = 0; mask < 4; ++mask) {
  295. Param param;
  296. param.transposeA = mask & 1;
  297. param.transposeB = mask & 2;
  298. DType stype = dtype == dtype::Int32() ? dtype::Int8() : dtype;
  299. TensorShape A, B;
  300. if (param.transposeA)
  301. A = TensorShape{k, m};
  302. else
  303. A = TensorShape{m, k};
  304. if (param.transposeB)
  305. B = TensorShape{n, k};
  306. else
  307. B = TensorShape{k, n};
  308. checker.set_param(param).
  309. set_dtype(0, stype).
  310. set_dtype(1, stype).
  311. set_dtype(2, dtype).
  312. set_epsilon(dtype == dtype::Float16() ? 5e-2 : 8e-3).
  313. execs({A, B, {}});
  314. }
  315. }
  316. // general tests
  317. auto args = matrix_mul::get_matmul_args();
  318. for (auto arg: args) {
  319. auto m = arg.m, n = arg.n, k = arg.k;
  320. auto mask = arg.mask;
  321. Param param;
  322. param.transposeA = mask & 1;
  323. param.transposeB = mask & 2;
  324. TensorShape AS, BS, CS;
  325. if (param.transposeA)
  326. AS = TensorShape{k, m};
  327. else
  328. AS = TensorShape{m, k};
  329. if (param.transposeB)
  330. BS = TensorShape{n, k};
  331. else
  332. BS = TensorShape{k, n};
  333. CS = TensorShape{m, n};
  334. TensorLayout AL, BL, CL;
  335. if (arg.A_stride == 0) {
  336. AL = TensorLayout(AS, dtype::Float32());
  337. } else {
  338. AL = TensorLayout(AS, {ptrdiff_t(arg.A_stride), 1},
  339. dtype::Float32());
  340. }
  341. if (arg.B_stride == 0) {
  342. BL = TensorLayout(BS, dtype::Float32());
  343. } else {
  344. BL = TensorLayout(BS, {ptrdiff_t(arg.B_stride), 1},
  345. dtype::Float32());
  346. }
  347. if (arg.C_stride == 0) {
  348. CL = TensorLayout(CS, dtype::Float32());
  349. } else {
  350. CL = TensorLayout(CS, {ptrdiff_t(arg.C_stride), 1},
  351. dtype::Float32());
  352. }
  353. checker.set_param(param).execl({AL, BL, CL});
  354. }
  355. }
  356. TEST_F(CUDA, MATRIX_MUL_CUBLASLT_SPECIAL_CASE) {
  357. require_compute_capability(7, 5);
  358. size_t m = 12, n = 16, k = 20;
  359. Checker<MatrixMul> checker(handle_cuda());
  360. checker.set_before_exec_callback(
  361. AlgoChecker<MatrixMulForward>("CUBLAS_LT"));
  362. using Param = MatrixMul::Param;
  363. Param param;
  364. DType stype = dtype::Float32();
  365. DType dtype = dtype::Float32();
  366. TensorShape A, B;
  367. param.transposeA=param.transposeB=1;
  368. if (param.transposeA)
  369. A = TensorShape{k, m};
  370. else
  371. A = TensorShape{m, k};
  372. if (param.transposeB)
  373. B = TensorShape{n, k};
  374. else
  375. B = TensorShape{k, n};
  376. checker.set_param(param).
  377. set_dtype(0, stype).
  378. set_dtype(1, stype).
  379. set_dtype(2, dtype).
  380. set_epsilon(dtype == dtype::Float16() ? 5e-1 : 5e-2).
  381. execs({A, B, {}});
  382. }
  383. TEST_F(CUDA, MATRIX_MUL_CUBLASLT_INT8) {
  384. require_compute_capability(7, 5);
  385. NormalRNG normal_rng;
  386. Checker<MatrixMul> checker(handle_cuda());
  387. checker.set_rng(0, &normal_rng)
  388. .set_rng(1, &normal_rng)
  389. .set_before_exec_callback(AlgoChecker<MatrixMulForward>("CUBLAS_LT"));
  390. using Param = MatrixMul::Param;
  391. //size_t m = 32, n = 32, k = 32;
  392. // test Int8 matmul
  393. for (size_t m=8; m<=64; m+=4)
  394. for (size_t n=8; n<=64; n+=4)
  395. for (size_t k=8; k<=64; k+=4)
  396. {
  397. DType dtype=dtype::Int32();
  398. Param param;
  399. param.transposeA = false;
  400. param.transposeB = false;
  401. DType stype = dtype == dtype::Int32() ? dtype::Int8() : dtype;
  402. TensorShape A, B;
  403. A = TensorShape{m, k};
  404. B = TensorShape{k, n};
  405. checker.set_param(param).
  406. set_dtype(0, stype).
  407. set_dtype(1, stype).
  408. set_dtype(2, dtype).
  409. set_epsilon(dtype == dtype::Float16() ? 5e-2 : 5e-3).
  410. execs({A, B, {}});
  411. }
  412. }
  413. } // namespace test
  414. } // namespace megdnn
  415. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台