You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

matrix_mul.cpp 17 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459
  1. /**
  2. * \file dnn/test/common/matrix_mul.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "src/common/utils.h"
  13. #include "test/common/benchmarker.h"
  14. #include "test/common/checker.h"
  15. #include "test/common/matrix_mul.h"
  16. using namespace megdnn;
  17. using namespace test;
  18. constexpr size_t matrix_mul::TestArg::UNSET_STRIDE_VAL;
  19. std::vector<matrix_mul::TestArg> matrix_mul::get_matmul_args_no_mask() {
  20. std::vector<TestArg> args;
  21. for (size_t m : {1, 2, 3, 4, 5, 6, 7, 8, 11, 12, 15, 16, 32})
  22. for (size_t n : {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
  23. 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 32})
  24. for (size_t k : {1, 2, 4, 8, 11, 12, 15, 16, 31, 32, 37})
  25. args.emplace_back(m, n, k, 0);
  26. for (size_t m : {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17})
  27. args.emplace_back(m, m + 1, m + 2, 0);
  28. for (size_t mbase : {11})
  29. for (size_t test_case_offset : {64, 256, 512}) {
  30. size_t mnk = mbase + test_case_offset;
  31. args.emplace_back(mnk, mnk, mnk, 0);
  32. return args;
  33. }
  34. return args;
  35. }
  36. std::vector<matrix_mul::TestArg> matrix_mul::get_matmul_mk_packed_args(
  37. size_t nbase) {
  38. std::vector<TestArg> args;
  39. for (size_t m : {1, 2, 3, 4, 5, 6, 7, 8, 11})
  40. for (size_t n : {1, 2, 3, 4, 5, 8, 12, 16, 24})
  41. for (size_t k : {1, 2, 3, 4, 5, 9, 10, 11})
  42. args.emplace_back(m, n * nbase, k, 0);
  43. return args;
  44. }
  45. std::vector<matrix_mul::TestArg>
  46. matrix_mul::get_batched_matmul_args_cublaslt() {
  47. std::vector<TestArg> args;
  48. for (size_t m : {4, 6, 8, 16}) {
  49. for (size_t n : {4, 6, 8, 16}) {
  50. //[TODO]: the following test case are disabled due to the
  51. // cublasLt(version: 10020) produce wrong result when k in [65, 97],
  52. // so please uncomment it if the bug is fixed
  53. for (size_t k : {32, 64}) {
  54. args.emplace_back(m, n, k, 0, TestArg::UNSET_STRIDE_VAL,
  55. TestArg::UNSET_STRIDE_VAL,
  56. TestArg::UNSET_STRIDE_VAL, 2);
  57. }
  58. }
  59. }
  60. return args;
  61. }
  62. std::vector<matrix_mul::TestArg>
  63. matrix_mul::get_batched_matmul_args_int8x8x32() {
  64. std::vector<TestArg> args;
  65. for (size_t m : {1, 2, 3, 4, 5, 8, 64}) {
  66. for (size_t n : {1, 2, 3, 4, 5, 8, 64}) {
  67. for (size_t k : {1, 2, 3, 4, 5, 8, 64}) {
  68. args.emplace_back(m, n, k, 0, TestArg::UNSET_STRIDE_VAL,
  69. TestArg::UNSET_STRIDE_VAL,
  70. TestArg::UNSET_STRIDE_VAL, 2);
  71. }
  72. }
  73. }
  74. return args;
  75. }
  76. std::vector<matrix_mul::TestArg> matrix_mul::get_matmul_args_mask(
  77. uint8_t mask) {
  78. std::vector<TestArg> args;
  79. std::vector<TestArg> args_temp = matrix_mul::get_matmul_args_no_mask();
  80. for (auto arg : args_temp) {
  81. arg.mask = mask;
  82. args.emplace_back(arg);
  83. }
  84. // non-contiguous case
  85. for (size_t m : {110})
  86. for (size_t n : {119})
  87. for (size_t k : {120}) {
  88. // A: (m, k)
  89. size_t Astride = mask & 1 ? m + 2 : k + 2;
  90. // B: (k, n)
  91. size_t Bstride = mask & 2 ? k + 2 : n + 2;
  92. size_t Cstride = n * 2 + 2;
  93. args.emplace_back(m, n, k, mask, Astride, Bstride, Cstride);
  94. }
  95. return args;
  96. }
  97. std::vector<matrix_mul::TestArg> matrix_mul::get_matmul_args() {
  98. std::vector<TestArg> args;
  99. for (size_t mask = 0; mask < 4; ++mask) {
  100. std::vector<TestArg> args_temp = matrix_mul::get_matmul_args_mask(mask);
  101. for (auto arg : args_temp)
  102. args.emplace_back(arg);
  103. }
  104. return args;
  105. }
  106. std::vector<matrix_mul::TestArg> matrix_mul::get_matmul_args_split_k() {
  107. std::vector<TestArg> args = get_matmul_args();
  108. for (auto iter = args.begin(); iter < args.end();) {
  109. if (iter->k <= iter->n) {
  110. iter = args.erase(iter);
  111. } else {
  112. iter++;
  113. }
  114. }
  115. return args;
  116. }
  117. std::vector<matrix_mul::TestArg> matrix_mul::get_batched_matmul_args_mask(
  118. uint8_t mask) {
  119. std::vector<TestArg> args;
  120. for (size_t b : {1, 2, 3}) {
  121. std::vector<TestArg> args_temp =
  122. megdnn::test::matrix_mul::get_matmul_args_mask(mask);
  123. for (auto arg : args_temp) {
  124. arg.b = b;
  125. args.emplace_back(arg);
  126. }
  127. }
  128. return args;
  129. }
  130. std::vector<matrix_mul::TestArg> matrix_mul::get_batched_matmul_args() {
  131. std::vector<TestArg> args;
  132. for (size_t mask = 0; mask < 4; ++mask) {
  133. std::vector<TestArg> args_temp =
  134. matrix_mul::get_batched_matmul_args_mask(mask);
  135. for (auto arg : args_temp)
  136. args.emplace_back(arg);
  137. }
  138. return args;
  139. }
  140. std::vector<matrix_mul::TestArg>
  141. matrix_mul::get_batched_matmul_broadcast_args() {
  142. std::vector<TestArg> args;
  143. for (size_t mask = 0; mask < 4; ++mask) {
  144. std::vector<TestArg> args_temp =
  145. matrix_mul::get_batched_matmul_broadcast_args_mask(mask);
  146. for (auto arg : args_temp)
  147. args.emplace_back(arg);
  148. }
  149. return args;
  150. }
  151. std::vector<matrix_mul::TestArg>
  152. matrix_mul::get_batched_matmul_broadcast_args_mask(uint8_t mask) {
  153. std::vector<TestArg> args;
  154. std::vector<TestArg> args_temp =
  155. matrix_mul::get_batched_matmul_args_mask(mask);
  156. for (auto arg : args_temp) {
  157. args.emplace_back(arg);
  158. args.back().A_batch_stride = 0;
  159. }
  160. return args;
  161. }
  162. template <typename Opr>
  163. void matrix_mul::check_matrix_mul(DType A_dtype, DType B_dtype, DType C_dtype,
  164. Handle* handle,
  165. const ExecutionPolicyAlgoName& algo,
  166. param::MatrixMul::Format format, size_t nbase,
  167. float eps, std::vector<TestArg>&& user_args,
  168. bool force_deduce_dst,
  169. param::MatrixMul::ComputeMode compute_mode) {
  170. megdnn_assert(A_dtype.enumv() == B_dtype.enumv());
  171. Checker<Opr> checker(handle);
  172. checker.set_force_deduce_dst(force_deduce_dst);
  173. if (!algo.name.empty()) {
  174. checker.set_before_exec_callback(AlgoChecker<Opr>(algo));
  175. }
  176. std::unique_ptr<RNG> rng;
  177. checker.set_epsilon(eps);
  178. if (A_dtype.enumv() == DTypeEnum::Int8 ||
  179. A_dtype.enumv() == DTypeEnum::QuantizedS8) {
  180. //! use larger rng to check the overflow
  181. rng = std::make_unique<UniformIntRNG>(-127, 127);
  182. } else if (A_dtype.enumv() == DTypeEnum::Uint8 ||
  183. A_dtype.enumv() == DTypeEnum::Quantized8Asymm) {
  184. rng = std::make_unique<NormalRNG>(128.f);
  185. } else if (A_dtype.enumv() == DTypeEnum::Int16) {
  186. rng = std::make_unique<UniformIntRNG>(-32767, 32767);
  187. } else if (A_dtype.enumv() == DTypeEnum::Float16) {
  188. rng = std::make_unique<NormalRNG>(2.f);
  189. //! if fp16 not set eps, default 1e-3, we just set it to 1e-2
  190. if (eps < 1e-2) {
  191. checker.set_epsilon(1e-2);
  192. }
  193. }
  194. if (rng) {
  195. checker.set_rng(0, rng.get()).set_rng(1, rng.get());
  196. }
  197. //! return expect if stride == -1, stride otherwise
  198. auto stride_val = [](size_t stride, size_t expect) -> size_t {
  199. if (stride == TestArg::UNSET_STRIDE_VAL) {
  200. return expect;
  201. } else {
  202. return stride;
  203. }
  204. };
  205. constexpr static bool batched =
  206. std::is_same<Opr, megdnn::BatchedMatrixMul>::value;
  207. using Param = MatrixMul::Param;
  208. std::vector<TestArg> args;
  209. if (user_args.empty()) {
  210. if (format == param::MatrixMul::Format::DEFAULT) {
  211. if (batched) {
  212. args = matrix_mul::get_batched_matmul_args();
  213. } else {
  214. args = matrix_mul::get_matmul_args();
  215. }
  216. } else {
  217. megdnn_assert(!batched,
  218. "BatchedMatrixMul does not support MK4/MK8");
  219. args = matrix_mul::get_matmul_mk_packed_args(nbase);
  220. }
  221. } else {
  222. args = user_args;
  223. }
  224. size_t pack_size = MatrixMulForward::pack_size(format);
  225. for (auto& arg : args) {
  226. size_t m = arg.m, n = arg.n, k = arg.k;
  227. if (handle->type() == Handle::HandleType::CUDA) {
  228. //! NOTE: cublas can only process 4B aligned 8-bit input matrix;
  229. bool is_dt_8bit = A_dtype.enumv() == DTypeEnum::Int8 ||
  230. A_dtype.enumv() == DTypeEnum::QuantizedS8 ||
  231. A_dtype.enumv() == DTypeEnum::Uint8 ||
  232. A_dtype.enumv() == DTypeEnum::Quantized8Asymm;
  233. if (is_dt_8bit && ((m % 4 != 0) || (n % 4 != 0))) {
  234. continue;
  235. }
  236. }
  237. Param param;
  238. param.transposeA = arg.mask & 0x1;
  239. param.transposeB = arg.mask & 0x2;
  240. param.compute_mode = compute_mode;
  241. param.format = format;
  242. checker.set_dtype(0, A_dtype)
  243. .set_dtype(1, B_dtype)
  244. .set_dtype(2, C_dtype);
  245. size_t A0 = m, A1 = k, B0 = k, B1 = n;
  246. TensorShape A, B;
  247. if (param.transposeA) {
  248. std::swap(A0, A1);
  249. }
  250. if (param.transposeB) {
  251. std::swap(B0, B1);
  252. }
  253. ptrdiff_t A_stride = arg.A_stride, B_stride = arg.B_stride,
  254. C_stride = arg.C_stride, A_batch_stride = arg.A_batch_stride,
  255. B_batch_stride = arg.B_batch_stride,
  256. C_batch_stride = arg.C_batch_stride;
  257. A_stride = stride_val(A_stride, A1);
  258. B_stride = stride_val(B_stride, B1);
  259. C_stride = stride_val(C_stride, n);
  260. A_batch_stride = stride_val(A_batch_stride, A0 * A_stride);
  261. B_batch_stride = stride_val(B_batch_stride, B0 * B_stride);
  262. C_batch_stride = stride_val(C_batch_stride, m * C_stride);
  263. checker.set_param(param);
  264. if (format == param::MatrixMul::Format::DEFAULT) {
  265. if (batched) {
  266. checker.execl({TensorLayout{{arg.b, A0, A1},
  267. {A_batch_stride, A_stride, 1},
  268. A_dtype},
  269. TensorLayout{{arg.b, B0, B1},
  270. {B_batch_stride, B_stride, 1},
  271. B_dtype},
  272. TensorLayout{{arg.b, m, n},
  273. {C_batch_stride, C_stride, 1},
  274. C_dtype}});
  275. } else {
  276. checker.execl({TensorLayout{{A0, A1}, {A_stride, 1}, A_dtype},
  277. TensorLayout{{B0, B1}, {B_stride, 1}, B_dtype},
  278. TensorLayout{{m, n}, {C_stride, 1}, C_dtype}});
  279. }
  280. } else {
  281. //! ignore non-contiguous, only DEFAULT format support
  282. //! non-contiguous input
  283. checker.execs(
  284. {{A0, A1, pack_size, pack_size}, {B0, B1, pack_size}, {}});
  285. }
  286. }
  287. }
  288. void matrix_mul::check_batched_matrix_mul(DType A_dtype, DType B_dtype,
  289. DType C_dtype, Handle* handle,
  290. const ExecutionPolicyAlgoName& algo,
  291. float eps,
  292. std::vector<TestArg>&& args,
  293. bool force_deduce_dst) {
  294. check_matrix_mul<megdnn::BatchedMatrixMul>(
  295. A_dtype, B_dtype, C_dtype, handle, algo,
  296. param::MatrixMul::Format::DEFAULT, 8, eps,
  297. std::forward<decltype(args)>(args), force_deduce_dst);
  298. }
  299. void matrix_mul::check_matrix_mul(DType A_dtype, DType B_dtype, DType C_dtype,
  300. Handle* handle,
  301. const ExecutionPolicyAlgoName& algo,
  302. param::MatrixMul::Format format, size_t nbase,
  303. float eps, bool force_deduce_dst) {
  304. check_matrix_mul<megdnn::MatrixMul>(A_dtype, B_dtype, C_dtype, handle, algo,
  305. format, nbase, eps, {},
  306. force_deduce_dst);
  307. }
  308. #if MEGDNN_WITH_BENCHMARK
  309. std::vector<matrix_mul::TestArg> matrix_mul::get_benchmark_matmul_args() {
  310. std::vector<matrix_mul::TestArg> args;
  311. args.emplace_back(256, 12 * 24, 256, 0);
  312. //////////////////////// gemv //////////////////////////
  313. for (size_t M : {8, 64, 112, 256}) {
  314. for (size_t K : {8, 64, 112, 256}) {
  315. args.emplace_back(M, 1, K, 0);
  316. }
  317. }
  318. //////////////////////// gemm //////////////////////////
  319. for (size_t M : {8, 64, 112, 256}) {
  320. for (size_t K : {8, 16, 32, 64, 112, 256}) {
  321. for (size_t N : {8, 64, 112, 256}) {
  322. args.emplace_back(M, N, K, 0);
  323. }
  324. }
  325. }
  326. return args;
  327. }
  328. std::vector<matrix_mul::TestArg>
  329. matrix_mul::get_benchmark_matmul_mk_packed_args(size_t nbase) {
  330. std::vector<TestArg> args;
  331. for (size_t m : {2, 4, 8, 16, 24, 32, 64})
  332. for (size_t n : {1, 2, 3, 4, 8, 16, 32, 64})
  333. for (size_t k : {2, 4, 8, 16, 24, 32, 64})
  334. args.emplace_back(m, n * nbase, k, 0);
  335. return args;
  336. }
  337. void matrix_mul::benchmark_with_contrast(
  338. Handle* handle, const std::vector<TestArg>& args, DType A_dtype,
  339. DType B_dtype, DType C_dtype, const char* algo,
  340. param::MatrixMul::Format format, DType contrast_A_dtype,
  341. DType contrast_B_dtype, DType contrast_C_dtype,
  342. const char* contrast_algo, param::MatrixMul::Format contrast_format) {
  343. using Param = MatrixMul::Param;
  344. megdnn_assert(A_dtype.enumv() == B_dtype.enumv());
  345. megdnn_assert(contrast_A_dtype.enumv() == contrast_B_dtype.enumv());
  346. Benchmarker<MatrixMul> benchmark_contrast(handle);
  347. Benchmarker<MatrixMul> benchmark(handle);
  348. constexpr size_t RUNS = 50;
  349. if (algo) {
  350. benchmark.set_before_exec_callback(AlgoChecker<MatrixMul>(algo));
  351. }
  352. if (contrast_algo) {
  353. benchmark_contrast.set_before_exec_callback(
  354. AlgoChecker<MatrixMul>(contrast_algo));
  355. }
  356. benchmark.set_dtype(0, A_dtype).set_dtype(1, B_dtype).set_dtype(2, C_dtype);
  357. benchmark.set_times(RUNS);
  358. benchmark_contrast.set_dtype(0, contrast_A_dtype)
  359. .set_dtype(1, contrast_B_dtype)
  360. .set_dtype(2, contrast_C_dtype);
  361. benchmark_contrast.set_times(RUNS);
  362. auto bench = [](Benchmarker<MatrixMul>& benchmark, Param param,
  363. param::MatrixMul::Format format, size_t m, size_t n,
  364. size_t k, size_t pack_size) -> float {
  365. param.format = format;
  366. benchmark.set_param(param);
  367. float used_algo = 1.0;
  368. if (format == param::MatrixMul::Format::DEFAULT) {
  369. size_t A0 = m * pack_size, A1 = k * pack_size, B0 = k * pack_size,
  370. B1 = n;
  371. TensorShape A, B;
  372. if (param.transposeA) {
  373. std::swap(A0, A1);
  374. }
  375. if (param.transposeB) {
  376. std::swap(B0, B1);
  377. }
  378. used_algo = benchmark.execs({{A0, A1}, {B0, B1}, {}}) / RUNS;
  379. } else {
  380. size_t A0 = m, A1 = k, B0 = k, B1 = n;
  381. if (param.transposeA) {
  382. std::swap(A0, A1);
  383. }
  384. if (param.transposeB) {
  385. std::swap(B0, B1);
  386. }
  387. used_algo = benchmark.execs({{A0, A1, pack_size, pack_size},
  388. {B0, B1, pack_size},
  389. {}}) /
  390. RUNS;
  391. }
  392. return used_algo;
  393. };
  394. size_t mk_size = MatrixMulForward::pack_size(format);
  395. size_t mk_size_contrast = MatrixMulForward::pack_size(contrast_format);
  396. size_t pack_size = std::max(mk_size, mk_size_contrast);
  397. for (auto& arg : args) {
  398. Param param;
  399. param.transposeA = arg.mask & 0x1;
  400. param.transposeB = arg.mask & 0x2;
  401. auto used_contrast = bench(benchmark_contrast, param, contrast_format,
  402. arg.m, arg.n, arg.k, pack_size);
  403. auto used_algo =
  404. bench(benchmark, param, format, arg.m, arg.n, arg.k, pack_size);
  405. float computations =
  406. 2.f * arg.m * pack_size * arg.k * pack_size * arg.n * 1e-6;
  407. printf("run: {(%zu, %zu) x (%zu, %zu)} contrast: %f ms %f Gflops %s: "
  408. "%f "
  409. "ms "
  410. "%f Gflops "
  411. "speedup: %f \n",
  412. arg.m * pack_size, arg.k * pack_size, arg.k * pack_size, arg.n,
  413. used_contrast, computations / used_contrast, algo, used_algo,
  414. computations / used_algo, used_contrast / used_algo);
  415. }
  416. }
  417. #endif
  418. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台