You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

matrix_mul.cpp 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265
  1. /**
  2. * \file dnn/test/x86/matrix_mul.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "test/x86/fixture.h"
  13. #include "src/x86/utils.h"
  14. #include "test/common/benchmarker.h"
  15. #include "test/common/checker.h"
  16. #include "test/common/matrix_mul.h"
  17. #include "test/common/rng.h"
  18. using namespace megdnn;
  19. using namespace test;
  20. using namespace megdnn::x86;
  21. #if MEGDNN_X86_WITH_VNNI
  22. TEST_F(X86, MATRIX_MUL_VNNI_8X8X32) {
  23. matrix_mul::check_matrix_mul(
  24. dtype::Int8{}, dtype::Int8{}, dtype::Int32{}, handle(),
  25. "X86_INT8X8X32_VNNI");
  26. }
  27. #endif
  28. #if MEGDNN_X86_WITH_MKL_DNN
  29. TEST_F(X86, MATRIX_MUL_MKLDNN_8X8X32) {
  30. if (is_supported(SIMDType::VNNI)) {
  31. matrix_mul::check_matrix_mul(
  32. dtype::Int8{}, dtype::Int8{}, dtype::Int32{}, handle(),
  33. "X86_INT8X8X32_MKLDNN");
  34. } else {
  35. std::cout << "can not do mkldnn matmul check for no vnni support" << std::endl;
  36. matrix_mul::check_matrix_mul(
  37. dtype::Int8{}, dtype::Int8{}, dtype::Int32{}, handle());
  38. }
  39. }
  40. #endif
  41. //! FIXME: need to add tests of GEMV and QUINT8
  42. TEST_F(X86, MATRIX_MUL_AVX2_8X8X32) {
  43. matrix_mul::check_matrix_mul(
  44. dtype::Int8{}, dtype::Int8{}, dtype::Int32{}, handle(),
  45. "X86_INT8X8X32_AVX2_2X4X16", param::MatrixMul::Format::DEFAULT, 8, 1e-3,
  46. false);
  47. matrix_mul::check_matrix_mul(
  48. dtype::Int8{}, dtype::Int8{}, dtype::Int32{}, handle(),
  49. "X86_INT8X8X32_AVX2_4X16X2", param::MatrixMul::Format::DEFAULT, 8, 1e-3,
  50. false);
  51. }
  52. TEST_F(X86, MATRIX_MUL_AVX2_8X8X16) {
  53. matrix_mul::check_matrix_mul(
  54. dtype::Int8{}, dtype::Int8{}, dtype::Int16{}, handle(),
  55. "X86_INT8X8X16_AVX2", param::MatrixMul::Format::DEFAULT, 8, 1e-3, false);
  56. }
  57. TEST_F(X86, MATRIX_MUL_SSE_8X8X16) {
  58. matrix_mul::check_matrix_mul(
  59. dtype::Int8{}, dtype::Int8{}, dtype::Int16{}, handle(), "X86_INT8X8X16_SSE",
  60. param::MatrixMul::Format::DEFAULT, 8, 1e-3, false);
  61. }
  62. TEST_F(X86, MATRIX_MUL_SSE_8X8X32) {
  63. matrix_mul::check_matrix_mul(
  64. dtype::Int8{}, dtype::Int8{}, dtype::Int32{}, handle(),
  65. "X86_INT8X8X32_SSE_4X8X2", param::MatrixMul::Format::DEFAULT, 8, 1e-3,
  66. false);
  67. }
  68. #if MEGDNN_X86_WITH_MKL && SUPPORT_MKL_PACKED_GEMM
  69. TEST_F(X86, MATRIX_MUL_MKL_PACKA) {
  70. matrix_mul::check_matrix_mul(
  71. dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(),
  72. "X86_F32_MKL_PACKA");
  73. }
  74. #endif
  75. TEST_F(X86, MATRIX_MUL_AVX2_MK8_8X8) {
  76. matrix_mul::check_matrix_mul(
  77. dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(),
  78. "X86_F32MK8_8X8", param::MatrixMul::Format::MK8, 1, 1e-3, false);
  79. }
  80. TEST_F(X86, MATRIX_MUL_AVX2_6x16) {
  81. matrix_mul::check_matrix_mul(
  82. dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(),
  83. "X86_F32_6x16", param::MatrixMul::Format::DEFAULT, 1, 1e-3, false);
  84. }
  85. #if MEGDNN_WITH_BENCHMARK
  86. TEST_F(X86, BENCHMARK_MATRIX_MUL_AVX2_MK8_8X8) {
  87. auto args = matrix_mul::get_benchmark_matmul_mk_packed_args(8);
  88. matrix_mul::benchmark_with_contrast(
  89. handle(), args, dtype::Float32{}, dtype::Float32{}, dtype::Float32{},
  90. "X86_F32MK8_8X8", param::MatrixMul::Format::MK8, dtype::Float32{},
  91. dtype::Float32{}, dtype::Float32{}, "X86_F32_BLAS");
  92. }
  93. TEST_F(X86, BENCHMARK_MATRIX_MUL_AVX2_6x16) {
  94. auto args = matrix_mul::get_benchmark_matmul_mk_packed_args(8);
  95. matrix_mul::benchmark_with_contrast(
  96. handle(), args, dtype::Float32{}, dtype::Float32{}, dtype::Float32{},
  97. "X86_F32_6x16", param::MatrixMul::Format::DEFAULT, dtype::Float32{},
  98. dtype::Float32{}, dtype::Float32{}, "X86_F32_BLAS");
  99. }
  100. TEST_F(X86, BENCHMARK_MATRIX_MUL_8X8X32) {
  101. constexpr size_t RUNS = 50;
  102. auto rng = std::make_unique<UniformIntRNG>(-127, 127);
  103. #if MEGDNN_X86_WITH_VNNI
  104. Benchmarker<MatrixMul> benchmarker_vnni(handle());
  105. benchmarker_vnni.set_times(RUNS)
  106. .set_dtype(0, dtype::Int8{})
  107. .set_dtype(1, dtype::Int8{})
  108. .set_dtype(2, dtype::Int32{})
  109. .set_display(false)
  110. .set_rng(0, rng.get())
  111. .set_rng(1, rng.get());
  112. benchmarker_vnni.set_before_exec_callback(
  113. AlgoChecker<MatrixMul>("X86_INT8X8X32_VNNI"));
  114. #endif
  115. #if MEGDNN_X86_WITH_MKL_DNN
  116. Benchmarker<MatrixMul> benchmarker_mkldnn(handle());
  117. benchmarker_mkldnn.set_times(RUNS)
  118. .set_dtype(0, dtype::Int8{})
  119. .set_dtype(1, dtype::Int8{})
  120. .set_dtype(2, dtype::Int32{})
  121. .set_display(false)
  122. .set_rng(0, rng.get())
  123. .set_rng(1, rng.get());
  124. benchmarker_mkldnn.set_before_exec_callback(
  125. AlgoChecker<MatrixMul>("X86_INT8X8X32_MKLDNN"));
  126. #endif
  127. Benchmarker<MatrixMul> benchmarker_avx2_4x16x2(handle());
  128. benchmarker_avx2_4x16x2.set_display(false)
  129. .set_times(RUNS)
  130. .set_dtype(0, dtype::Int8{})
  131. .set_dtype(1, dtype::Int8{})
  132. .set_dtype(2, dtype::Int32{})
  133. .set_rng(0, rng.get())
  134. .set_rng(1, rng.get());
  135. benchmarker_avx2_4x16x2.set_before_exec_callback(
  136. AlgoChecker<MatrixMul>("X86_INT8X8X32_AVX2_4X16X2"));
  137. Benchmarker<MatrixMul> benchmarker_avx2_4x16x2_8816(handle());
  138. benchmarker_avx2_4x16x2_8816.set_display(false)
  139. .set_times(RUNS)
  140. .set_dtype(0, dtype::Int8{})
  141. .set_dtype(1, dtype::Int8{})
  142. .set_dtype(2, dtype::Int16{})
  143. .set_rng(0, rng.get())
  144. .set_rng(1, rng.get());
  145. benchmarker_avx2_4x16x2_8816.set_before_exec_callback(
  146. AlgoChecker<MatrixMul>("X86_INT8X8X16_AVX2"));
  147. Benchmarker<MatrixMul> benchmarker_sse_4x8x2_8816(handle());
  148. benchmarker_sse_4x8x2_8816.set_display(false)
  149. .set_times(RUNS)
  150. .set_dtype(0, dtype::Int8{})
  151. .set_dtype(1, dtype::Int8{})
  152. .set_dtype(2, dtype::Int16{})
  153. .set_rng(0, rng.get())
  154. .set_rng(1, rng.get());
  155. benchmarker_sse_4x8x2_8816.set_before_exec_callback(
  156. AlgoChecker<MatrixMul>("X86_INT8X8X16_SSE"));
  157. Benchmarker<MatrixMul> benchmarker_avx2_2x4x16(handle());
  158. benchmarker_avx2_2x4x16.set_display(false)
  159. .set_times(RUNS)
  160. .set_dtype(0, dtype::Int8{})
  161. .set_dtype(1, dtype::Int8{})
  162. .set_dtype(2, dtype::Int32{})
  163. .set_rng(0, rng.get())
  164. .set_rng(1, rng.get());
  165. benchmarker_avx2_2x4x16.set_before_exec_callback(
  166. AlgoChecker<MatrixMul>("X86_INT8X8X32_AVX2_2X4X16"));
  167. Benchmarker<MatrixMul> benchmarker_sse_4x8x2(handle());
  168. benchmarker_sse_4x8x2.set_display(false)
  169. .set_times(RUNS)
  170. .set_dtype(0, dtype::Int8{})
  171. .set_dtype(1, dtype::Int8{})
  172. .set_dtype(2, dtype::Int32{})
  173. .set_rng(0, rng.get())
  174. .set_rng(1, rng.get());
  175. benchmarker_sse_4x8x2.set_before_exec_callback(
  176. AlgoChecker<MatrixMul>("X86_INT8X8X32_SSE_4X8X2"));
  177. Benchmarker<MatrixMul> benchmarker_float(handle());
  178. benchmarker_float.set_display(false)
  179. .set_times(RUNS)
  180. .set_rng(0, rng.get())
  181. .set_rng(1, rng.get());
  182. benchmarker_float.set_before_exec_callback(AlgoChecker<MatrixMul>("X86_F32_BLAS"));
  183. auto run = [&](size_t M, size_t N, size_t K) {
  184. const float computations = 2.f * M * K * N * 1e-6;
  185. std::cout << "run : {" << M << "," << N << "," << K << "} ";
  186. auto float_used = benchmarker_float.exec({{M, K}, {K, N}, {}}) / RUNS;
  187. std::cout << "float: " << float_used << " ms, " << computations / float_used
  188. << " Gflops, ";
  189. #if MEGDNN_X86_WITH_VNNI
  190. if (is_supported(SIMDType::VNNI)) {
  191. auto vnni_used = benchmarker_vnni.exec({{M, K}, {K, N}, {}}) / RUNS;
  192. std::cout << "vnni: " << vnni_used << " ms, " << computations / vnni_used
  193. << " Gflops, "
  194. << "speed_up " << float_used / vnni_used << ", ";
  195. }
  196. #endif
  197. #if MEGDNN_X86_WITH_MKL_DNN
  198. if (is_supported(SIMDType::VNNI)) {
  199. auto mkldnn_used = benchmarker_mkldnn.exec({{M, K}, {K, N}, {}}) / RUNS;
  200. std::cout << "mkldnn: " << mkldnn_used << " ms, "
  201. << computations / mkldnn_used << " Gflops, "
  202. << "speed_up " << float_used / mkldnn_used << ", ";
  203. }
  204. #endif
  205. if (is_supported(SIMDType::AVX2)) {
  206. auto avx2_used_4x16x2 =
  207. benchmarker_avx2_4x16x2.exec({{M, K}, {K, N}, {}}) / RUNS;
  208. auto avx2_used_2x4x16 =
  209. benchmarker_avx2_2x4x16.exec({{M, K}, {K, N}, {}}) / RUNS;
  210. std::cout << "avx2_k2: " << avx2_used_4x16x2 << " ms, k2 throughput "
  211. << computations / avx2_used_4x16x2 << " Gflops, "
  212. << "k2_speed_up " << float_used / avx2_used_4x16x2
  213. << ", k16_speed_up " << float_used / avx2_used_2x4x16 << ",";
  214. auto avx2_used_4x16x2_8816 =
  215. benchmarker_avx2_4x16x2_8816.exec({{M, K}, {K, N}, {}}) / RUNS;
  216. std::cout << "avx2_8816: " << avx2_used_4x16x2_8816
  217. << " ms, 8816 throughput " << computations / avx2_used_4x16x2_8816
  218. << " Gflops,";
  219. }
  220. if (is_supported(SIMDType::SSE4_1)) {
  221. auto sse_used = benchmarker_sse_4x8x2.exec({{M, K}, {K, N}, {}}) / RUNS;
  222. std::cout << "sse: " << sse_used << " ms, " << computations / sse_used
  223. << " Gflops, "
  224. << "speed_up " << float_used / sse_used << ", ";
  225. auto sse_used_8816 =
  226. benchmarker_sse_4x8x2_8816.exec({{M, K}, {K, N}, {}}) / RUNS;
  227. std::cout << "sse_8816: " << sse_used_8816 << " ms, "
  228. << computations / sse_used_8816 << " Gflops, ";
  229. }
  230. std::cout << std::endl;
  231. };
  232. run(256, 256, 256);
  233. for (size_t M : {8, 64, 112, 256, 512}) {
  234. for (size_t K : {8, 16, 32, 64, 112, 256, 512}) {
  235. for (size_t N : {8, 64, 112, 256, 512}) {
  236. run(M, N, K);
  237. }
  238. }
  239. }
  240. }
  241. #endif // MEGDNN_WITH_BENCHMARK
  242. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台