You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

cutlass_matmul.cpp 20 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460
  1. #include <cuda.h>
  2. #include "megdnn/oprs/linalg.h"
  3. #include "src/common/utils.h"
  4. #include "test/common/checker.h"
  5. #include "test/common/matrix_mul.h"
  6. #include "test/common/tensor.h"
  7. #include "test/common/workspace_wrapper.h"
  8. #include "test/cuda/benchmark.h"
  9. #include "test/cuda/fixture.h"
  10. #include "test/cuda/utils.h"
  11. #if CUDA_VERSION >= 9020
  12. namespace megdnn {
  13. namespace test {
  14. namespace {
  15. void test_multibatchsize(
  16. Handle* handle_cuda, DType A_dtype, DType B_dtype, DType C_dtype,
  17. const char* algo, const std::vector<matrix_mul::TestArg>& args,
  18. param::MatrixMul::Format format = param::MatrixMul::Format::DEFAULT,
  19. const std::function<bool(const matrix_mul::TestArg&)>& filter = {}) {
  20. Checker<MatrixMulForward> checker(handle_cuda, false);
  21. if (algo) {
  22. checker.set_before_exec_callback(AlgoChecker<MatrixMulForward>(algo));
  23. }
  24. std::unique_ptr<RNG> rng;
  25. if (A_dtype.enumv() == DTypeEnum::Float32) {
  26. rng = std::make_unique<UniformFloatRNG>(-1, 1);
  27. megdnn_assert(
  28. B_dtype.enumv() == DTypeEnum::Float32 &&
  29. C_dtype.enumv() == DTypeEnum::Float32);
  30. }
  31. megdnn_assert(rng != nullptr);
  32. struct Compare {
  33. bool is_same(dt_float32 expected, dt_float32 actual) const {
  34. return expected == actual;
  35. }
  36. };
  37. // copy rhs->lhs, lhs is 8 times of rhs
  38. auto copy = [](SyncedTensor<dt_float32, Compare>& lhs,
  39. SyncedTensor<dt_float32, Compare>& rhs) {
  40. size_t chunk = rhs.layout().span().dist_byte();
  41. size_t tot = lhs.layout().span().dist_byte();
  42. megdnn_assert(tot % chunk == 0);
  43. char* pointer_lhs = reinterpret_cast<char*>(lhs.ptr_mutable_host());
  44. const char* pointer_rhs = reinterpret_cast<const char*>(rhs.ptr_host());
  45. for (size_t i = 0; i < tot; i += chunk) {
  46. std::memcpy(pointer_lhs + i, pointer_rhs, chunk);
  47. }
  48. };
  49. using Param = param::MatrixMul;
  50. megdnn_assert(format == Param::Format::DEFAULT);
  51. for (auto&& arg : args) {
  52. megdnn_assert(arg.mask == 0x0);
  53. // make m, n, k big enough
  54. size_t m = arg.m, n = (arg.n << 3), k = (arg.k << 3);
  55. size_t m_prime = (m << 3);
  56. if (filter && filter(arg))
  57. continue;
  58. TensorShape A{m, k}, B{k, n}, C{m, n};
  59. TensorShape A_prime{m_prime, k}, C_prime{m_prime, n};
  60. SyncedTensor<dt_float32, Compare> A_tensor{handle_cuda, {A, A_dtype}},
  61. B_tensor{handle_cuda, {B, B_dtype}},
  62. C_tensor{handle_cuda, {C, C_dtype}},
  63. A_tensor_prime{handle_cuda, {A_prime, A_dtype}},
  64. C_tensor_prime{handle_cuda, {C_prime, C_dtype}},
  65. C_tensor_batch{handle_cuda, {C_prime, C_dtype}};
  66. rng->gen(A_tensor.tensornd_host());
  67. rng->gen(B_tensor.tensornd_host());
  68. copy(A_tensor_prime, A_tensor);
  69. auto opr_reference = handle_cuda->create_operator<MatrixMulForward>();
  70. {
  71. opr_reference->execution_policy().algo.reset();
  72. for (auto i : opr_reference->get_all_algorithms_info_safe(
  73. A_tensor.layout(), B_tensor.layout(), C_tensor.layout())) {
  74. if (std::regex_match(
  75. i.desc.name.c_str(),
  76. std::regex("(" + std::string(algo) + ")(.*)"))) {
  77. opr_reference->execution_policy().algo = i.desc;
  78. break;
  79. }
  80. }
  81. megdnn_assert(opr_reference->execution_policy().algo.valid());
  82. size_t ws_size = opr_reference->get_workspace_in_bytes(
  83. A_tensor.layout(), B_tensor.layout(), C_tensor.layout());
  84. WorkspaceWrapper ws_reference(handle_cuda, ws_size);
  85. opr_reference->exec(
  86. A_tensor.tensornd_dev(), B_tensor.tensornd_dev(),
  87. C_tensor.tensornd_dev(), ws_reference.workspace());
  88. }
  89. copy(C_tensor_prime, C_tensor);
  90. checker.set_dtype(0, A_dtype)
  91. .set_dtype(1, B_dtype)
  92. .set_dtype(2, C_dtype)
  93. .set_epsilon(1e-6)
  94. .exect({A_tensor_prime.tensornd_host(), B_tensor.tensornd_host(), {}},
  95. {{}, {}, C_tensor_prime.tensornd_host()});
  96. {
  97. opr_reference->execution_policy().algo.reset();
  98. for (auto i : opr_reference->get_all_algorithms_info_safe(
  99. A_tensor_prime.layout(), B_tensor.layout(),
  100. C_tensor_batch.layout())) {
  101. if (std::regex_match(
  102. i.desc.name.c_str(),
  103. std::regex("(" + std::string(algo) + ")(.*)"))) {
  104. opr_reference->execution_policy().algo = i.desc;
  105. break;
  106. }
  107. }
  108. megdnn_assert(opr_reference->execution_policy().algo.valid());
  109. size_t ws_size = opr_reference->get_workspace_in_bytes(
  110. A_tensor_prime.layout(), B_tensor.layout(),
  111. C_tensor_batch.layout());
  112. WorkspaceWrapper ws_reference(handle_cuda, ws_size);
  113. opr_reference->exec(
  114. A_tensor_prime.tensornd_dev(), B_tensor.tensornd_dev(),
  115. C_tensor_batch.tensornd_dev(), ws_reference.workspace());
  116. }
  117. C_tensor_batch.check_with(C_tensor_prime);
  118. }
  119. }
  120. #if MEGDNN_WITH_BENCHMARK
  121. struct BenchArgs {
  122. size_t m, n, k, mask = 0x0;
  123. };
  124. std::vector<BenchArgs> get_square_matmul_args() {
  125. std::vector<BenchArgs> args;
  126. args.emplace_back(BenchArgs{128, 128, 128});
  127. args.emplace_back(BenchArgs{256, 256, 256});
  128. args.emplace_back(BenchArgs{512, 512, 512});
  129. args.emplace_back(BenchArgs{1024, 1024, 1024});
  130. args.emplace_back(BenchArgs{2048, 2048, 2048});
  131. args.emplace_back(BenchArgs{4096, 4096, 4096});
  132. return args;
  133. }
  134. std::vector<BenchArgs> get_feat_model_args() {
  135. std::vector<BenchArgs> args;
  136. args.emplace_back(BenchArgs{2, 4096, 4096});
  137. args.emplace_back(BenchArgs{2, 1024, 6912});
  138. args.emplace_back(BenchArgs{2, 3456, 3456});
  139. args.emplace_back(BenchArgs{2, 2304, 2304});
  140. args.emplace_back(BenchArgs{1, 256, 8192});
  141. args.emplace_back(BenchArgs{2, 864, 864});
  142. args.emplace_back(BenchArgs{2, 9, 64});
  143. args.emplace_back(BenchArgs{4, 4096, 4096});
  144. args.emplace_back(BenchArgs{4, 1024, 6912});
  145. args.emplace_back(BenchArgs{4, 3456, 3456});
  146. args.emplace_back(BenchArgs{4, 2304, 2304});
  147. args.emplace_back(BenchArgs{2, 256, 8192});
  148. args.emplace_back(BenchArgs{4, 864, 864});
  149. args.emplace_back(BenchArgs{4, 9, 64});
  150. args.emplace_back(BenchArgs{8, 4096, 4096});
  151. args.emplace_back(BenchArgs{8, 1024, 6912});
  152. args.emplace_back(BenchArgs{8, 3456, 3456});
  153. args.emplace_back(BenchArgs{8, 2304, 2304});
  154. args.emplace_back(BenchArgs{4, 256, 8192});
  155. args.emplace_back(BenchArgs{8, 864, 864});
  156. args.emplace_back(BenchArgs{4, 9, 64});
  157. args.emplace_back(BenchArgs{16, 4096, 4096});
  158. args.emplace_back(BenchArgs{16, 1024, 6912});
  159. args.emplace_back(BenchArgs{16, 3456, 3456});
  160. args.emplace_back(BenchArgs{16, 2304, 2304});
  161. args.emplace_back(BenchArgs{8, 256, 8192});
  162. args.emplace_back(BenchArgs{16, 864, 864});
  163. args.emplace_back(BenchArgs{8, 9, 64});
  164. args.emplace_back(BenchArgs{32, 4096, 4096});
  165. args.emplace_back(BenchArgs{32, 1024, 6912});
  166. args.emplace_back(BenchArgs{32, 3456, 3456});
  167. args.emplace_back(BenchArgs{32, 2304, 2304});
  168. args.emplace_back(BenchArgs{16, 256, 8192});
  169. args.emplace_back(BenchArgs{32, 864, 864});
  170. args.emplace_back(BenchArgs{32, 9, 64});
  171. args.emplace_back(BenchArgs{64, 4096, 4096});
  172. args.emplace_back(BenchArgs{64, 1024, 6912});
  173. args.emplace_back(BenchArgs{64, 3456, 3456});
  174. args.emplace_back(BenchArgs{64, 2304, 2304});
  175. args.emplace_back(BenchArgs{32, 256, 8192});
  176. args.emplace_back(BenchArgs{64, 864, 864});
  177. args.emplace_back(BenchArgs{64, 9, 64});
  178. args.emplace_back(BenchArgs{128, 4096, 4096});
  179. args.emplace_back(BenchArgs{128, 1024, 6912});
  180. args.emplace_back(BenchArgs{128, 3456, 3456});
  181. args.emplace_back(BenchArgs{128, 2304, 2304});
  182. args.emplace_back(BenchArgs{64, 256, 8192});
  183. args.emplace_back(BenchArgs{128, 864, 864});
  184. args.emplace_back(BenchArgs{128, 9, 64});
  185. return args;
  186. }
  187. #if CUDA_VERSION >= 10010
  188. std::vector<BenchArgs> get_f16_feat_model_args() {
  189. std::vector<BenchArgs> args;
  190. args.emplace_back(BenchArgs{128, 9216, 9216});
  191. args.emplace_back(BenchArgs{128, 6400, 6400});
  192. args.emplace_back(BenchArgs{128, 5184, 5184});
  193. return args;
  194. }
  195. #endif
  196. void benchmark_matrix_mul(
  197. Handle* handle, const std::vector<BenchArgs>& args, DType A_dtype,
  198. DType B_dtype, DType C_dtype, const char* algo = nullptr,
  199. param::MatrixMul::Format format = param::MatrixMul::Format::DEFAULT) {
  200. megdnn_assert(A_dtype.enumv() == B_dtype.enumv());
  201. CUBenchmarker<MatrixMulForward> benchmarker(handle);
  202. CUBenchmarker<MatrixMulForward> benchmarker_cublas(handle);
  203. size_t RUNS = 1000;
  204. benchmarker.set_display(false).set_times(RUNS);
  205. benchmarker_cublas.set_display(false).set_times(RUNS);
  206. benchmarker_cublas.set_before_exec_callback(
  207. AlgoChecker<MatrixMulForward>("CUBLAS"));
  208. benchmarker.set_dtype(0, A_dtype).set_dtype(1, B_dtype).set_dtype(2, C_dtype);
  209. benchmarker_cublas.set_dtype(0, A_dtype)
  210. .set_dtype(1, B_dtype)
  211. .set_dtype(2, C_dtype);
  212. using Param = MatrixMul::Param;
  213. for (auto&& arg : args) {
  214. size_t m = arg.m, n = arg.n, k = arg.k;
  215. Param param;
  216. param.transposeA = arg.mask & 0x1;
  217. param.transposeB = arg.mask & 0x2;
  218. param.format = format;
  219. size_t A0 = m, A1 = k, B0 = k, B1 = n;
  220. if (param.transposeA) {
  221. std::swap(A0, A1);
  222. }
  223. if (param.transposeB) {
  224. std::swap(B0, B1);
  225. }
  226. benchmarker.set_param(param);
  227. TensorShape A{A0, A1}, B{B0, B1}, C{m, n};
  228. float time_in_ms = 0.f;
  229. if (algo) {
  230. time_in_ms = algo_benchmark<
  231. MatrixMulForward, OprProxy<MatrixMulForward>, CUTimer>(
  232. benchmarker, {A, B, C}, algo) /
  233. RUNS;
  234. } else {
  235. time_in_ms = benchmarker.execs({A, B, C}) / RUNS;
  236. }
  237. benchmarker_cublas.set_param(param);
  238. auto time_in_ms_cublas = benchmarker_cublas.execs({A, B, C}) / RUNS;
  239. float flo = 2.0 * m * n * k / (1e12);
  240. printf("A=%s, B=%s, C=%s, time(algo=%s)=%.2f %.2fTops, "
  241. "time(cublas)=%.2f %.2fTops, "
  242. "perf(algo=%s)/perf(cublas)=%.2f\n",
  243. A.to_string().c_str(), B.to_string().c_str(), C.to_string().c_str(),
  244. algo, time_in_ms, (flo / (time_in_ms * 1e-3)), time_in_ms_cublas,
  245. (flo / (time_in_ms_cublas * 1e-3)), algo,
  246. time_in_ms_cublas / time_in_ms);
  247. }
  248. }
  249. #endif
  250. } // namespace
  251. TEST_F(CUDA, CUTLASS_GEMM_MULTI_BATCHSIZE) {
  252. auto args = matrix_mul::get_matmul_args_no_mask();
  253. test_multibatchsize(
  254. handle_cuda(), dtype::Float32(), dtype::Float32(), dtype::Float32(),
  255. "CUTLASS_FLOAT32_SIMT_128X128X8_32X64X8", args,
  256. param::MatrixMul::Format::DEFAULT);
  257. }
  258. TEST_F(CUDA, CUTLASS_GEMM_SPLIT_K_MULTI_BATCHSIZE) {
  259. auto args = matrix_mul::get_matmul_args_no_mask();
  260. test_multibatchsize(
  261. handle_cuda(), dtype::Float32(), dtype::Float32(), dtype::Float32(),
  262. "CUTLASS_FLOAT32_SIMT_SPLIT_K_128X128X8_32X64X8", args,
  263. param::MatrixMul::Format::DEFAULT,
  264. [](const matrix_mul::TestArg& arg) { return arg.k <= arg.n; });
  265. }
  266. TEST_F(CUDA, CUTLASS_GEMV_BATCHED_STRIDED_128_MULTI_BATCHSIZE) {
  267. auto args = matrix_mul::get_matmul_args_no_mask();
  268. test_multibatchsize(
  269. handle_cuda(), dtype::Float32(), dtype::Float32(), dtype::Float32(),
  270. "CUTLASS_FLOAT32_SIMT_GEMV_BATCHED_STRIDED_128", args,
  271. param::MatrixMul::Format::DEFAULT);
  272. }
  273. TEST_F(CUDA, CUTLASS_GEMV_BATCHED_STRIDED_64_MULTI_BATCHSIZE) {
  274. auto args = matrix_mul::get_matmul_args_no_mask();
  275. test_multibatchsize(
  276. handle_cuda(), dtype::Float32(), dtype::Float32(), dtype::Float32(),
  277. "CUTLASS_FLOAT32_SIMT_GEMV_BATCHED_STRIDED_64", args,
  278. param::MatrixMul::Format::DEFAULT);
  279. }
  280. TEST_F(CUDA, CUTLASS_GEMV_BATCHED_STRIDED_32_MULTI_BATCHSIZE) {
  281. auto args = matrix_mul::get_matmul_args_no_mask();
  282. test_multibatchsize(
  283. handle_cuda(), dtype::Float32(), dtype::Float32(), dtype::Float32(),
  284. "CUTLASS_FLOAT32_SIMT_GEMV_BATCHED_STRIDED_32", args,
  285. param::MatrixMul::Format::DEFAULT);
  286. }
  287. #define MEGDNN_FOREACH_CUTLASS_KERNEL(cb) \
  288. cb(1, 64, 256, 8, 32, 64, 8); \
  289. cb(2, 256, 64, 8, 64, 32, 8); \
  290. cb(3, 32, 256, 8, 16, 64, 8); \
  291. cb(4, 256, 32, 8, 64, 16, 8); \
  292. cb(5, 128, 128, 8, 32, 64, 8); \
  293. cb(6, 128, 64, 8, 64, 32, 8); \
  294. cb(7, 64, 128, 8, 32, 64, 8); \
  295. cb(8, 128, 32, 8, 64, 32, 8); \
  296. cb(9, 32, 128, 8, 32, 64, 8); \
  297. cb(10, 64, 64, 8, 32, 64, 8); \
  298. cb(11, 32, 64, 8, 32, 64, 8); \
  299. cb(12, 64, 32, 8, 64, 32, 8); \
  300. cb(13, 32, 32, 8, 32, 32, 8); \
  301. cb(14, 8, 32, 8, 8, 32, 8); \
  302. cb(15, 16, 32, 8, 16, 32, 8); \
  303. cb(16, 16, 64, 8, 16, 64, 8); \
  304. cb(17, 16, 128, 8, 16, 64, 8);
  305. #define cb(name, tbm, tbn, tbk, wm, wn, wk) \
  306. TEST_F(CUDA, CUTLASS_GEMM_##name) { \
  307. matrix_mul::check_matrix_mul<MatrixMulForward>( \
  308. dtype::Float32(), dtype::Float32(), dtype::Float32(), handle_cuda(), \
  309. "CUTLASS_FLOAT32_SIMT_" #tbm "X" #tbn "X" #tbk "_" #wm "X" #wn \
  310. "X" #wk); \
  311. }
  312. MEGDNN_FOREACH_CUTLASS_KERNEL(cb)
  313. #undef cb
  314. #define cb(name, tbm, tbn, tbk, wm, wn, wk) \
  315. TEST_F(CUDA, CUTLASS_GEMM_SPLIT_K_##name) { \
  316. matrix_mul::check_matrix_mul<MatrixMulForward>( \
  317. dtype::Float32(), dtype::Float32(), dtype::Float32(), handle_cuda(), \
  318. "CUTLASS_FLOAT32_SIMT_SPLIT_K_" #tbm "X" #tbn "X" #tbk "_" #wm "X" #wn \
  319. "X" #wk, \
  320. param::MatrixMul::Format::DEFAULT, 8, 1e-3, \
  321. matrix_mul::get_matmul_args_split_k()); \
  322. }
  323. MEGDNN_FOREACH_CUTLASS_KERNEL(cb)
  324. #undef cb
  325. #undef MEGDNN_FOREACH_CUTLASS_KERNEL
  326. #if CUDA_VERSION >= 10010
  327. #define MEGDNN_FOREACH_CUTLASS_KERNEL(cb) \
  328. cb(1, 256, 128, 32, 64, 64, 32, 8, 8, 4); \
  329. cb(2, 128, 256, 32, 64, 64, 32, 8, 8, 4); \
  330. cb(3, 128, 128, 32, 64, 64, 32, 8, 8, 4);
  331. #define cb(name, tbm, tbn, tbk, wm, wn, wk, im, in, ik) \
  332. TEST_F(CUDA, CUTLASS_F16_884_GEMM_##name) { \
  333. require_compute_capability(7, 0); \
  334. matrix_mul::check_matrix_mul<MatrixMulForward>( \
  335. dtype::Float16(), dtype::Float16(), dtype::Float16(), handle_cuda(), \
  336. "CUTLASS_FLOAT16_TENSOR_OP_h" #im #in #ik "_" #tbm "X" #tbn "X" #tbk \
  337. "_" #wm "X" #wn "X" #wk, \
  338. param::MatrixMul::Format::DEFAULT, 8, 1e-2, \
  339. matrix_mul::get_matmul_args()); \
  340. }
  341. MEGDNN_FOREACH_CUTLASS_KERNEL(cb)
  342. #undef cb
  343. #define cb(name, tbm, tbn, tbk, wm, wn, wk, im, in, ik) \
  344. TEST_F(CUDA, CUTLASS_F16_884_GEMM_SPLIT_K_##name) { \
  345. require_compute_capability(7, 0); \
  346. matrix_mul::check_matrix_mul<MatrixMulForward>( \
  347. dtype::Float16(), dtype::Float16(), dtype::Float16(), handle_cuda(), \
  348. "CUTLASS_FLOAT16_TENSOR_OP_SPLIT_K_h" #im #in #ik "_" #tbm "X" #tbn \
  349. "X" #tbk "_" #wm "X" #wn "X" #wk, \
  350. param::MatrixMul::Format::DEFAULT, 8, 1e-3, \
  351. matrix_mul::get_matmul_args_split_k(), true, \
  352. param::MatrixMul::ComputeMode::FLOAT32); \
  353. }
  354. MEGDNN_FOREACH_CUTLASS_KERNEL(cb)
  355. #undef cb
  356. #undef MEGDNN_FOREACH_CUTLASS_KERNEL
  357. #endif
  358. #if CUDA_VERSION >= 10020
  359. #define MEGDNN_FOREACH_CUTLASS_KERNEL(cb) \
  360. cb(1, 256, 128, 32, 64, 64, 32, 16, 8, 8); \
  361. cb(2, 128, 256, 32, 64, 64, 32, 16, 8, 8); \
  362. cb(3, 128, 128, 32, 64, 64, 32, 16, 8, 8);
  363. #define cb(name, tbm, tbn, tbk, wm, wn, wk, im, in, ik) \
  364. TEST_F(CUDA, CUTLASS_F16_1688_GEMM_##name) { \
  365. require_compute_capability(7, 5); \
  366. matrix_mul::check_matrix_mul<MatrixMulForward>( \
  367. dtype::Float16(), dtype::Float16(), dtype::Float16(), handle_cuda(), \
  368. "CUTLASS_FLOAT16_TENSOR_OP_h" #im #in #ik "_" #tbm "X" #tbn "X" #tbk \
  369. "_" #wm "X" #wn "X" #wk, \
  370. param::MatrixMul::Format::DEFAULT, 8, 1e-2, \
  371. matrix_mul::get_matmul_args(), true, \
  372. param::MatrixMul::ComputeMode::FLOAT32); \
  373. }
  374. MEGDNN_FOREACH_CUTLASS_KERNEL(cb)
  375. #undef cb
  376. #define cb(name, tbm, tbn, tbk, wm, wn, wk, im, in, ik) \
  377. TEST_F(CUDA, CUTLASS_F16_1688_GEMM_SPLIT_K_##name) { \
  378. require_compute_capability(7, 5); \
  379. matrix_mul::check_matrix_mul<MatrixMulForward>( \
  380. dtype::Float16(), dtype::Float16(), dtype::Float16(), handle_cuda(), \
  381. "CUTLASS_FLOAT16_TENSOR_OP_SPLIT_K_h" #im #in #ik "_" #tbm "X" #tbn \
  382. "X" #tbk "_" #wm "X" #wn "X" #wk, \
  383. param::MatrixMul::Format::DEFAULT, 8, 1e-3, \
  384. matrix_mul::get_matmul_args_split_k()); \
  385. }
  386. MEGDNN_FOREACH_CUTLASS_KERNEL(cb)
  387. #undef cb
  388. #undef MEGDNN_FOREACH_CUTLASS_KERNEL
  389. #endif
  390. #if MEGDNN_WITH_BENCHMARK
  391. TEST_F(CUDA, BENCHMARK_CUTLASS_MATMUL) {
  392. benchmark_matrix_mul(
  393. handle_cuda(), get_square_matmul_args(), dtype::Float32(), dtype::Float32(),
  394. dtype::Float32(), "CUTLASS_FLOAT32_SIMT");
  395. }
  396. TEST_F(CUDA, BENCHMARK_CUTLASS_MATMUL_FEAT) {
  397. benchmark_matrix_mul(
  398. handle_cuda(), get_feat_model_args(), dtype::Float32(), dtype::Float32(),
  399. dtype::Float32(), "CUTLASS_FLOAT32_SIMT");
  400. }
  401. #if CUDA_VERSION >= 10010
  402. TEST_F(CUDA, BENCHMARK_CUTLASS_F16_MATMUL_FEAT) {
  403. benchmark_matrix_mul(
  404. handle_cuda(), get_f16_feat_model_args(), dtype::Float16(),
  405. dtype::Float16(), dtype::Float16(), "CUTLASS_FLOAT16_TENSOR_OP");
  406. }
  407. #endif
  408. #endif
  409. } // namespace test
  410. } // namespace megdnn
  411. #endif
  412. // vim: syntax=cpp.doxygen