You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

matrix_mul.cpp 19 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496
  1. #include "test/common/matrix_mul.h"
  2. #include "src/common/utils.h"
  3. #include "test/common/benchmarker.h"
  4. #include "test/common/checker.h"
  5. using namespace megdnn;
  6. using namespace test;
  7. constexpr size_t matrix_mul::TestArg::UNSET_STRIDE_VAL;
  8. std::vector<matrix_mul::TestArg> matrix_mul::get_matmul_args_no_mask() {
  9. std::vector<TestArg> args;
  10. for (size_t m : {1, 2, 3, 4, 5, 6, 7, 8, 11, 12, 15, 16, 32})
  11. for (size_t n : {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
  12. 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 32})
  13. for (size_t k : {1, 2, 4, 8, 11, 12, 15, 16, 31, 32, 37})
  14. args.emplace_back(m, n, k, 0);
  15. for (size_t m : {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17})
  16. args.emplace_back(m, m + 1, m + 2, 0);
  17. for (size_t mbase : {11})
  18. for (size_t test_case_offset : {64, 256, 512}) {
  19. size_t mnk = mbase + test_case_offset;
  20. args.emplace_back(mnk, mnk, mnk, 0);
  21. return args;
  22. }
  23. return args;
  24. }
  25. std::vector<matrix_mul::TestArg> matrix_mul::get_matmul_mk_packed_args(size_t nbase) {
  26. std::vector<TestArg> args;
  27. for (size_t m : {1, 2, 3, 4, 5, 6, 7, 8, 11})
  28. for (size_t n : {1, 2, 3, 4, 5, 8, 12, 16, 24})
  29. for (size_t k : {1, 2, 3, 4, 5, 9, 10, 11})
  30. args.emplace_back(m, n * nbase, k, 0);
  31. return args;
  32. }
  33. std::vector<matrix_mul::TestArg> matrix_mul::get_batched_matmul_args_cublaslt() {
  34. std::vector<TestArg> args;
  35. for (size_t m : {4, 6, 8, 16}) {
  36. for (size_t n : {4, 6, 8, 16}) {
  37. //[TODO]: the following test case are disabled due to the
  38. // cublasLt(version: 10020) produce wrong result when k in [65, 97],
  39. // so please uncomment it if the bug is fixed
  40. for (size_t k : {32, 64}) {
  41. args.emplace_back(
  42. m, n, k, 0, TestArg::UNSET_STRIDE_VAL,
  43. TestArg::UNSET_STRIDE_VAL, TestArg::UNSET_STRIDE_VAL, 2);
  44. }
  45. }
  46. }
  47. return args;
  48. }
  49. std::vector<matrix_mul::TestArg> matrix_mul::get_batched_matmul_args_int8x8x32() {
  50. std::vector<TestArg> args;
  51. for (size_t m : {1, 2, 3, 4, 5, 8, 64}) {
  52. for (size_t n : {1, 2, 3, 4, 5, 8, 64}) {
  53. for (size_t k : {1, 2, 3, 4, 5, 8, 64}) {
  54. args.emplace_back(
  55. m, n, k, 0, TestArg::UNSET_STRIDE_VAL,
  56. TestArg::UNSET_STRIDE_VAL, TestArg::UNSET_STRIDE_VAL, 2);
  57. }
  58. }
  59. }
  60. return args;
  61. }
  62. std::vector<matrix_mul::TestArg> matrix_mul::get_matmul_args_mask(uint8_t mask) {
  63. std::vector<TestArg> args;
  64. std::vector<TestArg> args_temp = matrix_mul::get_matmul_args_no_mask();
  65. for (auto arg : args_temp) {
  66. arg.mask = mask;
  67. args.emplace_back(arg);
  68. }
  69. // non-contiguous case
  70. for (size_t m : {110})
  71. for (size_t n : {119})
  72. for (size_t k : {120}) {
  73. // A: (m, k)
  74. size_t Astride = mask & 1 ? m + 2 : k + 2;
  75. // B: (k, n)
  76. size_t Bstride = mask & 2 ? k + 2 : n + 2;
  77. size_t Cstride = n * 2 + 2;
  78. args.emplace_back(m, n, k, mask, Astride, Bstride, Cstride);
  79. }
  80. return args;
  81. }
  82. std::vector<matrix_mul::TestArg> matrix_mul::get_matmul_args() {
  83. std::vector<TestArg> args;
  84. for (size_t mask = 0; mask < 4; ++mask) {
  85. std::vector<TestArg> args_temp = matrix_mul::get_matmul_args_mask(mask);
  86. for (auto arg : args_temp)
  87. args.emplace_back(arg);
  88. }
  89. return args;
  90. }
  91. std::vector<matrix_mul::TestArg> matrix_mul::get_matmul_args_split_k() {
  92. std::vector<TestArg> args = get_matmul_args();
  93. for (auto iter = args.begin(); iter < args.end();) {
  94. if (iter->k <= iter->n) {
  95. iter = args.erase(iter);
  96. } else {
  97. iter++;
  98. }
  99. }
  100. return args;
  101. }
  102. std::vector<matrix_mul::TestArg> matrix_mul::get_batched_matmul_args_mask(
  103. uint8_t mask) {
  104. std::vector<TestArg> args;
  105. for (size_t b : {1, 2, 3}) {
  106. std::vector<TestArg> args_temp =
  107. megdnn::test::matrix_mul::get_matmul_args_mask(mask);
  108. for (auto arg : args_temp) {
  109. arg.b = b;
  110. args.emplace_back(arg);
  111. }
  112. }
  113. return args;
  114. }
  115. std::vector<matrix_mul::TestArg> matrix_mul::get_batched_matmul_args() {
  116. std::vector<TestArg> args;
  117. for (size_t mask = 0; mask < 4; ++mask) {
  118. std::vector<TestArg> args_temp = matrix_mul::get_batched_matmul_args_mask(mask);
  119. for (auto arg : args_temp)
  120. args.emplace_back(arg);
  121. }
  122. return args;
  123. }
  124. std::vector<matrix_mul::TestArg> matrix_mul::get_batched_matmul_broadcast_args() {
  125. std::vector<TestArg> args;
  126. for (size_t mask = 0; mask < 4; ++mask) {
  127. std::vector<TestArg> args_temp =
  128. matrix_mul::get_batched_matmul_broadcast_args_mask(mask);
  129. for (auto arg : args_temp)
  130. args.emplace_back(arg);
  131. }
  132. return args;
  133. }
  134. std::vector<matrix_mul::TestArg> matrix_mul::get_batched_matmul_broadcast_args_mask(
  135. uint8_t mask) {
  136. std::vector<TestArg> args;
  137. std::vector<TestArg> args_temp = matrix_mul::get_batched_matmul_args_mask(mask);
  138. for (auto arg : args_temp) {
  139. args.emplace_back(arg);
  140. args.back().A_batch_stride = 0;
  141. }
  142. return args;
  143. }
  144. template <typename Opr>
  145. void matrix_mul::check_matrix_mul(
  146. DType A_dtype, DType B_dtype, DType C_dtype, Handle* handle,
  147. const ExecutionPolicyAlgoName& algo, param::MatrixMul::Format format,
  148. size_t nbase, float eps, std::vector<TestArg>&& user_args,
  149. bool force_deduce_dst, param::MatrixMul::ComputeMode compute_mode) {
  150. megdnn_assert(A_dtype.enumv() == B_dtype.enumv());
  151. Checker<Opr> checker(handle);
  152. checker.set_force_deduce_dst(force_deduce_dst);
  153. if (!algo.name.empty()) {
  154. checker.set_before_exec_callback(AlgoChecker<Opr>(algo));
  155. }
  156. std::unique_ptr<RNG> rng;
  157. checker.set_epsilon(eps);
  158. if (A_dtype.enumv() == DTypeEnum::Int8 ||
  159. A_dtype.enumv() == DTypeEnum::QuantizedS8) {
  160. //! use larger rng to check the overflow
  161. rng = std::make_unique<UniformIntRNG>(-127, 127);
  162. } else if (
  163. A_dtype.enumv() == DTypeEnum::Uint8 ||
  164. A_dtype.enumv() == DTypeEnum::Quantized8Asymm) {
  165. rng = std::make_unique<NormalRNG>(128.f);
  166. } else if (A_dtype.enumv() == DTypeEnum::Int16) {
  167. rng = std::make_unique<UniformIntRNG>(-32767, 32767);
  168. } else if (A_dtype.enumv() == DTypeEnum::Float16) {
  169. rng = std::make_unique<NormalRNG>(2.f);
  170. //! if fp16 not set eps, default 1e-3, we just set it to 1e-2
  171. if (eps < 1e-2) {
  172. checker.set_epsilon(1e-2);
  173. }
  174. }
  175. if (rng) {
  176. checker.set_rng(0, rng.get()).set_rng(1, rng.get());
  177. }
  178. //! return expect if stride == -1, stride otherwise
  179. auto stride_val = [](size_t stride, size_t expect) -> size_t {
  180. if (stride == TestArg::UNSET_STRIDE_VAL) {
  181. return expect;
  182. } else {
  183. return stride;
  184. }
  185. };
  186. constexpr static bool batched = std::is_same<Opr, megdnn::BatchedMatrixMul>::value;
  187. using Param = MatrixMul::Param;
  188. std::vector<TestArg> args;
  189. if (user_args.empty()) {
  190. if (format == param::MatrixMul::Format::DEFAULT) {
  191. if (batched) {
  192. args = matrix_mul::get_batched_matmul_args();
  193. } else {
  194. args = matrix_mul::get_matmul_args();
  195. }
  196. } else {
  197. megdnn_assert(!batched, "BatchedMatrixMul does not support MK4/MK8");
  198. args = matrix_mul::get_matmul_mk_packed_args(nbase);
  199. }
  200. } else {
  201. args = user_args;
  202. }
  203. size_t pack_size = MatrixMulForward::pack_size(format);
  204. for (auto& arg : args) {
  205. size_t m = arg.m, n = arg.n, k = arg.k;
  206. if (handle->type() == Handle::HandleType::CUDA) {
  207. //! NOTE: cublas can only process 4B aligned 8-bit input matrix;
  208. bool is_dt_8bit = A_dtype.enumv() == DTypeEnum::Int8 ||
  209. A_dtype.enumv() == DTypeEnum::QuantizedS8 ||
  210. A_dtype.enumv() == DTypeEnum::Uint8 ||
  211. A_dtype.enumv() == DTypeEnum::Quantized8Asymm;
  212. if (is_dt_8bit && ((m % 4 != 0) || (n % 4 != 0))) {
  213. continue;
  214. }
  215. }
  216. Param param;
  217. param.transposeA = arg.mask & 0x1;
  218. param.transposeB = arg.mask & 0x2;
  219. param.compute_mode = compute_mode;
  220. param.format = format;
  221. checker.set_dtype(0, A_dtype).set_dtype(1, B_dtype).set_dtype(2, C_dtype);
  222. size_t A0 = m, A1 = k, B0 = k, B1 = n;
  223. TensorShape A, B;
  224. if (param.transposeA) {
  225. std::swap(A0, A1);
  226. }
  227. if (param.transposeB) {
  228. std::swap(B0, B1);
  229. }
  230. ptrdiff_t A_stride = arg.A_stride, B_stride = arg.B_stride,
  231. C_stride = arg.C_stride, A_batch_stride = arg.A_batch_stride,
  232. B_batch_stride = arg.B_batch_stride,
  233. C_batch_stride = arg.C_batch_stride;
  234. A_stride = stride_val(A_stride, A1);
  235. B_stride = stride_val(B_stride, B1);
  236. C_stride = stride_val(C_stride, n);
  237. A_batch_stride = stride_val(A_batch_stride, A0 * A_stride);
  238. B_batch_stride = stride_val(B_batch_stride, B0 * B_stride);
  239. C_batch_stride = stride_val(C_batch_stride, m * C_stride);
  240. checker.set_param(param);
  241. if (format == param::MatrixMul::Format::DEFAULT) {
  242. if (batched) {
  243. checker.execl(
  244. {TensorLayout{
  245. {arg.b, A0, A1},
  246. {A_batch_stride, A_stride, 1},
  247. A_dtype},
  248. TensorLayout{
  249. {arg.b, B0, B1},
  250. {B_batch_stride, B_stride, 1},
  251. B_dtype},
  252. TensorLayout{
  253. {arg.b, m, n},
  254. {C_batch_stride, C_stride, 1},
  255. C_dtype}});
  256. } else {
  257. checker.execl(
  258. {TensorLayout{{A0, A1}, {A_stride, 1}, A_dtype},
  259. TensorLayout{{B0, B1}, {B_stride, 1}, B_dtype},
  260. TensorLayout{{m, n}, {C_stride, 1}, C_dtype}});
  261. }
  262. } else {
  263. //! ignore non-contiguous, only DEFAULT format support
  264. //! non-contiguous input
  265. checker.execs({{A0, A1, pack_size, pack_size}, {B0, B1, pack_size}, {}});
  266. }
  267. }
  268. }
  269. void matrix_mul::check_batched_matrix_mul(
  270. DType A_dtype, DType B_dtype, DType C_dtype, Handle* handle,
  271. const ExecutionPolicyAlgoName& algo, float eps, std::vector<TestArg>&& args,
  272. bool force_deduce_dst) {
  273. check_matrix_mul<megdnn::BatchedMatrixMul>(
  274. A_dtype, B_dtype, C_dtype, handle, algo, param::MatrixMul::Format::DEFAULT,
  275. 8, eps, std::forward<decltype(args)>(args), force_deduce_dst);
  276. }
  277. void matrix_mul::check_matrix_mul(
  278. DType A_dtype, DType B_dtype, DType C_dtype, Handle* handle,
  279. const ExecutionPolicyAlgoName& algo, param::MatrixMul::Format format,
  280. size_t nbase, float eps, bool force_deduce_dst) {
  281. check_matrix_mul<megdnn::MatrixMul>(
  282. A_dtype, B_dtype, C_dtype, handle, algo, format, nbase, eps, {},
  283. force_deduce_dst);
  284. }
  285. #if MEGDNN_WITH_BENCHMARK
  286. std::vector<matrix_mul::TestArg> matrix_mul::get_benchmark_matmul_args() {
  287. std::vector<matrix_mul::TestArg> args;
  288. args.emplace_back(256, 12 * 24, 256, 0);
  289. //////////////////////// gemv //////////////////////////
  290. for (size_t M : {8, 64, 112, 256}) {
  291. for (size_t K : {8, 64, 112, 256}) {
  292. args.emplace_back(M, 1, K, 0);
  293. }
  294. }
  295. //////////////////////// gemm //////////////////////////
  296. for (size_t M : {8, 64, 112, 256}) {
  297. for (size_t K : {8, 16, 32, 64, 112, 256}) {
  298. for (size_t N : {8, 64, 112, 256}) {
  299. args.emplace_back(M, N, K, 0);
  300. }
  301. }
  302. }
  303. return args;
  304. }
  305. std::vector<matrix_mul::TestArg> matrix_mul::get_benchmark_matmul_mk_packed_args(
  306. size_t nbase) {
  307. std::vector<TestArg> args;
  308. for (size_t m : {2, 4, 8, 16, 24, 32, 64})
  309. for (size_t n : {1, 2, 3, 4, 8, 16, 32, 64})
  310. for (size_t k : {2, 4, 8, 16, 24, 32, 64})
  311. args.emplace_back(m, n * nbase, k, 0);
  312. return args;
  313. }
  314. void matrix_mul::benchmark_with_contrast(
  315. Handle* handle, const std::vector<TestArg>& args, DType A_dtype, DType B_dtype,
  316. DType C_dtype, const char* algo, param::MatrixMul::Format format,
  317. DType contrast_A_dtype, DType contrast_B_dtype, DType contrast_C_dtype,
  318. const char* contrast_algo, param::MatrixMul::Format contrast_format) {
  319. using Param = MatrixMul::Param;
  320. megdnn_assert(A_dtype.enumv() == B_dtype.enumv());
  321. megdnn_assert(contrast_A_dtype.enumv() == contrast_B_dtype.enumv());
  322. Benchmarker<MatrixMul> benchmark_contrast(handle);
  323. Benchmarker<MatrixMul> benchmark(handle);
  324. constexpr size_t RUNS = 50;
  325. if (algo) {
  326. benchmark.set_before_exec_callback(AlgoChecker<MatrixMul>(algo));
  327. }
  328. if (contrast_algo) {
  329. benchmark_contrast.set_before_exec_callback(
  330. AlgoChecker<MatrixMul>(contrast_algo));
  331. }
  332. benchmark.set_dtype(0, A_dtype).set_dtype(1, B_dtype).set_dtype(2, C_dtype);
  333. benchmark.set_times(RUNS);
  334. benchmark_contrast.set_dtype(0, contrast_A_dtype)
  335. .set_dtype(1, contrast_B_dtype)
  336. .set_dtype(2, contrast_C_dtype);
  337. benchmark_contrast.set_times(RUNS);
  338. auto bench = [](Benchmarker<MatrixMul>& benchmark, Param param,
  339. param::MatrixMul::Format format, size_t m, size_t n, size_t k,
  340. size_t pack_size) -> float {
  341. param.format = format;
  342. benchmark.set_param(param);
  343. float used_algo = 1.0;
  344. if (format == param::MatrixMul::Format::DEFAULT) {
  345. size_t A0 = m * pack_size, A1 = k * pack_size, B0 = k * pack_size, B1 = n;
  346. TensorShape A, B;
  347. if (param.transposeA) {
  348. std::swap(A0, A1);
  349. }
  350. if (param.transposeB) {
  351. std::swap(B0, B1);
  352. }
  353. used_algo = benchmark.execs({{A0, A1}, {B0, B1}, {}}) / RUNS;
  354. } else {
  355. size_t A0 = m, A1 = k, B0 = k, B1 = n;
  356. if (param.transposeA) {
  357. std::swap(A0, A1);
  358. }
  359. if (param.transposeB) {
  360. std::swap(B0, B1);
  361. }
  362. used_algo =
  363. benchmark.execs(
  364. {{A0, A1, pack_size, pack_size}, {B0, B1, pack_size}, {}}) /
  365. RUNS;
  366. }
  367. return used_algo;
  368. };
  369. size_t mk_size = MatrixMulForward::pack_size(format);
  370. size_t mk_size_contrast = MatrixMulForward::pack_size(contrast_format);
  371. size_t pack_size = std::max(mk_size, mk_size_contrast);
  372. for (auto& arg : args) {
  373. Param param;
  374. param.transposeA = arg.mask & 0x1;
  375. param.transposeB = arg.mask & 0x2;
  376. auto used_contrast =
  377. bench(benchmark_contrast, param, contrast_format, arg.m, arg.n, arg.k,
  378. pack_size);
  379. auto used_algo =
  380. bench(benchmark, param, format, arg.m, arg.n, arg.k, pack_size);
  381. float computations = 2.f * arg.m * pack_size * arg.k * pack_size * arg.n * 1e-6;
  382. printf("run: {(%zu, %zu) x (%zu, %zu)} contrast: %f ms %f Gflops %s: "
  383. "%f "
  384. "ms "
  385. "%f Gflops "
  386. "speedup: %f \n",
  387. arg.m * pack_size, arg.k * pack_size, arg.k * pack_size, arg.n,
  388. used_contrast, computations / used_contrast, algo, used_algo,
  389. computations / used_algo, used_contrast / used_algo);
  390. }
  391. }
  392. void matrix_mul::benchmark_single_algo(
  393. Handle* handle, const std::vector<TestArg>& args, DType A_dtype, DType B_dtype,
  394. DType C_dtype, const char* algo, param::MatrixMul::Format format) {
  395. using Param = MatrixMul::Param;
  396. megdnn_assert(A_dtype.enumv() == B_dtype.enumv());
  397. Benchmarker<MatrixMul> benchmark(handle);
  398. constexpr size_t RUNS = 50;
  399. if (algo) {
  400. benchmark.set_before_exec_callback(AlgoChecker<MatrixMul>(algo));
  401. }
  402. benchmark.set_dtype(0, A_dtype).set_dtype(1, B_dtype).set_dtype(2, C_dtype);
  403. benchmark.set_times(RUNS);
  404. auto bench = [](Benchmarker<MatrixMul>& benchmark, Param param,
  405. param::MatrixMul::Format format, size_t m, size_t n, size_t k,
  406. size_t pack_size) -> float {
  407. param.format = format;
  408. benchmark.set_param(param);
  409. float used_algo = 1.0;
  410. if (format == param::MatrixMul::Format::DEFAULT) {
  411. size_t A0 = m * pack_size, A1 = k * pack_size, B0 = k * pack_size, B1 = n;
  412. TensorShape A, B;
  413. if (param.transposeA) {
  414. std::swap(A0, A1);
  415. }
  416. if (param.transposeB) {
  417. std::swap(B0, B1);
  418. }
  419. used_algo = benchmark.execs({{A0, A1}, {B0, B1}, {}}) / RUNS;
  420. } else {
  421. size_t A0 = m, A1 = k, B0 = k, B1 = n;
  422. if (param.transposeA) {
  423. std::swap(A0, A1);
  424. }
  425. if (param.transposeB) {
  426. std::swap(B0, B1);
  427. }
  428. used_algo =
  429. benchmark.execs(
  430. {{A0, A1, pack_size, pack_size}, {B0, B1, pack_size}, {}}) /
  431. RUNS;
  432. }
  433. return used_algo;
  434. };
  435. size_t pack_size = MatrixMulForward::pack_size(format);
  436. for (auto& arg : args) {
  437. Param param;
  438. param.transposeA = arg.mask & 0x1;
  439. param.transposeB = arg.mask & 0x2;
  440. auto used_algo =
  441. bench(benchmark, param, format, arg.m, arg.n, arg.k, pack_size);
  442. float computations = 2.f * arg.m * pack_size * arg.k * pack_size * arg.n * 1e-6;
  443. printf("run: {(%zu, %zu) x (%zu, %zu)} %f ms %f Gflops\n", arg.m * pack_size,
  444. arg.k * pack_size, arg.k * pack_size, arg.n, used_algo,
  445. computations / used_algo);
  446. }
  447. }
  448. #endif
  449. // vim: syntax=cpp.doxygen