You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

cutlass_matmul.cpp 20 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471
  1. /**
  2. * \file dnn/test/cuda/cutlass_matmul.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include <cuda.h>
  13. #include "megdnn/oprs/linalg.h"
  14. #include "src/common/utils.h"
  15. #include "test/common/checker.h"
  16. #include "test/common/matrix_mul.h"
  17. #include "test/common/tensor.h"
  18. #include "test/common/workspace_wrapper.h"
  19. #include "test/cuda/benchmark.h"
  20. #include "test/cuda/fixture.h"
  21. #include "test/cuda/utils.h"
  22. #if CUDA_VERSION >= 9020
  23. namespace megdnn {
  24. namespace test {
  25. namespace {
  26. void test_multibatchsize(
  27. Handle* handle_cuda, DType A_dtype, DType B_dtype, DType C_dtype,
  28. const char* algo, const std::vector<matrix_mul::TestArg>& args,
  29. param::MatrixMul::Format format = param::MatrixMul::Format::DEFAULT,
  30. const std::function<bool(const matrix_mul::TestArg&)>& filter = {}) {
  31. Checker<MatrixMulForward> checker(handle_cuda, false);
  32. if (algo) {
  33. checker.set_before_exec_callback(AlgoChecker<MatrixMulForward>(algo));
  34. }
  35. std::unique_ptr<RNG> rng;
  36. if (A_dtype.enumv() == DTypeEnum::Float32) {
  37. rng = std::make_unique<UniformFloatRNG>(-1, 1);
  38. megdnn_assert(
  39. B_dtype.enumv() == DTypeEnum::Float32 &&
  40. C_dtype.enumv() == DTypeEnum::Float32);
  41. }
  42. megdnn_assert(rng != nullptr);
  43. struct Compare {
  44. bool is_same(dt_float32 expected, dt_float32 actual) const {
  45. return expected == actual;
  46. }
  47. };
  48. // copy rhs->lhs, lhs is 8 times of rhs
  49. auto copy = [](SyncedTensor<dt_float32, Compare>& lhs,
  50. SyncedTensor<dt_float32, Compare>& rhs) {
  51. size_t chunk = rhs.layout().span().dist_byte();
  52. size_t tot = lhs.layout().span().dist_byte();
  53. megdnn_assert(tot % chunk == 0);
  54. char* pointer_lhs = reinterpret_cast<char*>(lhs.ptr_mutable_host());
  55. const char* pointer_rhs = reinterpret_cast<const char*>(rhs.ptr_host());
  56. for (size_t i = 0; i < tot; i += chunk) {
  57. std::memcpy(pointer_lhs + i, pointer_rhs, chunk);
  58. }
  59. };
  60. using Param = param::MatrixMul;
  61. megdnn_assert(format == Param::Format::DEFAULT);
  62. for (auto&& arg : args) {
  63. megdnn_assert(arg.mask == 0x0);
  64. // make m, n, k big enough
  65. size_t m = arg.m, n = (arg.n << 3), k = (arg.k << 3);
  66. size_t m_prime = (m << 3);
  67. if (filter && filter(arg))
  68. continue;
  69. TensorShape A{m, k}, B{k, n}, C{m, n};
  70. TensorShape A_prime{m_prime, k}, C_prime{m_prime, n};
  71. SyncedTensor<dt_float32, Compare> A_tensor{handle_cuda, {A, A_dtype}},
  72. B_tensor{handle_cuda, {B, B_dtype}},
  73. C_tensor{handle_cuda, {C, C_dtype}},
  74. A_tensor_prime{handle_cuda, {A_prime, A_dtype}},
  75. C_tensor_prime{handle_cuda, {C_prime, C_dtype}},
  76. C_tensor_batch{handle_cuda, {C_prime, C_dtype}};
  77. rng->gen(A_tensor.tensornd_host());
  78. rng->gen(B_tensor.tensornd_host());
  79. copy(A_tensor_prime, A_tensor);
  80. auto opr_reference = handle_cuda->create_operator<MatrixMulForward>();
  81. {
  82. opr_reference->execution_policy().algo.reset();
  83. for (auto i : opr_reference->get_all_algorithms_info_safe(
  84. A_tensor.layout(), B_tensor.layout(), C_tensor.layout())) {
  85. if (std::regex_match(
  86. i.desc.name.c_str(),
  87. std::regex("(" + std::string(algo) + ")(.*)"))) {
  88. opr_reference->execution_policy().algo = i.desc;
  89. break;
  90. }
  91. }
  92. megdnn_assert(opr_reference->execution_policy().algo.valid());
  93. size_t ws_size = opr_reference->get_workspace_in_bytes(
  94. A_tensor.layout(), B_tensor.layout(), C_tensor.layout());
  95. WorkspaceWrapper ws_reference(handle_cuda, ws_size);
  96. opr_reference->exec(
  97. A_tensor.tensornd_dev(), B_tensor.tensornd_dev(),
  98. C_tensor.tensornd_dev(), ws_reference.workspace());
  99. }
  100. copy(C_tensor_prime, C_tensor);
  101. checker.set_dtype(0, A_dtype)
  102. .set_dtype(1, B_dtype)
  103. .set_dtype(2, C_dtype)
  104. .set_epsilon(1e-6)
  105. .exect({A_tensor_prime.tensornd_host(), B_tensor.tensornd_host(), {}},
  106. {{}, {}, C_tensor_prime.tensornd_host()});
  107. {
  108. opr_reference->execution_policy().algo.reset();
  109. for (auto i : opr_reference->get_all_algorithms_info_safe(
  110. A_tensor_prime.layout(), B_tensor.layout(),
  111. C_tensor_batch.layout())) {
  112. if (std::regex_match(
  113. i.desc.name.c_str(),
  114. std::regex("(" + std::string(algo) + ")(.*)"))) {
  115. opr_reference->execution_policy().algo = i.desc;
  116. break;
  117. }
  118. }
  119. megdnn_assert(opr_reference->execution_policy().algo.valid());
  120. size_t ws_size = opr_reference->get_workspace_in_bytes(
  121. A_tensor_prime.layout(), B_tensor.layout(),
  122. C_tensor_batch.layout());
  123. WorkspaceWrapper ws_reference(handle_cuda, ws_size);
  124. opr_reference->exec(
  125. A_tensor_prime.tensornd_dev(), B_tensor.tensornd_dev(),
  126. C_tensor_batch.tensornd_dev(), ws_reference.workspace());
  127. }
  128. C_tensor_batch.check_with(C_tensor_prime);
  129. }
  130. }
  131. #if MEGDNN_WITH_BENCHMARK
  132. struct BenchArgs {
  133. size_t m, n, k, mask = 0x0;
  134. };
  135. std::vector<BenchArgs> get_square_matmul_args() {
  136. std::vector<BenchArgs> args;
  137. args.emplace_back(BenchArgs{128, 128, 128});
  138. args.emplace_back(BenchArgs{256, 256, 256});
  139. args.emplace_back(BenchArgs{512, 512, 512});
  140. args.emplace_back(BenchArgs{1024, 1024, 1024});
  141. args.emplace_back(BenchArgs{2048, 2048, 2048});
  142. args.emplace_back(BenchArgs{4096, 4096, 4096});
  143. return args;
  144. }
  145. std::vector<BenchArgs> get_feat_model_args() {
  146. std::vector<BenchArgs> args;
  147. args.emplace_back(BenchArgs{2, 4096, 4096});
  148. args.emplace_back(BenchArgs{2, 1024, 6912});
  149. args.emplace_back(BenchArgs{2, 3456, 3456});
  150. args.emplace_back(BenchArgs{2, 2304, 2304});
  151. args.emplace_back(BenchArgs{1, 256, 8192});
  152. args.emplace_back(BenchArgs{2, 864, 864});
  153. args.emplace_back(BenchArgs{2, 9, 64});
  154. args.emplace_back(BenchArgs{4, 4096, 4096});
  155. args.emplace_back(BenchArgs{4, 1024, 6912});
  156. args.emplace_back(BenchArgs{4, 3456, 3456});
  157. args.emplace_back(BenchArgs{4, 2304, 2304});
  158. args.emplace_back(BenchArgs{2, 256, 8192});
  159. args.emplace_back(BenchArgs{4, 864, 864});
  160. args.emplace_back(BenchArgs{4, 9, 64});
  161. args.emplace_back(BenchArgs{8, 4096, 4096});
  162. args.emplace_back(BenchArgs{8, 1024, 6912});
  163. args.emplace_back(BenchArgs{8, 3456, 3456});
  164. args.emplace_back(BenchArgs{8, 2304, 2304});
  165. args.emplace_back(BenchArgs{4, 256, 8192});
  166. args.emplace_back(BenchArgs{8, 864, 864});
  167. args.emplace_back(BenchArgs{4, 9, 64});
  168. args.emplace_back(BenchArgs{16, 4096, 4096});
  169. args.emplace_back(BenchArgs{16, 1024, 6912});
  170. args.emplace_back(BenchArgs{16, 3456, 3456});
  171. args.emplace_back(BenchArgs{16, 2304, 2304});
  172. args.emplace_back(BenchArgs{8, 256, 8192});
  173. args.emplace_back(BenchArgs{16, 864, 864});
  174. args.emplace_back(BenchArgs{8, 9, 64});
  175. args.emplace_back(BenchArgs{32, 4096, 4096});
  176. args.emplace_back(BenchArgs{32, 1024, 6912});
  177. args.emplace_back(BenchArgs{32, 3456, 3456});
  178. args.emplace_back(BenchArgs{32, 2304, 2304});
  179. args.emplace_back(BenchArgs{16, 256, 8192});
  180. args.emplace_back(BenchArgs{32, 864, 864});
  181. args.emplace_back(BenchArgs{32, 9, 64});
  182. args.emplace_back(BenchArgs{64, 4096, 4096});
  183. args.emplace_back(BenchArgs{64, 1024, 6912});
  184. args.emplace_back(BenchArgs{64, 3456, 3456});
  185. args.emplace_back(BenchArgs{64, 2304, 2304});
  186. args.emplace_back(BenchArgs{32, 256, 8192});
  187. args.emplace_back(BenchArgs{64, 864, 864});
  188. args.emplace_back(BenchArgs{64, 9, 64});
  189. args.emplace_back(BenchArgs{128, 4096, 4096});
  190. args.emplace_back(BenchArgs{128, 1024, 6912});
  191. args.emplace_back(BenchArgs{128, 3456, 3456});
  192. args.emplace_back(BenchArgs{128, 2304, 2304});
  193. args.emplace_back(BenchArgs{64, 256, 8192});
  194. args.emplace_back(BenchArgs{128, 864, 864});
  195. args.emplace_back(BenchArgs{128, 9, 64});
  196. return args;
  197. }
  198. #if CUDA_VERSION >= 10010
  199. std::vector<BenchArgs> get_f16_feat_model_args() {
  200. std::vector<BenchArgs> args;
  201. args.emplace_back(BenchArgs{128, 9216, 9216});
  202. args.emplace_back(BenchArgs{128, 6400, 6400});
  203. args.emplace_back(BenchArgs{128, 5184, 5184});
  204. return args;
  205. }
  206. #endif
  207. void benchmark_matrix_mul(
  208. Handle* handle, const std::vector<BenchArgs>& args, DType A_dtype,
  209. DType B_dtype, DType C_dtype, const char* algo = nullptr,
  210. param::MatrixMul::Format format = param::MatrixMul::Format::DEFAULT) {
  211. megdnn_assert(A_dtype.enumv() == B_dtype.enumv());
  212. CUBenchmarker<MatrixMulForward> benchmarker(handle);
  213. CUBenchmarker<MatrixMulForward> benchmarker_cublas(handle);
  214. size_t RUNS = 1000;
  215. benchmarker.set_display(false).set_times(RUNS);
  216. benchmarker_cublas.set_display(false).set_times(RUNS);
  217. benchmarker_cublas.set_before_exec_callback(
  218. AlgoChecker<MatrixMulForward>("CUBLAS"));
  219. benchmarker.set_dtype(0, A_dtype).set_dtype(1, B_dtype).set_dtype(2, C_dtype);
  220. benchmarker_cublas.set_dtype(0, A_dtype)
  221. .set_dtype(1, B_dtype)
  222. .set_dtype(2, C_dtype);
  223. using Param = MatrixMul::Param;
  224. for (auto&& arg : args) {
  225. size_t m = arg.m, n = arg.n, k = arg.k;
  226. Param param;
  227. param.transposeA = arg.mask & 0x1;
  228. param.transposeB = arg.mask & 0x2;
  229. param.format = format;
  230. size_t A0 = m, A1 = k, B0 = k, B1 = n;
  231. if (param.transposeA) {
  232. std::swap(A0, A1);
  233. }
  234. if (param.transposeB) {
  235. std::swap(B0, B1);
  236. }
  237. benchmarker.set_param(param);
  238. TensorShape A{A0, A1}, B{B0, B1}, C{m, n};
  239. float time_in_ms = 0.f;
  240. if (algo) {
  241. time_in_ms = algo_benchmark<
  242. MatrixMulForward, OprProxy<MatrixMulForward>, CUTimer>(
  243. benchmarker, {A, B, C}, algo) /
  244. RUNS;
  245. } else {
  246. time_in_ms = benchmarker.execs({A, B, C}) / RUNS;
  247. }
  248. benchmarker_cublas.set_param(param);
  249. auto time_in_ms_cublas = benchmarker_cublas.execs({A, B, C}) / RUNS;
  250. float flo = 2.0 * m * n * k / (1e12);
  251. printf("A=%s, B=%s, C=%s, time(algo=%s)=%.2f %.2fTops, "
  252. "time(cublas)=%.2f %.2fTops, "
  253. "perf(algo=%s)/perf(cublas)=%.2f\n",
  254. A.to_string().c_str(), B.to_string().c_str(), C.to_string().c_str(),
  255. algo, time_in_ms, (flo / (time_in_ms * 1e-3)), time_in_ms_cublas,
  256. (flo / (time_in_ms_cublas * 1e-3)), algo,
  257. time_in_ms_cublas / time_in_ms);
  258. }
  259. }
  260. #endif
  261. } // namespace
  262. TEST_F(CUDA, CUTLASS_GEMM_MULTI_BATCHSIZE) {
  263. auto args = matrix_mul::get_matmul_args_no_mask();
  264. test_multibatchsize(
  265. handle_cuda(), dtype::Float32(), dtype::Float32(), dtype::Float32(),
  266. "CUTLASS_FLOAT32_SIMT_128X128X8_32X64X8", args,
  267. param::MatrixMul::Format::DEFAULT);
  268. }
  269. TEST_F(CUDA, CUTLASS_GEMM_SPLIT_K_MULTI_BATCHSIZE) {
  270. auto args = matrix_mul::get_matmul_args_no_mask();
  271. test_multibatchsize(
  272. handle_cuda(), dtype::Float32(), dtype::Float32(), dtype::Float32(),
  273. "CUTLASS_FLOAT32_SIMT_SPLIT_K_128X128X8_32X64X8", args,
  274. param::MatrixMul::Format::DEFAULT,
  275. [](const matrix_mul::TestArg& arg) { return arg.k <= arg.n; });
  276. }
  277. TEST_F(CUDA, CUTLASS_GEMV_BATCHED_STRIDED_128_MULTI_BATCHSIZE) {
  278. auto args = matrix_mul::get_matmul_args_no_mask();
  279. test_multibatchsize(
  280. handle_cuda(), dtype::Float32(), dtype::Float32(), dtype::Float32(),
  281. "CUTLASS_FLOAT32_SIMT_GEMV_BATCHED_STRIDED_128", args,
  282. param::MatrixMul::Format::DEFAULT);
  283. }
  284. TEST_F(CUDA, CUTLASS_GEMV_BATCHED_STRIDED_64_MULTI_BATCHSIZE) {
  285. auto args = matrix_mul::get_matmul_args_no_mask();
  286. test_multibatchsize(
  287. handle_cuda(), dtype::Float32(), dtype::Float32(), dtype::Float32(),
  288. "CUTLASS_FLOAT32_SIMT_GEMV_BATCHED_STRIDED_64", args,
  289. param::MatrixMul::Format::DEFAULT);
  290. }
  291. TEST_F(CUDA, CUTLASS_GEMV_BATCHED_STRIDED_32_MULTI_BATCHSIZE) {
  292. auto args = matrix_mul::get_matmul_args_no_mask();
  293. test_multibatchsize(
  294. handle_cuda(), dtype::Float32(), dtype::Float32(), dtype::Float32(),
  295. "CUTLASS_FLOAT32_SIMT_GEMV_BATCHED_STRIDED_32", args,
  296. param::MatrixMul::Format::DEFAULT);
  297. }
  298. #define MEGDNN_FOREACH_CUTLASS_KERNEL(cb) \
  299. cb(1, 64, 256, 8, 32, 64, 8); \
  300. cb(2, 256, 64, 8, 64, 32, 8); \
  301. cb(3, 32, 256, 8, 16, 64, 8); \
  302. cb(4, 256, 32, 8, 64, 16, 8); \
  303. cb(5, 128, 128, 8, 32, 64, 8); \
  304. cb(6, 128, 64, 8, 64, 32, 8); \
  305. cb(7, 64, 128, 8, 32, 64, 8); \
  306. cb(8, 128, 32, 8, 64, 32, 8); \
  307. cb(9, 32, 128, 8, 32, 64, 8); \
  308. cb(10, 64, 64, 8, 32, 64, 8); \
  309. cb(11, 32, 64, 8, 32, 64, 8); \
  310. cb(12, 64, 32, 8, 64, 32, 8); \
  311. cb(13, 32, 32, 8, 32, 32, 8); \
  312. cb(14, 8, 32, 8, 8, 32, 8); \
  313. cb(15, 16, 32, 8, 16, 32, 8); \
  314. cb(16, 16, 64, 8, 16, 64, 8); \
  315. cb(17, 16, 128, 8, 16, 64, 8);
  316. #define cb(name, tbm, tbn, tbk, wm, wn, wk) \
  317. TEST_F(CUDA, CUTLASS_GEMM_##name) { \
  318. matrix_mul::check_matrix_mul<MatrixMulForward>( \
  319. dtype::Float32(), dtype::Float32(), dtype::Float32(), handle_cuda(), \
  320. "CUTLASS_FLOAT32_SIMT_" #tbm "X" #tbn "X" #tbk "_" #wm "X" #wn \
  321. "X" #wk); \
  322. }
  323. MEGDNN_FOREACH_CUTLASS_KERNEL(cb)
  324. #undef cb
  325. #define cb(name, tbm, tbn, tbk, wm, wn, wk) \
  326. TEST_F(CUDA, CUTLASS_GEMM_SPLIT_K_##name) { \
  327. matrix_mul::check_matrix_mul<MatrixMulForward>( \
  328. dtype::Float32(), dtype::Float32(), dtype::Float32(), handle_cuda(), \
  329. "CUTLASS_FLOAT32_SIMT_SPLIT_K_" #tbm "X" #tbn "X" #tbk "_" #wm "X" #wn \
  330. "X" #wk, \
  331. param::MatrixMul::Format::DEFAULT, 8, 1e-3, \
  332. matrix_mul::get_matmul_args_split_k()); \
  333. }
  334. MEGDNN_FOREACH_CUTLASS_KERNEL(cb)
  335. #undef cb
  336. #undef MEGDNN_FOREACH_CUTLASS_KERNEL
  337. #if CUDA_VERSION >= 10010
  338. #define MEGDNN_FOREACH_CUTLASS_KERNEL(cb) \
  339. cb(1, 256, 128, 32, 64, 64, 32, 8, 8, 4); \
  340. cb(2, 128, 256, 32, 64, 64, 32, 8, 8, 4); \
  341. cb(3, 128, 128, 32, 64, 64, 32, 8, 8, 4);
  342. #define cb(name, tbm, tbn, tbk, wm, wn, wk, im, in, ik) \
  343. TEST_F(CUDA, CUTLASS_F16_884_GEMM_##name) { \
  344. require_compute_capability(7, 0); \
  345. matrix_mul::check_matrix_mul<MatrixMulForward>( \
  346. dtype::Float16(), dtype::Float16(), dtype::Float16(), handle_cuda(), \
  347. "CUTLASS_FLOAT16_TENSOR_OP_h" #im #in #ik "_" #tbm "X" #tbn "X" #tbk \
  348. "_" #wm "X" #wn "X" #wk, \
  349. param::MatrixMul::Format::DEFAULT, 8, 1e-2, \
  350. matrix_mul::get_matmul_args()); \
  351. }
  352. MEGDNN_FOREACH_CUTLASS_KERNEL(cb)
  353. #undef cb
  354. #define cb(name, tbm, tbn, tbk, wm, wn, wk, im, in, ik) \
  355. TEST_F(CUDA, CUTLASS_F16_884_GEMM_SPLIT_K_##name) { \
  356. require_compute_capability(7, 0); \
  357. matrix_mul::check_matrix_mul<MatrixMulForward>( \
  358. dtype::Float16(), dtype::Float16(), dtype::Float16(), handle_cuda(), \
  359. "CUTLASS_FLOAT16_TENSOR_OP_SPLIT_K_h" #im #in #ik "_" #tbm "X" #tbn \
  360. "X" #tbk "_" #wm "X" #wn "X" #wk, \
  361. param::MatrixMul::Format::DEFAULT, 8, 1e-3, \
  362. matrix_mul::get_matmul_args_split_k(), true, \
  363. param::MatrixMul::ComputeMode::FLOAT32); \
  364. }
  365. MEGDNN_FOREACH_CUTLASS_KERNEL(cb)
  366. #undef cb
  367. #undef MEGDNN_FOREACH_CUTLASS_KERNEL
  368. #endif
  369. #if CUDA_VERSION >= 10020
  370. #define MEGDNN_FOREACH_CUTLASS_KERNEL(cb) \
  371. cb(1, 256, 128, 32, 64, 64, 32, 16, 8, 8); \
  372. cb(2, 128, 256, 32, 64, 64, 32, 16, 8, 8); \
  373. cb(3, 128, 128, 32, 64, 64, 32, 16, 8, 8);
  374. #define cb(name, tbm, tbn, tbk, wm, wn, wk, im, in, ik) \
  375. TEST_F(CUDA, CUTLASS_F16_1688_GEMM_##name) { \
  376. require_compute_capability(7, 5); \
  377. matrix_mul::check_matrix_mul<MatrixMulForward>( \
  378. dtype::Float16(), dtype::Float16(), dtype::Float16(), handle_cuda(), \
  379. "CUTLASS_FLOAT16_TENSOR_OP_h" #im #in #ik "_" #tbm "X" #tbn "X" #tbk \
  380. "_" #wm "X" #wn "X" #wk, \
  381. param::MatrixMul::Format::DEFAULT, 8, 1e-2, \
  382. matrix_mul::get_matmul_args(), true, \
  383. param::MatrixMul::ComputeMode::FLOAT32); \
  384. }
  385. MEGDNN_FOREACH_CUTLASS_KERNEL(cb)
  386. #undef cb
  387. #define cb(name, tbm, tbn, tbk, wm, wn, wk, im, in, ik) \
  388. TEST_F(CUDA, CUTLASS_F16_1688_GEMM_SPLIT_K_##name) { \
  389. require_compute_capability(7, 5); \
  390. matrix_mul::check_matrix_mul<MatrixMulForward>( \
  391. dtype::Float16(), dtype::Float16(), dtype::Float16(), handle_cuda(), \
  392. "CUTLASS_FLOAT16_TENSOR_OP_SPLIT_K_h" #im #in #ik "_" #tbm "X" #tbn \
  393. "X" #tbk "_" #wm "X" #wn "X" #wk, \
  394. param::MatrixMul::Format::DEFAULT, 8, 1e-3, \
  395. matrix_mul::get_matmul_args_split_k()); \
  396. }
  397. MEGDNN_FOREACH_CUTLASS_KERNEL(cb)
  398. #undef cb
  399. #undef MEGDNN_FOREACH_CUTLASS_KERNEL
  400. #endif
  401. #if MEGDNN_WITH_BENCHMARK
  402. TEST_F(CUDA, BENCHMARK_CUTLASS_MATMUL) {
  403. benchmark_matrix_mul(
  404. handle_cuda(), get_square_matmul_args(), dtype::Float32(), dtype::Float32(),
  405. dtype::Float32(), "CUTLASS_FLOAT32_SIMT");
  406. }
  407. TEST_F(CUDA, BENCHMARK_CUTLASS_MATMUL_FEAT) {
  408. benchmark_matrix_mul(
  409. handle_cuda(), get_feat_model_args(), dtype::Float32(), dtype::Float32(),
  410. dtype::Float32(), "CUTLASS_FLOAT32_SIMT");
  411. }
  412. #if CUDA_VERSION >= 10010
  413. TEST_F(CUDA, BENCHMARK_CUTLASS_F16_MATMUL_FEAT) {
  414. benchmark_matrix_mul(
  415. handle_cuda(), get_f16_feat_model_args(), dtype::Float16(),
  416. dtype::Float16(), dtype::Float16(), "CUTLASS_FLOAT16_TENSOR_OP");
  417. }
  418. #endif
  419. #endif
  420. } // namespace test
  421. } // namespace megdnn
  422. #endif
  423. // vim: syntax=cpp.doxygen