You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

matrix_mul.cpp 6.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. #include "test/common/matrix_mul.h"
  2. #include "test/common/checker.h"
  3. #include "test/common/rng.h"
  4. #include "test/common/task_record_check.h"
  5. #include "test/fallback/fixture.h"
  6. namespace megdnn {
  7. namespace test {
  8. TEST_F(FALLBACK, MATRIX_MUL) {
  9. Checker<MatrixMul> checker(handle());
  10. using Param = MatrixMul::Param;
  11. auto args = matrix_mul::get_matmul_args();
  12. for (auto arg : args) {
  13. auto m = arg.m, n = arg.n, k = arg.k;
  14. auto mask = arg.mask;
  15. Param param;
  16. param.transposeA = mask & 1;
  17. param.transposeB = mask & 2;
  18. TensorShape AS, BS, CS;
  19. if (param.transposeA)
  20. AS = TensorShape{k, m};
  21. else
  22. AS = TensorShape{m, k};
  23. if (param.transposeB)
  24. BS = TensorShape{n, k};
  25. else
  26. BS = TensorShape{k, n};
  27. CS = TensorShape{m, n};
  28. TensorLayout AL, BL, CL;
  29. AL = TensorLayout(AS, dtype::Float32());
  30. BL = TensorLayout(BS, dtype::Float32());
  31. CL = TensorLayout(CS, dtype::Float32());
  32. checker.set_param(param);
  33. checker.execl({AL, BL, CL});
  34. }
  35. }
  36. TEST_F(FALLBACK, MATRIX_MUL_MK4_GI) {
  37. matrix_mul::check_matrix_mul(
  38. dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(),
  39. "FB_GI_F32_MK4_4x8", param::MatrixMul::Format::MK4, 1);
  40. }
  41. TEST_F(FALLBACK, MATRIX_MUL_GI_F32_4x12) {
  42. matrix_mul::check_matrix_mul(
  43. dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(),
  44. "FB_GI_F32_4x12");
  45. }
  46. TEST_F(FALLBACK, MATRIX_MUL_GI_PACK_MK4) {
  47. matrix_mul::check_matrix_mul(
  48. dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(),
  49. "FB_GI_F32_MK4_PACK_4x12", param::MatrixMul::Format::MK4, 1);
  50. }
  51. TEST_F(FALLBACK, MATRIX_MUL_RECORD) {
  52. TaskRecordChecker<MatrixMul> checker(1);
  53. using Param = MatrixMul::Param;
  54. auto args = matrix_mul::get_matmul_args();
  55. for (auto arg : args) {
  56. auto m = arg.m, n = arg.n, k = arg.k;
  57. auto mask = arg.mask;
  58. Param param;
  59. param.transposeA = mask & 1;
  60. param.transposeB = mask & 2;
  61. TensorShape AS, BS, CS;
  62. if (param.transposeA)
  63. AS = TensorShape{k, m};
  64. else
  65. AS = TensorShape{m, k};
  66. if (param.transposeB)
  67. BS = TensorShape{n, k};
  68. else
  69. BS = TensorShape{k, n};
  70. CS = TensorShape{m, n};
  71. TensorLayout AL, BL, CL;
  72. AL = TensorLayout(AS, dtype::Float32());
  73. BL = TensorLayout(BS, dtype::Float32());
  74. CL = TensorLayout(CS, dtype::Float32());
  75. checker.set_param(param);
  76. checker.execl({AL, BL, CL});
  77. }
  78. }
  79. TEST_F(FALLBACK, MATRIX_MUL_NAIVE_MK4) {
  80. matrix_mul::check_matrix_mul(
  81. dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(), "FB_NAIVE",
  82. param::MatrixMul::Format::MK4, 1);
  83. }
  84. TEST_F(FALLBACK, MATRIX_MUL_NAIVE_MK8) {
  85. matrix_mul::check_matrix_mul(
  86. dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(), "FB_NAIVE",
  87. param::MatrixMul::Format::MK8, 1);
  88. }
  89. TEST_F(FALLBACK, MATRIX_MUL_NAIVE_MK4_DOT) {
  90. matrix_mul::check_matrix_mul(
  91. dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(), "FB_NAIVE",
  92. param::MatrixMul::Format::MK4_DOT, 1);
  93. }
  94. TEST_F(FALLBACK, MATRIX_MUL_NAIVE) {
  95. Checker<MatrixMul> checker(handle());
  96. checker.set_before_exec_callback(AlgoChecker<MatrixMul>("FB_NAIVE"));
  97. using Param = MatrixMul::Param;
  98. auto args = matrix_mul::get_matmul_args();
  99. for (auto arg : args) {
  100. auto m = arg.m, n = arg.n, k = arg.k;
  101. auto mask = arg.mask;
  102. Param param;
  103. param.transposeA = mask & 1;
  104. param.transposeB = mask & 2;
  105. TensorShape AS, BS, CS;
  106. if (param.transposeA)
  107. AS = TensorShape{k, m};
  108. else
  109. AS = TensorShape{m, k};
  110. if (param.transposeB)
  111. BS = TensorShape{n, k};
  112. else
  113. BS = TensorShape{k, n};
  114. CS = TensorShape{m, n};
  115. TensorLayout AL, BL, CL;
  116. AL = TensorLayout(AS, dtype::Float32());
  117. BL = TensorLayout(BS, dtype::Float32());
  118. CL = TensorLayout(CS, dtype::Float32());
  119. checker.set_param(param);
  120. checker.execl({AL, BL, CL});
  121. }
  122. }
  123. TEST_F(FALLBACK, BATCHED_MATRIX_MUL) {
  124. Checker<BatchedMatrixMul> checker(handle());
  125. using Param = MatrixMul::Param;
  126. auto args = matrix_mul::get_batched_matmul_args();
  127. for (auto arg : args) {
  128. auto b = arg.b, m = arg.m, n = arg.n, k = arg.k;
  129. auto mask = arg.mask;
  130. Param param;
  131. param.transposeA = mask & 1;
  132. param.transposeB = mask & 2;
  133. TensorShape AS, BS, CS;
  134. if (param.transposeA)
  135. AS = TensorShape{b, k, m};
  136. else
  137. AS = TensorShape{b, m, k};
  138. if (param.transposeB)
  139. BS = TensorShape{b, n, k};
  140. else
  141. BS = TensorShape{b, k, n};
  142. TensorLayout AL, BL;
  143. AL = TensorLayout(AS, dtype::Float32());
  144. BL = TensorLayout(BS, dtype::Float32());
  145. checker.set_param(param);
  146. checker.execs({AL, BL, {}});
  147. }
  148. }
  149. #if MEGDNN_WITH_BENCHMARK
  150. TEST_F(FALLBACK, BENCHMARK_MATRIX_MUL_FB_GI_F32_4x12) {
  151. auto args = matrix_mul::get_benchmark_matmul_args();
  152. matrix_mul::benchmark_single_algo(
  153. handle(), args, dtype::Float32{}, dtype::Float32{}, dtype::Float32{},
  154. "FB_GI_F32_4x12", param::MatrixMul::Format::DEFAULT);
  155. }
  156. TEST_F(FALLBACK, BENCHMARK_MATRIX_MUL_GI_PACK_MK4) {
  157. auto args = matrix_mul::get_benchmark_matmul_args();
  158. matrix_mul::benchmark_single_algo(
  159. handle(), args, dtype::Float32{}, dtype::Float32{}, dtype::Float32{},
  160. "FB_GI_F32_MK4_PACK_4x12", param::MatrixMul::Format::MK4);
  161. }
  162. TEST_F(FALLBACK, BENCHMARK_MATRIX_FB_GI_F32_MK4_4x8) {
  163. auto args = matrix_mul::get_benchmark_matmul_args();
  164. matrix_mul::benchmark_single_algo(
  165. handle(), args, dtype::Float32{}, dtype::Float32{}, dtype::Float32{},
  166. "FB_GI_F32_MK4_4x8", param::MatrixMul::Format::MK4);
  167. }
  168. #endif
  169. } // namespace test
  170. } // namespace megdnn
  171. // vim: syntax=cpp.doxygen