You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

batched_matrix_mul.cpp 10 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271
  1. /**
  2. * \file dnn/test/cuda/batched_matrix_mul.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include <vector>
  12. #include "test/cuda/fixture.h"
  13. #include "test/common/checker.h"
  14. #include "test/common/matrix_mul.h"
  15. #include "test/common/rng.h"
  16. #include "test/cuda/benchmark.h"
  17. #include "test/cuda/utils.h"
  18. using namespace megdnn;
  19. using namespace test;
  20. #define F32_TEST_PART(x, algo) \
  21. matrix_mul::check_batched_matrix_mul( \
  22. dtype::Float32{}, dtype::Float32{}, {}, handle_cuda(), algo, 1e-3, \
  23. matrix_mul::get_batched_matmul_args_mask(x))
  24. TEST_F(CUDA, BATCHED_MATRIX_MUL_F32_PART1) {
  25. F32_TEST_PART(0, "CUBLAS");
  26. }
  27. TEST_F(CUDA, BATCHED_MATRIX_MUL_F32_PART2) {
  28. F32_TEST_PART(1, "CUBLAS");
  29. }
  30. TEST_F(CUDA, BATCHED_MATRIX_MUL_F32_PART3) {
  31. F32_TEST_PART(2, "CUBLAS");
  32. }
  33. TEST_F(CUDA, BATCHED_MATRIX_MUL_F32_PART4) {
  34. F32_TEST_PART(3, "CUBLAS");
  35. }
  36. TEST_F(CUDA, BATCHED_MATRIX_MUL_LT_F32_PART1) {
  37. require_compute_capability(7, 0);
  38. F32_TEST_PART(0, "CUBLAS_LT");
  39. }
  40. TEST_F(CUDA, BATCHED_MATRIX_MUL_LT_F32_PART2) {
  41. require_compute_capability(7, 0);
  42. F32_TEST_PART(1, "CUBLAS_LT");
  43. }
  44. TEST_F(CUDA, BATCHED_MATRIX_MUL_LT_F32_PART3) {
  45. require_compute_capability(7, 0);
  46. F32_TEST_PART(2, "CUBLAS_LT");
  47. }
  48. TEST_F(CUDA, BATCHED_MATRIX_MUL_LT_F32_PART4) {
  49. require_compute_capability(7, 0);
  50. F32_TEST_PART(3, "CUBLAS_LT");
  51. }
  52. #undef F32_TEST_PART
  53. TEST_F(CUDA, BATCHED_MATRIX_MUL_F32_BROADCAST_PART1){
  54. matrix_mul::check_batched_matrix_mul(
  55. dtype::Float32{}, dtype::Float32{}, {}, handle_cuda(), "CUBLAS",
  56. 1e-3, matrix_mul::get_batched_matmul_broadcast_args_mask(0));
  57. }
  58. TEST_F(CUDA, BATCHED_MATRIX_MUL_F32_BROADCAST_PART2){
  59. matrix_mul::check_batched_matrix_mul(
  60. dtype::Float32{}, dtype::Float32{}, {}, handle_cuda(), "CUBLAS",
  61. 1e-3, matrix_mul::get_batched_matmul_broadcast_args_mask(1));
  62. }
  63. TEST_F(CUDA, BATCHED_MATRIX_MUL_F32_BROADCAST_PART3){
  64. matrix_mul::check_batched_matrix_mul(
  65. dtype::Float32{}, dtype::Float32{}, {}, handle_cuda(), "CUBLAS",
  66. 1e-3, matrix_mul::get_batched_matmul_broadcast_args_mask(2));
  67. }
  68. TEST_F(CUDA, BATCHED_MATRIX_MUL_F32_BROADCAST_PART4){
  69. matrix_mul::check_batched_matrix_mul(
  70. dtype::Float32{}, dtype::Float32{}, {}, handle_cuda(), "CUBLAS",
  71. 1e-3, matrix_mul::get_batched_matmul_broadcast_args_mask(3));
  72. }
  73. TEST_F(CUDA, BATCHED_MATRIX_MUL_F32_BRUTE_FORCE_PART1) {
  74. matrix_mul::check_batched_matrix_mul(
  75. dtype::Float32{}, dtype::Float32{}, {}, handle_cuda(),
  76. ExecutionPolicyAlgoName{"BRUTE_FORCE", {{"CUBLAS", {}}}}, 1e-3,
  77. matrix_mul::get_batched_matmul_args_mask(0));
  78. }
  79. TEST_F(CUDA, BATCHED_MATRIX_MUL_F32_BRUTE_FORCE_PART2) {
  80. matrix_mul::check_batched_matrix_mul(
  81. dtype::Float32{}, dtype::Float32{}, {}, handle_cuda(),
  82. ExecutionPolicyAlgoName{"BRUTE_FORCE", {{"CUBLAS", {}}}}, 1e-3,
  83. matrix_mul::get_batched_matmul_args_mask(1));
  84. }
  85. TEST_F(CUDA, BATCHED_MATRIX_MUL_F32_BRUTE_FORCE_PART3) {
  86. matrix_mul::check_batched_matrix_mul(
  87. dtype::Float32{}, dtype::Float32{}, {}, handle_cuda(),
  88. ExecutionPolicyAlgoName{"BRUTE_FORCE", {{"CUBLAS", {}}}}, 1e-3,
  89. matrix_mul::get_batched_matmul_args_mask(2));
  90. }
  91. TEST_F(CUDA, BATCHED_MATRIX_MUL_F32_BRUTE_FORCE_PART4) {
  92. matrix_mul::check_batched_matrix_mul(
  93. dtype::Float32{}, dtype::Float32{}, {}, handle_cuda(),
  94. ExecutionPolicyAlgoName{"BRUTE_FORCE", {{"CUBLAS", {}}}}, 1e-3,
  95. matrix_mul::get_batched_matmul_args_mask(3));
  96. }
  97. TEST_F(CUDA, BATCHED_MATRIX_MUL_F16_PART1) {
  98. require_compute_capability(6, 0);
  99. matrix_mul::check_batched_matrix_mul(
  100. dtype::Float16{}, dtype::Float16{}, {}, handle_cuda(), "CUBLAS",
  101. 2e-2, matrix_mul::get_batched_matmul_args_mask(0));
  102. }
  103. TEST_F(CUDA, BATCHED_MATRIX_MUL_F16_PART2) {
  104. require_compute_capability(6, 0);
  105. matrix_mul::check_batched_matrix_mul(
  106. dtype::Float16{}, dtype::Float16{}, {}, handle_cuda(), "CUBLAS",
  107. 2e-2, matrix_mul::get_batched_matmul_args_mask(1));
  108. }
  109. TEST_F(CUDA, BATCHED_MATRIX_MUL_F16_PART3) {
  110. require_compute_capability(6, 0);
  111. matrix_mul::check_batched_matrix_mul(
  112. dtype::Float16{}, dtype::Float16{}, {}, handle_cuda(), "CUBLAS",
  113. 2e-2, matrix_mul::get_batched_matmul_args_mask(2));
  114. }
  115. TEST_F(CUDA, BATCHED_MATRIX_MUL_F16_PART4) {
  116. require_compute_capability(6, 0);
  117. matrix_mul::check_batched_matrix_mul(
  118. dtype::Float16{}, dtype::Float16{}, {}, handle_cuda(), "CUBLAS",
  119. 2e-2, matrix_mul::get_batched_matmul_args_mask(3));
  120. }
  121. TEST_F(CUDA, BATCHED_MATRIX_MUL_CUBLASLT_F16_PART1) {
  122. require_compute_capability(7, 0);
  123. matrix_mul::check_batched_matrix_mul(
  124. dtype::Float16{}, dtype::Float16{}, {}, handle_cuda(), "CUBLAS_LT",
  125. 2e-2, matrix_mul::get_batched_matmul_args_mask(0));
  126. }
  127. TEST_F(CUDA, BATCHED_MATRIX_MUL_CUBLASLT_F16_PART2) {
  128. require_compute_capability(7, 0);
  129. matrix_mul::check_batched_matrix_mul(
  130. dtype::Float16{}, dtype::Float16{}, {}, handle_cuda(), "CUBLAS_LT",
  131. 2e-2, matrix_mul::get_batched_matmul_args_mask(1));
  132. }
  133. TEST_F(CUDA, BATCHED_MATRIX_MUL_CUBLASLT_F16_PART3) {
  134. require_compute_capability(7, 0);
  135. matrix_mul::check_batched_matrix_mul(
  136. dtype::Float16{}, dtype::Float16{}, {}, handle_cuda(), "CUBLAS_LT",
  137. 2e-2, matrix_mul::get_batched_matmul_args_mask(2));
  138. }
  139. TEST_F(CUDA, BATCHED_MATRIX_MUL_CUBLASLT_F16_PART4) {
  140. require_compute_capability(7, 0);
  141. matrix_mul::check_batched_matrix_mul(
  142. dtype::Float16{}, dtype::Float16{}, {}, handle_cuda(), "CUBLAS_LT",
  143. 2e-2, matrix_mul::get_batched_matmul_args_mask(3));
  144. }
  145. TEST_F(CUDA, BATCHED_MATRIX_MUL_CUBLASLT_INT8) {
  146. require_compute_capability(7, 5);
  147. matrix_mul::check_batched_matrix_mul(
  148. dtype::Int8{}, dtype::Int8{}, {}, handle_cuda(), "CUBLAS_LT", 1e-3,
  149. matrix_mul::get_batched_matmul_args_cublaslt());
  150. }
  151. TEST_F(CUDA, BATCHED_MATRIX_MUL_CUBLASLT_QS8) {
  152. require_compute_capability(7, 5);
  153. matrix_mul::check_batched_matrix_mul(
  154. dtype::QuantizedS8(1.2f), dtype::QuantizedS8(1.3f), {},
  155. handle_cuda(), "CUBLAS_LT", 1e-3,
  156. matrix_mul::get_batched_matmul_args_cublaslt());
  157. }
  158. TEST_F(CUDA, BATCHED_MATRIX_MUL_QS8) {
  159. matrix_mul::check_batched_matrix_mul(dtype::QuantizedS8(1.2f),
  160. dtype::QuantizedS8(1.3f), {},
  161. handle_cuda());
  162. }
  163. TEST_F(CUDA, BATCHED_MATRIX_MUL_INT8x8x32) {
  164. require_compute_capability(6, 1);
  165. matrix_mul::check_batched_matrix_mul(
  166. dtype::Int8{}, dtype::Int8{}, dtype::Int32{}, handle_cuda(),
  167. "INT8x8x32", 1e-2, matrix_mul::get_batched_matmul_args_int8x8x32());
  168. }
  169. #if MEGDNN_WITH_BENCHMARK
  170. TEST_F(CUDA, BATCHED_MATMUL_8x8x32_BENCHMARK) {
  171. require_compute_capability(6, 1);
  172. auto run = [&](bool transA, bool transB, size_t m, size_t n, size_t k,
  173. const ExecutionPolicyAlgoName& algo1,
  174. const ExecutionPolicyAlgoName& algo2, size_t b = 128) {
  175. size_t RUNS = 10;
  176. CUBenchmarker<BatchedMatrixMul> bencher1(handle_cuda());
  177. bencher1.set_display(false).set_times(RUNS);
  178. bencher1.set_before_exec_callback(AlgoChecker<BatchedMatrixMul>(algo1));
  179. CUBenchmarker<BatchedMatrixMul> bencher2(handle_cuda());
  180. bencher2.set_display(false).set_times(RUNS);
  181. bencher2.set_before_exec_callback(AlgoChecker<BatchedMatrixMul>(algo2));
  182. using Param = MatrixMul::Param;
  183. DType stype = dtype::Int8(), dtype = dtype::Int32();
  184. Param param;
  185. UniformIntRNG rng(-128, 127);
  186. param.transposeA = transA;
  187. param.transposeB = transB;
  188. TensorShape A, B;
  189. if (param.transposeA)
  190. A = TensorShape{b, k, m};
  191. else
  192. A = TensorShape{b, m, k};
  193. if (param.transposeB)
  194. B = TensorShape{b, n, k};
  195. else
  196. B = TensorShape{b, k, n};
  197. auto flo = (double)m * n * k * b * 2;
  198. bencher1.set_param(param)
  199. .set_dtype(0, stype)
  200. .set_dtype(1, stype)
  201. .set_dtype(2, dtype)
  202. .set_rng(0, &rng)
  203. .set_rng(1, &rng);
  204. auto time1 = bencher1.execs({A, B, {}}) / RUNS;
  205. auto flops1 = flo / time1 / 1e6;
  206. bencher2.set_param(param)
  207. .set_dtype(0, stype)
  208. .set_dtype(1, stype)
  209. .set_dtype(2, dtype)
  210. .set_rng(0, &rng)
  211. .set_rng(1, &rng);
  212. auto time2 = bencher2.execs({A, B, {}}) / RUNS;
  213. auto flops2 = flo / time2 / 1e6;
  214. printf("trA: %d, trB: %d, m: %ld, n: %ld, k: %ld, b: %ld, speedup: %s "
  215. "/ "
  216. "%s %.3f\n",
  217. transA, transB, m, n, k, b, algo1.name.c_str(),
  218. algo2.name.c_str(), flops1 / flops2);
  219. };
  220. for (bool transA : {0, 1})
  221. for (bool transB : {0, 1}) {
  222. run(transA, transB, 128, 576, 128, "INT8x8x32",
  223. ExecutionPolicyAlgoName{"BRUTE_FORCE", {{"CUBLAS", {}}}});
  224. run(transA, transB, 256, 144, 256, "INT8x8x32",
  225. ExecutionPolicyAlgoName{"BRUTE_FORCE", {{"CUBLAS", {}}}});
  226. run(transA, transB, 512, 36, 512, "INT8x8x32",
  227. ExecutionPolicyAlgoName{"BRUTE_FORCE", {{"CUBLAS", {}}}});
  228. run(transA, transB, 1024, 8, 1024, "INT8x8x32",
  229. ExecutionPolicyAlgoName{"BRUTE_FORCE", {{"CUBLAS", {}}}});
  230. }
  231. }
  232. #endif
  233. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台