You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

matrix_mul.cpp 20 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516
  1. /**
  2. * \file dnn/test/armv7/matrix_mul.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "test/armv7/fixture.h"
  13. #include "test/common/benchmarker.h"
  14. #include "test/common/checker.h"
  15. #include "test/common/matrix_mul.h"
  16. #include "test/common/rng.h"
  17. using namespace megdnn;
  18. using namespace test;
  19. TEST_F(ARMV7, MATRIX_MUL) {
  20. matrix_mul::check_matrix_mul(dtype::Float32{}, dtype::Float32{},
  21. dtype::Float32{}, handle(), "ARMV7_F32");
  22. }
  23. TEST_F(ARMV7, MATRIX_MUL_MK4) {
  24. matrix_mul::check_matrix_mul(
  25. dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(),
  26. "ARMV7_F32_MK4_4x8", param::MatrixMul::Format::MK4, 1);
  27. }
  28. TEST_F(ARMV7, MATRIX_MUL_PACK_MK4) {
  29. matrix_mul::check_matrix_mul(
  30. dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(),
  31. "ARMV7_F32_MK4_PACK_4X12", param::MatrixMul::Format::MK4, 1);
  32. }
  33. TEST_F(ARMV7, MATRIX_MUL_MK4_INT8) {
  34. std::vector<matrix_mul::TestArg> args;
  35. for (size_t m : {1, 2, 3, 4, 5, 7, 10, 11})
  36. for (size_t n : {1, 2, 3, 4, 5, 8, 16, 24, 25, 32})
  37. for (size_t k : {1, 2, 3, 4, 5, 6, 7, 8, 16, 32, 33, 34})
  38. args.emplace_back(m, n, k, 0);
  39. matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int32{},
  40. handle(), "ARMV7_INT8X8X32_MK4_4X2X16",
  41. param::MatrixMul::Format::MK4, 1, 1e-3,
  42. std::move(args));
  43. }
  44. TEST_F(ARMV7, MATRIX_MUL_INT8x8x16_K4x8x8) {
  45. matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int16{},
  46. handle(), "ARMV7_INT8X8X16_K4X8X8");
  47. }
  48. TEST_F(ARMV7, MATRIX_MUL_INT8x8x16_MK4_K8x8x4) {
  49. matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int16{},
  50. handle(), "ARMV7_INT8X8X16_MK4_K8X8X4",
  51. param::MatrixMul::Format::MK4, 1);
  52. }
  53. TEST_F(ARMV7, MATRIX_MUL_INT16x16x32) {
  54. matrix_mul::check_matrix_mul(dtype::Int16{}, dtype::Int16{}, dtype::Int32{},
  55. handle(), "ARMV7_INT16X16X32_K12X4X1");
  56. }
  57. TEST_F(ARMV7, MATRIX_MUL_INT16x16x32_MK8) {
  58. matrix_mul::check_matrix_mul(dtype::Int16{}, dtype::Int16{}, dtype::Int32{},
  59. handle(), "ARMV7_INT16X16X32_MK8_4X8",
  60. param::MatrixMul::Format::MK8, 1);
  61. }
  62. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  63. TEST_F(ARMV7, MATRIX_MUL_FP16) {
  64. matrix_mul::check_matrix_mul(dtype::Float16{}, dtype::Float16{},
  65. dtype::Float16{}, handle(),
  66. "AARCH32_F16_K4X16X1");
  67. }
  68. TEST_F(ARMV7, MATRIX_MUL_F16_MK8) {
  69. matrix_mul::check_matrix_mul(
  70. dtype::Float16{}, dtype::Float16{}, dtype::Float16{}, handle(),
  71. "AARCH32_F16_MK8_4X8", param::MatrixMul::Format::MK8, 1);
  72. }
  73. #endif
  74. #if __ARM_FEATURE_DOTPROD
  75. TEST_F(ARMV7, MATRIX_MUL_SDOT) {
  76. matrix_mul::check_matrix_mul(dtype::Int8(), dtype::Int8(), dtype::Int32(),
  77. handle(), "AARCH32_INT8_K6X8X4");
  78. }
  79. TEST_F(ARMV7, MATRIX_MUL_UDOT) {
  80. matrix_mul::check_matrix_mul(
  81. dtype::Quantized8Asymm(4.0f, static_cast<uint8_t>(10)),
  82. dtype::Quantized8Asymm(3.0f, static_cast<uint8_t>(54)),
  83. dtype::QuantizedS32(12.0f), handle(), "AARCH32_QUINT8_K4X8X4");
  84. }
  85. TEST_F(ARMV7, MATRIX_MUL_MK4_DOT_INT8) {
  86. std::vector<matrix_mul::TestArg> args;
  87. for (size_t m : {1, 2, 3, 4, 5, 7, 10, 11})
  88. for (size_t n : {1, 2, 3, 4, 5, 8, 16, 24, 25, 32})
  89. for (size_t k : {1, 2, 3, 4, 5, 6, 7, 8, 16, 32, 33, 34})
  90. args.emplace_back(m, n, k, 0);
  91. matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int32{},
  92. handle(), "AARCH32_INT8_MK4_8X4X4_DOTPROD",
  93. param::MatrixMul::Format::MK4_DOT, 1, 1e-3,
  94. std::move(args));
  95. }
  96. #endif
  97. #if MEGDNN_WITH_BENCHMARK
  98. namespace {
  99. void run_8x8x16_benchmark(
  100. const char* algo, Handle* handle,
  101. MatrixMul::Param::Format format = MatrixMul::Param::Format::DEFAULT) {
  102. constexpr size_t RUNS = 50;
  103. param::MatrixMul param;
  104. Benchmarker<MatrixMul> benchmarker_int(handle);
  105. Benchmarker<MatrixMul> benchmarker_int_kern_4x2x16(handle);
  106. benchmarker_int.set_before_exec_callback(
  107. AlgoChecker<MatrixMul>("ARM_COMMON_INT8X8X16"));
  108. benchmarker_int.set_times(RUNS)
  109. .set_dtype(0, dtype::Int8{})
  110. .set_dtype(1, dtype::Int8{})
  111. .set_dtype(2, dtype::Int16{})
  112. .set_param(param)
  113. .set_display(false);
  114. param::MatrixMul target_param;
  115. target_param.format = format;
  116. benchmarker_int_kern_4x2x16.set_before_exec_callback(
  117. AlgoChecker<MatrixMul>(algo));
  118. benchmarker_int_kern_4x2x16.set_times(RUNS)
  119. .set_dtype(0, dtype::Int8{})
  120. .set_dtype(1, dtype::Int8{})
  121. .set_dtype(2, dtype::Int16{})
  122. .set_param(target_param)
  123. .set_display(false);
  124. Benchmarker<MatrixMul> benchmarker_float(handle);
  125. benchmarker_float.set_display(false).set_times(RUNS);
  126. auto run = [&](size_t M, size_t N, size_t K) {
  127. auto int_used = benchmarker_int.exec({{M, K}, {K, N}, {}}) / RUNS;
  128. auto int_kern_used = 1e10;
  129. if (format == MatrixMul::Param::Format::MK4) {
  130. int_kern_used = benchmarker_int_kern_4x2x16.exec(
  131. {{M / 4, K / 4, 4, 4}, {K / 4, N, 4}, {}}) /
  132. RUNS;
  133. } else {
  134. int_kern_used =
  135. benchmarker_int_kern_4x2x16.exec({{M, K}, {K, N}, {}}) /
  136. RUNS;
  137. }
  138. auto float_used = benchmarker_float.exec({{M, K}, {K, N}, {}}) / RUNS;
  139. float computations = 2.f * M * K * N * 1e-6;
  140. printf("run: {%zu{M} %zu{K} %zu{N}} float: %f ms %f Gflops int: %f "
  141. "ms "
  142. "%f Gflops %s: %f ms %f Gflops "
  143. "speedup(%s/arm_common, %s/float): %f "
  144. "%f\n",
  145. M, K, N, float_used, computations / float_used, int_used,
  146. computations / int_used, algo, int_kern_used,
  147. computations / int_kern_used, algo, algo,
  148. int_used / int_kern_used, float_used / int_kern_used);
  149. };
  150. run(256, 12 * 24, 256);
  151. run(256, 256, 256);
  152. //////////////////////// gemv //////////////////////////
  153. for (size_t M : {8, 64, 112, 256}) {
  154. for (size_t K : {8, 64, 112, 256}) {
  155. run(M, 1, K);
  156. }
  157. }
  158. //////////////////////// gemm //////////////////////////
  159. for (size_t M : {8, 64, 112, 256}) {
  160. for (size_t K : {8, 16, 32, 64, 112, 256}) {
  161. for (size_t N : {8, 64, 112, 256}) {
  162. run(M, N, K);
  163. }
  164. }
  165. }
  166. }
  167. void run_16x16x32_benchmark(const char* algo, Handle* handle) {
  168. constexpr size_t RUNS = 50;
  169. param::MatrixMul param;
  170. Benchmarker<MatrixMul> benchmarker_int(handle);
  171. benchmarker_int.set_before_exec_callback(
  172. AlgoChecker<MatrixMul>("ARMV7_INT16X16X32_K12X4X1"));
  173. benchmarker_int.set_times(RUNS)
  174. .set_dtype(0, dtype::Int16{})
  175. .set_dtype(1, dtype::Int16{})
  176. .set_dtype(2, dtype::Int32{})
  177. .set_param(param)
  178. .set_display(false);
  179. Benchmarker<MatrixMul> benchmarker_float(handle);
  180. benchmarker_float.set_display(false).set_times(RUNS);
  181. auto run = [&](size_t M, size_t N, size_t K) {
  182. auto int_used = benchmarker_int.exec({{M, K}, {K, N}, {}}) / RUNS;
  183. auto float_used = benchmarker_float.exec({{M, K}, {K, N}, {}}) / RUNS;
  184. float computations = 2.f * M * K * N * 1e-6;
  185. printf("run: {%zu{M} %zu{K} %zu{N}} float: %f ms %f Gflops \n"
  186. "int: %f ms %f Gflops %s: \n"
  187. "speedup(%s/arm_common, %s/float): %f\n",
  188. M, K, N, float_used, computations / float_used, int_used,
  189. computations / int_used, algo, algo, algo,
  190. float_used / int_used);
  191. };
  192. run(256, 12 * 24, 256);
  193. //////////////////////// gemv //////////////////////////
  194. for (size_t M : {8, 64, 112, 256}) {
  195. for (size_t K : {8, 64, 112, 256}) {
  196. run(M, 1, K);
  197. }
  198. }
  199. //////////////////////// gemm //////////////////////////
  200. for (size_t M : {8, 64, 112, 256}) {
  201. for (size_t K : {8, 16, 32, 64, 112, 256}) {
  202. for (size_t N :
  203. {1, 2, 3, 4, 8, 64, 112, 113, 114, 115, 256, 257, 258, 259}) {
  204. run(M, N, K);
  205. }
  206. }
  207. }
  208. }
  209. #if __ARM_FEATURE_DOTPROD
  210. void run_8x8x32_benchmark(const char* algo, Handle* handle) {
  211. constexpr size_t RUNS = 50;
  212. param::MatrixMul param;
  213. Benchmarker<MatrixMul> benchmarker_int8(handle);
  214. benchmarker_int8.set_before_exec_callback(AlgoChecker<MatrixMul>(algo));
  215. benchmarker_int8.set_times(RUNS)
  216. .set_dtype(0, dtype::Int8{})
  217. .set_dtype(1, dtype::Int8{})
  218. .set_dtype(2, dtype::Int32{})
  219. .set_param(param)
  220. .set_display(false);
  221. Benchmarker<MatrixMul> benchmarker_float(handle);
  222. benchmarker_float.set_display(false).set_times(RUNS);
  223. auto run = [&](size_t M, size_t N, size_t K) {
  224. auto int_used = benchmarker_int8.exec({{M, K}, {K, N}, {}}) / RUNS;
  225. auto float_used = benchmarker_float.exec({{M, K}, {K, N}, {}}) / RUNS;
  226. float computations = 2.f * M * K * N * 1e-6;
  227. printf("run: {%zu{M} %zu{K} %zu{N}} float: %f ms %f Gflops \n"
  228. "int: %f ms %f Gflops %s: \n"
  229. "speedup(%s/arm_common, %s/float): %f\n",
  230. M, K, N, float_used, computations / float_used, int_used,
  231. computations / int_used, algo, algo, algo,
  232. float_used / int_used);
  233. };
  234. run(256, 12 * 24, 256);
  235. //////////////////////// gemm //////////////////////////
  236. for (size_t M : {8, 64, 112, 256}) {
  237. for (size_t K : {8, 16, 32, 64, 112, 256}) {
  238. for (size_t N : {113, 114, 115, 256, 1024}) {
  239. run(M, N, K);
  240. }
  241. }
  242. }
  243. }
  244. void run_8x8x32_quint_benchmark(Handle* handle) {
  245. constexpr size_t RUNS = 50;
  246. param::MatrixMul param;
  247. Benchmarker<MatrixMul> benchmarker_quint8_dot(handle);
  248. benchmarker_quint8_dot.set_before_exec_callback(
  249. AlgoChecker<MatrixMul>("AARCH32_QUINT8_K4X8X4"));
  250. benchmarker_quint8_dot.set_times(RUNS)
  251. .set_dtype(0,
  252. dtype::Quantized8Asymm(2.3f, static_cast<uint8_t>(20)))
  253. .set_dtype(1,
  254. dtype::Quantized8Asymm(3.1f, static_cast<uint8_t>(30)))
  255. .set_dtype(2, dtype::QuantizedS32(2.3f * 3.1f))
  256. .set_param(param)
  257. .set_display(false);
  258. Benchmarker<MatrixMul> benchmarker_quint8(handle);
  259. benchmarker_quint8.set_before_exec_callback(
  260. AlgoChecker<MatrixMul>("ARMV7_QUINT8_K4X8X8"));
  261. benchmarker_quint8.set_times(RUNS)
  262. .set_dtype(0,
  263. dtype::Quantized8Asymm(2.3f, static_cast<uint8_t>(20)))
  264. .set_dtype(1,
  265. dtype::Quantized8Asymm(3.1f, static_cast<uint8_t>(30)))
  266. .set_dtype(2, dtype::QuantizedS32(2.3f * 3.1f))
  267. .set_param(param)
  268. .set_display(false);
  269. auto run = [&](size_t M, size_t N, size_t K) {
  270. auto dot_used =
  271. benchmarker_quint8_dot.exec({{M, K}, {K, N}, {}}) / RUNS;
  272. auto normal_used = benchmarker_quint8.exec({{M, K}, {K, N}, {}}) / RUNS;
  273. float computations = 2.f * M * K * N * 1e-6;
  274. printf("run: {%zu{M} %zu{K} %zu{N}} dot: %f ms %f Gflops \n"
  275. "normal: %f ms %f Gflops.speedup: %f\n",
  276. M, K, N, dot_used, computations / dot_used, normal_used,
  277. computations / normal_used, normal_used / dot_used);
  278. };
  279. run(256, 12 * 24, 256);
  280. //////////////////////// gemm //////////////////////////
  281. for (size_t M : {8, 64, 112, 256}) {
  282. for (size_t K : {8, 16, 32, 64, 112, 256}) {
  283. for (size_t N : {113, 114, 115, 256, 1024}) {
  284. run(M, N, K);
  285. }
  286. }
  287. }
  288. }
  289. #endif
  290. } // namespace
  291. #if __ARM_FEATURE_DOTPROD
  292. TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_INT8x8x32_K6x8x4) {
  293. run_8x8x32_benchmark("AARCH32_INT8_K6X8X4", handle());
  294. }
  295. TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_QUINT8x8x32_K4x8x4) {
  296. run_8x8x32_quint_benchmark(handle());
  297. }
  298. TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_INT8x8x32_MK4_DOT) {
  299. constexpr size_t RUNS = 50;
  300. param::MatrixMul param;
  301. Benchmarker<MatrixMul> benchmarker_default(handle());
  302. benchmarker_default.set_times(RUNS)
  303. .set_dtype(0, dtype::Int8())
  304. .set_dtype(1, dtype::Int8())
  305. .set_dtype(2, dtype::Int32())
  306. .set_param(param)
  307. .set_display(false);
  308. benchmarker_default.set_before_exec_callback(
  309. AlgoChecker<MatrixMul>("AARCH32_INT8_K6X8X4"));
  310. param.format = MatrixMul::Param::Format::MK4_DOT;
  311. Benchmarker<MatrixMul> benchmarker_mk4_dot(handle());
  312. benchmarker_mk4_dot.set_before_exec_callback(
  313. AlgoChecker<MatrixMul>("AARCH32_INT8_MK4_8X4X4_DOTPROD"));
  314. benchmarker_mk4_dot.set_param(param)
  315. .set_dtype(0, dtype::Int8())
  316. .set_dtype(1, dtype::Int8())
  317. .set_dtype(2, dtype::Int32())
  318. .set_display(false)
  319. .set_times(RUNS);
  320. auto run = [&](size_t M, size_t N, size_t K) {
  321. auto default_used =
  322. benchmarker_default.exec({{M, K}, {K, N}, {}}) / RUNS;
  323. auto mk4_dot_used = benchmarker_mk4_dot.exec(
  324. {{M / 4, K / 4, 4, 4}, {K / 4, N, 4}, {}}) /
  325. RUNS;
  326. float computations = 2.f * M * K * N * 1e-6;
  327. printf("run: {%zu{M} %zu{K} %zu{N}} default: %f ms %f Gflops mk4_dot: "
  328. "%f ms "
  329. "%f Gflops speedup: %f\n",
  330. M, K, N, default_used, computations / default_used, mk4_dot_used,
  331. computations / mk4_dot_used, default_used / mk4_dot_used);
  332. };
  333. for (size_t M = 4; M < 512; M *= 2) {
  334. for (size_t K = 4; K < 512; K *= 2) {
  335. for (size_t N : {4, 8, 33, 113, 128}) {
  336. run(M, N, K);
  337. }
  338. }
  339. }
  340. }
  341. #endif
  342. TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_INT8x8x16_K4x2x16) {
  343. run_8x8x16_benchmark("ARMV7_INT8X8X16_K4X2X16", handle());
  344. }
  345. TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_INT8x8x16_K4x8x8) {
  346. run_8x8x16_benchmark("ARMV7_INT8X8X16_K4X8X8", handle());
  347. }
  348. TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_INT8x8x16_MK4_K4x8x8) {
  349. run_8x8x16_benchmark("ARMV7_INT8X8X16_MK4_K8X8X4", handle(),
  350. MatrixMul::Param::Format::MK4);
  351. }
  352. TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_INT16x16x32_K12x4x1) {
  353. run_16x16x32_benchmark("ARMV7_INT16X16X32_K12X4X1", handle());
  354. }
  355. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  356. TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_FP16) {
  357. constexpr size_t RUNS = 50;
  358. param::MatrixMul param;
  359. Benchmarker<MatrixMul> benchmarker_fp16(handle());
  360. benchmarker_fp16.set_times(RUNS)
  361. .set_dtype(0, dtype::Float16())
  362. .set_dtype(1, dtype::Float16())
  363. .set_dtype(2, dtype::Float16())
  364. .set_param(param)
  365. .set_display(false);
  366. Benchmarker<MatrixMul> benchmarker_float(handle());
  367. benchmarker_float.set_param(param).set_display(false).set_times(RUNS);
  368. auto run = [&](size_t M, size_t N, size_t K) {
  369. auto fp16_used = benchmarker_fp16.exec({{M, K}, {K, N}, {}}) / RUNS;
  370. auto float_used = benchmarker_float.exec({{M, K}, {K, N}, {}}) / RUNS;
  371. float computations = 2.f * M * K * N * 1e-6;
  372. printf("run: {%zu{M} %zu{K} %zu{N}} float: %f ms %f Gflops fp16: %f ms "
  373. "%f Gflops speedup: %f\n",
  374. M, K, N, float_used, computations / float_used, fp16_used,
  375. computations / fp16_used, float_used / fp16_used);
  376. };
  377. run(256, 12 * 24, 256);
  378. for (size_t M : {8, 64, 112, 256}) {
  379. for (size_t K : {8, 64, 112, 256}) {
  380. for (size_t N : {8, 64, 112, 256}) {
  381. run(M, N, K);
  382. }
  383. }
  384. }
  385. }
  386. TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_F16_MK8) {
  387. auto args = matrix_mul::get_benchmark_matmul_mk_packed_args(4);
  388. matrix_mul::benchmark_with_contrast(
  389. handle(), args, dtype::Float16{}, dtype::Float16{},
  390. dtype::Float16{}, "AARCH32_F16_MK8_4X8",
  391. param::MatrixMul::Format::MK8, dtype::Float16{}, dtype::Float16{},
  392. dtype::Float16{}, "AARCH32_F16_K4X16X1");
  393. }
  394. #endif
  395. TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_MK4) {
  396. auto args = matrix_mul::get_benchmark_matmul_mk_packed_args(8);
  397. matrix_mul::benchmark_with_contrast(
  398. handle(), args, dtype::Float32{}, dtype::Float32{},
  399. dtype::Float32{}, "ARMV7_F32_MK4_4x8",
  400. param::MatrixMul::Format::MK4, dtype::Float32{}, dtype::Float32{},
  401. dtype::Float32{});
  402. }
  403. TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_PACK_MK4) {
  404. auto args = matrix_mul::get_benchmark_matmul_mk_packed_args(8);
  405. matrix_mul::benchmark_with_contrast(
  406. handle(), args, dtype::Float32{}, dtype::Float32{},
  407. dtype::Float32{}, "ARMV7_F32_MK4_PACK_4X12",
  408. param::MatrixMul::Format::MK4, dtype::Float32{}, dtype::Float32{},
  409. dtype::Float32{});
  410. }
  411. TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_INT16x16x32_MK8) {
  412. auto args = matrix_mul::get_benchmark_matmul_mk_packed_args(4);
  413. matrix_mul::benchmark_with_contrast(
  414. handle(), args, dtype::Int16{}, dtype::Int16{}, dtype::Int32{},
  415. "ARMV7_INT16X16X32_MK8_4X8", param::MatrixMul::Format::MK8,
  416. dtype::Int16{}, dtype::Int16{}, dtype::Int32{});
  417. }
  418. TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_INT32_MK_4X2X16) {
  419. constexpr size_t RUNS = 50;
  420. param::MatrixMul param;
  421. param.transposeA = false;
  422. param.transposeB = false;
  423. Benchmarker<MatrixMul> benchmarker(handle());
  424. Benchmarker<MatrixMul> benchmarker_mk4(handle());
  425. benchmarker.set_times(RUNS)
  426. .set_dtype(0, dtype::Int8{})
  427. .set_dtype(1, dtype::Int8{})
  428. .set_dtype(2, dtype::Int32{})
  429. .set_param(param)
  430. .set_display(false);
  431. benchmarker.set_before_exec_callback(
  432. AlgoChecker<MatrixMul>("ARMV7_INT8X8X32_K4X2X16"));
  433. param.format = MatrixMul::Param::Format::MK4;
  434. benchmarker_mk4.set_before_exec_callback(
  435. AlgoChecker<MatrixMul>("ARMV7_INT8X8X32_MK4_4X2X16"));
  436. benchmarker_mk4.set_times(RUNS)
  437. .set_dtype(0, dtype::Int8{})
  438. .set_dtype(1, dtype::Int8{})
  439. .set_dtype(2, dtype::Int32{})
  440. .set_param(param)
  441. .set_display(false);
  442. auto run = [&](size_t M, size_t N, size_t K) {
  443. auto mk_used = benchmarker_mk4.exec(
  444. {{M / 4, K / 4, 4, 4}, {K / 4, N, 4}, {}}) /
  445. RUNS;
  446. auto default_used = benchmarker.exec({{M, K}, {K, N}, {}}) / RUNS;
  447. float computations = 2.f * M * K * N * 1e-6;
  448. printf("run: {%zu{M} %zu{K} %zu{N}} normal: %f ms %f Gflops mk4: %f ms "
  449. "%f Gflops speedup_vs_normal: %f\n",
  450. M, K, N, default_used, computations / default_used, mk_used,
  451. computations / mk_used, default_used / mk_used);
  452. };
  453. run(256, 256, 128);
  454. for (size_t k = 4; k <= 512; k *= 2) {
  455. for (size_t m = 4; m <= 512; m *= 2) {
  456. for (size_t n = 4; n <= 512; n *= 2) {
  457. run(m, n, k);
  458. }
  459. }
  460. std::cout << std::endl;
  461. }
  462. }
  463. #endif
  464. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台