You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

batched_matrix_mul.cpp 3.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. #include "megdnn/oprs.h"
  2. #include "src/common/utils.h"
  3. namespace megdnn {
  4. void BatchedMatrixMulForward::deduce_dtype(DType A, DType B, DType& C) {
  5. DType C_candi, C_candi2;
  6. if (A.category() == DTypeCategory::FLOAT) {
  7. C_candi = A;
  8. } else if (A.enumv() == DTypeEnum::Int8) {
  9. C_candi = dtype::Int32();
  10. C_candi2 = dtype::Int16();
  11. } else if (A.enumv() == DTypeEnum::QuantizedS8) {
  12. C_candi = dtype::QuantizedS32(mul_scale(A, B));
  13. } else if (A.enumv() == DTypeEnum::Quantized8Asymm) {
  14. C_candi = dtype::QuantizedS32(mul_scale(A, B));
  15. } else if (A.enumv() == DTypeEnum::Quantized4Asymm) {
  16. C_candi = dtype::QuantizedS32(mul_scale(A, B));
  17. }
  18. if (!C.valid()) {
  19. C = C_candi;
  20. }
  21. megdnn_assert(
  22. C.valid() && (C == C_candi || C == C_candi2),
  23. "runtime does not support BatchedMatMul(%s, %s) -> %s\n"
  24. "now support case list: BatchedMatMul(FLOAT, FLOAT)\n"
  25. " BatchedMatMul(Int8, Int8)\n"
  26. " BatchedMatMul(QuantizedS8, QuantizedS8)\n"
  27. " BatchedMatMul(Quantized8Asymm, Quantized8Asymm)\n"
  28. " BatchedMatMul(Quantized4Asymm, Quantized4Asymm)\n",
  29. A.name(), B.name(), C.name());
  30. }
  31. void BatchedMatrixMulForward::deduce_layout(
  32. const TensorLayout& A, const TensorLayout& B, TensorLayout& C) {
  33. auto errmsg = [&]() {
  34. std::string msg;
  35. msg.append("A=");
  36. msg.append(A.to_string());
  37. msg.append(", B=");
  38. msg.append(B.to_string());
  39. msg.append(", C=");
  40. msg.append(C.to_string());
  41. msg.append(", transposeA=");
  42. msg.append(std::to_string(m_param.transposeA));
  43. msg.append(", transposeB=");
  44. msg.append(std::to_string(m_param.transposeB));
  45. return msg;
  46. };
  47. MEGDNN_MARK_USED_VAR(errmsg);
  48. auto good_layout = [](const TensorLayout& l) {
  49. // l.stride[0] == 0 because im2col conv need batched matrixmul and
  50. // filter tensor need to be broadcasted. It's only implemented in
  51. // opencl.
  52. return l.ndim == 3 && l.stride[2] == 1 &&
  53. l.stride[1] >= static_cast<ptrdiff_t>(l.shape[2]) &&
  54. (l.shape[0] == 1 ||
  55. l.stride[0] >= static_cast<ptrdiff_t>(l.shape[1]) * l.stride[1] ||
  56. l.stride[0] == 0);
  57. };
  58. size_t A0, A1, B0, B1;
  59. A0 = A.shape[1];
  60. A1 = A.shape[2];
  61. B0 = B.shape[1];
  62. B1 = B.shape[2];
  63. if (m_param.transposeA)
  64. std::swap(A0, A1);
  65. if (m_param.transposeB)
  66. std::swap(B0, B1);
  67. deduce_dtype(A.dtype, B.dtype, C.dtype);
  68. megdnn_assert(
  69. good_layout(A) && good_layout(B) && A1 == B0 && A[0] == B[0] &&
  70. A.dtype.enumv() == B.dtype.enumv(),
  71. "bad input layouts: %s", errmsg().c_str());
  72. C = TensorLayout(TensorShape({A[0], A0, B1}), C.dtype);
  73. }
  74. void BatchedMatrixMulForward::check_exec(
  75. const TensorLayout& A, const TensorLayout& B, const TensorLayout& C,
  76. size_t workspace_in_bytes) {
  77. TensorLayout C_expect;
  78. deduce_layout(A, B, C_expect);
  79. megdnn_assert(
  80. C_expect.eq_layout(C), "bad layout for C: expect=%s got=%s",
  81. C_expect.to_string().c_str(), C.to_string().c_str());
  82. auto required_workspace_in_bytes = get_workspace_in_bytes(A, B, C);
  83. megdnn_assert(
  84. workspace_in_bytes >= required_workspace_in_bytes,
  85. "needed workspace: %zu; got: %zu", required_workspace_in_bytes,
  86. workspace_in_bytes);
  87. }
  88. } // namespace megdnn
  89. // vim: syntax=cpp.doxygen