You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

matrix_mul.cpp 19 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484
  1. /**
  2. * \file dnn/test/armv7/matrix_mul.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "test/armv7/fixture.h"
  12. #include "test/common/benchmarker.h"
  13. #include "test/common/checker.h"
  14. #include "test/common/matrix_mul.h"
  15. #include "test/common/rng.h"
  16. using namespace megdnn;
  17. using namespace test;
  18. TEST_F(ARMV7, MATRIX_MUL) {
  19. matrix_mul::check_matrix_mul(dtype::Float32{}, dtype::Float32{},
  20. dtype::Float32{}, handle(), "ARMV7_F32");
  21. }
  22. TEST_F(ARMV7, MATRIX_MUL_MK4) {
  23. matrix_mul::check_matrix_mul(
  24. dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(),
  25. "ARMV7_F32_MK4_4x8", param::MatrixMul::Format::MK4, 4);
  26. }
  27. TEST_F(ARMV7, MATRIX_MUL_PACK_MK4) {
  28. matrix_mul::check_matrix_mul(
  29. dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(),
  30. "ARMV7_F32_MK4_PACK_4X12", param::MatrixMul::Format::MK4, 1);
  31. }
  32. TEST_F(ARMV7, MATRIX_MUL_MK4_INT8) {
  33. std::vector<matrix_mul::TestArg> args;
  34. for (size_t m : {1, 2, 3, 4, 5, 7, 10, 11})
  35. for (size_t n : {1, 2, 3, 4, 5, 8, 16, 24, 25, 32})
  36. for (size_t k : {1, 2, 3, 4, 5, 6, 7, 8, 16, 32, 33, 34})
  37. args.emplace_back(m, n, k, 0);
  38. matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int32{},
  39. handle(), "ARMV7_INT8X8X32_MK4_4X2X16",
  40. param::MatrixMul::Format::MK4, 1, 1e-3,
  41. std::move(args));
  42. }
  43. TEST_F(ARMV7, MATRIX_MUL_INT8x8x16_K4x8x8) {
  44. matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int16{},
  45. handle(), "ARMV7_INT8X8X16_K4X8X8");
  46. }
  47. TEST_F(ARMV7, MATRIX_MUL_INT16x16x32) {
  48. matrix_mul::check_matrix_mul(dtype::Int16{}, dtype::Int16{}, dtype::Int32{},
  49. handle(),"ARMV7_INT16X16X32_K12X4X1");
  50. }
  51. TEST_F(ARMV7, MATRIX_MUL_INT16x16x32_MK8) {
  52. matrix_mul::check_matrix_mul(dtype::Int16{}, dtype::Int16{}, dtype::Int32{},
  53. handle(), "ARMV7_INT16X16X32_MK8_4X8",
  54. param::MatrixMul::Format::MK8, 4);
  55. }
  56. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  57. TEST_F(ARMV7, MATRIX_MUL_FP16) {
  58. matrix_mul::check_matrix_mul(dtype::Float16{}, dtype::Float16{},
  59. dtype::Float16{}, handle(),
  60. "AARCH32_F16_K4X16X1");
  61. }
  62. TEST_F(ARMV7, MATRIX_MUL_F16_MK8) {
  63. matrix_mul::check_matrix_mul(
  64. dtype::Float16{}, dtype::Float16{}, dtype::Float16{}, handle(),
  65. "AARCH32_F16_MK8_4X8", param::MatrixMul::Format::MK8, 4);
  66. }
  67. #endif
  68. #if __ARM_FEATURE_DOTPROD
  69. TEST_F(ARMV7, MATRIX_MUL_SDOT) {
  70. matrix_mul::check_matrix_mul(dtype::Int8(), dtype::Int8(), dtype::Int32(),
  71. handle(), "AARCH32_INT8_K6X8X4");
  72. }
  73. TEST_F(ARMV7, MATRIX_MUL_UDOT) {
  74. matrix_mul::check_matrix_mul(
  75. dtype::Quantized8Asymm(4.0f, static_cast<uint8_t>(10)), dtype::Quantized8Asymm(3.0f, static_cast<uint8_t>(54)),
  76. dtype::QuantizedS32(12.0f), handle(), "AARCH32_QUINT8_K4X8X4");
  77. }
  78. TEST_F(ARMV7, MATRIX_MUL_MK4_DOT_INT8) {
  79. std::vector<matrix_mul::TestArg> args;
  80. for (size_t m : {1, 2, 3, 4, 5, 7, 10, 11})
  81. for (size_t n : {1, 2, 3, 4, 5, 8, 16, 24, 25, 32})
  82. for (size_t k : {1, 2, 3, 4, 5, 6, 7, 8, 16, 32, 33, 34})
  83. args.emplace_back(m, n, k, 0);
  84. matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int32{},
  85. handle(), "AARCH32_INT8_MK4_8X4X4_DOTPROD",
  86. param::MatrixMul::Format::MK4_DOT, 1, 1e-3,
  87. std::move(args));
  88. }
  89. #endif
  90. #if MEGDNN_WITH_BENCHMARK
  91. namespace {
  92. void run_8x8x16_benchmark(const char* algo, Handle* handle) {
  93. constexpr size_t RUNS = 50;
  94. param::MatrixMul param;
  95. Benchmarker<MatrixMul> benchmarker_int(handle);
  96. Benchmarker<MatrixMul> benchmarker_int_kern_4x2x16(handle);
  97. benchmarker_int.set_before_exec_callback(
  98. AlgoChecker<MatrixMul>("ARM_COMMON_INT8X8X16"));
  99. benchmarker_int.set_times(RUNS)
  100. .set_dtype(0, dtype::Int8{})
  101. .set_dtype(1, dtype::Int8{})
  102. .set_dtype(2, dtype::Int16{})
  103. .set_param(param)
  104. .set_display(false);
  105. benchmarker_int_kern_4x2x16.set_before_exec_callback(
  106. AlgoChecker<MatrixMul>(algo));
  107. benchmarker_int_kern_4x2x16.set_times(RUNS)
  108. .set_dtype(0, dtype::Int8{})
  109. .set_dtype(1, dtype::Int8{})
  110. .set_dtype(2, dtype::Int16{})
  111. .set_param(param)
  112. .set_display(false);
  113. Benchmarker<MatrixMul> benchmarker_float(handle);
  114. benchmarker_float.set_display(false).set_times(RUNS);
  115. auto run = [&](size_t M, size_t N, size_t K) {
  116. auto int_used = benchmarker_int.exec({{M, K}, {K, N}, {}}) / RUNS;
  117. auto int_kern_used =
  118. benchmarker_int_kern_4x2x16.exec({{M, K}, {K, N}, {}}) / RUNS;
  119. auto float_used = benchmarker_float.exec({{M, K}, {K, N}, {}}) / RUNS;
  120. float computations = 2.f * M * K * N * 1e-6;
  121. printf("run: {%zu{M} %zu{K} %zu{N}} float: %f ms %f Gflops int: %f "
  122. "ms "
  123. "%f Gflops %s: %f ms %f Gflops "
  124. "speedup(%s/arm_common, %s/float): %f "
  125. "%f\n",
  126. M, K, N, float_used, computations / float_used, int_used,
  127. computations / int_used, algo, int_kern_used,
  128. computations / int_kern_used, algo, algo,
  129. int_used / int_kern_used, float_used / int_kern_used);
  130. };
  131. run(256, 12 * 24, 256);
  132. //////////////////////// gemv //////////////////////////
  133. for (size_t M : {8, 64, 112, 256}) {
  134. for (size_t K : {8, 64, 112, 256}) {
  135. run(M, 1, K);
  136. }
  137. }
  138. //////////////////////// gemm //////////////////////////
  139. for (size_t M : {8, 64, 112, 256}) {
  140. for (size_t K : {8, 16, 32, 64, 112, 256}) {
  141. for (size_t N : {8, 64, 112, 256}) {
  142. run(M, N, K);
  143. }
  144. }
  145. }
  146. }
  147. void run_16x16x32_benchmark(const char* algo, Handle* handle) {
  148. constexpr size_t RUNS = 50;
  149. param::MatrixMul param;
  150. Benchmarker<MatrixMul> benchmarker_int(handle);
  151. benchmarker_int.set_before_exec_callback(
  152. AlgoChecker<MatrixMul>("ARMV7_INT16X16X32_K12X4X1"));
  153. benchmarker_int.set_times(RUNS)
  154. .set_dtype(0, dtype::Int16{})
  155. .set_dtype(1, dtype::Int16{})
  156. .set_dtype(2, dtype::Int32{})
  157. .set_param(param)
  158. .set_display(false);
  159. Benchmarker<MatrixMul> benchmarker_float(handle);
  160. benchmarker_float.set_display(false).set_times(RUNS);
  161. auto run = [&](size_t M, size_t N, size_t K) {
  162. auto int_used = benchmarker_int.exec({{M, K}, {K, N}, {}}) / RUNS;
  163. auto float_used = benchmarker_float.exec({{M, K}, {K, N}, {}}) / RUNS;
  164. float computations = 2.f * M * K * N * 1e-6;
  165. printf("run: {%zu{M} %zu{K} %zu{N}} float: %f ms %f Gflops \n"
  166. "int: %f ms %f Gflops %s: \n"
  167. "speedup(%s/arm_common, %s/float): %f\n",
  168. M, K, N, float_used, computations / float_used, int_used,
  169. computations / int_used,algo,algo,algo,float_used / int_used);
  170. };
  171. run(256, 12 * 24, 256);
  172. //////////////////////// gemv //////////////////////////
  173. for (size_t M : {8, 64, 112, 256}) {
  174. for (size_t K : {8, 64, 112, 256}) {
  175. run(M, 1, K);
  176. }
  177. }
  178. //////////////////////// gemm //////////////////////////
  179. for (size_t M : {8, 64, 112, 256}) {
  180. for (size_t K : {8, 16, 32, 64, 112, 256}) {
  181. for (size_t N :
  182. {1, 2, 3, 4, 8, 64, 112, 113, 114, 115, 256, 257, 258, 259}) {
  183. run(M, N, K);
  184. }
  185. }
  186. }
  187. }
  188. #if __ARM_FEATURE_DOTPROD
  189. void run_8x8x32_benchmark(const char* algo, Handle* handle) {
  190. constexpr size_t RUNS = 50;
  191. param::MatrixMul param;
  192. Benchmarker<MatrixMul> benchmarker_int8(handle);
  193. benchmarker_int8.set_before_exec_callback(AlgoChecker<MatrixMul>(algo));
  194. benchmarker_int8.set_times(RUNS)
  195. .set_dtype(0, dtype::Int8{})
  196. .set_dtype(1, dtype::Int8{})
  197. .set_dtype(2, dtype::Int32{})
  198. .set_param(param)
  199. .set_display(false);
  200. Benchmarker<MatrixMul> benchmarker_float(handle);
  201. benchmarker_float.set_display(false).set_times(RUNS);
  202. auto run = [&](size_t M, size_t N, size_t K) {
  203. auto int_used = benchmarker_int8.exec({{M, K}, {K, N}, {}}) / RUNS;
  204. auto float_used = benchmarker_float.exec({{M, K}, {K, N}, {}}) / RUNS;
  205. float computations = 2.f * M * K * N * 1e-6;
  206. printf("run: {%zu{M} %zu{K} %zu{N}} float: %f ms %f Gflops \n"
  207. "int: %f ms %f Gflops %s: \n"
  208. "speedup(%s/arm_common, %s/float): %f\n",
  209. M, K, N, float_used, computations / float_used, int_used,
  210. computations / int_used,algo,algo,algo,float_used / int_used);
  211. };
  212. run(256, 12 * 24, 256);
  213. //////////////////////// gemm //////////////////////////
  214. for (size_t M : {8, 64, 112, 256}) {
  215. for (size_t K : {8, 16, 32, 64, 112, 256}) {
  216. for (size_t N : {113, 114, 115, 256, 1024}) {
  217. run(M, N, K);
  218. }
  219. }
  220. }
  221. }
  222. void run_8x8x32_quint_benchmark(Handle* handle) {
  223. constexpr size_t RUNS = 50;
  224. param::MatrixMul param;
  225. Benchmarker<MatrixMul> benchmarker_quint8_dot(handle);
  226. benchmarker_quint8_dot.set_before_exec_callback(
  227. AlgoChecker<MatrixMul>("AARCH32_QUINT8_K4X8X4"));
  228. benchmarker_quint8_dot.set_times(RUNS)
  229. .set_dtype(0, dtype::Quantized8Asymm(2.3f, static_cast<uint8_t>(20)))
  230. .set_dtype(1, dtype::Quantized8Asymm(3.1f, static_cast<uint8_t>(30)))
  231. .set_dtype(2, dtype::QuantizedS32(2.3f*3.1f))
  232. .set_param(param)
  233. .set_display(false);
  234. Benchmarker<MatrixMul> benchmarker_quint8(handle);
  235. benchmarker_quint8.set_before_exec_callback(
  236. AlgoChecker<MatrixMul>("ARMV7_QUINT8_K4X8X8"));
  237. benchmarker_quint8.set_times(RUNS)
  238. .set_dtype(0, dtype::Quantized8Asymm(2.3f, static_cast<uint8_t>(20)))
  239. .set_dtype(1, dtype::Quantized8Asymm(3.1f, static_cast<uint8_t>(30)))
  240. .set_dtype(2, dtype::QuantizedS32(2.3f*3.1f))
  241. .set_param(param)
  242. .set_display(false);
  243. auto run = [&](size_t M, size_t N, size_t K) {
  244. auto dot_used = benchmarker_quint8_dot.exec({{M, K}, {K, N}, {}}) / RUNS;
  245. auto normal_used = benchmarker_quint8.exec({{M, K}, {K, N}, {}}) / RUNS;
  246. float computations = 2.f * M * K * N * 1e-6;
  247. printf("run: {%zu{M} %zu{K} %zu{N}} dot: %f ms %f Gflops \n"
  248. "normal: %f ms %f Gflops.speedup: %f\n",
  249. M, K, N, dot_used, computations / dot_used, normal_used,
  250. computations / normal_used, normal_used / dot_used);
  251. };
  252. run(256, 12 * 24, 256);
  253. //////////////////////// gemm //////////////////////////
  254. for (size_t M : {8, 64, 112, 256}) {
  255. for (size_t K : {8, 16, 32, 64, 112, 256}) {
  256. for (size_t N : {113, 114, 115, 256, 1024}) {
  257. run(M, N, K);
  258. }
  259. }
  260. }
  261. }
  262. #endif
  263. } // namespace
  264. #if __ARM_FEATURE_DOTPROD
  265. TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_INT8x8x32_K6x8x4) {
  266. run_8x8x32_benchmark("AARCH32_INT8_K6X8X4", handle());
  267. }
  268. TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_QUINT8x8x32_K4x8x4) {
  269. run_8x8x32_quint_benchmark(handle());
  270. }
  271. TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_INT8x8x32_MK4_DOT) {
  272. constexpr size_t RUNS = 50;
  273. param::MatrixMul param;
  274. Benchmarker<MatrixMul> benchmarker_default(handle());
  275. benchmarker_default.set_times(RUNS)
  276. .set_dtype(0, dtype::Int8())
  277. .set_dtype(1, dtype::Int8())
  278. .set_dtype(2, dtype::Int32())
  279. .set_param(param)
  280. .set_display(false);
  281. benchmarker_default.set_before_exec_callback(
  282. AlgoChecker<MatrixMul>("AARCH32_INT8_K6X8X4"));
  283. param.format = MatrixMul::Param::Format::MK4_DOT;
  284. Benchmarker<MatrixMul> benchmarker_mk4_dot(handle());
  285. benchmarker_mk4_dot.set_before_exec_callback(
  286. AlgoChecker<MatrixMul>("AARCH32_INT8_MK4_8X4X4_DOTPROD"));
  287. benchmarker_mk4_dot.set_param(param)
  288. .set_dtype(0, dtype::Int8())
  289. .set_dtype(1, dtype::Int8())
  290. .set_dtype(2, dtype::Int32())
  291. .set_display(false)
  292. .set_times(RUNS);
  293. auto run = [&](size_t M, size_t N, size_t K) {
  294. auto default_used =
  295. benchmarker_default.exec({{M, K}, {K, N}, {}}) / RUNS;
  296. auto mk4_dot_used = benchmarker_mk4_dot.exec(
  297. {{M / 4, K / 4, 4, 4}, {K / 4, N, 4}, {}}) /
  298. RUNS;
  299. float computations = 2.f * M * K * N * 1e-6;
  300. printf("run: {%zu{M} %zu{K} %zu{N}} default: %f ms %f Gflops mk4_dot: "
  301. "%f ms "
  302. "%f Gflops speedup: %f\n",
  303. M, K, N, default_used, computations / default_used, mk4_dot_used,
  304. computations / mk4_dot_used, default_used / mk4_dot_used);
  305. };
  306. for (size_t M = 4; M < 512; M *= 2) {
  307. for (size_t K = 4; K < 512; K *= 2) {
  308. for (size_t N : {4, 8, 33, 113, 128}) {
  309. run(M, N, K);
  310. }
  311. }
  312. }
  313. }
  314. #endif
  315. TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_INT8x8x16_K4x2x16) {
  316. run_8x8x16_benchmark("ARMV7_INT8X8X16_K4X2X16", handle());
  317. }
  318. TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_INT8x8x16_K4x8x8) {
  319. run_8x8x16_benchmark("ARMV7_INT8X8X16_K4X8X8", handle());
  320. }
  321. TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_INT16x16x32_K12x4x1) {
  322. run_16x16x32_benchmark("ARMV7_INT16X16X32_K12X4X1", handle());
  323. }
  324. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  325. TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_FP16) {
  326. constexpr size_t RUNS = 50;
  327. param::MatrixMul param;
  328. Benchmarker<MatrixMul> benchmarker_fp16(handle());
  329. benchmarker_fp16.set_times(RUNS)
  330. .set_dtype(0, dtype::Float16())
  331. .set_dtype(1, dtype::Float16())
  332. .set_dtype(2, dtype::Float16())
  333. .set_param(param)
  334. .set_display(false);
  335. Benchmarker<MatrixMul> benchmarker_float(handle());
  336. benchmarker_float.set_param(param).set_display(false).set_times(RUNS);
  337. auto run = [&](size_t M, size_t N, size_t K) {
  338. auto fp16_used = benchmarker_fp16.exec({{M, K}, {K, N}, {}}) / RUNS;
  339. auto float_used = benchmarker_float.exec({{M, K}, {K, N}, {}}) / RUNS;
  340. float computations = 2.f * M * K * N * 1e-6;
  341. printf("run: {%zu{M} %zu{K} %zu{N}} float: %f ms %f Gflops fp16: %f ms "
  342. "%f Gflops speedup: %f\n",
  343. M, K, N, float_used, computations / float_used, fp16_used,
  344. computations / fp16_used, float_used / fp16_used);
  345. };
  346. run(256, 12 * 24, 256);
  347. for (size_t M : {8, 64, 112, 256}) {
  348. for (size_t K : {8, 64, 112, 256}) {
  349. for (size_t N : {8, 64, 112, 256}) {
  350. run(M, N, K);
  351. }
  352. }
  353. }
  354. }
  355. TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_F16_MK8) {
  356. auto args = matrix_mul::get_benchmark_matmul_mk_packed_args(4);
  357. matrix_mul::benchmark_with_contrast(
  358. handle(), args, dtype::Float16{}, dtype::Float16{},
  359. dtype::Float16{}, "AARCH32_F16_MK8_4X8",
  360. param::MatrixMul::Format::MK8, dtype::Float16{}, dtype::Float16{},
  361. dtype::Float16{}, "AARCH32_F16_K4X16X1");
  362. }
  363. #endif
  364. TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_MK4) {
  365. auto args = matrix_mul::get_benchmark_matmul_mk_packed_args(8);
  366. matrix_mul::benchmark_with_contrast(
  367. handle(), args, dtype::Float32{}, dtype::Float32{},
  368. dtype::Float32{}, "ARMV7_F32_MK4_4x8",
  369. param::MatrixMul::Format::MK4, dtype::Float32{}, dtype::Float32{},
  370. dtype::Float32{});
  371. }
  372. TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_PACK_MK4) {
  373. auto args = matrix_mul::get_benchmark_matmul_mk_packed_args(8);
  374. matrix_mul::benchmark_with_contrast(
  375. handle(), args, dtype::Float32{}, dtype::Float32{},
  376. dtype::Float32{}, "ARMV7_F32_MK4_PACK_4X12",
  377. param::MatrixMul::Format::MK4, dtype::Float32{}, dtype::Float32{},
  378. dtype::Float32{});
  379. }
  380. TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_INT16x16x32_MK8) {
  381. auto args = matrix_mul::get_benchmark_matmul_mk_packed_args(4);
  382. matrix_mul::benchmark_with_contrast(
  383. handle(), args, dtype::Int16{}, dtype::Int16{}, dtype::Int32{},
  384. "ARMV7_INT16X16X32_MK8_4X8", param::MatrixMul::Format::MK8,
  385. dtype::Int16{}, dtype::Int16{}, dtype::Int32{});
  386. }
  387. TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_INT32_MK_4X2X16) {
  388. constexpr size_t RUNS = 50;
  389. param::MatrixMul param;
  390. param.transposeA = false;
  391. param.transposeB = false;
  392. Benchmarker<MatrixMul> benchmarker(handle());
  393. Benchmarker<MatrixMul> benchmarker_mk4(handle());
  394. benchmarker.set_times(RUNS)
  395. .set_dtype(0, dtype::Int8{})
  396. .set_dtype(1, dtype::Int8{})
  397. .set_dtype(2, dtype::Int32{})
  398. .set_param(param)
  399. .set_display(false);
  400. benchmarker.set_before_exec_callback(
  401. AlgoChecker<MatrixMul>("ARMV7_INT8X8X32_K4X2X16"));
  402. param.format = MatrixMul::Param::Format::MK4;
  403. benchmarker_mk4.set_before_exec_callback(
  404. AlgoChecker<MatrixMul>("ARMV7_INT8X8X32_MK4_4X2X16"));
  405. benchmarker_mk4.set_times(RUNS)
  406. .set_dtype(0, dtype::Int8{})
  407. .set_dtype(1, dtype::Int8{})
  408. .set_dtype(2, dtype::Int32{})
  409. .set_param(param)
  410. .set_display(false);
  411. auto run = [&](size_t M, size_t N, size_t K) {
  412. auto mk_used = benchmarker_mk4.exec(
  413. {{M / 4, K / 4, 4, 4}, {K / 4, N, 4}, {}}) /
  414. RUNS;
  415. auto default_used = benchmarker.exec({{M, K}, {K, N}, {}}) / RUNS;
  416. float computations = 2.f * M * K * N * 1e-6;
  417. printf("run: {%zu{M} %zu{K} %zu{N}} normal: %f ms %f Gflops mk4: %f ms "
  418. "%f Gflops speedup_vs_normal: %f\n",
  419. M, K, N, default_used, computations / default_used, mk_used,
  420. computations / mk_used, default_used / mk_used);
  421. };
  422. run(256, 256, 128);
  423. for (size_t k = 4; k <= 512; k *= 2) {
  424. for (size_t m = 4; m <= 512; m *= 2) {
  425. for (size_t n = 4; n <= 512; n *= 2) {
  426. run(m, n, k);
  427. }
  428. }
  429. std::cout << std::endl;
  430. }
  431. }
  432. #endif
  433. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台