You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

matrix_mul.cpp 5.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. /**
  2. * \file dnn/test/fallback/matrix_mul.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "test/common/matrix_mul.h"
  12. #include "test/common/checker.h"
  13. #include "test/common/rng.h"
  14. #include "test/common/task_record_check.h"
  15. #include "test/fallback/fixture.h"
  16. namespace megdnn {
  17. namespace test {
  18. TEST_F(FALLBACK, MATRIX_MUL) {
  19. Checker<MatrixMul> checker(handle());
  20. using Param = MatrixMul::Param;
  21. auto args = matrix_mul::get_matmul_args();
  22. for (auto arg : args) {
  23. auto m = arg.m, n = arg.n, k = arg.k;
  24. auto mask = arg.mask;
  25. Param param;
  26. param.transposeA = mask & 1;
  27. param.transposeB = mask & 2;
  28. TensorShape AS, BS, CS;
  29. if (param.transposeA)
  30. AS = TensorShape{k, m};
  31. else
  32. AS = TensorShape{m, k};
  33. if (param.transposeB)
  34. BS = TensorShape{n, k};
  35. else
  36. BS = TensorShape{k, n};
  37. CS = TensorShape{m, n};
  38. TensorLayout AL, BL, CL;
  39. AL = TensorLayout(AS, dtype::Float32());
  40. BL = TensorLayout(BS, dtype::Float32());
  41. CL = TensorLayout(CS, dtype::Float32());
  42. checker.set_param(param);
  43. checker.execl({AL, BL, CL});
  44. }
  45. }
  46. TEST_F(FALLBACK, MATRIX_MUL_RECORD) {
  47. TaskRecordChecker<MatrixMul> checker(1);
  48. using Param = MatrixMul::Param;
  49. auto args = matrix_mul::get_matmul_args();
  50. for (auto arg : args) {
  51. auto m = arg.m, n = arg.n, k = arg.k;
  52. auto mask = arg.mask;
  53. Param param;
  54. param.transposeA = mask & 1;
  55. param.transposeB = mask & 2;
  56. TensorShape AS, BS, CS;
  57. if (param.transposeA)
  58. AS = TensorShape{k, m};
  59. else
  60. AS = TensorShape{m, k};
  61. if (param.transposeB)
  62. BS = TensorShape{n, k};
  63. else
  64. BS = TensorShape{k, n};
  65. CS = TensorShape{m, n};
  66. TensorLayout AL, BL, CL;
  67. AL = TensorLayout(AS, dtype::Float32());
  68. BL = TensorLayout(BS, dtype::Float32());
  69. CL = TensorLayout(CS, dtype::Float32());
  70. checker.set_param(param);
  71. checker.execl({AL, BL, CL});
  72. }
  73. }
  74. TEST_F(FALLBACK, MATRIX_MUL_NAIVE_MK4) {
  75. matrix_mul::check_matrix_mul(
  76. dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(), "FB_NAIVE",
  77. param::MatrixMul::Format::MK4, 1);
  78. }
  79. TEST_F(FALLBACK, MATRIX_MUL_NAIVE_MK8) {
  80. matrix_mul::check_matrix_mul(
  81. dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(), "FB_NAIVE",
  82. param::MatrixMul::Format::MK8, 1);
  83. }
  84. TEST_F(FALLBACK, MATRIX_MUL_NAIVE_MK4_DOT) {
  85. matrix_mul::check_matrix_mul(
  86. dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(), "FB_NAIVE",
  87. param::MatrixMul::Format::MK4_DOT, 1);
  88. }
  89. TEST_F(FALLBACK, MATRIX_MUL_NAIVE) {
  90. Checker<MatrixMul> checker(handle());
  91. checker.set_before_exec_callback(AlgoChecker<MatrixMul>("FB_NAIVE"));
  92. using Param = MatrixMul::Param;
  93. auto args = matrix_mul::get_matmul_args();
  94. for (auto arg : args) {
  95. auto m = arg.m, n = arg.n, k = arg.k;
  96. auto mask = arg.mask;
  97. Param param;
  98. param.transposeA = mask & 1;
  99. param.transposeB = mask & 2;
  100. TensorShape AS, BS, CS;
  101. if (param.transposeA)
  102. AS = TensorShape{k, m};
  103. else
  104. AS = TensorShape{m, k};
  105. if (param.transposeB)
  106. BS = TensorShape{n, k};
  107. else
  108. BS = TensorShape{k, n};
  109. CS = TensorShape{m, n};
  110. TensorLayout AL, BL, CL;
  111. AL = TensorLayout(AS, dtype::Float32());
  112. BL = TensorLayout(BS, dtype::Float32());
  113. CL = TensorLayout(CS, dtype::Float32());
  114. checker.set_param(param);
  115. checker.execl({AL, BL, CL});
  116. }
  117. }
  118. TEST_F(FALLBACK, BATCHED_MATRIX_MUL) {
  119. Checker<BatchedMatrixMul> checker(handle());
  120. using Param = MatrixMul::Param;
  121. auto args = matrix_mul::get_batched_matmul_args();
  122. for (auto arg : args) {
  123. auto b = arg.b, m = arg.m, n = arg.n, k = arg.k;
  124. auto mask = arg.mask;
  125. Param param;
  126. param.transposeA = mask & 1;
  127. param.transposeB = mask & 2;
  128. TensorShape AS, BS, CS;
  129. if (param.transposeA)
  130. AS = TensorShape{b, k, m};
  131. else
  132. AS = TensorShape{b, m, k};
  133. if (param.transposeB)
  134. BS = TensorShape{b, n, k};
  135. else
  136. BS = TensorShape{b, k, n};
  137. TensorLayout AL, BL;
  138. AL = TensorLayout(AS, dtype::Float32());
  139. BL = TensorLayout(BS, dtype::Float32());
  140. checker.set_param(param);
  141. checker.execs({AL, BL, {}});
  142. }
  143. }
  144. } // namespace test
  145. } // namespace megdnn
  146. // vim: syntax=cpp.doxygen