You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

batched_matrix_mul.cpp 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315
  1. #include <vector>
  2. #include "test/cuda/fixture.h"
  3. #include "test/common/checker.h"
  4. #include "test/common/matrix_mul.h"
  5. #include "test/common/rng.h"
  6. #include "test/cuda/benchmark.h"
  7. #include "test/cuda/utils.h"
  8. using namespace megdnn;
  9. using namespace test;
  10. #define F32_TEST_PART(x, algo) \
  11. matrix_mul::check_batched_matrix_mul( \
  12. dtype::Float32{}, dtype::Float32{}, {}, handle_cuda(), algo, 1e-3, \
  13. matrix_mul::get_batched_matmul_args_mask(x))
  14. TEST_F(CUDA, BATCHED_MATRIX_MUL_F32_PART1) {
  15. F32_TEST_PART(0, "CUBLAS");
  16. }
  17. TEST_F(CUDA, BATCHED_MATRIX_MUL_F32_PART2) {
  18. F32_TEST_PART(1, "CUBLAS");
  19. }
  20. TEST_F(CUDA, BATCHED_MATRIX_MUL_F32_PART3) {
  21. F32_TEST_PART(2, "CUBLAS");
  22. }
  23. TEST_F(CUDA, BATCHED_MATRIX_MUL_F32_PART4) {
  24. F32_TEST_PART(3, "CUBLAS");
  25. }
  26. TEST_F(CUDA, BATCHED_MATRIX_MUL_LT_F32_PART1) {
  27. require_compute_capability(7, 0);
  28. F32_TEST_PART(0, "CUBLAS_LT");
  29. }
  30. TEST_F(CUDA, BATCHED_MATRIX_MUL_LT_F32_PART2) {
  31. require_compute_capability(7, 0);
  32. F32_TEST_PART(1, "CUBLAS_LT");
  33. }
  34. TEST_F(CUDA, BATCHED_MATRIX_MUL_LT_F32_PART3) {
  35. require_compute_capability(7, 0);
  36. F32_TEST_PART(2, "CUBLAS_LT");
  37. }
  38. TEST_F(CUDA, BATCHED_MATRIX_MUL_LT_F32_PART4) {
  39. require_compute_capability(7, 0);
  40. F32_TEST_PART(3, "CUBLAS_LT");
  41. }
  42. #undef F32_TEST_PART
  43. TEST_F(CUDA, BATCHED_MATRIX_MUL_F32_BROADCAST_PART1) {
  44. matrix_mul::check_batched_matrix_mul(
  45. dtype::Float32{}, dtype::Float32{}, {}, handle_cuda(), "CUBLAS", 1e-3,
  46. matrix_mul::get_batched_matmul_broadcast_args_mask(0));
  47. }
  48. TEST_F(CUDA, BATCHED_MATRIX_MUL_F32_BROADCAST_PART2) {
  49. matrix_mul::check_batched_matrix_mul(
  50. dtype::Float32{}, dtype::Float32{}, {}, handle_cuda(), "CUBLAS", 1e-3,
  51. matrix_mul::get_batched_matmul_broadcast_args_mask(1));
  52. }
  53. TEST_F(CUDA, BATCHED_MATRIX_MUL_F32_BROADCAST_PART3) {
  54. matrix_mul::check_batched_matrix_mul(
  55. dtype::Float32{}, dtype::Float32{}, {}, handle_cuda(), "CUBLAS", 1e-3,
  56. matrix_mul::get_batched_matmul_broadcast_args_mask(2));
  57. }
  58. TEST_F(CUDA, BATCHED_MATRIX_MUL_F32_BROADCAST_PART4) {
  59. matrix_mul::check_batched_matrix_mul(
  60. dtype::Float32{}, dtype::Float32{}, {}, handle_cuda(), "CUBLAS", 1e-3,
  61. matrix_mul::get_batched_matmul_broadcast_args_mask(3));
  62. }
  63. TEST_F(CUDA, BATCHED_MATRIX_MUL_F32_BRUTE_FORCE_PART1) {
  64. matrix_mul::check_batched_matrix_mul(
  65. dtype::Float32{}, dtype::Float32{}, {}, handle_cuda(),
  66. ExecutionPolicyAlgoName{"BRUTE_FORCE", {{"CUBLAS", {}}}}, 1e-3,
  67. matrix_mul::get_batched_matmul_args_mask(0));
  68. }
  69. TEST_F(CUDA, BATCHED_MATRIX_MUL_F32_BRUTE_FORCE_PART2) {
  70. matrix_mul::check_batched_matrix_mul(
  71. dtype::Float32{}, dtype::Float32{}, {}, handle_cuda(),
  72. ExecutionPolicyAlgoName{"BRUTE_FORCE", {{"CUBLAS", {}}}}, 1e-3,
  73. matrix_mul::get_batched_matmul_args_mask(1));
  74. }
  75. TEST_F(CUDA, BATCHED_MATRIX_MUL_F32_BRUTE_FORCE_PART3) {
  76. matrix_mul::check_batched_matrix_mul(
  77. dtype::Float32{}, dtype::Float32{}, {}, handle_cuda(),
  78. ExecutionPolicyAlgoName{"BRUTE_FORCE", {{"CUBLAS", {}}}}, 1e-3,
  79. matrix_mul::get_batched_matmul_args_mask(2));
  80. }
  81. TEST_F(CUDA, BATCHED_MATRIX_MUL_F32_BRUTE_FORCE_PART4) {
  82. matrix_mul::check_batched_matrix_mul(
  83. dtype::Float32{}, dtype::Float32{}, {}, handle_cuda(),
  84. ExecutionPolicyAlgoName{"BRUTE_FORCE", {{"CUBLAS", {}}}}, 1e-3,
  85. matrix_mul::get_batched_matmul_args_mask(3));
  86. }
  87. TEST_F(CUDA, BATCHED_MATRIX_MUL_F32_NAIVE_PART0) {
  88. matrix_mul::check_batched_matrix_mul(
  89. dtype::Float32{}, dtype::Float32{}, {}, handle_cuda(),
  90. ExecutionPolicyAlgoName{"NAIVE_BMM"}, 1e-5,
  91. matrix_mul::get_batched_matmul_args_mask(0));
  92. }
  93. TEST_F(CUDA, BATCHED_MATRIX_MUL_F32_NAIVE_PART1) {
  94. matrix_mul::check_batched_matrix_mul(
  95. dtype::Float32{}, dtype::Float32{}, {}, handle_cuda(),
  96. ExecutionPolicyAlgoName{"NAIVE_BMM"}, 1e-5,
  97. matrix_mul::get_batched_matmul_args_mask(1));
  98. }
  99. TEST_F(CUDA, BATCHED_MATRIX_MUL_F32_NAIVE_PART2) {
  100. matrix_mul::check_batched_matrix_mul(
  101. dtype::Float32{}, dtype::Float32{}, {}, handle_cuda(),
  102. ExecutionPolicyAlgoName{"NAIVE_BMM"}, 1e-5,
  103. matrix_mul::get_batched_matmul_args_mask(2));
  104. }
  105. TEST_F(CUDA, BATCHED_MATRIX_MUL_F32_NAIVE_PART3) {
  106. matrix_mul::check_batched_matrix_mul(
  107. dtype::Float32{}, dtype::Float32{}, {}, handle_cuda(),
  108. ExecutionPolicyAlgoName{"NAIVE_BMM"}, 1e-5,
  109. matrix_mul::get_batched_matmul_args_mask(3));
  110. }
  111. TEST_F(CUDA, BATCHED_MATRIX_MUL_F16_NAIVE_PART0) {
  112. matrix_mul::check_batched_matrix_mul(
  113. dtype::Float16{}, dtype::Float16{}, {}, handle_cuda(),
  114. ExecutionPolicyAlgoName{"NAIVE_BMM"}, 1e-5,
  115. matrix_mul::get_batched_matmul_args_mask(0));
  116. }
  117. TEST_F(CUDA, BATCHED_MATRIX_MUL_F16_NAIVE_PART1) {
  118. matrix_mul::check_batched_matrix_mul(
  119. dtype::Float16{}, dtype::Float16{}, {}, handle_cuda(),
  120. ExecutionPolicyAlgoName{"NAIVE_BMM"}, 1e-5,
  121. matrix_mul::get_batched_matmul_args_mask(1));
  122. }
  123. TEST_F(CUDA, BATCHED_MATRIX_MUL_F16_NAIVE_PART2) {
  124. matrix_mul::check_batched_matrix_mul(
  125. dtype::Float16{}, dtype::Float16{}, {}, handle_cuda(),
  126. ExecutionPolicyAlgoName{"NAIVE_BMM"}, 1e-5,
  127. matrix_mul::get_batched_matmul_args_mask(2));
  128. }
  129. TEST_F(CUDA, BATCHED_MATRIX_MUL_F16_NAIVE_PART3) {
  130. matrix_mul::check_batched_matrix_mul(
  131. dtype::Float16{}, dtype::Float16{}, {}, handle_cuda(),
  132. ExecutionPolicyAlgoName{"NAIVE_BMM"}, 1e-5,
  133. matrix_mul::get_batched_matmul_args_mask(3));
  134. }
  135. TEST_F(CUDA, BATCHED_MATRIX_MUL_F16_PART1) {
  136. require_compute_capability(6, 0);
  137. matrix_mul::check_batched_matrix_mul(
  138. dtype::Float16{}, dtype::Float16{}, {}, handle_cuda(), "CUBLAS", 2e-2,
  139. matrix_mul::get_batched_matmul_args_mask(0));
  140. }
  141. TEST_F(CUDA, BATCHED_MATRIX_MUL_F16_PART2) {
  142. require_compute_capability(6, 0);
  143. matrix_mul::check_batched_matrix_mul(
  144. dtype::Float16{}, dtype::Float16{}, {}, handle_cuda(), "CUBLAS", 2e-2,
  145. matrix_mul::get_batched_matmul_args_mask(1));
  146. }
  147. TEST_F(CUDA, BATCHED_MATRIX_MUL_F16_PART3) {
  148. require_compute_capability(6, 0);
  149. matrix_mul::check_batched_matrix_mul(
  150. dtype::Float16{}, dtype::Float16{}, {}, handle_cuda(), "CUBLAS", 2e-2,
  151. matrix_mul::get_batched_matmul_args_mask(2));
  152. }
  153. TEST_F(CUDA, BATCHED_MATRIX_MUL_F16_PART4) {
  154. require_compute_capability(6, 0);
  155. matrix_mul::check_batched_matrix_mul(
  156. dtype::Float16{}, dtype::Float16{}, {}, handle_cuda(), "CUBLAS", 2e-2,
  157. matrix_mul::get_batched_matmul_args_mask(3));
  158. }
  159. TEST_F(CUDA, BATCHED_MATRIX_MUL_CUBLASLT_F16_PART1) {
  160. require_compute_capability(7, 0);
  161. matrix_mul::check_batched_matrix_mul(
  162. dtype::Float16{}, dtype::Float16{}, {}, handle_cuda(), "CUBLAS_LT", 2e-2,
  163. matrix_mul::get_batched_matmul_args_mask(0));
  164. }
  165. TEST_F(CUDA, BATCHED_MATRIX_MUL_CUBLASLT_F16_PART2) {
  166. require_compute_capability(7, 0);
  167. matrix_mul::check_batched_matrix_mul(
  168. dtype::Float16{}, dtype::Float16{}, {}, handle_cuda(), "CUBLAS_LT", 2e-2,
  169. matrix_mul::get_batched_matmul_args_mask(1));
  170. }
  171. TEST_F(CUDA, BATCHED_MATRIX_MUL_CUBLASLT_F16_PART3) {
  172. require_compute_capability(7, 0);
  173. matrix_mul::check_batched_matrix_mul(
  174. dtype::Float16{}, dtype::Float16{}, {}, handle_cuda(), "CUBLAS_LT", 2e-2,
  175. matrix_mul::get_batched_matmul_args_mask(2));
  176. }
  177. TEST_F(CUDA, BATCHED_MATRIX_MUL_CUBLASLT_F16_PART4) {
  178. require_compute_capability(7, 0);
  179. matrix_mul::check_batched_matrix_mul(
  180. dtype::Float16{}, dtype::Float16{}, {}, handle_cuda(), "CUBLAS_LT", 2e-2,
  181. matrix_mul::get_batched_matmul_args_mask(3));
  182. }
  183. TEST_F(CUDA, BATCHED_MATRIX_MUL_CUBLASLT_INT8) {
  184. require_compute_capability(7, 5);
  185. matrix_mul::check_batched_matrix_mul(
  186. dtype::Int8{}, dtype::Int8{}, {}, handle_cuda(), "CUBLAS_LT", 1e-3,
  187. matrix_mul::get_batched_matmul_args_cublaslt());
  188. }
  189. TEST_F(CUDA, BATCHED_MATRIX_MUL_CUBLASLT_QS8) {
  190. require_compute_capability(7, 5);
  191. matrix_mul::check_batched_matrix_mul(
  192. dtype::QuantizedS8(1.2f), dtype::QuantizedS8(1.3f), {}, handle_cuda(),
  193. "CUBLAS_LT", 1e-3, matrix_mul::get_batched_matmul_args_cublaslt());
  194. }
  195. TEST_F(CUDA, BATCHED_MATRIX_MUL_QS8) {
  196. matrix_mul::check_batched_matrix_mul(
  197. dtype::QuantizedS8(1.2f), dtype::QuantizedS8(1.3f), {}, handle_cuda());
  198. }
  199. TEST_F(CUDA, BATCHED_MATRIX_MUL_INT8x8x32) {
  200. require_compute_capability(6, 1);
  201. matrix_mul::check_batched_matrix_mul(
  202. dtype::Int8{}, dtype::Int8{}, dtype::Int32{}, handle_cuda(), "INT8x8x32",
  203. 1e-2, matrix_mul::get_batched_matmul_args_int8x8x32());
  204. }
  205. #if MEGDNN_WITH_BENCHMARK
  206. TEST_F(CUDA, BATCHED_MATMUL_8x8x32_BENCHMARK) {
  207. require_compute_capability(6, 1);
  208. auto run = [&](bool transA, bool transB, size_t m, size_t n, size_t k,
  209. const ExecutionPolicyAlgoName& algo1,
  210. const ExecutionPolicyAlgoName& algo2, size_t b = 128) {
  211. size_t RUNS = 10;
  212. CUBenchmarker<BatchedMatrixMul> bencher1(handle_cuda());
  213. bencher1.set_display(false).set_times(RUNS);
  214. bencher1.set_before_exec_callback(AlgoChecker<BatchedMatrixMul>(algo1));
  215. CUBenchmarker<BatchedMatrixMul> bencher2(handle_cuda());
  216. bencher2.set_display(false).set_times(RUNS);
  217. bencher2.set_before_exec_callback(AlgoChecker<BatchedMatrixMul>(algo2));
  218. using Param = MatrixMul::Param;
  219. DType stype = dtype::Int8(), dtype = dtype::Int32();
  220. Param param;
  221. UniformIntRNG rng(-128, 127);
  222. param.transposeA = transA;
  223. param.transposeB = transB;
  224. TensorShape A, B;
  225. if (param.transposeA)
  226. A = TensorShape{b, k, m};
  227. else
  228. A = TensorShape{b, m, k};
  229. if (param.transposeB)
  230. B = TensorShape{b, n, k};
  231. else
  232. B = TensorShape{b, k, n};
  233. auto flo = (double)m * n * k * b * 2;
  234. bencher1.set_param(param)
  235. .set_dtype(0, stype)
  236. .set_dtype(1, stype)
  237. .set_dtype(2, dtype)
  238. .set_rng(0, &rng)
  239. .set_rng(1, &rng);
  240. auto time1 = bencher1.execs({A, B, {}}) / RUNS;
  241. auto flops1 = flo / time1 / 1e6;
  242. bencher2.set_param(param)
  243. .set_dtype(0, stype)
  244. .set_dtype(1, stype)
  245. .set_dtype(2, dtype)
  246. .set_rng(0, &rng)
  247. .set_rng(1, &rng);
  248. auto time2 = bencher2.execs({A, B, {}}) / RUNS;
  249. auto flops2 = flo / time2 / 1e6;
  250. printf("trA: %d, trB: %d, m: %ld, n: %ld, k: %ld, b: %ld, speedup: %s "
  251. "/ "
  252. "%s %.3f\n",
  253. transA, transB, m, n, k, b, algo1.name.c_str(), algo2.name.c_str(),
  254. flops1 / flops2);
  255. };
  256. for (bool transA : {0, 1})
  257. for (bool transB : {0, 1}) {
  258. run(transA, transB, 128, 576, 128, "INT8x8x32",
  259. ExecutionPolicyAlgoName{"BRUTE_FORCE", {{"CUBLAS", {}}}});
  260. run(transA, transB, 256, 144, 256, "INT8x8x32",
  261. ExecutionPolicyAlgoName{"BRUTE_FORCE", {{"CUBLAS", {}}}});
  262. run(transA, transB, 512, 36, 512, "INT8x8x32",
  263. ExecutionPolicyAlgoName{"BRUTE_FORCE", {{"CUBLAS", {}}}});
  264. run(transA, transB, 1024, 8, 1024, "INT8x8x32",
  265. ExecutionPolicyAlgoName{"BRUTE_FORCE", {{"CUBLAS", {}}}});
  266. }
  267. }
  268. #endif
  269. // vim: syntax=cpp.doxygen