#include "megdnn/oprs.h" #include "src/common/utils.h" namespace megdnn { void BatchedMatrixMulForward::deduce_dtype(DType A, DType B, DType& C) { DType C_candi, C_candi2; if (A.category() == DTypeCategory::FLOAT) { C_candi = A; } else if (A.enumv() == DTypeEnum::Int8) { C_candi = dtype::Int32(); C_candi2 = dtype::Int16(); } else if (A.enumv() == DTypeEnum::QuantizedS8) { C_candi = dtype::QuantizedS32(mul_scale(A, B)); } else if (A.enumv() == DTypeEnum::Quantized8Asymm) { C_candi = dtype::QuantizedS32(mul_scale(A, B)); } else if (A.enumv() == DTypeEnum::Quantized4Asymm) { C_candi = dtype::QuantizedS32(mul_scale(A, B)); } if (!C.valid()) { C = C_candi; } megdnn_assert( C.valid() && (C == C_candi || C == C_candi2), "runtime does not support BatchedMatMul(%s, %s) -> %s\n" "now support case list: BatchedMatMul(FLOAT, FLOAT)\n" " BatchedMatMul(Int8, Int8)\n" " BatchedMatMul(QuantizedS8, QuantizedS8)\n" " BatchedMatMul(Quantized8Asymm, Quantized8Asymm)\n" " BatchedMatMul(Quantized4Asymm, Quantized4Asymm)\n", A.name(), B.name(), C.name()); } void BatchedMatrixMulForward::deduce_layout( const TensorLayout& A, const TensorLayout& B, TensorLayout& C) { auto errmsg = [&]() { std::string msg; msg.append("A="); msg.append(A.to_string()); msg.append(", B="); msg.append(B.to_string()); msg.append(", C="); msg.append(C.to_string()); msg.append(", transposeA="); msg.append(std::to_string(m_param.transposeA)); msg.append(", transposeB="); msg.append(std::to_string(m_param.transposeB)); return msg; }; MEGDNN_MARK_USED_VAR(errmsg); auto good_layout = [](const TensorLayout& l) { // l.stride[0] == 0 because im2col conv need batched matrixmul and // filter tensor need to be broadcasted. It's only implemented in // opencl. return l.ndim == 3 && l.stride[2] == 1 && l.stride[1] >= static_cast(l.shape[2]) && (l.shape[0] == 1 || l.stride[0] >= static_cast(l.shape[1]) * l.stride[1] || l.stride[0] == 0); }; size_t A0, A1, B0, B1; A0 = A.shape[1]; A1 = A.shape[2]; B0 = B.shape[1]; B1 = B.shape[2]; if (m_param.transposeA) std::swap(A0, A1); if (m_param.transposeB) std::swap(B0, B1); deduce_dtype(A.dtype, B.dtype, C.dtype); megdnn_assert( good_layout(A) && good_layout(B) && A1 == B0 && A[0] == B[0] && A.dtype.enumv() == B.dtype.enumv(), "bad input layouts: %s", errmsg().c_str()); C = TensorLayout(TensorShape({A[0], A0, B1}), C.dtype); } void BatchedMatrixMulForward::check_exec( const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, size_t workspace_in_bytes) { TensorLayout C_expect; deduce_layout(A, B, C_expect); megdnn_assert( C_expect.eq_layout(C), "bad layout for C: expect=%s got=%s", C_expect.to_string().c_str(), C.to_string().c_str()); auto required_workspace_in_bytes = get_workspace_in_bytes(A, B, C); megdnn_assert( workspace_in_bytes >= required_workspace_in_bytes, "needed workspace: %zu; got: %zu", required_workspace_in_bytes, workspace_in_bytes); } } // namespace megdnn // vim: syntax=cpp.doxygen