|
|
@@ -22,7 +22,13 @@ void BatchedMatrixMulForward::deduce_dtype(DType A, DType B, DType& C) { |
|
|
|
} |
|
|
|
megdnn_assert( |
|
|
|
C.valid() && (C == C_candi || C == C_candi2), |
|
|
|
"unsupported BatchedMatMul(%s, %s) -> %s", A.name(), B.name(), C.name()); |
|
|
|
"runtime does not support BatchedMatMul(%s, %s) -> %s\n" |
|
|
|
"now support case list: BatchedMatMul(FLOAT, FLOAT)\n" |
|
|
|
" BatchedMatMul(Int8, Int8)\n" |
|
|
|
" BatchedMatMul(QuantizedS8, QuantizedS8)\n" |
|
|
|
" BatchedMatMul(Quantized8Asymm, Quantized8Asymm)\n" |
|
|
|
" BatchedMatMul(Quantized4Asymm, Quantized4Asymm)\n", |
|
|
|
A.name(), B.name(), C.name()); |
|
|
|
} |
|
|
|
void BatchedMatrixMulForward::deduce_layout( |
|
|
|
const TensorLayout& A, const TensorLayout& B, TensorLayout& C) { |
|
|
|