|
@@ -186,6 +186,15 @@ ValueRefList convolution_rule(const OpDef& op, Span<ValueRef> inputs) { |
|
|
ValueRefList matmul_rule(const OpDef& op, Span<ValueRef> inputs) { |
|
|
ValueRefList matmul_rule(const OpDef& op, Span<ValueRef> inputs) { |
|
|
auto&& conv_op = const_cast<MatrixMul&>(op.cast_final_safe<MatrixMul>()); |
|
|
auto&& conv_op = const_cast<MatrixMul&>(op.cast_final_safe<MatrixMul>()); |
|
|
SmallVector<DType> dtypes = get_value_dtypes(inputs); |
|
|
SmallVector<DType> dtypes = get_value_dtypes(inputs); |
|
|
|
|
|
|
|
|
|
|
|
// skip dtype promotion when inputs are quantized |
|
|
|
|
|
if (dtypes[0].category() == megdnn::DTypeCategory::QUANTIZED) { |
|
|
|
|
|
mgb_assert( |
|
|
|
|
|
dtypes[0].category() == dtypes[1].category(), |
|
|
|
|
|
"inputs of matmul should have same quantized dtype."); |
|
|
|
|
|
return imperative::apply(op, inputs); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
mgb::DType target_dtype; |
|
|
mgb::DType target_dtype; |
|
|
|
|
|
|
|
|
if (DTypePromoteCfg::amp_dtype_autocast_enabled) { |
|
|
if (DTypePromoteCfg::amp_dtype_autocast_enabled) { |
|
@@ -212,6 +221,15 @@ ValueRefList batch_matmul_rule(const OpDef& op, Span<ValueRef> inputs) { |
|
|
auto&& conv_op = |
|
|
auto&& conv_op = |
|
|
const_cast<BatchedMatrixMul&>(op.cast_final_safe<BatchedMatrixMul>()); |
|
|
const_cast<BatchedMatrixMul&>(op.cast_final_safe<BatchedMatrixMul>()); |
|
|
SmallVector<DType> dtypes = get_value_dtypes(inputs); |
|
|
SmallVector<DType> dtypes = get_value_dtypes(inputs); |
|
|
|
|
|
|
|
|
|
|
|
// skip dtype promotion when inputs are quantized |
|
|
|
|
|
if (dtypes[0].category() == megdnn::DTypeCategory::QUANTIZED) { |
|
|
|
|
|
mgb_assert( |
|
|
|
|
|
dtypes[0].category() == dtypes[1].category(), |
|
|
|
|
|
"inputs of batched matmul should have same quantized dtype."); |
|
|
|
|
|
return imperative::apply(op, inputs); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
mgb::DType target_dtype; |
|
|
mgb::DType target_dtype; |
|
|
|
|
|
|
|
|
if (DTypePromoteCfg::amp_dtype_autocast_enabled) { |
|
|
if (DTypePromoteCfg::amp_dtype_autocast_enabled) { |
|
|