You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

batched_matrix_mul.cpp 9.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282
  1. /**
  2. * \file dnn/test/aarch64/batched_matrix_mul.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "test/common/benchmarker.h"
  12. #include "test/common/checker.h"
  13. #include "test/common/matrix_mul.h"
  14. #include "test/common/rng.h"
  15. #include "test/common/task_record_check.h"
  16. #include "test/aarch64/fixture.h"
  17. namespace megdnn {
  18. namespace test {
  19. TEST_F(AARCH64, BATCHED_MATRIX_MUL) {
  20. Checker<BatchedMatrixMul> checker(handle());
  21. checker.set_epsilon(1e-2);
  22. using Param = MatrixMul::Param;
  23. // auto args = get_batch_matmul_args();
  24. auto args = matrix_mul::get_batched_matmul_args();
  25. for (DType dtype : std::vector<DType>{dtype::Float32()}) {
  26. for (unsigned mask = 0; mask < 4; ++mask) {
  27. for (auto& arg : args) {
  28. size_t b = arg.b, m = arg.m, n = arg.n, k = arg.k;
  29. //! if test all batch sizes, the test case will time out.
  30. if (b != 2) {
  31. continue;
  32. }
  33. Param param;
  34. param.transposeA = mask & 1;
  35. param.transposeB = mask & 2;
  36. TensorShape A, B;
  37. if (param.transposeA)
  38. A = TensorShape{b, k, m};
  39. else
  40. A = TensorShape{b, m, k};
  41. if (param.transposeB)
  42. B = TensorShape{b, n, k};
  43. else
  44. B = TensorShape{b, k, n};
  45. checker.set_param(param).set_dtype(0, dtype).set_dtype(1, dtype).execs(
  46. {A, B, {}});
  47. }
  48. }
  49. }
  50. }
  51. TEST_F(AARCH64, BATCHED_MATRIX_MUL_RECORD) {
  52. TaskRecordChecker<BatchedMatrixMul> checker(0);
  53. checker.set_epsilon(1e-2);
  54. using Param = MatrixMul::Param;
  55. // auto args = get_batch_matmul_args();
  56. auto args = matrix_mul::get_batched_matmul_args();
  57. for (DType dtype : std::vector<DType>{dtype::Float32()}) {
  58. for (unsigned mask = 0; mask < 4; ++mask) {
  59. for (auto& arg : args) {
  60. size_t b = arg.b, m = arg.m, n = arg.n, k = arg.k;
  61. //! if test all batch sizes, the test case will time out.
  62. if (b != 2) {
  63. continue;
  64. }
  65. Param param;
  66. param.transposeA = mask & 1;
  67. param.transposeB = mask & 2;
  68. TensorShape A, B;
  69. if (param.transposeA)
  70. A = TensorShape{b, k, m};
  71. else
  72. A = TensorShape{b, m, k};
  73. if (param.transposeB)
  74. B = TensorShape{b, n, k};
  75. else
  76. B = TensorShape{b, k, n};
  77. checker.set_param(param).set_dtype(0, dtype).set_dtype(1, dtype).execs(
  78. {A, B, {}});
  79. }
  80. }
  81. }
  82. }
  83. #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  84. TEST_F(AARCH64, BATCHED_MATRIX_MUL_FP16) {
  85. Checker<BatchedMatrixMul> checker(handle());
  86. using Param = MatrixMul::Param;
  87. auto args = matrix_mul::get_batched_matmul_args();
  88. NormalRNG rng(1.f);
  89. checker.set_rng(0, &rng).set_rng(1, &rng).set_epsilon(1e-2);
  90. for (DType dtype : std::vector<DType>{dtype::Float16()}) {
  91. for (unsigned mask = 0; mask < 4; ++mask) {
  92. for (auto& arg : args) {
  93. size_t b = arg.b, m = arg.m, n = arg.n, k = arg.k;
  94. //! if test all batch sizes, the test case will time out on
  95. //! sdm855
  96. if (b != 1) {
  97. continue;
  98. }
  99. Param param;
  100. param.transposeA = mask & 1;
  101. param.transposeB = mask & 2;
  102. TensorShape A, B;
  103. if (param.transposeA)
  104. A = TensorShape{b, k, m};
  105. else
  106. A = TensorShape{b, m, k};
  107. if (param.transposeB)
  108. B = TensorShape{b, n, k};
  109. else
  110. B = TensorShape{b, k, n};
  111. checker.set_param(param)
  112. .set_dtype(0, dtype)
  113. .set_dtype(1, dtype)
  114. .set_dtype(2, dtype)
  115. .execs({A, B, {}});
  116. }
  117. }
  118. }
  119. }
  120. #if MEGDNN_WITH_BENCHMARK
  121. TEST_F(AARCH64, BENCHMARK_TRANSPOSED_MATRIX_MUL_QUICK_FP16) {
  122. int exec_times = 10;
  123. Benchmarker<MatrixMul> benchmarker_gemm(handle());
  124. benchmarker_gemm.set_times(exec_times);
  125. float mod = 1000 * exec_times / 1e9;
  126. using Param = MatrixMul::Param;
  127. auto run = [&](size_t M, size_t K, size_t N) {
  128. float time = 1.f, perf = 1.f;
  129. std::cout << "GEMM: (" << M << ", " << K << ", " << N << ")" << std::endl;
  130. Param param;
  131. param.transposeA = true;
  132. param.transposeB = true;
  133. benchmarker_gemm.set_param(param)
  134. .set_dtype(0, dtype::Float32())
  135. .set_dtype(1, dtype::Float32());
  136. time = benchmarker_gemm.exec({{M, K}, {K, N}, {}});
  137. perf = 2.f * M * K * N / time * mod;
  138. std::cout << "gemm fp32, Performance is " << perf << " Gflops" << std::endl;
  139. benchmarker_gemm.set_param(param)
  140. .set_dtype(0, dtype::Float16())
  141. .set_dtype(1, dtype::Float16());
  142. time = benchmarker_gemm.exec({{M, K}, {K, N}, {}});
  143. perf = 2.f * M * K * N / time * mod;
  144. std::cout << "gemm fp16, Performance is " << perf << " Gflops" << std::endl;
  145. };
  146. // run M = K = N
  147. run(32, 32, 32);
  148. run(64, 64, 64);
  149. run(128, 128, 128);
  150. run(256, 256, 256);
  151. run(512, 512, 512);
  152. run(1024, 1024, 1024);
  153. run(2048, 2048, 2048);
  154. }
  155. TEST_F(AARCH64, BENCHMARK_TRANSPOSED_MATRIX_MUL_ALL_SIZES_FP16) {
  156. int exec_times = 50;
  157. Benchmarker<MatrixMul> benchmarker_gemm(handle());
  158. benchmarker_gemm.set_times(exec_times);
  159. float mod = 1000 * exec_times / 1e9;
  160. using Param = MatrixMul::Param;
  161. auto run = [&](size_t M, size_t K, size_t N) {
  162. float time = 1.f, perf = 1.f;
  163. std::cout << "GEMM: (" << M << ", " << K << ", " << N << ")" << std::endl;
  164. Param param;
  165. param.transposeA = param.transposeB = true;
  166. benchmarker_gemm.set_param(param)
  167. .set_dtype(0, dtype::Float32())
  168. .set_dtype(1, dtype::Float32());
  169. time = benchmarker_gemm.exec({{K, M}, {N, K}, {}});
  170. perf = 2.f * M * K * N / time * mod;
  171. std::cout << "gemm fp32, Performance is " << perf << " Gflops" << std::endl;
  172. benchmarker_gemm.set_param(param)
  173. .set_dtype(0, dtype::Float16())
  174. .set_dtype(1, dtype::Float16());
  175. time = benchmarker_gemm.exec({{K, M}, {N, K}, {}});
  176. perf = 2.f * M * K * N / time * mod;
  177. std::cout << "gemm fp16, Performance is " << perf << " Gflops" << std::endl;
  178. };
  179. std::cout << "warm up:\n";
  180. for (int i = 0; i < 50; i++) {
  181. benchmarker_gemm.set_dtype(0, dtype::Float32())
  182. .set_dtype(1, dtype::Float32())
  183. .set_display(false)
  184. .exec({{256, 256}, {256, 256}, {}});
  185. benchmarker_gemm.set_display(true);
  186. }
  187. // run M = K = N
  188. run(8, 8, 8);
  189. run(16, 16, 16);
  190. run(32, 32, 32);
  191. run(64, 64, 64);
  192. run(128, 128, 128);
  193. run(256, 256, 256);
  194. run(512, 512, 512);
  195. run(1024, 1024, 1024);
  196. run(2048, 2048, 2048);
  197. // run sgmev like
  198. run(32, 32, 1);
  199. run(64, 64, 1);
  200. run(128, 128, 1);
  201. run(256, 256, 1);
  202. run(512, 512, 1);
  203. // run M, N >> K
  204. run(32, 16, 32);
  205. run(64, 16, 64);
  206. run(128, 16, 128);
  207. run(256, 16, 256);
  208. run(512, 16, 512);
  209. // run N, K >> M
  210. run(16, 32, 32);
  211. run(16, 64, 64);
  212. run(16, 128, 128);
  213. run(16, 256, 256);
  214. run(16, 512, 512);
  215. // run M >> K, N
  216. run(32, 16, 16);
  217. run(64, 16, 16);
  218. run(128, 16, 16);
  219. run(256, 16, 16);
  220. run(512, 16, 16);
  221. // run K >> M, N
  222. run(16, 32, 16);
  223. run(16, 64, 16);
  224. run(16, 128, 16);
  225. run(16, 256, 16);
  226. run(16, 512, 16);
  227. // run N >> M, K
  228. run(16, 16, 32);
  229. run(16, 16, 64);
  230. run(16, 16, 128);
  231. run(16, 16, 256);
  232. run(16, 16, 512);
  233. // run VGG
  234. // conv 1.1
  235. run(64, 3 * 3 * 3, 224 * 224);
  236. // conv 1.2
  237. run(128, 64 * 3 * 3, 112 * 112);
  238. // conv 2.1
  239. run(128, 128 * 3 * 3, 112 * 112);
  240. // conv 2.2
  241. run(128, 128 * 3 * 3, 56 * 56);
  242. // conv 3.1
  243. run(256, 128 * 3 * 3, 56 * 56);
  244. // conv 3.2
  245. run(256, 256 * 3 * 3, 28 * 28);
  246. // conv 4.1
  247. run(512, 256 * 3 * 3, 28 * 28);
  248. // conv 4.2
  249. run(512, 512 * 3 * 3, 14 * 14);
  250. }
  251. #endif
  252. #endif
  253. } // namespace test
  254. } // namespace megdnn
  255. // vim: syntax=cpp.doxygen