You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

matrix_mul.cpp 5.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. /**
  2. * \file dnn/test/fallback/matrix_mul.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "test/common/matrix_mul.h"
  12. #include "test/common/checker.h"
  13. #include "test/common/rng.h"
  14. #include "test/common/task_record_check.h"
  15. #include "test/fallback/fixture.h"
  16. namespace megdnn {
  17. namespace test {
  18. TEST_F(FALLBACK, MATRIX_MUL) {
  19. Checker<MatrixMul> checker(handle());
  20. using Param = MatrixMul::Param;
  21. auto args = matrix_mul::get_matmul_args();
  22. for (auto arg : args) {
  23. auto m = arg.m, n = arg.n, k = arg.k;
  24. auto mask = arg.mask;
  25. Param param;
  26. param.transposeA = mask & 1;
  27. param.transposeB = mask & 2;
  28. TensorShape AS, BS, CS;
  29. if (param.transposeA)
  30. AS = TensorShape{k, m};
  31. else
  32. AS = TensorShape{m, k};
  33. if (param.transposeB)
  34. BS = TensorShape{n, k};
  35. else
  36. BS = TensorShape{k, n};
  37. CS = TensorShape{m, n};
  38. TensorLayout AL, BL, CL;
  39. AL = TensorLayout(AS, dtype::Float32());
  40. BL = TensorLayout(BS, dtype::Float32());
  41. CL = TensorLayout(CS, dtype::Float32());
  42. checker.set_param(param);
  43. checker.execl({AL, BL, CL});
  44. }
  45. }
  46. TEST_F(FALLBACK, MATRIX_MUL_MK4_GI) {
  47. matrix_mul::check_matrix_mul(
  48. dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(),
  49. "FB_GI_F32_MK4_4x8", param::MatrixMul::Format::MK4, 1);
  50. }
  51. TEST_F(FALLBACK, MATRIX_MUL_RECORD) {
  52. TaskRecordChecker<MatrixMul> checker(1);
  53. using Param = MatrixMul::Param;
  54. auto args = matrix_mul::get_matmul_args();
  55. for (auto arg : args) {
  56. auto m = arg.m, n = arg.n, k = arg.k;
  57. auto mask = arg.mask;
  58. Param param;
  59. param.transposeA = mask & 1;
  60. param.transposeB = mask & 2;
  61. TensorShape AS, BS, CS;
  62. if (param.transposeA)
  63. AS = TensorShape{k, m};
  64. else
  65. AS = TensorShape{m, k};
  66. if (param.transposeB)
  67. BS = TensorShape{n, k};
  68. else
  69. BS = TensorShape{k, n};
  70. CS = TensorShape{m, n};
  71. TensorLayout AL, BL, CL;
  72. AL = TensorLayout(AS, dtype::Float32());
  73. BL = TensorLayout(BS, dtype::Float32());
  74. CL = TensorLayout(CS, dtype::Float32());
  75. checker.set_param(param);
  76. checker.execl({AL, BL, CL});
  77. }
  78. }
  79. TEST_F(FALLBACK, MATRIX_MUL_NAIVE_MK4) {
  80. matrix_mul::check_matrix_mul(
  81. dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(), "FB_NAIVE",
  82. param::MatrixMul::Format::MK4, 1);
  83. }
  84. TEST_F(FALLBACK, MATRIX_MUL_NAIVE_MK8) {
  85. matrix_mul::check_matrix_mul(
  86. dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(), "FB_NAIVE",
  87. param::MatrixMul::Format::MK8, 1);
  88. }
  89. TEST_F(FALLBACK, MATRIX_MUL_NAIVE_MK4_DOT) {
  90. matrix_mul::check_matrix_mul(
  91. dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(), "FB_NAIVE",
  92. param::MatrixMul::Format::MK4_DOT, 1);
  93. }
  94. TEST_F(FALLBACK, MATRIX_MUL_NAIVE) {
  95. Checker<MatrixMul> checker(handle());
  96. checker.set_before_exec_callback(AlgoChecker<MatrixMul>("FB_NAIVE"));
  97. using Param = MatrixMul::Param;
  98. auto args = matrix_mul::get_matmul_args();
  99. for (auto arg : args) {
  100. auto m = arg.m, n = arg.n, k = arg.k;
  101. auto mask = arg.mask;
  102. Param param;
  103. param.transposeA = mask & 1;
  104. param.transposeB = mask & 2;
  105. TensorShape AS, BS, CS;
  106. if (param.transposeA)
  107. AS = TensorShape{k, m};
  108. else
  109. AS = TensorShape{m, k};
  110. if (param.transposeB)
  111. BS = TensorShape{n, k};
  112. else
  113. BS = TensorShape{k, n};
  114. CS = TensorShape{m, n};
  115. TensorLayout AL, BL, CL;
  116. AL = TensorLayout(AS, dtype::Float32());
  117. BL = TensorLayout(BS, dtype::Float32());
  118. CL = TensorLayout(CS, dtype::Float32());
  119. checker.set_param(param);
  120. checker.execl({AL, BL, CL});
  121. }
  122. }
  123. TEST_F(FALLBACK, BATCHED_MATRIX_MUL) {
  124. Checker<BatchedMatrixMul> checker(handle());
  125. using Param = MatrixMul::Param;
  126. auto args = matrix_mul::get_batched_matmul_args();
  127. for (auto arg : args) {
  128. auto b = arg.b, m = arg.m, n = arg.n, k = arg.k;
  129. auto mask = arg.mask;
  130. Param param;
  131. param.transposeA = mask & 1;
  132. param.transposeB = mask & 2;
  133. TensorShape AS, BS, CS;
  134. if (param.transposeA)
  135. AS = TensorShape{b, k, m};
  136. else
  137. AS = TensorShape{b, m, k};
  138. if (param.transposeB)
  139. BS = TensorShape{b, n, k};
  140. else
  141. BS = TensorShape{b, k, n};
  142. TensorLayout AL, BL;
  143. AL = TensorLayout(AS, dtype::Float32());
  144. BL = TensorLayout(BS, dtype::Float32());
  145. checker.set_param(param);
  146. checker.execs({AL, BL, {}});
  147. }
  148. }
  149. } // namespace test
  150. } // namespace megdnn
  151. // vim: syntax=cpp.doxygen