You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

batched_matrix_mul.cpp 7.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216
  1. /**
  2. * \file dnn/test/cuda/batched_matrix_mul.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "test/cuda/fixture.h"
  12. #include "test/common/checker.h"
  13. #include "test/common/matrix_mul.h"
  14. #include "test/common/rng.h"
  15. #include "test/cuda/benchmark.h"
  16. #include "test/cuda/utils.h"
  17. using namespace megdnn;
  18. using namespace test;
  19. #define F32_TEST_PART(x, algo) \
  20. matrix_mul::check_batched_matrix_mul( \
  21. dtype::Float32{}, dtype::Float32{}, {}, handle_cuda(), algo, 1e-3, \
  22. matrix_mul::get_batched_matmul_args_mask(x))
  23. TEST_F(CUDA, BATCHED_MATRIX_MUL_F32_PART1) {
  24. F32_TEST_PART(0, "CUBLAS");
  25. }
  26. TEST_F(CUDA, BATCHED_MATRIX_MUL_F32_PART2) {
  27. F32_TEST_PART(1, "CUBLAS");
  28. }
  29. TEST_F(CUDA, BATCHED_MATRIX_MUL_F32_PART3) {
  30. F32_TEST_PART(2, "CUBLAS");
  31. }
  32. TEST_F(CUDA, BATCHED_MATRIX_MUL_F32_PART4) {
  33. F32_TEST_PART(3, "CUBLAS");
  34. }
  35. TEST_F(CUDA, BATCHED_MATRIX_MUL_LT_F32_PART1) {
  36. require_compute_capability(7, 0);
  37. F32_TEST_PART(0, "CUBLAS_LT");
  38. }
  39. TEST_F(CUDA, BATCHED_MATRIX_MUL_LT_F32_PART2) {
  40. require_compute_capability(7, 0);
  41. F32_TEST_PART(1, "CUBLAS_LT");
  42. }
  43. TEST_F(CUDA, BATCHED_MATRIX_MUL_LT_F32_PART3) {
  44. require_compute_capability(7, 0);
  45. F32_TEST_PART(2, "CUBLAS_LT");
  46. }
  47. TEST_F(CUDA, BATCHED_MATRIX_MUL_LT_F32_PART4) {
  48. require_compute_capability(7, 0);
  49. F32_TEST_PART(3, "CUBLAS_LT");
  50. }
  51. #undef F32_TEST_PART
  52. TEST_F(CUDA, BATCHED_MATRIX_MUL_F16_PART1) {
  53. require_compute_capability(6, 0);
  54. matrix_mul::check_batched_matrix_mul(
  55. dtype::Float16{}, dtype::Float16{}, {}, handle_cuda(), "CUBLAS",
  56. 2e-2, matrix_mul::get_batched_matmul_args_mask(0));
  57. }
  58. TEST_F(CUDA, BATCHED_MATRIX_MUL_F16_PART2) {
  59. require_compute_capability(6, 0);
  60. matrix_mul::check_batched_matrix_mul(
  61. dtype::Float16{}, dtype::Float16{}, {}, handle_cuda(), "CUBLAS",
  62. 2e-2, matrix_mul::get_batched_matmul_args_mask(1));
  63. }
  64. TEST_F(CUDA, BATCHED_MATRIX_MUL_F16_PART3) {
  65. require_compute_capability(6, 0);
  66. matrix_mul::check_batched_matrix_mul(
  67. dtype::Float16{}, dtype::Float16{}, {}, handle_cuda(), "CUBLAS",
  68. 2e-2, matrix_mul::get_batched_matmul_args_mask(2));
  69. }
  70. TEST_F(CUDA, BATCHED_MATRIX_MUL_F16_PART4) {
  71. require_compute_capability(6, 0);
  72. matrix_mul::check_batched_matrix_mul(
  73. dtype::Float16{}, dtype::Float16{}, {}, handle_cuda(), "CUBLAS",
  74. 2e-2, matrix_mul::get_batched_matmul_args_mask(3));
  75. }
  76. TEST_F(CUDA, BATCHED_MATRIX_MUL_CUBLASLT_F16_PART1) {
  77. require_compute_capability(7, 0);
  78. matrix_mul::check_batched_matrix_mul(
  79. dtype::Float16{}, dtype::Float16{}, {}, handle_cuda(), "CUBLAS_LT",
  80. 2e-2, matrix_mul::get_batched_matmul_args_mask(0));
  81. }
  82. TEST_F(CUDA, BATCHED_MATRIX_MUL_CUBLASLT_F16_PART2) {
  83. require_compute_capability(7, 0);
  84. matrix_mul::check_batched_matrix_mul(
  85. dtype::Float16{}, dtype::Float16{}, {}, handle_cuda(), "CUBLAS_LT",
  86. 2e-2, matrix_mul::get_batched_matmul_args_mask(1));
  87. }
  88. TEST_F(CUDA, BATCHED_MATRIX_MUL_CUBLASLT_F16_PART3) {
  89. require_compute_capability(7, 0);
  90. matrix_mul::check_batched_matrix_mul(
  91. dtype::Float16{}, dtype::Float16{}, {}, handle_cuda(), "CUBLAS_LT",
  92. 2e-2, matrix_mul::get_batched_matmul_args_mask(2));
  93. }
  94. TEST_F(CUDA, BATCHED_MATRIX_MUL_CUBLASLT_F16_PART4) {
  95. require_compute_capability(7, 0);
  96. matrix_mul::check_batched_matrix_mul(
  97. dtype::Float16{}, dtype::Float16{}, {}, handle_cuda(), "CUBLAS_LT",
  98. 2e-2, matrix_mul::get_batched_matmul_args_mask(3));
  99. }
  100. TEST_F(CUDA, BATCHED_MATRIX_MUL_CUBLASLT_INT8) {
  101. require_compute_capability(7, 5);
  102. matrix_mul::check_batched_matrix_mul(
  103. dtype::Int8{}, dtype::Int8{}, {}, handle_cuda(), "CUBLAS_LT", 1e-3,
  104. matrix_mul::get_batched_matmul_args_cublaslt());
  105. }
  106. TEST_F(CUDA, BATCHED_MATRIX_MUL_CUBLASLT_QS8) {
  107. require_compute_capability(7, 5);
  108. matrix_mul::check_batched_matrix_mul(
  109. dtype::QuantizedS8(1.2f), dtype::QuantizedS8(1.3f), {},
  110. handle_cuda(), "CUBLAS_LT", 1e-3,
  111. matrix_mul::get_batched_matmul_args_cublaslt());
  112. }
  113. TEST_F(CUDA, BATCHED_MATRIX_MUL_QS8) {
  114. matrix_mul::check_batched_matrix_mul(dtype::QuantizedS8(1.2f),
  115. dtype::QuantizedS8(1.3f), {},
  116. handle_cuda());
  117. }
  118. TEST_F(CUDA, BATCHED_MATRIX_MUL_INT8x8x32) {
  119. require_compute_capability(6, 1);
  120. matrix_mul::check_batched_matrix_mul(
  121. dtype::Int8{}, dtype::Int8{}, dtype::Int32{}, handle_cuda(),
  122. "INT8x8x32", 1e-2, matrix_mul::get_batched_matmul_args_int8x8x32());
  123. }
  124. #if MEGDNN_WITH_BENCHMARK
  125. TEST_F(CUDA, BATCHED_MATMUL_8x8x32_BENCHMARK) {
  126. require_compute_capability(6, 1);
  127. auto run = [&](bool transA, bool transB, size_t m, size_t n, size_t k,
  128. const char* algo1, const char* algo2, size_t b = 128) {
  129. size_t RUNS = 10;
  130. CUBenchmarker<BatchedMatrixMul> bencher1(handle_cuda());
  131. bencher1.set_display(false).set_times(RUNS);
  132. bencher1.set_before_exec_callback(AlgoChecker<BatchedMatrixMul>(algo1));
  133. CUBenchmarker<BatchedMatrixMul> bencher2(handle_cuda());
  134. bencher2.set_display(false).set_times(RUNS);
  135. bencher2.set_before_exec_callback(AlgoChecker<BatchedMatrixMul>(algo2));
  136. using Param = MatrixMul::Param;
  137. DType stype = dtype::Int8(), dtype = dtype::Int32();
  138. Param param;
  139. UniformIntRNG rng(-128, 127);
  140. param.transposeA = transA;
  141. param.transposeB = transB;
  142. TensorShape A, B;
  143. if (param.transposeA)
  144. A = TensorShape{b, k, m};
  145. else
  146. A = TensorShape{b, m, k};
  147. if (param.transposeB)
  148. B = TensorShape{b, n, k};
  149. else
  150. B = TensorShape{b, k, n};
  151. auto flo = (double)m * n * k * b * 2;
  152. bencher1.set_param(param)
  153. .set_dtype(0, stype)
  154. .set_dtype(1, stype)
  155. .set_dtype(2, dtype)
  156. .set_rng(0, &rng)
  157. .set_rng(1, &rng);
  158. auto time1 = bencher1.execs({A, B, {}}) / RUNS;
  159. auto flops1 = flo / time1 / 1e6;
  160. bencher2.set_param(param)
  161. .set_dtype(0, stype)
  162. .set_dtype(1, stype)
  163. .set_dtype(2, dtype)
  164. .set_rng(0, &rng)
  165. .set_rng(1, &rng);
  166. auto time2 = bencher2.execs({A, B, {}}) / RUNS;
  167. auto flops2 = flo / time2 / 1e6;
  168. printf("trA: %d, trB: %d, m: %ld, n: %ld, k: %ld, b: %ld, speedup: %s "
  169. "/ "
  170. "%s %.3f\n",
  171. transA, transB, m, n, k, b, algo1, algo2, flops1 / flops2);
  172. };
  173. for (bool transA : {0, 1})
  174. for (bool transB : {0, 1}) {
  175. run(transA, transB, 128, 576, 128, "INT8x8x32",
  176. "BRUTE_FORCE-CUBLAS");
  177. run(transA, transB, 256, 144, 256, "INT8x8x32",
  178. "BRUTE_FORCE-CUBLAS");
  179. run(transA, transB, 512, 36, 512, "INT8x8x32",
  180. "BRUTE_FORCE-CUBLAS");
  181. run(transA, transB, 1024, 8, 1024, "INT8x8x32",
  182. "BRUTE_FORCE-CUBLAS");
  183. }
  184. }
  185. #endif
  186. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台