You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

matrix_mul.cpp 8.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219
  1. #include "megdnn/oprs.h"
  2. #include "src/common/utils.h"
  3. namespace megdnn {
  4. void MatrixMulForward::deduce_dtype(DType A, DType B, DType& C) {
  5. // Expect that the user specifies output dtype (C), we then do sanity
  6. // check on the dtype supplied by the user. C_dtype and C_dtype2 are the
  7. // expected dtypes. If the user does not specify an output dtype by setting
  8. // C = {}, we deduce one (C_dtype) and return it to the user.
  9. DType C_candi, C_candi2;
  10. if (A.category() == DTypeCategory::FLOAT) {
  11. C_candi = A;
  12. } else if (A.enumv() == DTypeEnum::Int8) {
  13. C_candi = dtype::Int32();
  14. C_candi2 = dtype::Int16();
  15. } else if (A.enumv() == DTypeEnum::Int16) {
  16. C_candi = dtype::Int32();
  17. } else if (A.enumv() == DTypeEnum::QuantizedS8) {
  18. C_candi = dtype::QuantizedS32(mul_scale(A, B));
  19. } else if (A.enumv() == DTypeEnum::Quantized8Asymm) {
  20. C_candi = dtype::QuantizedS32(mul_scale(A, B));
  21. } else if (A.enumv() == DTypeEnum::Quantized4Asymm) {
  22. C_candi = dtype::QuantizedS32(mul_scale(A, B));
  23. } else if (A.enumv() == DTypeEnum::QuantizedS4) {
  24. C_candi = dtype::QuantizedS16(mul_scale(A, B));
  25. }
  26. if (!C.valid()) {
  27. C = C_candi;
  28. }
  29. megdnn_assert(
  30. C.valid() && (C == C_candi || C == C_candi2),
  31. "unsupported MatMul(%s, %s) -> %s", A.name(), B.name(), C.name());
  32. }
  33. void MatrixMulForward::deduce_layout(
  34. const TensorLayout& A, const TensorLayout& B, TensorLayout& C) {
  35. megdnn_assert(
  36. A.dtype.enumv() == B.dtype.enumv(),
  37. "matmul input should be of same dtype, got %s and %s", A.dtype.name(),
  38. B.dtype.name());
  39. deduce_dtype(A.dtype, B.dtype, C.dtype);
  40. size_t A0, A1, B0, B1;
  41. if (param().format == param::MatrixMul::Format::DEFAULT) {
  42. megdnn_assert(
  43. A.ndim == 2 && B.ndim == 2,
  44. "matmul requires input to be 2-dimensional; get: %s %s",
  45. A.TensorShape::to_string().c_str(), B.TensorShape::to_string().c_str());
  46. A0 = A.shape[0];
  47. A1 = A.shape[1];
  48. B0 = B.shape[0];
  49. B1 = B.shape[1];
  50. if (m_param.transposeA)
  51. std::swap(A0, A1);
  52. if (m_param.transposeB)
  53. std::swap(B0, B1);
  54. megdnn_assert(
  55. A1 == B0,
  56. "shape mismatch in matmal: (transposed) A is (%zu,%zu), "
  57. "(transposed) B is (%zu,%zu)",
  58. A0, A1, B0, B1);
  59. C = TensorLayout(TensorShape({A0, B1}), C.dtype);
  60. } else if (param().format == param::MatrixMul::Format::N32K4_DOT) {
  61. A0 = A.shape[0];
  62. A1 = A.shape[1];
  63. B0 = B.shape[0];
  64. B1 = B.shape[1];
  65. megdnn_assert(!m_param.transposeA && !m_param.transposeB);
  66. megdnn_assert(A0 == 1 && A1 % 4 == 0);
  67. megdnn_assert(B.ndim == 4);
  68. C = TensorLayout(TensorShape({A0, B0 * 32}), C.dtype);
  69. } else {
  70. auto do_deduce = [&](size_t pack_size) {
  71. megdnn_assert(
  72. A.ndim == 4 && B.ndim == 3,
  73. "matmul requires input dimension to be A(4), B(3); "
  74. "get: %s %s",
  75. A.TensorShape::to_string().c_str(),
  76. B.TensorShape::to_string().c_str());
  77. A0 = A.shape[0];
  78. A1 = A.shape[1];
  79. B0 = B.shape[0];
  80. B1 = B.shape[1];
  81. if (m_param.transposeA)
  82. std::swap(A0, A1);
  83. if (m_param.transposeB)
  84. std::swap(B0, B1);
  85. megdnn_assert(
  86. A1 == B0,
  87. "shape mismatch in matmal: (transposed) A is "
  88. "(%zu,%zu,4,4), "
  89. "(transposed) B is (%zu,%zu,4)",
  90. A0, A1, B0, B1);
  91. C = TensorLayout(TensorShape({A0, B1, pack_size}), C.dtype);
  92. };
  93. do_deduce(pack_size(param().format));
  94. }
  95. }
  96. void MatrixMulForward::check_exec(
  97. const TensorLayout& A, const TensorLayout& B, const TensorLayout& C,
  98. size_t workspace_in_bytes) {
  99. auto errmsg = [&]() {
  100. std::string msg;
  101. msg.append("A=");
  102. msg.append(A.to_string());
  103. msg.append(", B=");
  104. msg.append(B.to_string());
  105. msg.append(", C=");
  106. msg.append(C.to_string());
  107. msg.append(", transposeA=");
  108. msg.append(std::to_string(param().transposeA));
  109. msg.append(", transposeB=");
  110. msg.append(std::to_string(param().transposeB));
  111. return msg;
  112. };
  113. MEGDNN_MARK_USED_VAR(errmsg);
  114. if (param().format == param::MatrixMul::Format::DEFAULT) {
  115. megdnn_assert_eq_size_t(A.ndim, 2_z);
  116. megdnn_assert_eq_size_t(B.ndim, 2_z);
  117. megdnn_assert_eq_size_t(C.ndim, 2_z);
  118. megdnn_assert(A.stride[1] == 1);
  119. megdnn_assert(A.stride[0] >= static_cast<ptrdiff_t>(A.shape[1]));
  120. megdnn_assert(B.stride[1] == 1);
  121. megdnn_assert(B.stride[0] >= static_cast<ptrdiff_t>(B.shape[1]));
  122. megdnn_assert(C.stride[1] == 1);
  123. megdnn_assert(C.stride[0] >= static_cast<ptrdiff_t>(C.shape[1]));
  124. size_t A0, A1, B0, B1, C0, C1;
  125. A0 = A.shape[0];
  126. A1 = A.shape[1];
  127. B0 = B.shape[0];
  128. B1 = B.shape[1];
  129. C0 = C.shape[0];
  130. C1 = C.shape[1];
  131. if (m_param.transposeA)
  132. std::swap(A0, A1);
  133. if (m_param.transposeB)
  134. std::swap(B0, B1);
  135. megdnn_assert(A0 == C0, "%s", errmsg().c_str());
  136. megdnn_assert(B1 == C1, "%s", errmsg().c_str());
  137. megdnn_assert(A1 == B0, "%s", errmsg().c_str());
  138. } else if (param().format == param::MatrixMul::Format::N32K4_DOT) {
  139. size_t A0 = A.shape[0];
  140. size_t A1 = A.shape[1];
  141. size_t B2 = B.shape[2];
  142. size_t B3 = B.shape[3];
  143. megdnn_assert(!m_param.transposeA && !m_param.transposeB);
  144. megdnn_assert(A0 == 1 && A1 % 4 == 0);
  145. megdnn_assert(B.ndim == 4);
  146. megdnn_assert(B2 == 32 && B3 == 4);
  147. megdnn_assert_contiguous(A);
  148. megdnn_assert_contiguous(B);
  149. megdnn_assert_contiguous(C);
  150. } else {
  151. megdnn_assert_eq_size_t(A.ndim, 4_z);
  152. megdnn_assert_eq_size_t(B.ndim, 3_z);
  153. megdnn_assert_eq_size_t(C.ndim, 3_z);
  154. megdnn_assert_contiguous(A);
  155. megdnn_assert_contiguous(B);
  156. megdnn_assert_contiguous(C);
  157. size_t A0, A1, B0, B1, C0, C1;
  158. A0 = A.shape[0];
  159. A1 = A.shape[1];
  160. B0 = B.shape[0];
  161. B1 = B.shape[1];
  162. C0 = C.shape[0];
  163. C1 = C.shape[1];
  164. if (m_param.transposeA)
  165. std::swap(A0, A1);
  166. if (m_param.transposeB)
  167. std::swap(B0, B1);
  168. megdnn_assert(A0 == C0, "%s", errmsg().c_str());
  169. megdnn_assert(B1 == C1, "%s", errmsg().c_str());
  170. megdnn_assert(A1 == B0, "%s", errmsg().c_str());
  171. }
  172. megdnn_assert(A.dtype.enumv() == B.dtype.enumv());
  173. if (A.dtype.category() == DTypeCategory::FLOAT) {
  174. megdnn_assert(A.dtype == C.dtype);
  175. } else if (A.dtype == dtype::Int8()) {
  176. megdnn_assert(C.dtype == dtype::Int16() || C.dtype == dtype::Int32());
  177. } else if (
  178. A.dtype.enumv() == DTypeEnum::QuantizedS8 ||
  179. A.dtype.enumv() == DTypeEnum::Quantized8Asymm ||
  180. A.dtype.enumv() == DTypeEnum::Quantized4Asymm) {
  181. megdnn_assert(C.dtype.enumv() == DTypeEnum::QuantizedS32);
  182. } else if (A.dtype.enumv() == DTypeEnum::QuantizedS4) {
  183. megdnn_assert(C.dtype.enumv() == DTypeEnum::QuantizedS16);
  184. }
  185. megdnn_assert(
  186. param().compute_mode != Param::ComputeMode::FLOAT32 DNN_INC_FLOAT16(
  187. || A.dtype == dtype::Float16() ||
  188. A.dtype == dtype::BFloat16()),
  189. "ComputeMode::FLOAT32 is only available for Float16/BFloat16 "
  190. "input / output.");
  191. auto required_workspace_in_bytes = get_workspace_in_bytes(A, B, C);
  192. megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
  193. }
  194. size_t MatrixMulForward::pack_size(const Param::Format format) {
  195. switch (format) {
  196. case Param::Format::DEFAULT:
  197. return 1;
  198. case Param::Format::MK4:
  199. return 4;
  200. case Param::Format::MK4_DOT:
  201. return 4;
  202. case Param::Format::MK8:
  203. return 8;
  204. default:
  205. megdnn_throw("Unknown matmul format.");
  206. }
  207. }
  208. } // namespace megdnn
  209. // vim: syntax=cpp.doxygen