You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

matrix_mul.cpp 5.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. #include "test/common/matrix_mul.h"
  2. #include "test/common/checker.h"
  3. #include "test/common/rng.h"
  4. #include "test/common/task_record_check.h"
  5. #include "test/fallback/fixture.h"
  6. namespace megdnn {
  7. namespace test {
  8. TEST_F(FALLBACK, MATRIX_MUL) {
  9. Checker<MatrixMul> checker(handle());
  10. using Param = MatrixMul::Param;
  11. auto args = matrix_mul::get_matmul_args();
  12. for (auto arg : args) {
  13. auto m = arg.m, n = arg.n, k = arg.k;
  14. auto mask = arg.mask;
  15. Param param;
  16. param.transposeA = mask & 1;
  17. param.transposeB = mask & 2;
  18. TensorShape AS, BS, CS;
  19. if (param.transposeA)
  20. AS = TensorShape{k, m};
  21. else
  22. AS = TensorShape{m, k};
  23. if (param.transposeB)
  24. BS = TensorShape{n, k};
  25. else
  26. BS = TensorShape{k, n};
  27. CS = TensorShape{m, n};
  28. TensorLayout AL, BL, CL;
  29. AL = TensorLayout(AS, dtype::Float32());
  30. BL = TensorLayout(BS, dtype::Float32());
  31. CL = TensorLayout(CS, dtype::Float32());
  32. checker.set_param(param);
  33. checker.execl({AL, BL, CL});
  34. }
  35. }
  36. TEST_F(FALLBACK, MATRIX_MUL_MK4_GI) {
  37. matrix_mul::check_matrix_mul(
  38. dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(),
  39. "FB_GI_F32_MK4_4x8", param::MatrixMul::Format::MK4, 1);
  40. }
  41. TEST_F(FALLBACK, MATRIX_MULF_GI_F32_4x12) {
  42. matrix_mul::check_matrix_mul(
  43. dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(),
  44. "FB_GI_F32_4x12");
  45. }
  46. TEST_F(FALLBACK, MATRIX_MUL_RECORD) {
  47. TaskRecordChecker<MatrixMul> checker(1);
  48. using Param = MatrixMul::Param;
  49. auto args = matrix_mul::get_matmul_args();
  50. for (auto arg : args) {
  51. auto m = arg.m, n = arg.n, k = arg.k;
  52. auto mask = arg.mask;
  53. Param param;
  54. param.transposeA = mask & 1;
  55. param.transposeB = mask & 2;
  56. TensorShape AS, BS, CS;
  57. if (param.transposeA)
  58. AS = TensorShape{k, m};
  59. else
  60. AS = TensorShape{m, k};
  61. if (param.transposeB)
  62. BS = TensorShape{n, k};
  63. else
  64. BS = TensorShape{k, n};
  65. CS = TensorShape{m, n};
  66. TensorLayout AL, BL, CL;
  67. AL = TensorLayout(AS, dtype::Float32());
  68. BL = TensorLayout(BS, dtype::Float32());
  69. CL = TensorLayout(CS, dtype::Float32());
  70. checker.set_param(param);
  71. checker.execl({AL, BL, CL});
  72. }
  73. }
  74. TEST_F(FALLBACK, MATRIX_MUL_NAIVE_MK4) {
  75. matrix_mul::check_matrix_mul(
  76. dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(), "FB_NAIVE",
  77. param::MatrixMul::Format::MK4, 1);
  78. }
  79. TEST_F(FALLBACK, MATRIX_MUL_NAIVE_MK8) {
  80. matrix_mul::check_matrix_mul(
  81. dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(), "FB_NAIVE",
  82. param::MatrixMul::Format::MK8, 1);
  83. }
  84. TEST_F(FALLBACK, MATRIX_MUL_NAIVE_MK4_DOT) {
  85. matrix_mul::check_matrix_mul(
  86. dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(), "FB_NAIVE",
  87. param::MatrixMul::Format::MK4_DOT, 1);
  88. }
  89. TEST_F(FALLBACK, MATRIX_MUL_NAIVE) {
  90. Checker<MatrixMul> checker(handle());
  91. checker.set_before_exec_callback(AlgoChecker<MatrixMul>("FB_NAIVE"));
  92. using Param = MatrixMul::Param;
  93. auto args = matrix_mul::get_matmul_args();
  94. for (auto arg : args) {
  95. auto m = arg.m, n = arg.n, k = arg.k;
  96. auto mask = arg.mask;
  97. Param param;
  98. param.transposeA = mask & 1;
  99. param.transposeB = mask & 2;
  100. TensorShape AS, BS, CS;
  101. if (param.transposeA)
  102. AS = TensorShape{k, m};
  103. else
  104. AS = TensorShape{m, k};
  105. if (param.transposeB)
  106. BS = TensorShape{n, k};
  107. else
  108. BS = TensorShape{k, n};
  109. CS = TensorShape{m, n};
  110. TensorLayout AL, BL, CL;
  111. AL = TensorLayout(AS, dtype::Float32());
  112. BL = TensorLayout(BS, dtype::Float32());
  113. CL = TensorLayout(CS, dtype::Float32());
  114. checker.set_param(param);
  115. checker.execl({AL, BL, CL});
  116. }
  117. }
  118. TEST_F(FALLBACK, BATCHED_MATRIX_MUL) {
  119. Checker<BatchedMatrixMul> checker(handle());
  120. using Param = MatrixMul::Param;
  121. auto args = matrix_mul::get_batched_matmul_args();
  122. for (auto arg : args) {
  123. auto b = arg.b, m = arg.m, n = arg.n, k = arg.k;
  124. auto mask = arg.mask;
  125. Param param;
  126. param.transposeA = mask & 1;
  127. param.transposeB = mask & 2;
  128. TensorShape AS, BS, CS;
  129. if (param.transposeA)
  130. AS = TensorShape{b, k, m};
  131. else
  132. AS = TensorShape{b, m, k};
  133. if (param.transposeB)
  134. BS = TensorShape{b, n, k};
  135. else
  136. BS = TensorShape{b, k, n};
  137. TensorLayout AL, BL;
  138. AL = TensorLayout(AS, dtype::Float32());
  139. BL = TensorLayout(BS, dtype::Float32());
  140. checker.set_param(param);
  141. checker.execs({AL, BL, {}});
  142. }
  143. }
  144. } // namespace test
  145. } // namespace megdnn
  146. // vim: syntax=cpp.doxygen