You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

matrix_mul.cpp 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294
  1. /**
  2. * \file dnn/test/x86/matrix_mul.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "test/x86/fixture.h"
  13. #include "src/x86/utils.h"
  14. #include "test/common/benchmarker.h"
  15. #include "test/common/checker.h"
  16. #include "test/common/matrix_mul.h"
  17. #include "test/common/rng.h"
  18. #include "test/common/task_record_check.h"
  19. using namespace megdnn;
  20. using namespace test;
  21. using namespace megdnn::x86;
  22. TEST_F(X86, MATRIX_MUL_RECORD) {
  23. TaskRecordChecker<MatrixMul> checker(0);
  24. using Param = MatrixMul::Param;
  25. auto args = matrix_mul::get_matmul_args();
  26. auto arg = args[0];
  27. auto m = arg.m, n = arg.n, k = arg.k;
  28. auto mask = arg.mask;
  29. Param param;
  30. param.transposeA = mask & 1;
  31. param.transposeB = mask & 2;
  32. TensorShape AS, BS, CS;
  33. if (param.transposeA)
  34. AS = TensorShape{k, m};
  35. else
  36. AS = TensorShape{m, k};
  37. if (param.transposeB)
  38. BS = TensorShape{n, k};
  39. else
  40. BS = TensorShape{k, n};
  41. CS = TensorShape{m, n};
  42. TensorLayout AL, BL, CL;
  43. AL = TensorLayout(AS, dtype::Float32());
  44. BL = TensorLayout(BS, dtype::Float32());
  45. CL = TensorLayout(CS, dtype::Float32());
  46. checker.set_param(param);
  47. checker.execl({AL, BL, CL});
  48. }
  49. #if MEGDNN_X86_WITH_VNNI
  50. TEST_F(X86, MATRIX_MUL_VNNI_8X8X32) {
  51. matrix_mul::check_matrix_mul(
  52. dtype::Int8{}, dtype::Int8{}, dtype::Int32{}, handle(),
  53. "X86_INT8X8X32_VNNI");
  54. }
  55. #endif
  56. #if MEGDNN_X86_WITH_MKL_DNN
  57. TEST_F(X86, MATRIX_MUL_MKLDNN_8X8X32) {
  58. if (is_supported(SIMDType::VNNI)) {
  59. matrix_mul::check_matrix_mul(
  60. dtype::Int8{}, dtype::Int8{}, dtype::Int32{}, handle(),
  61. "X86_INT8X8X32_MKLDNN");
  62. } else {
  63. std::cout << "can not do mkldnn matmul check for no vnni support" << std::endl;
  64. matrix_mul::check_matrix_mul(
  65. dtype::Int8{}, dtype::Int8{}, dtype::Int32{}, handle());
  66. }
  67. }
  68. #endif
  69. //! FIXME: need to add tests of GEMV and QUINT8
  70. TEST_F(X86, MATRIX_MUL_AVX2_8X8X32) {
  71. matrix_mul::check_matrix_mul(
  72. dtype::Int8{}, dtype::Int8{}, dtype::Int32{}, handle(),
  73. "X86_INT8X8X32_AVX2_2X4X16", param::MatrixMul::Format::DEFAULT, 8, 1e-3,
  74. false);
  75. matrix_mul::check_matrix_mul(
  76. dtype::Int8{}, dtype::Int8{}, dtype::Int32{}, handle(),
  77. "X86_INT8X8X32_AVX2_4X16X2", param::MatrixMul::Format::DEFAULT, 8, 1e-3,
  78. false);
  79. }
  80. TEST_F(X86, MATRIX_MUL_AVX2_8X8X16) {
  81. matrix_mul::check_matrix_mul(
  82. dtype::Int8{}, dtype::Int8{}, dtype::Int16{}, handle(),
  83. "X86_INT8X8X16_AVX2", param::MatrixMul::Format::DEFAULT, 8, 1e-3, false);
  84. }
  85. TEST_F(X86, MATRIX_MUL_SSE_8X8X16) {
  86. matrix_mul::check_matrix_mul(
  87. dtype::Int8{}, dtype::Int8{}, dtype::Int16{}, handle(), "X86_INT8X8X16_SSE",
  88. param::MatrixMul::Format::DEFAULT, 8, 1e-3, false);
  89. }
  90. TEST_F(X86, MATRIX_MUL_SSE_8X8X32) {
  91. matrix_mul::check_matrix_mul(
  92. dtype::Int8{}, dtype::Int8{}, dtype::Int32{}, handle(),
  93. "X86_INT8X8X32_SSE_4X8X2", param::MatrixMul::Format::DEFAULT, 8, 1e-3,
  94. false);
  95. }
  96. #if MEGDNN_X86_WITH_MKL && SUPPORT_MKL_PACKED_GEMM
  97. TEST_F(X86, MATRIX_MUL_MKL_PACKA) {
  98. matrix_mul::check_matrix_mul(
  99. dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(),
  100. "X86_F32_MKL_PACKA");
  101. }
  102. #endif
  103. TEST_F(X86, MATRIX_MUL_AVX2_MK8_8X8) {
  104. matrix_mul::check_matrix_mul(
  105. dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(),
  106. "X86_F32MK8_8X8", param::MatrixMul::Format::MK8, 1, 1e-3, false);
  107. }
  108. TEST_F(X86, MATRIX_MUL_AVX2_6x16) {
  109. matrix_mul::check_matrix_mul(
  110. dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(),
  111. "X86_F32_6x16", param::MatrixMul::Format::DEFAULT, 1, 1e-3, false);
  112. }
  113. #if MEGDNN_WITH_BENCHMARK
  114. TEST_F(X86, BENCHMARK_MATRIX_MUL_AVX2_MK8_8X8) {
  115. auto args = matrix_mul::get_benchmark_matmul_mk_packed_args(8);
  116. matrix_mul::benchmark_with_contrast(
  117. handle(), args, dtype::Float32{}, dtype::Float32{}, dtype::Float32{},
  118. "X86_F32MK8_8X8", param::MatrixMul::Format::MK8, dtype::Float32{},
  119. dtype::Float32{}, dtype::Float32{}, "X86_F32_BLAS");
  120. }
  121. TEST_F(X86, BENCHMARK_MATRIX_MUL_AVX2_6x16) {
  122. auto args = matrix_mul::get_benchmark_matmul_mk_packed_args(8);
  123. matrix_mul::benchmark_with_contrast(
  124. handle(), args, dtype::Float32{}, dtype::Float32{}, dtype::Float32{},
  125. "X86_F32_6x16", param::MatrixMul::Format::DEFAULT, dtype::Float32{},
  126. dtype::Float32{}, dtype::Float32{}, "X86_F32_BLAS");
  127. }
  128. TEST_F(X86, BENCHMARK_MATRIX_MUL_8X8X32) {
  129. constexpr size_t RUNS = 50;
  130. auto rng = std::make_unique<UniformIntRNG>(-127, 127);
  131. #if MEGDNN_X86_WITH_VNNI
  132. Benchmarker<MatrixMul> benchmarker_vnni(handle());
  133. benchmarker_vnni.set_times(RUNS)
  134. .set_dtype(0, dtype::Int8{})
  135. .set_dtype(1, dtype::Int8{})
  136. .set_dtype(2, dtype::Int32{})
  137. .set_display(false)
  138. .set_rng(0, rng.get())
  139. .set_rng(1, rng.get());
  140. benchmarker_vnni.set_before_exec_callback(
  141. AlgoChecker<MatrixMul>("X86_INT8X8X32_VNNI"));
  142. #endif
  143. #if MEGDNN_X86_WITH_MKL_DNN
  144. Benchmarker<MatrixMul> benchmarker_mkldnn(handle());
  145. benchmarker_mkldnn.set_times(RUNS)
  146. .set_dtype(0, dtype::Int8{})
  147. .set_dtype(1, dtype::Int8{})
  148. .set_dtype(2, dtype::Int32{})
  149. .set_display(false)
  150. .set_rng(0, rng.get())
  151. .set_rng(1, rng.get());
  152. benchmarker_mkldnn.set_before_exec_callback(
  153. AlgoChecker<MatrixMul>("X86_INT8X8X32_MKLDNN"));
  154. #endif
  155. Benchmarker<MatrixMul> benchmarker_avx2_4x16x2(handle());
  156. benchmarker_avx2_4x16x2.set_display(false)
  157. .set_times(RUNS)
  158. .set_dtype(0, dtype::Int8{})
  159. .set_dtype(1, dtype::Int8{})
  160. .set_dtype(2, dtype::Int32{})
  161. .set_rng(0, rng.get())
  162. .set_rng(1, rng.get());
  163. benchmarker_avx2_4x16x2.set_before_exec_callback(
  164. AlgoChecker<MatrixMul>("X86_INT8X8X32_AVX2_4X16X2"));
  165. Benchmarker<MatrixMul> benchmarker_avx2_4x16x2_8816(handle());
  166. benchmarker_avx2_4x16x2_8816.set_display(false)
  167. .set_times(RUNS)
  168. .set_dtype(0, dtype::Int8{})
  169. .set_dtype(1, dtype::Int8{})
  170. .set_dtype(2, dtype::Int16{})
  171. .set_rng(0, rng.get())
  172. .set_rng(1, rng.get());
  173. benchmarker_avx2_4x16x2_8816.set_before_exec_callback(
  174. AlgoChecker<MatrixMul>("X86_INT8X8X16_AVX2"));
  175. Benchmarker<MatrixMul> benchmarker_sse_4x8x2_8816(handle());
  176. benchmarker_sse_4x8x2_8816.set_display(false)
  177. .set_times(RUNS)
  178. .set_dtype(0, dtype::Int8{})
  179. .set_dtype(1, dtype::Int8{})
  180. .set_dtype(2, dtype::Int16{})
  181. .set_rng(0, rng.get())
  182. .set_rng(1, rng.get());
  183. benchmarker_sse_4x8x2_8816.set_before_exec_callback(
  184. AlgoChecker<MatrixMul>("X86_INT8X8X16_SSE"));
  185. Benchmarker<MatrixMul> benchmarker_avx2_2x4x16(handle());
  186. benchmarker_avx2_2x4x16.set_display(false)
  187. .set_times(RUNS)
  188. .set_dtype(0, dtype::Int8{})
  189. .set_dtype(1, dtype::Int8{})
  190. .set_dtype(2, dtype::Int32{})
  191. .set_rng(0, rng.get())
  192. .set_rng(1, rng.get());
  193. benchmarker_avx2_2x4x16.set_before_exec_callback(
  194. AlgoChecker<MatrixMul>("X86_INT8X8X32_AVX2_2X4X16"));
  195. Benchmarker<MatrixMul> benchmarker_sse_4x8x2(handle());
  196. benchmarker_sse_4x8x2.set_display(false)
  197. .set_times(RUNS)
  198. .set_dtype(0, dtype::Int8{})
  199. .set_dtype(1, dtype::Int8{})
  200. .set_dtype(2, dtype::Int32{})
  201. .set_rng(0, rng.get())
  202. .set_rng(1, rng.get());
  203. benchmarker_sse_4x8x2.set_before_exec_callback(
  204. AlgoChecker<MatrixMul>("X86_INT8X8X32_SSE_4X8X2"));
  205. Benchmarker<MatrixMul> benchmarker_float(handle());
  206. benchmarker_float.set_display(false)
  207. .set_times(RUNS)
  208. .set_rng(0, rng.get())
  209. .set_rng(1, rng.get());
  210. benchmarker_float.set_before_exec_callback(AlgoChecker<MatrixMul>("X86_F32_BLAS"));
  211. auto run = [&](size_t M, size_t N, size_t K) {
  212. const float computations = 2.f * M * K * N * 1e-6;
  213. std::cout << "run : {" << M << "," << N << "," << K << "} ";
  214. auto float_used = benchmarker_float.exec({{M, K}, {K, N}, {}}) / RUNS;
  215. std::cout << "float: " << float_used << " ms, " << computations / float_used
  216. << " Gflops, ";
  217. #if MEGDNN_X86_WITH_VNNI
  218. if (is_supported(SIMDType::VNNI)) {
  219. auto vnni_used = benchmarker_vnni.exec({{M, K}, {K, N}, {}}) / RUNS;
  220. std::cout << "vnni: " << vnni_used << " ms, " << computations / vnni_used
  221. << " Gflops, "
  222. << "speed_up " << float_used / vnni_used << ", ";
  223. }
  224. #endif
  225. #if MEGDNN_X86_WITH_MKL_DNN
  226. if (is_supported(SIMDType::VNNI)) {
  227. auto mkldnn_used = benchmarker_mkldnn.exec({{M, K}, {K, N}, {}}) / RUNS;
  228. std::cout << "mkldnn: " << mkldnn_used << " ms, "
  229. << computations / mkldnn_used << " Gflops, "
  230. << "speed_up " << float_used / mkldnn_used << ", ";
  231. }
  232. #endif
  233. if (is_supported(SIMDType::AVX2)) {
  234. auto avx2_used_4x16x2 =
  235. benchmarker_avx2_4x16x2.exec({{M, K}, {K, N}, {}}) / RUNS;
  236. auto avx2_used_2x4x16 =
  237. benchmarker_avx2_2x4x16.exec({{M, K}, {K, N}, {}}) / RUNS;
  238. std::cout << "avx2_k2: " << avx2_used_4x16x2 << " ms, k2 throughput "
  239. << computations / avx2_used_4x16x2 << " Gflops, "
  240. << "k2_speed_up " << float_used / avx2_used_4x16x2
  241. << ", k16_speed_up " << float_used / avx2_used_2x4x16 << ",";
  242. auto avx2_used_4x16x2_8816 =
  243. benchmarker_avx2_4x16x2_8816.exec({{M, K}, {K, N}, {}}) / RUNS;
  244. std::cout << "avx2_8816: " << avx2_used_4x16x2_8816
  245. << " ms, 8816 throughput " << computations / avx2_used_4x16x2_8816
  246. << " Gflops,";
  247. }
  248. if (is_supported(SIMDType::SSE4_1)) {
  249. auto sse_used = benchmarker_sse_4x8x2.exec({{M, K}, {K, N}, {}}) / RUNS;
  250. std::cout << "sse: " << sse_used << " ms, " << computations / sse_used
  251. << " Gflops, "
  252. << "speed_up " << float_used / sse_used << ", ";
  253. auto sse_used_8816 =
  254. benchmarker_sse_4x8x2_8816.exec({{M, K}, {K, N}, {}}) / RUNS;
  255. std::cout << "sse_8816: " << sse_used_8816 << " ms, "
  256. << computations / sse_used_8816 << " Gflops, ";
  257. }
  258. std::cout << std::endl;
  259. };
  260. run(256, 256, 256);
  261. for (size_t M : {8, 64, 112, 256, 512}) {
  262. for (size_t K : {8, 16, 32, 64, 112, 256, 512}) {
  263. for (size_t N : {8, 64, 112, 256, 512}) {
  264. run(M, N, K);
  265. }
  266. }
  267. }
  268. }
  269. #endif // MEGDNN_WITH_BENCHMARK
  270. // vim: syntax=cpp.doxygen