GitOrigin-RevId: 96b922dd20
HuaHua404-patch-4
@@ -22,7 +22,13 @@ void BatchedMatrixMulForward::deduce_dtype(DType A, DType B, DType& C) { | |||||
} | } | ||||
megdnn_assert( | megdnn_assert( | ||||
C.valid() && (C == C_candi || C == C_candi2), | C.valid() && (C == C_candi || C == C_candi2), | ||||
"unsupported BatchedMatMul(%s, %s) -> %s", A.name(), B.name(), C.name()); | |||||
"runtime does not support BatchedMatMul(%s, %s) -> %s\n" | |||||
"now support case list: BatchedMatMul(FLOAT, FLOAT)\n" | |||||
" BatchedMatMul(Int8, Int8)\n" | |||||
" BatchedMatMul(QuantizedS8, QuantizedS8)\n" | |||||
" BatchedMatMul(Quantized8Asymm, Quantized8Asymm)\n" | |||||
" BatchedMatMul(Quantized4Asymm, Quantized4Asymm)\n", | |||||
A.name(), B.name(), C.name()); | |||||
} | } | ||||
void BatchedMatrixMulForward::deduce_layout( | void BatchedMatrixMulForward::deduce_layout( | ||||
const TensorLayout& A, const TensorLayout& B, TensorLayout& C) { | const TensorLayout& A, const TensorLayout& B, TensorLayout& C) { | ||||
@@ -31,7 +31,15 @@ void MatrixMulForward::deduce_dtype(DType A, DType B, DType& C) { | |||||
} | } | ||||
megdnn_assert( | megdnn_assert( | ||||
C.valid() && (C == C_candi || C == C_candi2), | C.valid() && (C == C_candi || C == C_candi2), | ||||
"unsupported MatMul(%s, %s) -> %s", A.name(), B.name(), C.name()); | |||||
"runtime does not support MatMul(%s, %s) -> %s\n" | |||||
"now support case list: MatMul(FLOAT, FLOAT)\n" | |||||
" MatMul(Int8, Int8)\n" | |||||
" MatMul(Int16, Int16)\n" | |||||
" MatMul(QuantizedS8, QuantizedS8)\n" | |||||
" MatMul(Quantized8Asymm, Quantized8Asymm)\n" | |||||
" MatMul(Quantized4Asymm, Quantized4Asymm)\n" | |||||
" MatMul(QuantizedS4, QuantizedS4)\n", | |||||
A.name(), B.name(), C.name()); | |||||
} | } | ||||
void MatrixMulForward::deduce_layout( | void MatrixMulForward::deduce_layout( | ||||
@@ -65,7 +65,14 @@ void MatrixMulForwardImpl::AlgoNaive::exec(const ExecArgs& args) const { | |||||
#undef DISPATCH_CMODE | #undef DISPATCH_CMODE | ||||
#undef DISPATCH | #undef DISPATCH | ||||
megdnn_throw(ssprintf( | megdnn_throw(ssprintf( | ||||
"unsupported Matmul(%s, %s) -> %s with cmode = %d", | |||||
"runtime does not support MatMul(%s, %s) -> %s with cmode = %d\n" | |||||
"now support case list: MatMul(FLOAT, FLOAT)\n" | |||||
" MatMul(Int8, Int8)\n" | |||||
" MatMul(Int16, Int16)\n" | |||||
" MatMul(QuantizedS8, QuantizedS8)\n" | |||||
" MatMul(Quantized8Asymm, Quantized8Asymm)\n" | |||||
" MatMul(Quantized4Asymm, Quantized4Asymm)\n" | |||||
" MatMul(QuantizedS4, QuantizedS4)\n", | |||||
args.layout_a.dtype.name(), args.layout_b.dtype.name(), | args.layout_a.dtype.name(), args.layout_b.dtype.name(), | ||||
args.layout_c.dtype.name(), static_cast<int>(param.compute_mode))); | args.layout_c.dtype.name(), static_cast<int>(param.compute_mode))); | ||||
} | } | ||||