You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

cutlass_matmul.cpp 14 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337
  1. /**
  2. * \file dnn/test/cuda/cutlass_matmul.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include <cuda.h>
  13. #include "megdnn/oprs/linalg.h"
  14. #include "src/common/utils.h"
  15. #include "test/common/checker.h"
  16. #include "test/common/matrix_mul.h"
  17. #include "test/common/tensor.h"
  18. #include "test/common/workspace_wrapper.h"
  19. #include "test/cuda/benchmark.h"
  20. #include "test/cuda/fixture.h"
  21. #include "test/cuda/utils.h"
  22. #if CUDA_VERSION >= 9020
  23. namespace megdnn {
  24. namespace test {
  25. namespace {
  26. void test_multibatchsize(
  27. Handle* handle_cuda, DType A_dtype, DType B_dtype, DType C_dtype,
  28. const char* algo, const std::vector<matrix_mul::TestArg>& args,
  29. param::MatrixMul::Format format = param::MatrixMul::Format::DEFAULT,
  30. const std::function<bool(const matrix_mul::TestArg&)>& filter = {}) {
  31. Checker<MatrixMulForward> checker(handle_cuda, false);
  32. if (algo) {
  33. checker.set_before_exec_callback(AlgoChecker<MatrixMulForward>(algo));
  34. }
  35. std::unique_ptr<RNG> rng;
  36. if (A_dtype.enumv() == DTypeEnum::Float32) {
  37. rng = std::make_unique<UniformFloatRNG>(-1, 1);
  38. megdnn_assert(B_dtype.enumv() == DTypeEnum::Float32 &&
  39. C_dtype.enumv() == DTypeEnum::Float32);
  40. }
  41. megdnn_assert(rng != nullptr);
  42. struct Compare {
  43. bool is_same(dt_float32 expected, dt_float32 actual) const {
  44. return expected == actual;
  45. }
  46. };
  47. // copy rhs->lhs, lhs is 8 times of rhs
  48. auto copy = [](SyncedTensor<dt_float32, Compare>& lhs,
  49. SyncedTensor<dt_float32, Compare>& rhs) {
  50. size_t chunk = rhs.layout().span().dist_byte();
  51. size_t tot = lhs.layout().span().dist_byte();
  52. megdnn_assert(tot % chunk == 0);
  53. char* pointer_lhs = reinterpret_cast<char*>(lhs.ptr_mutable_host());
  54. const char* pointer_rhs = reinterpret_cast<const char*>(rhs.ptr_host());
  55. for (size_t i = 0; i < tot; i += chunk) {
  56. std::memcpy(pointer_lhs + i, pointer_rhs, chunk);
  57. }
  58. };
  59. using Param = param::MatrixMul;
  60. megdnn_assert(format == Param::Format::DEFAULT);
  61. for (auto&& arg : args) {
  62. megdnn_assert(arg.mask == 0x0);
  63. // make m, n, k big enough
  64. size_t m = arg.m, n = (arg.n << 3), k = (arg.k << 3);
  65. size_t m_prime = (m << 3);
  66. if (filter && filter(arg))
  67. continue;
  68. TensorShape A{m, k}, B{k, n}, C{m, n};
  69. TensorShape A_prime{m_prime, k}, C_prime{m_prime, n};
  70. SyncedTensor<dt_float32, Compare> A_tensor{handle_cuda, {A, A_dtype}},
  71. B_tensor{handle_cuda, {B, B_dtype}},
  72. C_tensor{handle_cuda, {C, C_dtype}},
  73. A_tensor_prime{handle_cuda, {A_prime, A_dtype}},
  74. C_tensor_prime{handle_cuda, {C_prime, C_dtype}},
  75. C_tensor_batch{handle_cuda, {C_prime, C_dtype}};
  76. rng->gen(A_tensor.tensornd_host());
  77. rng->gen(B_tensor.tensornd_host());
  78. copy(A_tensor_prime, A_tensor);
  79. auto opr_reference = handle_cuda->create_operator<MatrixMulForward>();
  80. {
  81. opr_reference->execution_policy().algo.reset();
  82. for (auto i : opr_reference->get_all_algorithms_info(
  83. A_tensor.layout(), B_tensor.layout(),
  84. C_tensor.layout())) {
  85. if (std::regex_match(
  86. i.name.c_str(),
  87. std::regex("(" + std::string(algo) + ")(.*)"))) {
  88. opr_reference->execution_policy().algo = i;
  89. break;
  90. }
  91. }
  92. megdnn_assert(opr_reference->execution_policy().algo.valid());
  93. size_t ws_size = opr_reference->get_workspace_in_bytes(
  94. A_tensor.layout(), B_tensor.layout(), C_tensor.layout());
  95. WorkspaceWrapper ws_reference(handle_cuda, ws_size);
  96. opr_reference->exec(
  97. A_tensor.tensornd_dev(), B_tensor.tensornd_dev(),
  98. C_tensor.tensornd_dev(), ws_reference.workspace());
  99. }
  100. copy(C_tensor_prime, C_tensor);
  101. checker.set_dtype(0, A_dtype)
  102. .set_dtype(1, B_dtype)
  103. .set_dtype(2, C_dtype)
  104. .set_epsilon(1e-6)
  105. .exect({A_tensor_prime.tensornd_host(),
  106. B_tensor.tensornd_host(),
  107. {}},
  108. {{}, {}, C_tensor_prime.tensornd_host()});
  109. {
  110. opr_reference->execution_policy().algo.reset();
  111. for (auto i : opr_reference->get_all_algorithms_info(
  112. A_tensor_prime.layout(), B_tensor.layout(),
  113. C_tensor_batch.layout())) {
  114. if (std::regex_match(
  115. i.name.c_str(),
  116. std::regex("(" + std::string(algo) + ")(.*)"))) {
  117. opr_reference->execution_policy().algo = i;
  118. break;
  119. }
  120. }
  121. megdnn_assert(opr_reference->execution_policy().algo.valid());
  122. size_t ws_size = opr_reference->get_workspace_in_bytes(
  123. A_tensor_prime.layout(), B_tensor.layout(),
  124. C_tensor_batch.layout());
  125. WorkspaceWrapper ws_reference(handle_cuda, ws_size);
  126. opr_reference->exec(
  127. A_tensor_prime.tensornd_dev(), B_tensor.tensornd_dev(),
  128. C_tensor_batch.tensornd_dev(), ws_reference.workspace());
  129. }
  130. C_tensor_batch.check_with(C_tensor_prime);
  131. }
  132. }
  133. #if MEGDNN_WITH_BENCHMARK
  134. struct BenchArgs {
  135. size_t m, n, k, mask = 0x0;
  136. };
  137. std::vector<BenchArgs> get_square_matmul_args() {
  138. std::vector<BenchArgs> args;
  139. args.emplace_back(BenchArgs{128, 128, 128});
  140. args.emplace_back(BenchArgs{256, 256, 256});
  141. args.emplace_back(BenchArgs{512, 512, 512});
  142. args.emplace_back(BenchArgs{1024, 1024, 1024});
  143. args.emplace_back(BenchArgs{2048, 2048, 2048});
  144. args.emplace_back(BenchArgs{4096, 4096, 4096});
  145. return args;
  146. }
  147. std::vector<BenchArgs> get_feat_model_args() {
  148. std::vector<BenchArgs> args;
  149. args.emplace_back(BenchArgs{2, 4096, 4096});
  150. args.emplace_back(BenchArgs{2, 1024, 6912});
  151. args.emplace_back(BenchArgs{2, 3456, 3456});
  152. args.emplace_back(BenchArgs{2, 2304, 2304});
  153. args.emplace_back(BenchArgs{1, 256, 8192});
  154. args.emplace_back(BenchArgs{2, 864, 864});
  155. args.emplace_back(BenchArgs{2, 9, 64});
  156. args.emplace_back(BenchArgs{4, 4096, 4096});
  157. args.emplace_back(BenchArgs{4, 1024, 6912});
  158. args.emplace_back(BenchArgs{4, 3456, 3456});
  159. args.emplace_back(BenchArgs{4, 2304, 2304});
  160. args.emplace_back(BenchArgs{2, 256, 8192});
  161. args.emplace_back(BenchArgs{4, 864, 864});
  162. args.emplace_back(BenchArgs{4, 9, 64});
  163. args.emplace_back(BenchArgs{8, 4096, 4096});
  164. args.emplace_back(BenchArgs{8, 1024, 6912});
  165. args.emplace_back(BenchArgs{8, 3456, 3456});
  166. args.emplace_back(BenchArgs{8, 2304, 2304});
  167. args.emplace_back(BenchArgs{4, 256, 8192});
  168. args.emplace_back(BenchArgs{8, 864, 864});
  169. args.emplace_back(BenchArgs{4, 9, 64});
  170. args.emplace_back(BenchArgs{16, 4096, 4096});
  171. args.emplace_back(BenchArgs{16, 1024, 6912});
  172. args.emplace_back(BenchArgs{16, 3456, 3456});
  173. args.emplace_back(BenchArgs{16, 2304, 2304});
  174. args.emplace_back(BenchArgs{8, 256, 8192});
  175. args.emplace_back(BenchArgs{16, 864, 864});
  176. args.emplace_back(BenchArgs{8, 9, 64});
  177. args.emplace_back(BenchArgs{32, 4096, 4096});
  178. args.emplace_back(BenchArgs{32, 1024, 6912});
  179. args.emplace_back(BenchArgs{32, 3456, 3456});
  180. args.emplace_back(BenchArgs{32, 2304, 2304});
  181. args.emplace_back(BenchArgs{16, 256, 8192});
  182. args.emplace_back(BenchArgs{32, 864, 864});
  183. args.emplace_back(BenchArgs{32, 9, 64});
  184. args.emplace_back(BenchArgs{64, 4096, 4096});
  185. args.emplace_back(BenchArgs{64, 1024, 6912});
  186. args.emplace_back(BenchArgs{64, 3456, 3456});
  187. args.emplace_back(BenchArgs{64, 2304, 2304});
  188. args.emplace_back(BenchArgs{32, 256, 8192});
  189. args.emplace_back(BenchArgs{64, 864, 864});
  190. args.emplace_back(BenchArgs{64, 9, 64});
  191. args.emplace_back(BenchArgs{128, 4096, 4096});
  192. args.emplace_back(BenchArgs{128, 1024, 6912});
  193. args.emplace_back(BenchArgs{128, 3456, 3456});
  194. args.emplace_back(BenchArgs{128, 2304, 2304});
  195. args.emplace_back(BenchArgs{64, 256, 8192});
  196. args.emplace_back(BenchArgs{128, 864, 864});
  197. args.emplace_back(BenchArgs{128, 9, 64});
  198. return args;
  199. }
  200. void benchmark_matrix_mul(
  201. Handle* handle, const std::vector<BenchArgs>& args, DType A_dtype,
  202. DType B_dtype, DType C_dtype, const char* algo = nullptr,
  203. param::MatrixMul::Format format = param::MatrixMul::Format::DEFAULT) {
  204. megdnn_assert(A_dtype.enumv() == B_dtype.enumv());
  205. CUBenchmarker<MatrixMulForward> benchmarker(handle);
  206. CUBenchmarker<MatrixMulForward> benchmarker_cublas(handle);
  207. size_t RUNS = 1000;
  208. benchmarker.set_display(false).set_times(RUNS);
  209. benchmarker_cublas.set_display(false).set_times(RUNS);
  210. benchmarker_cublas.set_before_exec_callback(
  211. AlgoChecker<MatrixMulForward>("CUBLAS"));
  212. benchmarker.set_dtype(0, A_dtype)
  213. .set_dtype(1, B_dtype)
  214. .set_dtype(2, C_dtype);
  215. benchmarker_cublas.set_dtype(0, A_dtype)
  216. .set_dtype(1, B_dtype)
  217. .set_dtype(2, C_dtype);
  218. using Param = MatrixMul::Param;
  219. for (auto&& arg : args) {
  220. size_t m = arg.m, n = arg.n, k = arg.k;
  221. Param param;
  222. param.transposeA = arg.mask & 0x1;
  223. param.transposeB = arg.mask & 0x2;
  224. param.format = format;
  225. size_t A0 = m, A1 = k, B0 = k, B1 = n;
  226. if (param.transposeA) {
  227. std::swap(A0, A1);
  228. }
  229. if (param.transposeB) {
  230. std::swap(B0, B1);
  231. }
  232. benchmarker.set_param(param);
  233. TensorShape A{A0, A1}, B{B0, B1}, C{m, n};
  234. float time_in_ms = 0.f;
  235. if (algo) {
  236. time_in_ms =
  237. algo_benchmark<MatrixMulForward, OprProxy<MatrixMulForward>,
  238. CUTimer>(benchmarker, {A, B, C}, algo) /
  239. RUNS;
  240. } else {
  241. time_in_ms = benchmarker.execs({A, B, C}) / RUNS;
  242. }
  243. benchmarker_cublas.set_param(param);
  244. auto time_in_ms_cublas = benchmarker_cublas.execs({A, B, C}) / RUNS;
  245. float flo = 2.0 * m * n * k / (1e12);
  246. printf("A=%s, B=%s, C=%s, time(algo=%s)=%.2f %.2fTops, "
  247. "time(cublas)=%.2f %.2fTops, "
  248. "perf(algo=%s)/perf(cublas)=%.2f\n",
  249. A.to_string().c_str(), B.to_string().c_str(),
  250. C.to_string().c_str(), algo, time_in_ms,
  251. (flo / (time_in_ms * 1e-3)), time_in_ms_cublas,
  252. (flo / (time_in_ms_cublas * 1e-3)), algo,
  253. time_in_ms_cublas / time_in_ms);
  254. }
  255. }
  256. #endif
  257. } // namespace
  258. TEST_F(CUDA, CUTLASS_GEMM_MULTI_BATCHSIZE) {
  259. auto args = matrix_mul::get_matmul_args_no_mask();
  260. test_multibatchsize(handle_cuda(), dtype::Float32(), dtype::Float32(),
  261. dtype::Float32(),
  262. "CUTLASS_FLOAT32_SIMT_128X128X8_32X64X8", args,
  263. param::MatrixMul::Format::DEFAULT);
  264. }
  265. #define MEGDNN_FOREACH_CUTLASS_KERNEL(cb) \
  266. cb(1, 64, 256, 8, 32, 64, 8); \
  267. cb(2, 256, 64, 8, 64, 32, 8); \
  268. cb(3, 32, 256, 8, 16, 64, 8); \
  269. cb(4, 256, 32, 8, 64, 16, 8); \
  270. cb(5, 128, 128, 8, 32, 64, 8); \
  271. cb(6, 128, 64, 8, 64, 32, 8); \
  272. cb(7, 64, 128, 8, 32, 64, 8); \
  273. cb(8, 128, 32, 8, 64, 32, 8); \
  274. cb(9, 32, 128, 8, 32, 64, 8); \
  275. cb(10, 64, 64, 8, 32, 64, 8); \
  276. cb(11, 32, 64, 8, 32, 64, 8); \
  277. cb(12, 64, 32, 8, 64, 32, 8); \
  278. cb(13, 32, 32, 8, 32, 32, 8); \
  279. cb(14, 8, 32, 8, 8, 32, 8); \
  280. cb(15, 16, 32, 8, 16, 32, 8); \
  281. cb(16, 16, 64, 8, 16, 64, 8); \
  282. cb(17, 16, 128, 8, 16, 64, 8);
  283. #define cb(name, tbm, tbn, tbk, wm, wn, wk) \
  284. TEST_F(CUDA, CUTLASS_GEMM_##name) { \
  285. matrix_mul::check_matrix_mul<MatrixMulForward>( \
  286. dtype::Float32(), dtype::Float32(), dtype::Float32(), \
  287. handle_cuda(), \
  288. "CUTLASS_FLOAT32_SIMT_" #tbm "X" #tbn "X" #tbk "_" #wm "X" #wn \
  289. "X" #wk); \
  290. }
  291. MEGDNN_FOREACH_CUTLASS_KERNEL(cb)
  292. #undef cb
  293. #undef MEGDNN_FOREACH_CUTLASS_KERNEL
  294. #if MEGDNN_WITH_BENCHMARK
  295. TEST_F(CUDA, BENCHMARK_CUTLASS_MATMUL) {
  296. benchmark_matrix_mul(handle_cuda(), get_square_matmul_args(),
  297. dtype::Float32(), dtype::Float32(), dtype::Float32(),
  298. "CUTLASS_FLOAT32_SIMT");
  299. }
  300. TEST_F(CUDA, BENCHMARK_CUTLASS_MATMUL_FEAT) {
  301. benchmark_matrix_mul(handle_cuda(), get_feat_model_args(), dtype::Float32(),
  302. dtype::Float32(), dtype::Float32(),
  303. "CUTLASS_FLOAT32_SIMT");
  304. }
  305. #endif
  306. } // namespace test
  307. } // namespace megdnn
  308. #endif
  309. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台