You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

batched_matrix_mul.cpp 8.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271
  1. #include "test/common/benchmarker.h"
  2. #include "test/common/checker.h"
  3. #include "test/common/matrix_mul.h"
  4. #include "test/common/rng.h"
  5. #include "test/common/task_record_check.h"
  6. #include "test/aarch64/fixture.h"
  7. namespace megdnn {
  8. namespace test {
  9. TEST_F(AARCH64, BATCHED_MATRIX_MUL) {
  10. Checker<BatchedMatrixMul> checker(handle());
  11. checker.set_epsilon(1e-2);
  12. using Param = MatrixMul::Param;
  13. // auto args = get_batch_matmul_args();
  14. auto args = matrix_mul::get_batched_matmul_args();
  15. for (DType dtype : std::vector<DType>{dtype::Float32()}) {
  16. for (unsigned mask = 0; mask < 4; ++mask) {
  17. for (auto& arg : args) {
  18. size_t b = arg.b, m = arg.m, n = arg.n, k = arg.k;
  19. //! if test all batch sizes, the test case will time out.
  20. if (b != 2) {
  21. continue;
  22. }
  23. Param param;
  24. param.transposeA = mask & 1;
  25. param.transposeB = mask & 2;
  26. TensorShape A, B;
  27. if (param.transposeA)
  28. A = TensorShape{b, k, m};
  29. else
  30. A = TensorShape{b, m, k};
  31. if (param.transposeB)
  32. B = TensorShape{b, n, k};
  33. else
  34. B = TensorShape{b, k, n};
  35. checker.set_param(param).set_dtype(0, dtype).set_dtype(1, dtype).execs(
  36. {A, B, {}});
  37. }
  38. }
  39. }
  40. }
  41. TEST_F(AARCH64, BATCHED_MATRIX_MUL_RECORD) {
  42. TaskRecordChecker<BatchedMatrixMul> checker(0);
  43. checker.set_epsilon(1e-2);
  44. using Param = MatrixMul::Param;
  45. // auto args = get_batch_matmul_args();
  46. auto args = matrix_mul::get_batched_matmul_args();
  47. for (DType dtype : std::vector<DType>{dtype::Float32()}) {
  48. for (unsigned mask = 0; mask < 4; ++mask) {
  49. for (auto& arg : args) {
  50. size_t b = arg.b, m = arg.m, n = arg.n, k = arg.k;
  51. //! if test all batch sizes, the test case will time out.
  52. if (b != 2) {
  53. continue;
  54. }
  55. Param param;
  56. param.transposeA = mask & 1;
  57. param.transposeB = mask & 2;
  58. TensorShape A, B;
  59. if (param.transposeA)
  60. A = TensorShape{b, k, m};
  61. else
  62. A = TensorShape{b, m, k};
  63. if (param.transposeB)
  64. B = TensorShape{b, n, k};
  65. else
  66. B = TensorShape{b, k, n};
  67. checker.set_param(param).set_dtype(0, dtype).set_dtype(1, dtype).execs(
  68. {A, B, {}});
  69. }
  70. }
  71. }
  72. }
  73. #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  74. TEST_F(AARCH64, BATCHED_MATRIX_MUL_FP16) {
  75. Checker<BatchedMatrixMul> checker(handle());
  76. using Param = MatrixMul::Param;
  77. auto args = matrix_mul::get_batched_matmul_args();
  78. NormalRNG rng(1.f);
  79. checker.set_rng(0, &rng).set_rng(1, &rng).set_epsilon(1e-2);
  80. for (DType dtype : std::vector<DType>{dtype::Float16()}) {
  81. for (unsigned mask = 0; mask < 4; ++mask) {
  82. for (auto& arg : args) {
  83. size_t b = arg.b, m = arg.m, n = arg.n, k = arg.k;
  84. //! if test all batch sizes, the test case will time out on
  85. //! sdm855
  86. if (b != 1) {
  87. continue;
  88. }
  89. Param param;
  90. param.transposeA = mask & 1;
  91. param.transposeB = mask & 2;
  92. TensorShape A, B;
  93. if (param.transposeA)
  94. A = TensorShape{b, k, m};
  95. else
  96. A = TensorShape{b, m, k};
  97. if (param.transposeB)
  98. B = TensorShape{b, n, k};
  99. else
  100. B = TensorShape{b, k, n};
  101. checker.set_param(param)
  102. .set_dtype(0, dtype)
  103. .set_dtype(1, dtype)
  104. .set_dtype(2, dtype)
  105. .execs({A, B, {}});
  106. }
  107. }
  108. }
  109. }
  110. #if MEGDNN_WITH_BENCHMARK
  111. TEST_F(AARCH64, BENCHMARK_TRANSPOSED_MATRIX_MUL_QUICK_FP16) {
  112. int exec_times = 10;
  113. Benchmarker<MatrixMul> benchmarker_gemm(handle());
  114. benchmarker_gemm.set_times(exec_times);
  115. float mod = 1000 * exec_times / 1e9;
  116. using Param = MatrixMul::Param;
  117. auto run = [&](size_t M, size_t K, size_t N) {
  118. float time = 1.f, perf = 1.f;
  119. std::cout << "GEMM: (" << M << ", " << K << ", " << N << ")" << std::endl;
  120. Param param;
  121. param.transposeA = true;
  122. param.transposeB = true;
  123. benchmarker_gemm.set_param(param)
  124. .set_dtype(0, dtype::Float32())
  125. .set_dtype(1, dtype::Float32());
  126. time = benchmarker_gemm.exec({{M, K}, {K, N}, {}});
  127. perf = 2.f * M * K * N / time * mod;
  128. std::cout << "gemm fp32, Performance is " << perf << " Gflops" << std::endl;
  129. benchmarker_gemm.set_param(param)
  130. .set_dtype(0, dtype::Float16())
  131. .set_dtype(1, dtype::Float16());
  132. time = benchmarker_gemm.exec({{M, K}, {K, N}, {}});
  133. perf = 2.f * M * K * N / time * mod;
  134. std::cout << "gemm fp16, Performance is " << perf << " Gflops" << std::endl;
  135. };
  136. // run M = K = N
  137. run(32, 32, 32);
  138. run(64, 64, 64);
  139. run(128, 128, 128);
  140. run(256, 256, 256);
  141. run(512, 512, 512);
  142. run(1024, 1024, 1024);
  143. run(2048, 2048, 2048);
  144. }
  145. TEST_F(AARCH64, BENCHMARK_TRANSPOSED_MATRIX_MUL_ALL_SIZES_FP16) {
  146. int exec_times = 50;
  147. Benchmarker<MatrixMul> benchmarker_gemm(handle());
  148. benchmarker_gemm.set_times(exec_times);
  149. float mod = 1000 * exec_times / 1e9;
  150. using Param = MatrixMul::Param;
  151. auto run = [&](size_t M, size_t K, size_t N) {
  152. float time = 1.f, perf = 1.f;
  153. std::cout << "GEMM: (" << M << ", " << K << ", " << N << ")" << std::endl;
  154. Param param;
  155. param.transposeA = param.transposeB = true;
  156. benchmarker_gemm.set_param(param)
  157. .set_dtype(0, dtype::Float32())
  158. .set_dtype(1, dtype::Float32());
  159. time = benchmarker_gemm.exec({{K, M}, {N, K}, {}});
  160. perf = 2.f * M * K * N / time * mod;
  161. std::cout << "gemm fp32, Performance is " << perf << " Gflops" << std::endl;
  162. benchmarker_gemm.set_param(param)
  163. .set_dtype(0, dtype::Float16())
  164. .set_dtype(1, dtype::Float16());
  165. time = benchmarker_gemm.exec({{K, M}, {N, K}, {}});
  166. perf = 2.f * M * K * N / time * mod;
  167. std::cout << "gemm fp16, Performance is " << perf << " Gflops" << std::endl;
  168. };
  169. std::cout << "warm up:\n";
  170. for (int i = 0; i < 50; i++) {
  171. benchmarker_gemm.set_dtype(0, dtype::Float32())
  172. .set_dtype(1, dtype::Float32())
  173. .set_display(false)
  174. .exec({{256, 256}, {256, 256}, {}});
  175. benchmarker_gemm.set_display(true);
  176. }
  177. // run M = K = N
  178. run(8, 8, 8);
  179. run(16, 16, 16);
  180. run(32, 32, 32);
  181. run(64, 64, 64);
  182. run(128, 128, 128);
  183. run(256, 256, 256);
  184. run(512, 512, 512);
  185. run(1024, 1024, 1024);
  186. run(2048, 2048, 2048);
  187. // run sgmev like
  188. run(32, 32, 1);
  189. run(64, 64, 1);
  190. run(128, 128, 1);
  191. run(256, 256, 1);
  192. run(512, 512, 1);
  193. // run M, N >> K
  194. run(32, 16, 32);
  195. run(64, 16, 64);
  196. run(128, 16, 128);
  197. run(256, 16, 256);
  198. run(512, 16, 512);
  199. // run N, K >> M
  200. run(16, 32, 32);
  201. run(16, 64, 64);
  202. run(16, 128, 128);
  203. run(16, 256, 256);
  204. run(16, 512, 512);
  205. // run M >> K, N
  206. run(32, 16, 16);
  207. run(64, 16, 16);
  208. run(128, 16, 16);
  209. run(256, 16, 16);
  210. run(512, 16, 16);
  211. // run K >> M, N
  212. run(16, 32, 16);
  213. run(16, 64, 16);
  214. run(16, 128, 16);
  215. run(16, 256, 16);
  216. run(16, 512, 16);
  217. // run N >> M, K
  218. run(16, 16, 32);
  219. run(16, 16, 64);
  220. run(16, 16, 128);
  221. run(16, 16, 256);
  222. run(16, 16, 512);
  223. // run VGG
  224. // conv 1.1
  225. run(64, 3 * 3 * 3, 224 * 224);
  226. // conv 1.2
  227. run(128, 64 * 3 * 3, 112 * 112);
  228. // conv 2.1
  229. run(128, 128 * 3 * 3, 112 * 112);
  230. // conv 2.2
  231. run(128, 128 * 3 * 3, 56 * 56);
  232. // conv 3.1
  233. run(256, 128 * 3 * 3, 56 * 56);
  234. // conv 3.2
  235. run(256, 256 * 3 * 3, 28 * 28);
  236. // conv 4.1
  237. run(512, 256 * 3 * 3, 28 * 28);
  238. // conv 4.2
  239. run(512, 512 * 3 * 3, 14 * 14);
  240. }
  241. #endif
  242. #endif
  243. } // namespace test
  244. } // namespace megdnn
  245. // vim: syntax=cpp.doxygen