You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

matrix_mul.cpp 37 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970
  1. /**
  2. * \file dnn/test/aarch64/matrix_mul.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "test/aarch64/fixture.h"
  13. #include "test/common/benchmarker.h"
  14. #include "test/common/checker.h"
  15. #include "test/common/matrix_mul.h"
  16. #include "test/common/rng.h"
  17. #include "test/arm_common/cpuinfo_help.h"
  18. using namespace megdnn;
  19. using namespace test;
  20. TEST_F(AARCH64, MATRIX_MUL_FP32K8X12) {
  21. matrix_mul::check_matrix_mul(dtype::Float32{}, dtype::Float32{},
  22. dtype::Float32{}, handle(),
  23. "AARCH64_F32K8X12X1");
  24. }
  25. #if MGB_ENABLE_CPUINFO
  26. TEST_F(AARCH64, MATRIX_MUL_FP32K8X12_A53) {
  27. CpuInfoTmpReplace cpu_replace_guard(cpuinfo_uarch_cortex_a53);
  28. matrix_mul::check_matrix_mul(dtype::Float32{}, dtype::Float32{},
  29. dtype::Float32{}, handle(),
  30. "AARCH64_F32K8X12X1");
  31. }
  32. TEST_F(AARCH64, MATRIX_MUL_FP32K8X12_A55) {
  33. CpuInfoTmpReplace cpu_replace_guard(cpuinfo_uarch_cortex_a55);
  34. matrix_mul::check_matrix_mul(dtype::Float32{}, dtype::Float32{},
  35. dtype::Float32{}, handle(),
  36. "AARCH64_F32K8X12X1");
  37. }
  38. #endif
  39. TEST_F(AARCH64, MATRIX_MUL_FP32K4X16) {
  40. matrix_mul::check_matrix_mul(dtype::Float32{}, dtype::Float32{},
  41. dtype::Float32{}, handle(),
  42. "AARCH64_F32K4X16X1");
  43. }
  44. TEST_F(AARCH64, MATRIX_MUL_FP32_PACK_MK4) {
  45. matrix_mul::check_matrix_mul(
  46. dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(),
  47. "AARCH64_F32_MK4_K8X12X1", param::MatrixMul::Format::MK4, 1);
  48. }
  49. #if MGB_ENABLE_CPUINFO
  50. TEST_F(AARCH64, MATRIX_MUL_FP32_PACK_MK4_A53) {
  51. CpuInfoTmpReplace cpu_replace_guard(cpuinfo_uarch_cortex_a53);
  52. matrix_mul::check_matrix_mul(
  53. dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(),
  54. "AARCH64_F32_MK4_K8X12X1", param::MatrixMul::Format::MK4, 1);
  55. }
  56. TEST_F(AARCH64, MATRIX_MUL_FP32_PACK_MK4_A55) {
  57. CpuInfoTmpReplace cpu_replace_guard(cpuinfo_uarch_cortex_a55);
  58. matrix_mul::check_matrix_mul(
  59. dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(),
  60. "AARCH64_F32_MK4_K8X12X1", param::MatrixMul::Format::MK4, 1);
  61. }
  62. #endif
  63. TEST_F(AARCH64, MATRIX_MUL_FP32_MK4) {
  64. matrix_mul::check_matrix_mul(
  65. dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(),
  66. "AARCH64_F32_MK4_4x16", param::MatrixMul::Format::MK4, 1);
  67. }
  68. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  69. TEST_F(AARCH64, MATRIX_MUL_F16_K8X24X1) {
  70. matrix_mul::check_matrix_mul(dtype::Float16{}, dtype::Float16{},
  71. dtype::Float16{}, handle(),
  72. "AARCH64_F16_K8X24X1");
  73. }
  74. TEST_F(AARCH64, MATRIX_MUL_F16_MK8) {
  75. matrix_mul::check_matrix_mul(
  76. dtype::Float16{}, dtype::Float16{}, dtype::Float16{}, handle(),
  77. "AARCH64_F16_MK8_8X8", param::MatrixMul::Format::MK8, 1);
  78. }
  79. #endif
  80. #if __ARM_FEATURE_DOTPROD
  81. TEST_F(AARCH64, MATRIX_MUL_INT8X8X32_K8X12X4_DOTPROD) {
  82. matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int32{},
  83. handle(), "AARCH64_INT8X8X32_K8X12X4_DOTPROD");
  84. }
  85. TEST_F(AARCH64, MATRIX_MUL_INT8X8X32_MK4_8X12X4_DOTPROD) {
  86. std::vector<matrix_mul::TestArg> args;
  87. for (size_t m : {1, 2, 3, 4, 5, 6, 7, 10, 11})
  88. for (size_t n : {2, 3, 4, 5, 8, 12, 13, 14, 15, 16, 31})
  89. for (size_t k : {1, 2, 3, 4, 5, 6, 7, 8, 16, 32, 33, 34})
  90. args.emplace_back(m, n, k, 0);
  91. matrix_mul::check_matrix_mul(
  92. dtype::Int8{}, dtype::Int8{}, dtype::Int32{}, handle(),
  93. "AARCH64_INT8X8X32_MK4_8X12X4_DOTPROD",
  94. param::MatrixMul::Format::MK4_DOT, 1, 1e-3, std::move(args));
  95. }
  96. #else
  97. TEST_F(AARCH64, MATRIX_MUL_INT8X8X32_K4X4X16) {
  98. matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int32{},
  99. handle(), "AARCH64_INT8X8X32_K4X4X16");
  100. }
  101. TEST_F(AARCH64, MATRIX_MUL_INT8_MK4) {
  102. std::vector<matrix_mul::TestArg> args;
  103. for (size_t m : {1, 2, 3, 4, 5, 7, 10, 11})
  104. for (size_t n : {1, 2, 3, 4, 5, 8, 16, 24, 25, 32})
  105. for (size_t k : {1, 2, 3, 4, 5, 6, 7, 8, 16, 32, 33, 34})
  106. args.emplace_back(m, n, k, 0);
  107. matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int32{},
  108. handle(), "AARCH64_INT8X8X32_MK4_4X4X16",
  109. param::MatrixMul::Format::MK4, 1, 1e-3,
  110. std::move(args));
  111. }
  112. TEST_F(AARCH64, MATRIX_MUL_INT8x8x16_MK4) {
  113. std::vector<matrix_mul::TestArg> args;
  114. for (size_t m : {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17})
  115. for (size_t n :
  116. {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 24})
  117. for (size_t k :
  118. {2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
  119. 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29})
  120. args.emplace_back(m, n, k, 0);
  121. matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int16{},
  122. handle(), "AARCH64_INT8X8X16_MK4_K8X8X8",
  123. param::MatrixMul::Format::MK4, 1, 1e-3,
  124. std::move(args));
  125. }
  126. TEST_F(AARCH64, MATRIX_MUL_MK4_8x8x16_4x4) {
  127. matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int16{},
  128. handle(), "AARCH64_INT8X8X16_MK4_4X4X8",
  129. param::MatrixMul::Format::MK4, 1);
  130. }
  131. TEST_F(AARCH64, MATRIX_MUL_MK4_8x8x16) {
  132. matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int16{},
  133. handle(), "AARCH64_INT8X8X16_MK4_16X12X4",
  134. param::MatrixMul::Format::MK4, 1);
  135. }
  136. TEST_F(AARCH64, MATRIX_MUL_INT8x8x32_K8x8x8) {
  137. matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int32{},
  138. handle(), "AARCH64_INT8X8X32_K8X8X8");
  139. }
  140. #endif
  141. TEST_F(AARCH64, MATRIX_MUL_INT8x8x16_K8x8x8) {
  142. matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int16{},
  143. handle(), "AARCH64_INT8X8X16_K8X8X8");
  144. }
  145. TEST_F(AARCH64, MATRIX_MUL_INT8x8x16_K4x4x16) {
  146. matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int16{},
  147. handle(), "AARCH64_INT8X8X16_K4X4X16");
  148. }
  149. TEST_F(AARCH64, MATRIX_MUL_INT4x4x16_K8x8x8_QUANTIZEDS4) {
  150. param::MatrixMul param;
  151. param.transposeA = false;
  152. param.transposeB = false;
  153. Checker<MatrixMul> checker(handle());
  154. checker.set_dtype(0, dtype::QuantizedS4{0.6})
  155. .set_dtype(1, dtype::QuantizedS4{0.5})
  156. .set_dtype(2, dtype::QuantizedS16{0.6 * 0.5})
  157. .set_param(param);
  158. checker.set_before_exec_callback(
  159. AlgoChecker<MatrixMul>("AARCH64_INT4X4X16_K8X8X8"));
  160. auto run = [&](size_t M, size_t N, size_t K) {
  161. printf("M N K %zu %zu %zu \n", M, N, K);
  162. TensorShape A, B;
  163. if (param.transposeA) {
  164. A = TensorShape{K, M};
  165. } else {
  166. A = TensorShape{M, K};
  167. }
  168. if (param.transposeB) {
  169. B = TensorShape{N, K};
  170. } else {
  171. B = TensorShape{K, N};
  172. }
  173. checker.exec({A, B, {}});
  174. };
  175. for (size_t m : {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 12, 16, 20})
  176. for (size_t n : {2, 4, 6, 8, 10, 12, 14, 16, 24})
  177. for (size_t k : {2, 4, 6, 8, 10, 12, 14, 16, 32})
  178. run(m, n, k);
  179. for (size_t k = 4; k <= 256; k *= 8) {
  180. for (size_t m = 4; m <= 256; m *= 4) {
  181. for (size_t n = 4; n <= 256; n *= 4) {
  182. run(m, n, k);
  183. }
  184. }
  185. }
  186. param.transposeA = true;
  187. run(8,8,8);
  188. run(16,8,16);
  189. param.transposeB = true;
  190. run(8,8,8);
  191. run(16,16,16);
  192. }
  193. TEST_F(AARCH64, MATRIX_MUL_INT16x16x32_K12X8X1) {
  194. matrix_mul::check_matrix_mul(dtype::Int16{}, dtype::Int16{}, dtype::Int32{},
  195. handle(), "AARCH64_INT16X16X32_K12X8X1");
  196. }
  197. TEST_F(AARCH64, MATRIX_MUL_INT16x16x32_MK8) {
  198. matrix_mul::check_matrix_mul(dtype::Int16{}, dtype::Int16{}, dtype::Int32{},
  199. handle(), "AARCH64_INT16X16X32_MK8_8X8",
  200. param::MatrixMul::Format::MK8, 1);
  201. }
  202. //! FIXME: need to add tests of GEMV and QUINT8
  203. #if MEGDNN_WITH_BENCHMARK
  204. TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_FP32_K4X16) {
  205. constexpr size_t RUNS = 50;
  206. param::MatrixMul param;
  207. param.transposeA = false;
  208. param.transposeB = false;
  209. Benchmarker<MatrixMul> benchmarker_K4X16(handle());
  210. Benchmarker<MatrixMul> benchmarker_K12X8(handle());
  211. benchmarker_K4X16.set_times(RUNS)
  212. .set_dtype(0, dtype::Float32{})
  213. .set_dtype(1, dtype::Float32{})
  214. .set_dtype(2, dtype::Float32{})
  215. .set_param(param)
  216. .set_display(false);
  217. benchmarker_K4X16.set_before_exec_callback(
  218. AlgoChecker<MatrixMul>("AARCH64_F32K4X16X1"));
  219. benchmarker_K12X8.set_before_exec_callback(
  220. AlgoChecker<MatrixMul>("AARCH64_F32K8X12X1"));
  221. benchmarker_K12X8.set_times(RUNS)
  222. .set_dtype(0, dtype::Float32{})
  223. .set_dtype(1, dtype::Float32{})
  224. .set_dtype(2, dtype::Float32{})
  225. .set_param(param)
  226. .set_display(false);
  227. auto run = [&](size_t M, size_t N, size_t K) {
  228. TensorShape A, B;
  229. if (param.transposeA) {
  230. A = TensorShape{K, M};
  231. } else {
  232. A = TensorShape{M, K};
  233. }
  234. if (param.transposeB) {
  235. B = TensorShape{N, K};
  236. } else {
  237. B = TensorShape{K, N};
  238. }
  239. auto k4x16_used = benchmarker_K4X16.exec({A, B, {}}) / RUNS;
  240. auto k12x8_used = benchmarker_K12X8.exec({A, B, {}}) / RUNS;
  241. float computations = 2.f * M * K * N * 1e-6;
  242. printf("run: {%zu{M} %zu{K} %zu{N}} k4x16: %f ms %f Gflops k12x8: %f "
  243. "ms "
  244. "%f Gflops k4x16_vs_k12x8: %f\n",
  245. M, K, N, k4x16_used, computations / k4x16_used, k12x8_used,
  246. computations / k12x8_used, k12x8_used / k4x16_used);
  247. };
  248. run(256, 256, 128);
  249. run(384, 384, 384);
  250. for (size_t k = 4; k <= 256; k *= 8) {
  251. for (size_t m = 4; m <= 256; m *= 4) {
  252. for (size_t n = 4; n <= 256; n *= 4) {
  253. run(m, n, k);
  254. }
  255. printf("\n");
  256. }
  257. printf("\n");
  258. }
  259. }
  260. TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_INT16_8X8X8) {
  261. constexpr size_t RUNS = 50;
  262. param::MatrixMul param;
  263. param.transposeA = false;
  264. param.transposeB = false;
  265. Benchmarker<MatrixMul> benchmarker_int(handle());
  266. Benchmarker<MatrixMul> benchmarker_int32(handle());
  267. benchmarker_int.set_times(RUNS)
  268. .set_dtype(0, dtype::Int8{})
  269. .set_dtype(1, dtype::Int8{})
  270. .set_dtype(2, dtype::Int16{})
  271. .set_param(param)
  272. .set_display(false);
  273. benchmarker_int.set_before_exec_callback(
  274. AlgoChecker<MatrixMul>("AARCH64_INT8X8X16_K8X8X8"));
  275. benchmarker_int32.set_before_exec_callback(
  276. AlgoChecker<MatrixMul>("AARCH64_INT8X8X32_K8X8X8"));
  277. benchmarker_int32.set_times(RUNS)
  278. .set_dtype(0, dtype::Int8{})
  279. .set_dtype(1, dtype::Int8{})
  280. .set_dtype(2, dtype::Int32{})
  281. .set_param(param)
  282. .set_display(false);
  283. Benchmarker<MatrixMul> benchmarker_float(handle());
  284. benchmarker_float.set_param(param).set_display(false).set_times(RUNS);
  285. auto run = [&](size_t M, size_t N, size_t K) {
  286. TensorShape A, B;
  287. if (param.transposeA) {
  288. A = TensorShape{K, M};
  289. } else {
  290. A = TensorShape{M, K};
  291. }
  292. if (param.transposeB) {
  293. B = TensorShape{N, K};
  294. } else {
  295. B = TensorShape{K, N};
  296. }
  297. auto int_used = benchmarker_int.exec({A, B, {}}) / RUNS;
  298. auto float_used = benchmarker_float.exec({A, B, {}}) / RUNS;
  299. auto int32_used = benchmarker_int32.exec({A, B, {}}) / RUNS;
  300. float computations = 2.f * M * K * N * 1e-6;
  301. printf("run: {%zu{M} %zu{K} %zu{N}} float: %f ms %f Gflops int: %f ms "
  302. "%f Gflops speedup_vs_fp32: %f, speedup_vs_int32: %f\n",
  303. M, K, N, float_used, computations / float_used, int_used,
  304. computations / int_used, float_used / int_used,
  305. int32_used / int_used);
  306. };
  307. run(256, 256, 256);
  308. for (size_t k = 4; k <= 256; k *= 8) {
  309. for (size_t m = 4; m <= 256; m *= 4) {
  310. for (size_t n = 4; n <= 256; n *= 4) {
  311. run(m, n, k);
  312. }
  313. std::cout << std::endl;
  314. }
  315. std::cout << std::endl;
  316. }
  317. }
  318. TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_INT32_MK_4X4X16) {
  319. constexpr size_t RUNS = 50;
  320. param::MatrixMul param;
  321. param.transposeA = false;
  322. param.transposeB = false;
  323. Benchmarker<MatrixMul> benchmarker(handle());
  324. Benchmarker<MatrixMul> benchmarker_mk4(handle());
  325. benchmarker.set_times(RUNS)
  326. .set_dtype(0, dtype::Int8{})
  327. .set_dtype(1, dtype::Int8{})
  328. .set_dtype(2, dtype::Int32{})
  329. .set_param(param)
  330. .set_display(false);
  331. benchmarker.set_before_exec_callback(
  332. AlgoChecker<MatrixMul>("AARCH64_INT8X8X32_K4X4X16"));
  333. param.format = MatrixMul::Param::Format::MK4;
  334. benchmarker_mk4.set_before_exec_callback(
  335. AlgoChecker<MatrixMul>("AARCH64_INT8X8X32_MK4_4X4X16"));
  336. benchmarker_mk4.set_times(RUNS)
  337. .set_dtype(0, dtype::Int8{})
  338. .set_dtype(1, dtype::Int8{})
  339. .set_dtype(2, dtype::Int32{})
  340. .set_param(param)
  341. .set_display(false);
  342. auto run = [&](size_t M, size_t N, size_t K) {
  343. auto default_used = benchmarker.exec({{M, K}, {K, N}, {}}) / RUNS;
  344. auto mk_used = benchmarker_mk4.exec(
  345. {{M / 4, K / 4, 4, 4}, {K / 4, N, 4}, {}}) /
  346. RUNS;
  347. float computations = 2.f * M * K * N * 1e-6;
  348. printf("run: {%zu{M} %zu{K} %zu{N}} normal: %f ms %f Gflops mk4: %f ms "
  349. "%f Gflops speedup_vs_normal: %f\n",
  350. M, K, N, default_used, computations / default_used, mk_used,
  351. computations / mk_used, default_used / mk_used);
  352. };
  353. run(256, 256, 128);
  354. for (size_t k = 4; k <= 512; k *= 2) {
  355. for (size_t m = 4; m <= 512; m *= 2) {
  356. for (size_t n = 4; n <= 512; n *= 2) {
  357. run(m, n, k);
  358. }
  359. }
  360. std::cout << std::endl;
  361. }
  362. }
  363. TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_MK4_8x8x16) {
  364. constexpr size_t RUNS = 50;
  365. param::MatrixMul param;
  366. param.transposeA = false;
  367. param.transposeB = false;
  368. Benchmarker<MatrixMul> benchmarker(handle());
  369. Benchmarker<MatrixMul> benchmarker_mk4(handle());
  370. Benchmarker<MatrixMul> benchmarker_mk4_16x12(handle());
  371. benchmarker.set_times(RUNS)
  372. .set_dtype(0, dtype::Int8{})
  373. .set_dtype(1, dtype::Int8{})
  374. .set_dtype(2, dtype::Int16{})
  375. .set_param(param)
  376. .set_display(false);
  377. benchmarker.set_before_exec_callback(
  378. AlgoChecker<MatrixMul>("AARCH64_INT8X8X16_K4X4X16"));
  379. param.format = MatrixMul::Param::Format::MK4;
  380. benchmarker_mk4.set_before_exec_callback(
  381. AlgoChecker<MatrixMul>("AARCH64_INT8X8X16_MK4_4X4X8"));
  382. benchmarker_mk4.set_times(RUNS)
  383. .set_dtype(0, dtype::Int8{})
  384. .set_dtype(1, dtype::Int8{})
  385. .set_dtype(2, dtype::Int16{})
  386. .set_param(param)
  387. .set_display(false);
  388. benchmarker_mk4_16x12.set_before_exec_callback(
  389. AlgoChecker<MatrixMul>("AARCH64_INT8X8X16_MK4_16X12X4"));
  390. benchmarker_mk4_16x12.set_times(RUNS)
  391. .set_dtype(0, dtype::Int8{})
  392. .set_dtype(1, dtype::Int8{})
  393. .set_dtype(2, dtype::Int16{})
  394. .set_param(param)
  395. .set_display(false);
  396. auto run = [&](size_t M, size_t N, size_t K) {
  397. auto default_used = benchmarker.exec({{M, K}, {K, N}, {}}) / RUNS;
  398. auto mk_used = benchmarker_mk4.exec(
  399. {{M / 4, K / 4, 4, 4}, {K / 4, N, 4}, {}}) /
  400. RUNS;
  401. auto mk4_16x12_used =
  402. benchmarker_mk4_16x12.exec(
  403. {{M / 4, K / 4, 4, 4}, {K / 4, N, 4}, {}}) /
  404. RUNS;
  405. float computations = 2.f * M * K * N * 1e-6;
  406. printf("run: {%zu{M} %zu{K} %zu{N}} normal: %f ms %f Gflops mk4: %f ms "
  407. "%f Gflops speedup: %f, mk4_16x12 %f Gflops speedup: %f\n",
  408. M, K, N, default_used, computations / default_used, mk_used,
  409. computations / mk_used, default_used / mk_used,
  410. computations / mk4_16x12_used, default_used / mk4_16x12_used);
  411. };
  412. run(384, 384, 384);
  413. }
  414. TEST_F(AARCH64, BENCHMARK_4x4x16_vs_8x8x16) {
  415. constexpr size_t RUNS = 50;
  416. param::MatrixMul param;
  417. param.transposeA = false;
  418. param.transposeB = false;
  419. Benchmarker<MatrixMul> benchmarker(handle());
  420. Benchmarker<MatrixMul> benchmarker_int4_4x4x16(handle());
  421. benchmarker_int4_4x4x16.set_times(RUNS)
  422. .set_dtype(0, dtype::QuantizedS4{0.3})
  423. .set_dtype(1, dtype::QuantizedS4{0.3})
  424. .set_dtype(2, dtype::QuantizedS16{0.09})
  425. .set_param(param)
  426. .set_display(false);
  427. benchmarker.set_times(RUNS)
  428. .set_dtype(0, dtype::Int8{})
  429. .set_dtype(1, dtype::Int8{})
  430. .set_dtype(2, dtype::Int16{})
  431. .set_param(param)
  432. .set_display(false);
  433. benchmarker.set_before_exec_callback(
  434. AlgoChecker<MatrixMul>("AARCH64_INT8X8X16_K4X4X16"));
  435. auto run = [&](size_t M, size_t N, size_t K) {
  436. auto default_used = benchmarker.exec({{M, K}, {K, N}, {}}) / RUNS;
  437. auto int4416_used =
  438. benchmarker_int4_4x4x16.exec({{M, K}, {K, N}, {}}) / RUNS;
  439. float computations = 2.f * M * K * N * 1e-6;
  440. printf("run: {%zu{M} %zu{K} %zu{N}} normal 8x8x16 used: %f ms %f "
  441. "Gflops int4416 used %f int4416_gflops %f speedup %f\n",
  442. M, K, N, default_used, computations / default_used, int4416_used,
  443. computations / int4416_used, default_used / int4416_used);
  444. };
  445. for (int m = 32; m <= 1024; m += 32)
  446. for (int n = 32; n <= 1024; n += 32)
  447. for (int k = 32; k <= 512; k += 32)
  448. run(m, n, k);
  449. run(32, 32, 32);
  450. run(32, 32, 8);
  451. run(32, 32, 16);
  452. run(32, 32, 24);
  453. run(32 * 2, 32 * 2, 32);
  454. run(32 * 4, 32 * 4, 32);
  455. run(32 * 6, 32 * 6, 32);
  456. run(32 * 8, 32 * 8, 32);
  457. run(32 * 2, 32 * 2, 32 * 2);
  458. run(32 * 4, 32 * 4, 32 * 3);
  459. run(32 * 6, 32 * 6, 32 * 4);
  460. run(32 * 8, 32 * 8, 32 * 5);
  461. run(32 * 10, 32 * 10, 32 * 10);
  462. run(384, 384, 384);
  463. run(256, 256, 384);
  464. run(512, 512, 384);
  465. run(1024, 1024, 384);
  466. }
  467. TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_MK4_8x8x8_8x8x16_vs_4x4x16_8x8x16) {
  468. constexpr size_t RUNS = 50;
  469. param::MatrixMul param;
  470. param.transposeA = false;
  471. param.transposeB = false;
  472. Benchmarker<MatrixMul> benchmarker(handle());
  473. Benchmarker<MatrixMul> benchmarker_mk4(handle());
  474. Benchmarker<MatrixMul> benchmarker_mk4_4x4x8(handle());
  475. benchmarker.set_times(RUNS)
  476. .set_dtype(0, dtype::Int8{})
  477. .set_dtype(1, dtype::Int8{})
  478. .set_dtype(2, dtype::Int16{})
  479. .set_param(param)
  480. .set_display(false);
  481. benchmarker.set_before_exec_callback(
  482. AlgoChecker<MatrixMul>("AARCH64_INT8X8X16_K4X4X16"));
  483. param.format = MatrixMul::Param::Format::MK4;
  484. benchmarker_mk4.set_before_exec_callback(
  485. AlgoChecker<MatrixMul>(
  486. "AARCH64_INT8X8X16_MK4_K8X8X8"
  487. ));
  488. benchmarker_mk4.set_times(RUNS)
  489. .set_dtype(0, dtype::Int8{})
  490. .set_dtype(1, dtype::Int8{})
  491. .set_dtype(2, dtype::Int16{})
  492. .set_param(param)
  493. .set_display(false);
  494. benchmarker_mk4_4x4x8.set_before_exec_callback(
  495. AlgoChecker<MatrixMul>("AARCH64_INT8X8X16_MK4_4X4X8"));
  496. benchmarker_mk4_4x4x8.set_times(RUNS)
  497. .set_dtype(0, dtype::Int8{})
  498. .set_dtype(1, dtype::Int8{})
  499. .set_dtype(2, dtype::Int16{})
  500. .set_param(param)
  501. .set_display(false);
  502. auto run = [&](size_t M, size_t N, size_t K) {
  503. auto default_used = benchmarker.exec({{M, K}, {K, N}, {}}) / RUNS;
  504. auto mk_used = benchmarker_mk4.exec(
  505. {{M / 4, K / 4, 4, 4}, {K / 4, N, 4}, {}}) /
  506. RUNS;
  507. auto mk4_4x4x8_used =
  508. benchmarker_mk4_4x4x8.exec(
  509. {{M / 4, K / 4, 4, 4}, {K / 4, N, 4}, {}}) /
  510. RUNS;
  511. float computations = 2.f * M * K * N * 1e-6;
  512. printf("run: {%zu{M} %zu{K} %zu{N}} normal: %f ms %f Gflops mk4: %f ms "
  513. "%f Gflops speedup: %f, mk4_4x4x8 %f Gflops %f ms speedup: %f\n",
  514. M, K, N, default_used, computations / default_used, mk_used,
  515. computations / mk_used, default_used / mk_used,
  516. computations / mk4_4x4x8_used, mk4_4x4x8_used , mk4_4x4x8_used/mk_used);
  517. };
  518. run(384, 384, 384);
  519. run(512, 512, 512);
  520. run(1024, 1024, 384);
  521. run(256, 256, 384);
  522. for(int m = 32; m <= 512;m*=2)
  523. for(int n = 32; n <= 512;n*=2)
  524. for(int k = 32; k < 512;k*=2){
  525. run(m,n,k);
  526. }
  527. }
  528. TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_INT16_4X4X16) {
  529. constexpr size_t RUNS = 50;
  530. param::MatrixMul param;
  531. param.transposeA = false;
  532. param.transposeB = false;
  533. Benchmarker<MatrixMul> benchmarker_int(handle());
  534. Benchmarker<MatrixMul> benchmarker_int32(handle());
  535. benchmarker_int.set_times(RUNS)
  536. .set_dtype(0, dtype::Int8{})
  537. .set_dtype(1, dtype::Int8{})
  538. .set_dtype(2, dtype::Int16{})
  539. .set_param(param)
  540. .set_display(false);
  541. benchmarker_int.set_before_exec_callback(
  542. AlgoChecker<MatrixMul>("AARCH64_INT8X8X16_K4X4X16"));
  543. benchmarker_int32.set_before_exec_callback(
  544. AlgoChecker<MatrixMul>("AARCH64_INT8X8X32_K4X4X16"));
  545. benchmarker_int32.set_times(RUNS)
  546. .set_dtype(0, dtype::Int8{})
  547. .set_dtype(1, dtype::Int8{})
  548. .set_dtype(2, dtype::Int32{})
  549. .set_param(param)
  550. .set_display(false);
  551. Benchmarker<MatrixMul> benchmarker_float(handle());
  552. benchmarker_float.set_param(param).set_display(false).set_times(RUNS);
  553. auto run = [&](size_t M, size_t N, size_t K) {
  554. TensorShape A, B;
  555. if (param.transposeA) {
  556. A = TensorShape{K, M};
  557. } else {
  558. A = TensorShape{M, K};
  559. }
  560. if (param.transposeB) {
  561. B = TensorShape{N, K};
  562. } else {
  563. B = TensorShape{K, N};
  564. }
  565. auto int_used = benchmarker_int.exec({A, B, {}}) / RUNS;
  566. auto float_used = benchmarker_float.exec({A, B, {}}) / RUNS;
  567. auto int32_used = benchmarker_int32.exec({A, B, {}}) / RUNS;
  568. float computations = 2.f * M * K * N * 1e-6;
  569. printf("run: {%zu{M} %zu{K} %zu{N}} float: %f ms %f Gflops int: %f ms "
  570. "%f Gflops speedup_vs_fp32: %f, speedup_vs_int32: %f\n",
  571. M, K, N, float_used, computations / float_used, int_used,
  572. computations / int_used, float_used / int_used,
  573. int32_used / int_used);
  574. };
  575. run(256, 256, 128);
  576. run(256, 256, 256);
  577. for (size_t k = 4; k <= 256; k *= 4) {
  578. for (size_t m = 4; m <= 256; m *= 4) {
  579. for (size_t n = 4; n <= 256; n *= 4) {
  580. run(m, n, k);
  581. }
  582. }
  583. std::cout << std::endl;
  584. }
  585. }
  586. TEST_F(AARCH64, BENCHMARK_GEMV) {
  587. int exec_times = 10;
  588. Benchmarker<MatrixMul> benchmarker_gemm(handle());
  589. benchmarker_gemm.set_times(exec_times);
  590. float mod = 1000 * exec_times / 1e9;
  591. auto run = [&](size_t M, size_t K, size_t N) {
  592. float time = 1.f, perf = 1.f;
  593. std::cout << "GEMM: (" << M << ", " << K << ", " << N << ")"
  594. << std::endl;
  595. benchmarker_gemm.set_dtype(0, dtype::Float32())
  596. .set_dtype(1, dtype::Float32());
  597. time = benchmarker_gemm.exec({{M, K}, {K, N}, {}});
  598. perf = 2.f * M * K * N / time * mod;
  599. std::cout << "gemm fp32, Performance is " << perf << " Gflops"
  600. << std::endl;
  601. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  602. benchmarker_gemm.set_dtype(0, dtype::Float16())
  603. .set_dtype(1, dtype::Float16());
  604. time = benchmarker_gemm.exec({{M, K}, {K, N}, {}});
  605. perf = 2.f * M * K * N / time * mod;
  606. std::cout << "gemm fp16, Performance is " << perf << " Gflops"
  607. << std::endl;
  608. #endif
  609. };
  610. std::cout << "warm up:\n";
  611. for (int i = 0; i < 50; i++) {
  612. benchmarker_gemm.set_dtype(0, dtype::Float32())
  613. .set_dtype(1, dtype::Float32())
  614. .set_display(false)
  615. .exec({{256, 256}, {256, 256}, {}});
  616. benchmarker_gemm.set_display(true);
  617. }
  618. // run gemv
  619. for (size_t M : {1, 2, 3, 4, 5, 6, 7, 8, 64, 256})
  620. for (size_t K : {1, 2, 3, 4, 5, 6, 7, 8, 64, 256})
  621. for (size_t N : {112})
  622. run(M, K, N);
  623. }
  624. #if __ARM_FEATURE_DOTPROD
  625. TEST_F(AARCH64, BENCHMARK_TRANSPOSED_MATRIX_MUL_INT_8X8X32) {
  626. constexpr size_t RUNS = 50;
  627. param::MatrixMul param;
  628. param.transposeA = param.transposeB = true;
  629. Benchmarker<MatrixMul> benchmarker_int(handle());
  630. benchmarker_int.set_times(RUNS)
  631. .set_dtype(0, dtype::Int8{})
  632. .set_dtype(1, dtype::Int8{})
  633. .set_dtype(2, {})
  634. .set_param(param)
  635. .set_display(false);
  636. Benchmarker<MatrixMul> benchmarker_float(handle());
  637. benchmarker_float.set_param(param).set_display(false).set_times(RUNS);
  638. auto run = [&](size_t M, size_t N, size_t K) {
  639. auto int_used = benchmarker_int.exec({{K, M}, {N, K}, {}}) / RUNS;
  640. auto float_used = benchmarker_float.exec({{K, M}, {N, K}, {}}) / RUNS;
  641. float computations = 2.f * M * K * N * 1e-6;
  642. printf("run: {%zu{M} %zu{K} %zu{N}} float: %f ms %f Gflops int: %f ms "
  643. "%f Gflops speedup: %f\n",
  644. M, K, N, float_used, computations / float_used, int_used,
  645. computations / int_used, float_used / int_used);
  646. };
  647. run(256, 12 * 24, 256);
  648. for (size_t M : {8, 64, 112, 256}) {
  649. for (size_t K : {8, 64, 112, 256}) {
  650. for (size_t N : {8, 64, 112, 256}) {
  651. run(M, N, K);
  652. }
  653. }
  654. }
  655. }
  656. TEST_F(AARCH64, BENCHMARK_GEMV_INT_8X8X32) {
  657. constexpr size_t RUNS = 50;
  658. param::MatrixMul param;
  659. Benchmarker<MatrixMul> benchmarker_int(handle());
  660. benchmarker_int.set_times(RUNS)
  661. .set_dtype(0, dtype::Int8{})
  662. .set_dtype(1, dtype::Int8{})
  663. .set_dtype(2, {})
  664. .set_param(param)
  665. .set_display(false);
  666. Benchmarker<MatrixMul> benchmarker_float(handle());
  667. benchmarker_float.set_display(false).set_times(RUNS);
  668. auto run = [&](size_t M, size_t N, size_t K) {
  669. auto int_used = benchmarker_int.exec({{M, K}, {K, N}, {}}) / RUNS;
  670. auto float_used = benchmarker_float.exec({{M, K}, {K, N}, {}}) / RUNS;
  671. float computations = 2.f * M * K * N * 1e-6;
  672. printf("run: {%zu{M} %zu{K} %zu{N}} float: %f ms %f Gflops int: %f ms "
  673. "%f Gflops speedup: %f\n",
  674. M, K, N, float_used, computations / float_used, int_used,
  675. computations / int_used, float_used / int_used);
  676. };
  677. for (size_t M : {1, 2, 3, 4, 5, 6, 7, 8, 16, 32, 64, 256})
  678. for (size_t N : {1, 2, 3, 4, 5, 6, 7, 8, 16, 32, 64, 256})
  679. for (size_t K : {1, 2, 3, 4, 5, 6, 7, 8, 16, 32, 64, 256})
  680. run(M, N, K);
  681. }
  682. TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_INT8X8X32_MK4_8X12X4) {
  683. constexpr size_t RUNS = 50;
  684. param::MatrixMul param;
  685. param.transposeA = false;
  686. param.transposeB = false;
  687. Benchmarker<MatrixMul> benchmarker(handle());
  688. Benchmarker<MatrixMul> benchmarker_mk4(handle());
  689. benchmarker.set_times(RUNS)
  690. .set_dtype(0, dtype::Int8{})
  691. .set_dtype(1, dtype::Int8{})
  692. .set_dtype(2, dtype::Int32{})
  693. .set_param(param)
  694. .set_display(false);
  695. benchmarker.set_before_exec_callback(
  696. AlgoChecker<MatrixMul>("AARCH64_INT8X8X32_K8X12X4"));
  697. param.format = MatrixMul::Param::Format::MK4_DOT;
  698. benchmarker_mk4.set_before_exec_callback(
  699. AlgoChecker<MatrixMul>("AARCH64_INT8X8X32_MK4_8X12X4_DOTPROD"));
  700. benchmarker_mk4.set_times(RUNS)
  701. .set_dtype(0, dtype::Int8{})
  702. .set_dtype(1, dtype::Int8{})
  703. .set_dtype(2, dtype::Int32{})
  704. .set_param(param)
  705. .set_display(false);
  706. auto run = [&](size_t M, size_t N, size_t K) {
  707. auto default_used = benchmarker.exec({{M, K}, {K, N}, {}}) / RUNS;
  708. auto mk_used = benchmarker_mk4.exec(
  709. {{M / 4, K / 4, 4, 4}, {K / 4, N, 4}, {}}) /
  710. RUNS;
  711. float computations = 2.f * M * K * N * 1e-6;
  712. printf("run: {%zu{M} %zu{K} %zu{N}} normal: %f ms %f Gflops mk4: %f ms "
  713. "%f Gflops speedup_vs_normal: %f\n",
  714. M, K, N, default_used, computations / default_used, mk_used,
  715. computations / mk_used, default_used / mk_used);
  716. };
  717. run(256, 256, 128);
  718. for (size_t k = 4; k <= 512; k *= 2) {
  719. for (size_t m = 4; m <= 512; m *= 2) {
  720. for (size_t n = 4; n <= 512; n *= 2) {
  721. run(m, n, k);
  722. }
  723. }
  724. std::cout << std::endl;
  725. }
  726. }
  727. #endif // __ARM_FEATURE_DOTPROD
  728. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  729. TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_F16_MK8) {
  730. auto args = matrix_mul::get_benchmark_matmul_mk_packed_args(8);
  731. matrix_mul::benchmark_with_contrast(
  732. handle(), args, dtype::Float16{}, dtype::Float16{},
  733. dtype::Float16{}, "AARCH64_F16_MK8_8X8",
  734. param::MatrixMul::Format::MK8, dtype::Float16{}, dtype::Float16{},
  735. dtype::Float16{}, "AARCH64_F16_K8X24X1");
  736. }
  737. #endif
  738. TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_INT16x16x32) {
  739. constexpr size_t RUNS = 50;
  740. Benchmarker<MatrixMul> benchmarker_int(handle());
  741. benchmarker_int.set_times(RUNS)
  742. .set_dtype(0, dtype::Int16{})
  743. .set_dtype(1, dtype::Int16{})
  744. .set_dtype(2, dtype::Int32{})
  745. .set_display(false);
  746. Benchmarker<MatrixMul> benchmarker_float(handle());
  747. benchmarker_float.set_display(false).set_times(RUNS);
  748. auto run = [&](size_t M, size_t N, size_t K, int mask) {
  749. param::MatrixMul param;
  750. param.transposeA = mask & 0x1;
  751. param.transposeB = mask & 0x2;
  752. benchmarker_int.set_param(param);
  753. benchmarker_float.set_param(param);
  754. TensorShape A, B;
  755. if (param.transposeA) {
  756. A = TensorShape{K, M};
  757. } else {
  758. A = TensorShape{M, K};
  759. }
  760. if (param.transposeB) {
  761. B = TensorShape{N, K};
  762. } else {
  763. B = TensorShape{K, N};
  764. }
  765. auto int_used = benchmarker_int.exec({A, B, {}}) / RUNS;
  766. auto float_used = benchmarker_float.exec({A, B, {}}) / RUNS;
  767. float computations = 2.f * M * K * N * 1e-6;
  768. printf("run: {%zu{M} %zu{K} %zu{N} %d{TA} %d{TB}} "
  769. "float: %f ms %f Gflops int: %f ms "
  770. "%f Gflops speedup: %f\n",
  771. M, K, N, param.transposeA, param.transposeB, float_used,
  772. computations / float_used, int_used, computations / int_used,
  773. float_used / int_used);
  774. };
  775. constexpr int mask = 4;
  776. for (auto i = 0; i < mask; i++) {
  777. for (size_t M : {8, 64, 112, 256}) {
  778. for (size_t K : {8, 64, 112, 256}) {
  779. for (size_t N : {8, 64, 112, 256}) {
  780. run(M, N, K, i);
  781. }
  782. }
  783. }
  784. }
  785. }
  786. TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_FP32_MK4) {
  787. auto args = matrix_mul::get_benchmark_matmul_mk_packed_args(16);
  788. matrix_mul::benchmark_with_contrast(
  789. handle(), args, dtype::Float32{}, dtype::Float32{},
  790. dtype::Float32{}, "AARCH64_F32_MK4_4x16",
  791. param::MatrixMul::Format::MK4, dtype::Float32{}, dtype::Float32{},
  792. dtype::Float32{});
  793. }
  794. TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_FP32_PACK_MK4) {
  795. auto args = matrix_mul::get_benchmark_matmul_mk_packed_args(16);
  796. matrix_mul::benchmark_with_contrast(
  797. handle(), args, dtype::Float32{}, dtype::Float32{},
  798. dtype::Float32{}, "AARCH64_F32_MK4_K8X12X1",
  799. param::MatrixMul::Format::MK4, dtype::Float32{}, dtype::Float32{},
  800. dtype::Float32{}, "AARCH64_F32K8X12X1");
  801. }
  802. TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_INT16x16x32_MK8) {
  803. auto args = matrix_mul::get_benchmark_matmul_mk_packed_args(8);
  804. matrix_mul::benchmark_with_contrast(
  805. handle(), args, dtype::Int16{}, dtype::Int16{}, dtype::Int32{},
  806. "AARCH64_INT16X16X32_MK8_8X8", param::MatrixMul::Format::MK8,
  807. dtype::Int16{}, dtype::Int16{}, dtype::Int32{});
  808. }
  809. TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_FP32_K8X12) {
  810. constexpr size_t RUNS = 50;
  811. param::MatrixMul param;
  812. param.transposeA = param.transposeB = true;
  813. Benchmarker<MatrixMul> benchmarker_k12x8(handle());
  814. Benchmarker<MatrixMul> benchmarker_k8x12(handle());
  815. benchmarker_k12x8.set_param(param).set_display(false).set_times(RUNS);
  816. benchmarker_k8x12.set_param(param).set_display(false).set_times(RUNS);
  817. benchmarker_k12x8.set_before_exec_callback(
  818. AlgoChecker<MatrixMul>("AARCH64_F32K4X16X1"));
  819. benchmarker_k8x12.set_before_exec_callback(
  820. AlgoChecker<MatrixMul>("AARCH64_F32K8X12X1"));
  821. auto run = [&](size_t M, size_t N, size_t K) {
  822. auto k12x8_used = benchmarker_k12x8.exec({{K, M}, {N, K}, {}}) / RUNS;
  823. auto k8x12_used = benchmarker_k8x12.exec({{K, M}, {N, K}, {}}) / RUNS;
  824. float computations = 2.f * M * K * N * 1e-6;
  825. printf("run: {%zu{M} %zu{K} %zu{N}} float k12x8: %f ms %f Gflops "
  826. "k8x12: %f ms "
  827. "%f Gflops speedup: %f\n",
  828. M, K, N, k12x8_used, computations / k12x8_used, k8x12_used,
  829. computations / k8x12_used, k12x8_used / k8x12_used);
  830. };
  831. run(256, 12 * 24, 256);
  832. for (size_t M : {8, 64, 112, 256}) {
  833. for (size_t K : {8, 64, 112, 256}) {
  834. for (size_t N : {8, 64, 112, 256}) {
  835. run(M, N, K);
  836. }
  837. }
  838. }
  839. }
  840. TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_FP32_K8X12_NO_TRANS) {
  841. constexpr size_t RUNS = 50;
  842. param::MatrixMul param;
  843. param.transposeA = param.transposeB = false;
  844. Benchmarker<MatrixMul> benchmarker_k12x8(handle());
  845. Benchmarker<MatrixMul> benchmarker_k8x12(handle());
  846. benchmarker_k12x8.set_param(param).set_display(false).set_times(RUNS);
  847. benchmarker_k8x12.set_param(param).set_display(false).set_times(RUNS);
  848. benchmarker_k12x8.set_before_exec_callback(
  849. AlgoChecker<MatrixMul>("AARCH64_F32K4X16X1"));
  850. benchmarker_k8x12.set_before_exec_callback(
  851. AlgoChecker<MatrixMul>("AARCH64_F32K8X12X1"));
  852. auto run = [&](size_t M, size_t N, size_t K) {
  853. auto k12x8_used = benchmarker_k12x8.exec({{M, K}, {K, N}, {}}) / RUNS;
  854. auto k8x12_used = benchmarker_k8x12.exec({{M, K}, {K, N}, {}}) / RUNS;
  855. float computations = 2.f * M * K * N * 1e-6;
  856. printf("run: {%zu{M} %zu{K} %zu{N}} float k12x8: %f ms %f Gflops "
  857. "k8x12: %f ms "
  858. "%f Gflops speedup: %f\n",
  859. M, K, N, k12x8_used, computations / k12x8_used, k8x12_used,
  860. computations / k8x12_used, k12x8_used / k8x12_used);
  861. };
  862. run(256, 12 * 24, 256);
  863. for (size_t M : {8, 64, 112, 256}) {
  864. for (size_t K : {8, 64, 112, 256}) {
  865. for (size_t N : {8, 64, 112, 256}) {
  866. run(M, N, K);
  867. }
  868. }
  869. }
  870. }
  871. #endif // MEGDNN_WITH_BENCHMARK
  872. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台