|
@@ -573,8 +573,15 @@ void ConvolutionBase<Parameter>::check_or_deduce_dtype_fwd( |
|
|
filter.param<dtype::QuantizedS8>().scale)); |
|
|
filter.param<dtype::QuantizedS8>().scale)); |
|
|
} else { |
|
|
} else { |
|
|
megdnn_throw(ssprintf( |
|
|
megdnn_throw(ssprintf( |
|
|
"unsupported input / filter DType: %s x %s", src.name(), |
|
|
|
|
|
filter.name())); |
|
|
|
|
|
|
|
|
"runtime does not support input / filter DType: %s x %s" |
|
|
|
|
|
"now support case list: FLOAT x FLOAT\n" |
|
|
|
|
|
" Int8 x Int8\n" |
|
|
|
|
|
" QuantizedS8 x QuantizedS8\n" |
|
|
|
|
|
" Quantized8Asymm x Quantized8Asymm\n" |
|
|
|
|
|
" QuantizedS4 x QuantizedS4\n" |
|
|
|
|
|
" Quantized4Asymm x Quantized4Asymm\n" |
|
|
|
|
|
" QuantizedS1 x QuantizedS1\n", |
|
|
|
|
|
src.name(), filter.name())); |
|
|
} |
|
|
} |
|
|
if (!dst.valid()) { |
|
|
if (!dst.valid()) { |
|
|
dst = supported_dst_dtype.at(0); |
|
|
dst = supported_dst_dtype.at(0); |
|
@@ -588,8 +595,21 @@ void ConvolutionBase<Parameter>::check_or_deduce_dtype_fwd( |
|
|
} |
|
|
} |
|
|
MEGDNN_MARK_USED_VAR(dst_supported); |
|
|
MEGDNN_MARK_USED_VAR(dst_supported); |
|
|
megdnn_assert( |
|
|
megdnn_assert( |
|
|
dst_supported, "unsupported Conv(%s, %s) -> %s", src.name(), |
|
|
|
|
|
filter.name(), dst.name()); |
|
|
|
|
|
|
|
|
dst_supported, |
|
|
|
|
|
"runtime does not support Conv(%s, %s) -> %s" |
|
|
|
|
|
"now support case list: Conv(FLOAT x FLOAT) -> FLOAT\n" |
|
|
|
|
|
" Conv(Int8 x Int8) -> Int32\n" |
|
|
|
|
|
" Conv(QuantizedS8 x QuantizedS8) -> " |
|
|
|
|
|
"QuantizedS32\n" |
|
|
|
|
|
" Conv(Quantized8Asymm x Quantized8Asymm) -> " |
|
|
|
|
|
"Quantized32Asymm\n" |
|
|
|
|
|
" Conv(QuantizedS4 x QuantizedS4) -> " |
|
|
|
|
|
"QuantizedS32\n" |
|
|
|
|
|
" Conv(Quantized4Asymm x Quantized4Asymm) -> " |
|
|
|
|
|
"Quantized32Asymm\n" |
|
|
|
|
|
" Conv(QuantizedS1 x QuantizedS1) -> " |
|
|
|
|
|
"QuantizedS32\n", |
|
|
|
|
|
src.name(), filter.name(), dst.name()); |
|
|
} |
|
|
} |
|
|
megdnn_assert( |
|
|
megdnn_assert( |
|
|
(param().compute_mode == Param::ComputeMode::FLOAT32 || |
|
|
(param().compute_mode == Param::ComputeMode::FLOAT32 || |
|
@@ -1098,15 +1118,26 @@ void ConvolutionBackwardData::deduce_dtype(DType filter, DType diff, DType& grad |
|
|
} |
|
|
} |
|
|
} else { |
|
|
} else { |
|
|
megdnn_throw(ssprintf( |
|
|
megdnn_throw(ssprintf( |
|
|
"unsupported input / diff DType: %s x %s", filter.name(), diff.name())); |
|
|
|
|
|
|
|
|
"runtime does not support input / diff DType: %s x %s" |
|
|
|
|
|
"now support case list: FLOAT x FLOAT\n" |
|
|
|
|
|
" Int8 x Int8\n" |
|
|
|
|
|
" QuantizedS8 x QuantizedS8\n" |
|
|
|
|
|
" Quantized8Asymm x Quantized8Asymm\n", |
|
|
|
|
|
filter.name(), diff.name())); |
|
|
} |
|
|
} |
|
|
if (!grad.valid()) { |
|
|
if (!grad.valid()) { |
|
|
grad = supported_dst_dtype.at(0); |
|
|
grad = supported_dst_dtype.at(0); |
|
|
} else { |
|
|
} else { |
|
|
megdnn_assert( |
|
|
megdnn_assert( |
|
|
vec_contains(supported_dst_dtype, grad), |
|
|
vec_contains(supported_dst_dtype, grad), |
|
|
"unsupported ConvBwd(%s, %s) -> %s", filter.name(), diff.name(), |
|
|
|
|
|
grad.name()); |
|
|
|
|
|
|
|
|
"runtime does not support ConvBwd(%s, %s) -> %s" |
|
|
|
|
|
"now support case list: ConvBwd(FLOAT x FLOAT) -> FLOAT\n" |
|
|
|
|
|
" ConvBwd(Int8 x Int8) -> Int32\n" |
|
|
|
|
|
" ConvBwd(QuantizedS8 x QuantizedS8) -> " |
|
|
|
|
|
"QuantizedS32\n" |
|
|
|
|
|
" ConvBwd(Quantized8Asymm x Quantized8Asymm) -> " |
|
|
|
|
|
"Quantized32Asymm\n", |
|
|
|
|
|
filter.name(), diff.name(), grad.name()); |
|
|
} |
|
|
} |
|
|
megdnn_assert( |
|
|
megdnn_assert( |
|
|
param().compute_mode != Param::ComputeMode::FLOAT32 |
|
|
param().compute_mode != Param::ComputeMode::FLOAT32 |
|
|