You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

batched_matrix_mul.cpp 9.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259
  1. #include <vector>
  2. #include "test/cuda/fixture.h"
  3. #include "test/common/checker.h"
  4. #include "test/common/matrix_mul.h"
  5. #include "test/common/rng.h"
  6. #include "test/cuda/benchmark.h"
  7. #include "test/cuda/utils.h"
  8. using namespace megdnn;
  9. using namespace test;
  10. #define F32_TEST_PART(x, algo) \
  11. matrix_mul::check_batched_matrix_mul( \
  12. dtype::Float32{}, dtype::Float32{}, {}, handle_cuda(), algo, 1e-3, \
  13. matrix_mul::get_batched_matmul_args_mask(x))
  14. TEST_F(CUDA, BATCHED_MATRIX_MUL_F32_PART1) {
  15. F32_TEST_PART(0, "CUBLAS");
  16. }
  17. TEST_F(CUDA, BATCHED_MATRIX_MUL_F32_PART2) {
  18. F32_TEST_PART(1, "CUBLAS");
  19. }
  20. TEST_F(CUDA, BATCHED_MATRIX_MUL_F32_PART3) {
  21. F32_TEST_PART(2, "CUBLAS");
  22. }
  23. TEST_F(CUDA, BATCHED_MATRIX_MUL_F32_PART4) {
  24. F32_TEST_PART(3, "CUBLAS");
  25. }
  26. TEST_F(CUDA, BATCHED_MATRIX_MUL_LT_F32_PART1) {
  27. require_compute_capability(7, 0);
  28. F32_TEST_PART(0, "CUBLAS_LT");
  29. }
  30. TEST_F(CUDA, BATCHED_MATRIX_MUL_LT_F32_PART2) {
  31. require_compute_capability(7, 0);
  32. F32_TEST_PART(1, "CUBLAS_LT");
  33. }
  34. TEST_F(CUDA, BATCHED_MATRIX_MUL_LT_F32_PART3) {
  35. require_compute_capability(7, 0);
  36. F32_TEST_PART(2, "CUBLAS_LT");
  37. }
  38. TEST_F(CUDA, BATCHED_MATRIX_MUL_LT_F32_PART4) {
  39. require_compute_capability(7, 0);
  40. F32_TEST_PART(3, "CUBLAS_LT");
  41. }
  42. #undef F32_TEST_PART
  43. TEST_F(CUDA, BATCHED_MATRIX_MUL_F32_BROADCAST_PART1) {
  44. matrix_mul::check_batched_matrix_mul(
  45. dtype::Float32{}, dtype::Float32{}, {}, handle_cuda(), "CUBLAS", 1e-3,
  46. matrix_mul::get_batched_matmul_broadcast_args_mask(0));
  47. }
  48. TEST_F(CUDA, BATCHED_MATRIX_MUL_F32_BROADCAST_PART2) {
  49. matrix_mul::check_batched_matrix_mul(
  50. dtype::Float32{}, dtype::Float32{}, {}, handle_cuda(), "CUBLAS", 1e-3,
  51. matrix_mul::get_batched_matmul_broadcast_args_mask(1));
  52. }
  53. TEST_F(CUDA, BATCHED_MATRIX_MUL_F32_BROADCAST_PART3) {
  54. matrix_mul::check_batched_matrix_mul(
  55. dtype::Float32{}, dtype::Float32{}, {}, handle_cuda(), "CUBLAS", 1e-3,
  56. matrix_mul::get_batched_matmul_broadcast_args_mask(2));
  57. }
  58. TEST_F(CUDA, BATCHED_MATRIX_MUL_F32_BROADCAST_PART4) {
  59. matrix_mul::check_batched_matrix_mul(
  60. dtype::Float32{}, dtype::Float32{}, {}, handle_cuda(), "CUBLAS", 1e-3,
  61. matrix_mul::get_batched_matmul_broadcast_args_mask(3));
  62. }
  63. TEST_F(CUDA, BATCHED_MATRIX_MUL_F32_BRUTE_FORCE_PART1) {
  64. matrix_mul::check_batched_matrix_mul(
  65. dtype::Float32{}, dtype::Float32{}, {}, handle_cuda(),
  66. ExecutionPolicyAlgoName{"BRUTE_FORCE", {{"CUBLAS", {}}}}, 1e-3,
  67. matrix_mul::get_batched_matmul_args_mask(0));
  68. }
  69. TEST_F(CUDA, BATCHED_MATRIX_MUL_F32_BRUTE_FORCE_PART2) {
  70. matrix_mul::check_batched_matrix_mul(
  71. dtype::Float32{}, dtype::Float32{}, {}, handle_cuda(),
  72. ExecutionPolicyAlgoName{"BRUTE_FORCE", {{"CUBLAS", {}}}}, 1e-3,
  73. matrix_mul::get_batched_matmul_args_mask(1));
  74. }
  75. TEST_F(CUDA, BATCHED_MATRIX_MUL_F32_BRUTE_FORCE_PART3) {
  76. matrix_mul::check_batched_matrix_mul(
  77. dtype::Float32{}, dtype::Float32{}, {}, handle_cuda(),
  78. ExecutionPolicyAlgoName{"BRUTE_FORCE", {{"CUBLAS", {}}}}, 1e-3,
  79. matrix_mul::get_batched_matmul_args_mask(2));
  80. }
  81. TEST_F(CUDA, BATCHED_MATRIX_MUL_F32_BRUTE_FORCE_PART4) {
  82. matrix_mul::check_batched_matrix_mul(
  83. dtype::Float32{}, dtype::Float32{}, {}, handle_cuda(),
  84. ExecutionPolicyAlgoName{"BRUTE_FORCE", {{"CUBLAS", {}}}}, 1e-3,
  85. matrix_mul::get_batched_matmul_args_mask(3));
  86. }
  87. TEST_F(CUDA, BATCHED_MATRIX_MUL_F16_PART1) {
  88. require_compute_capability(6, 0);
  89. matrix_mul::check_batched_matrix_mul(
  90. dtype::Float16{}, dtype::Float16{}, {}, handle_cuda(), "CUBLAS", 2e-2,
  91. matrix_mul::get_batched_matmul_args_mask(0));
  92. }
  93. TEST_F(CUDA, BATCHED_MATRIX_MUL_F16_PART2) {
  94. require_compute_capability(6, 0);
  95. matrix_mul::check_batched_matrix_mul(
  96. dtype::Float16{}, dtype::Float16{}, {}, handle_cuda(), "CUBLAS", 2e-2,
  97. matrix_mul::get_batched_matmul_args_mask(1));
  98. }
  99. TEST_F(CUDA, BATCHED_MATRIX_MUL_F16_PART3) {
  100. require_compute_capability(6, 0);
  101. matrix_mul::check_batched_matrix_mul(
  102. dtype::Float16{}, dtype::Float16{}, {}, handle_cuda(), "CUBLAS", 2e-2,
  103. matrix_mul::get_batched_matmul_args_mask(2));
  104. }
  105. TEST_F(CUDA, BATCHED_MATRIX_MUL_F16_PART4) {
  106. require_compute_capability(6, 0);
  107. matrix_mul::check_batched_matrix_mul(
  108. dtype::Float16{}, dtype::Float16{}, {}, handle_cuda(), "CUBLAS", 2e-2,
  109. matrix_mul::get_batched_matmul_args_mask(3));
  110. }
  111. TEST_F(CUDA, BATCHED_MATRIX_MUL_CUBLASLT_F16_PART1) {
  112. require_compute_capability(7, 0);
  113. matrix_mul::check_batched_matrix_mul(
  114. dtype::Float16{}, dtype::Float16{}, {}, handle_cuda(), "CUBLAS_LT", 2e-2,
  115. matrix_mul::get_batched_matmul_args_mask(0));
  116. }
  117. TEST_F(CUDA, BATCHED_MATRIX_MUL_CUBLASLT_F16_PART2) {
  118. require_compute_capability(7, 0);
  119. matrix_mul::check_batched_matrix_mul(
  120. dtype::Float16{}, dtype::Float16{}, {}, handle_cuda(), "CUBLAS_LT", 2e-2,
  121. matrix_mul::get_batched_matmul_args_mask(1));
  122. }
  123. TEST_F(CUDA, BATCHED_MATRIX_MUL_CUBLASLT_F16_PART3) {
  124. require_compute_capability(7, 0);
  125. matrix_mul::check_batched_matrix_mul(
  126. dtype::Float16{}, dtype::Float16{}, {}, handle_cuda(), "CUBLAS_LT", 2e-2,
  127. matrix_mul::get_batched_matmul_args_mask(2));
  128. }
  129. TEST_F(CUDA, BATCHED_MATRIX_MUL_CUBLASLT_F16_PART4) {
  130. require_compute_capability(7, 0);
  131. matrix_mul::check_batched_matrix_mul(
  132. dtype::Float16{}, dtype::Float16{}, {}, handle_cuda(), "CUBLAS_LT", 2e-2,
  133. matrix_mul::get_batched_matmul_args_mask(3));
  134. }
  135. TEST_F(CUDA, BATCHED_MATRIX_MUL_CUBLASLT_INT8) {
  136. require_compute_capability(7, 5);
  137. matrix_mul::check_batched_matrix_mul(
  138. dtype::Int8{}, dtype::Int8{}, {}, handle_cuda(), "CUBLAS_LT", 1e-3,
  139. matrix_mul::get_batched_matmul_args_cublaslt());
  140. }
  141. TEST_F(CUDA, BATCHED_MATRIX_MUL_CUBLASLT_QS8) {
  142. require_compute_capability(7, 5);
  143. matrix_mul::check_batched_matrix_mul(
  144. dtype::QuantizedS8(1.2f), dtype::QuantizedS8(1.3f), {}, handle_cuda(),
  145. "CUBLAS_LT", 1e-3, matrix_mul::get_batched_matmul_args_cublaslt());
  146. }
  147. TEST_F(CUDA, BATCHED_MATRIX_MUL_QS8) {
  148. matrix_mul::check_batched_matrix_mul(
  149. dtype::QuantizedS8(1.2f), dtype::QuantizedS8(1.3f), {}, handle_cuda());
  150. }
  151. TEST_F(CUDA, BATCHED_MATRIX_MUL_INT8x8x32) {
  152. require_compute_capability(6, 1);
  153. matrix_mul::check_batched_matrix_mul(
  154. dtype::Int8{}, dtype::Int8{}, dtype::Int32{}, handle_cuda(), "INT8x8x32",
  155. 1e-2, matrix_mul::get_batched_matmul_args_int8x8x32());
  156. }
  157. #if MEGDNN_WITH_BENCHMARK
  158. TEST_F(CUDA, BATCHED_MATMUL_8x8x32_BENCHMARK) {
  159. require_compute_capability(6, 1);
  160. auto run = [&](bool transA, bool transB, size_t m, size_t n, size_t k,
  161. const ExecutionPolicyAlgoName& algo1,
  162. const ExecutionPolicyAlgoName& algo2, size_t b = 128) {
  163. size_t RUNS = 10;
  164. CUBenchmarker<BatchedMatrixMul> bencher1(handle_cuda());
  165. bencher1.set_display(false).set_times(RUNS);
  166. bencher1.set_before_exec_callback(AlgoChecker<BatchedMatrixMul>(algo1));
  167. CUBenchmarker<BatchedMatrixMul> bencher2(handle_cuda());
  168. bencher2.set_display(false).set_times(RUNS);
  169. bencher2.set_before_exec_callback(AlgoChecker<BatchedMatrixMul>(algo2));
  170. using Param = MatrixMul::Param;
  171. DType stype = dtype::Int8(), dtype = dtype::Int32();
  172. Param param;
  173. UniformIntRNG rng(-128, 127);
  174. param.transposeA = transA;
  175. param.transposeB = transB;
  176. TensorShape A, B;
  177. if (param.transposeA)
  178. A = TensorShape{b, k, m};
  179. else
  180. A = TensorShape{b, m, k};
  181. if (param.transposeB)
  182. B = TensorShape{b, n, k};
  183. else
  184. B = TensorShape{b, k, n};
  185. auto flo = (double)m * n * k * b * 2;
  186. bencher1.set_param(param)
  187. .set_dtype(0, stype)
  188. .set_dtype(1, stype)
  189. .set_dtype(2, dtype)
  190. .set_rng(0, &rng)
  191. .set_rng(1, &rng);
  192. auto time1 = bencher1.execs({A, B, {}}) / RUNS;
  193. auto flops1 = flo / time1 / 1e6;
  194. bencher2.set_param(param)
  195. .set_dtype(0, stype)
  196. .set_dtype(1, stype)
  197. .set_dtype(2, dtype)
  198. .set_rng(0, &rng)
  199. .set_rng(1, &rng);
  200. auto time2 = bencher2.execs({A, B, {}}) / RUNS;
  201. auto flops2 = flo / time2 / 1e6;
  202. printf("trA: %d, trB: %d, m: %ld, n: %ld, k: %ld, b: %ld, speedup: %s "
  203. "/ "
  204. "%s %.3f\n",
  205. transA, transB, m, n, k, b, algo1.name.c_str(), algo2.name.c_str(),
  206. flops1 / flops2);
  207. };
  208. for (bool transA : {0, 1})
  209. for (bool transB : {0, 1}) {
  210. run(transA, transB, 128, 576, 128, "INT8x8x32",
  211. ExecutionPolicyAlgoName{"BRUTE_FORCE", {{"CUBLAS", {}}}});
  212. run(transA, transB, 256, 144, 256, "INT8x8x32",
  213. ExecutionPolicyAlgoName{"BRUTE_FORCE", {{"CUBLAS", {}}}});
  214. run(transA, transB, 512, 36, 512, "INT8x8x32",
  215. ExecutionPolicyAlgoName{"BRUTE_FORCE", {{"CUBLAS", {}}}});
  216. run(transA, transB, 1024, 8, 1024, "INT8x8x32",
  217. ExecutionPolicyAlgoName{"BRUTE_FORCE", {{"CUBLAS", {}}}});
  218. }
  219. }
  220. #endif
  221. // vim: syntax=cpp.doxygen