You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

matrix_mul.cpp 8.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208
  1. /**
  2. * \file dnn/test/x86/matrix_mul.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "test/x86/fixture.h"
  12. #include "src/x86/utils.h"
  13. #include "test/common/benchmarker.h"
  14. #include "test/common/checker.h"
  15. #include "test/common/matrix_mul.h"
  16. #include "test/common/rng.h"
  17. using namespace megdnn;
  18. using namespace test;
  19. using namespace megdnn::x86;
  20. #if MEGDNN_X86_WITH_VNNI
  21. TEST_F(X86, MATRIX_MUL_VNNI_8X8X32) {
  22. matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int32{},
  23. handle(), "X86_INT8X8X32_VNNI");
  24. }
  25. #endif
  26. #if MEGDNN_X86_WITH_MKL_DNN
  27. TEST_F(X86, MATRIX_MUL_MKLDNN_8X8X32) {
  28. if (is_supported(SIMDType::VNNI)) {
  29. matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{},
  30. dtype::Int32{}, handle(),
  31. "X86_INT8X8X32_MKLDNN");
  32. } else {
  33. std::cout << "can not do mkldnn matmul check for no vnni support"
  34. << std::endl;
  35. matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{},
  36. dtype::Int32{}, handle());
  37. }
  38. }
  39. #endif
  40. //! FIXME: need to add tests of GEMV and QUINT8
  41. TEST_F(X86, MATRIX_MUL_AVX2_8X8X32) {
  42. matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int32{},
  43. handle(), "X86_INT8X8X32_AVX2_2X4X16");
  44. matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int32{},
  45. handle(), "X86_INT8X8X32_AVX2_4X16X2");
  46. }
  47. TEST_F(X86, MATRIX_MUL_SSE_8X8X32) {
  48. matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int32{},
  49. handle(), "X86_INT8X8X32_SSE_4X8X2");
  50. }
  51. #if MEGDNN_X86_WITH_MKL && SUPPORT_MKL_PACKED_GEMM
  52. TEST_F(X86, MATRIX_MUL_MKL_PACKA) {
  53. matrix_mul::check_matrix_mul(dtype::Float32{}, dtype::Float32{},
  54. dtype::Float32{}, handle(),
  55. "X86_F32_MKL_PACKA");
  56. }
  57. #endif
  58. TEST_F(X86, MATRIX_MUL_AVX2_MK8_8X8) {
  59. matrix_mul::check_matrix_mul(dtype::Float32{}, dtype::Float32{},
  60. dtype::Float32{}, handle(), "X86_F32MK8_8X8",
  61. param::MatrixMul::Format::MK8, 1);
  62. }
  63. #if MEGDNN_WITH_BENCHMARK
  64. TEST_F(X86, BENCHMARK_MATRIX_MUL_AVX2_MK8_8X8) {
  65. auto args = matrix_mul::get_benchmark_matmul_mk_packed_args(8);
  66. matrix_mul::benchmark_with_contrast(
  67. handle(), args, dtype::Float32{}, dtype::Float32{},
  68. dtype::Float32{}, "X86_F32MK8_8X8", param::MatrixMul::Format::MK8,
  69. dtype::Float32{}, dtype::Float32{}, dtype::Float32{},
  70. "X86_F32_BLAS");
  71. }
  72. TEST_F(X86, BENCHMARK_MATRIX_MUL_8X8X32) {
  73. constexpr size_t RUNS = 50;
  74. auto rng = std::make_unique<UniformIntRNG>(-127, 127);
  75. #if MEGDNN_X86_WITH_VNNI
  76. Benchmarker<MatrixMul> benchmarker_vnni(handle());
  77. benchmarker_vnni.set_times(RUNS)
  78. .set_dtype(0, dtype::Int8{})
  79. .set_dtype(1, dtype::Int8{})
  80. .set_dtype(2, dtype::Int32{})
  81. .set_display(false)
  82. .set_rng(0, rng.get())
  83. .set_rng(1, rng.get());
  84. benchmarker_vnni.set_before_exec_callback(
  85. AlgoChecker<MatrixMul>("X86_INT8X8X32_VNNI"));
  86. #endif
  87. #if MEGDNN_X86_WITH_MKL_DNN
  88. Benchmarker<MatrixMul> benchmarker_mkldnn(handle());
  89. benchmarker_mkldnn.set_times(RUNS)
  90. .set_dtype(0, dtype::Int8{})
  91. .set_dtype(1, dtype::Int8{})
  92. .set_dtype(2, dtype::Int32{})
  93. .set_display(false)
  94. .set_rng(0, rng.get())
  95. .set_rng(1, rng.get());
  96. benchmarker_mkldnn.set_before_exec_callback(
  97. AlgoChecker<MatrixMul>("X86_INT8X8X32_MKLDNN"));
  98. #endif
  99. Benchmarker<MatrixMul> benchmarker_avx2_4x16x2(handle());
  100. benchmarker_avx2_4x16x2.set_display(false)
  101. .set_times(RUNS)
  102. .set_dtype(0, dtype::Int8{})
  103. .set_dtype(1, dtype::Int8{})
  104. .set_dtype(2, dtype::Int32{})
  105. .set_rng(0, rng.get())
  106. .set_rng(1, rng.get());
  107. benchmarker_avx2_4x16x2.set_before_exec_callback(
  108. AlgoChecker<MatrixMul>("X86_INT8X8X32_AVX2_4X16X2"));
  109. Benchmarker<MatrixMul> benchmarker_avx2_2x4x16(handle());
  110. benchmarker_avx2_2x4x16.set_display(false)
  111. .set_times(RUNS)
  112. .set_dtype(0, dtype::Int8{})
  113. .set_dtype(1, dtype::Int8{})
  114. .set_dtype(2, dtype::Int32{})
  115. .set_rng(0, rng.get())
  116. .set_rng(1, rng.get());
  117. benchmarker_avx2_2x4x16.set_before_exec_callback(
  118. AlgoChecker<MatrixMul>("X86_INT8X8X32_AVX2_2X4X16"));
  119. Benchmarker<MatrixMul> benchmarker_sse_4x8x2(handle());
  120. benchmarker_sse_4x8x2.set_display(false)
  121. .set_times(RUNS)
  122. .set_dtype(0, dtype::Int8{})
  123. .set_dtype(1, dtype::Int8{})
  124. .set_dtype(2, dtype::Int32{})
  125. .set_rng(0, rng.get())
  126. .set_rng(1, rng.get());
  127. benchmarker_sse_4x8x2.set_before_exec_callback(
  128. AlgoChecker<MatrixMul>("X86_INT8X8X32_SSE_4X8X2"));
  129. Benchmarker<MatrixMul> benchmarker_float(handle());
  130. benchmarker_float.set_display(false)
  131. .set_times(RUNS)
  132. .set_rng(0, rng.get())
  133. .set_rng(1, rng.get());
  134. benchmarker_float.set_before_exec_callback(
  135. AlgoChecker<MatrixMul>("X86_F32_BLAS"));
  136. auto run = [&](size_t M, size_t N, size_t K) {
  137. const float computations = 2.f * M * K * N * 1e-6;
  138. std::cout << "run : {" << M << "," << N << "," << K << "} ";
  139. auto float_used = benchmarker_float.exec({{M, K}, {K, N}, {}}) / RUNS;
  140. std::cout << "float: " << float_used << " ms, "
  141. << computations / float_used << " Gflops, ";
  142. #if MEGDNN_X86_WITH_VNNI
  143. if (is_supported(SIMDType::VNNI)) {
  144. auto vnni_used = benchmarker_vnni.exec({{M, K}, {K, N}, {}}) / RUNS;
  145. std::cout << "vnni: " << vnni_used << " ms, "
  146. << computations / vnni_used << " Gflops, "
  147. << "speed_up " << float_used / vnni_used << ", ";
  148. }
  149. #endif
  150. #if MEGDNN_X86_WITH_MKL_DNN
  151. if (is_supported(SIMDType::VNNI)) {
  152. auto mkldnn_used =
  153. benchmarker_mkldnn.exec({{M, K}, {K, N}, {}}) / RUNS;
  154. std::cout << "mkldnn: " << mkldnn_used << " ms, "
  155. << computations / mkldnn_used << " Gflops, "
  156. << "speed_up " << float_used / mkldnn_used << ", ";
  157. }
  158. #endif
  159. if (is_supported(SIMDType::AVX2)) {
  160. auto avx2_used_4x16x2 =
  161. benchmarker_avx2_4x16x2.exec({{M, K}, {K, N}, {}}) / RUNS;
  162. auto avx2_used_2x4x16 =
  163. benchmarker_avx2_2x4x16.exec({{M, K}, {K, N}, {}}) / RUNS;
  164. std::cout << "avx2_k2: " << avx2_used_4x16x2
  165. << " ms, k2 throughput "
  166. << computations / avx2_used_4x16x2 << " Gflops, "
  167. << "k2_speed_up " << float_used / avx2_used_4x16x2
  168. << ", k16_speed_up " << float_used / avx2_used_2x4x16
  169. << ",";
  170. }
  171. if (is_supported(SIMDType::SSE4_1)) {
  172. auto sse_used =
  173. benchmarker_sse_4x8x2.exec({{M, K}, {K, N}, {}}) / RUNS;
  174. std::cout << "sse: " << sse_used << " ms, "
  175. << computations / sse_used << " Gflops, "
  176. << "speed_up " << float_used / sse_used << ", ";
  177. }
  178. std::cout << std::endl;
  179. };
  180. for (size_t M : {8, 64, 112, 256, 512}) {
  181. for (size_t K : {8, 16, 32, 64, 112, 256, 512}) {
  182. for (size_t N : {8, 64, 112, 256, 512}) {
  183. run(M, N, K);
  184. }
  185. }
  186. }
  187. }
  188. #endif // MEGDNN_WITH_BENCHMARK
  189. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台