Browse Source

fix(mgb): make error infomation of input channel mismatch more readable

GitOrigin-RevId: 6f95260070
HuaHua404-patch-4
Megvii Engine Team 2 years ago
parent
commit
f5597d9a10
2 changed files with 53 additions and 13 deletions
  1. +38
    -7
      dnn/src/common/convolution.cpp
  2. +15
    -6
      imperative/src/impl/ops/convolution.cpp

+ 38
- 7
dnn/src/common/convolution.cpp View File

@@ -573,8 +573,15 @@ void ConvolutionBase<Parameter>::check_or_deduce_dtype_fwd(
filter.param<dtype::QuantizedS8>().scale)); filter.param<dtype::QuantizedS8>().scale));
} else { } else {
megdnn_throw(ssprintf( megdnn_throw(ssprintf(
"unsupported input / filter DType: %s x %s", src.name(),
filter.name()));
"runtime does not support input / filter DType: %s x %s"
"now support case list: FLOAT x FLOAT\n"
" Int8 x Int8\n"
" QuantizedS8 x QuantizedS8\n"
" Quantized8Asymm x Quantized8Asymm\n"
" QuantizedS4 x QuantizedS4\n"
" Quantized4Asymm x Quantized4Asymm\n"
" QuantizedS1 x QuantizedS1\n",
src.name(), filter.name()));
} }
if (!dst.valid()) { if (!dst.valid()) {
dst = supported_dst_dtype.at(0); dst = supported_dst_dtype.at(0);
@@ -588,8 +595,21 @@ void ConvolutionBase<Parameter>::check_or_deduce_dtype_fwd(
} }
MEGDNN_MARK_USED_VAR(dst_supported); MEGDNN_MARK_USED_VAR(dst_supported);
megdnn_assert( megdnn_assert(
dst_supported, "unsupported Conv(%s, %s) -> %s", src.name(),
filter.name(), dst.name());
dst_supported,
"runtime does not support Conv(%s, %s) -> %s"
"now support case list: Conv(FLOAT x FLOAT) -> FLOAT\n"
" Conv(Int8 x Int8) -> Int32\n"
" Conv(QuantizedS8 x QuantizedS8) -> "
"QuantizedS32\n"
" Conv(Quantized8Asymm x Quantized8Asymm) -> "
"Quantized32Asymm\n"
" Conv(QuantizedS4 x QuantizedS4) -> "
"QuantizedS32\n"
" Conv(Quantized4Asymm x Quantized4Asymm) -> "
"Quantized32Asymm\n"
" Conv(QuantizedS1 x QuantizedS1) -> "
"QuantizedS32\n",
src.name(), filter.name(), dst.name());
} }
megdnn_assert( megdnn_assert(
(param().compute_mode == Param::ComputeMode::FLOAT32 || (param().compute_mode == Param::ComputeMode::FLOAT32 ||
@@ -1098,15 +1118,26 @@ void ConvolutionBackwardData::deduce_dtype(DType filter, DType diff, DType& grad
} }
} else { } else {
megdnn_throw(ssprintf( megdnn_throw(ssprintf(
"unsupported input / diff DType: %s x %s", filter.name(), diff.name()));
"runtime does not support input / diff DType: %s x %s"
"now support case list: FLOAT x FLOAT\n"
" Int8 x Int8\n"
" QuantizedS8 x QuantizedS8\n"
" Quantized8Asymm x Quantized8Asymm\n",
filter.name(), diff.name()));
} }
if (!grad.valid()) { if (!grad.valid()) {
grad = supported_dst_dtype.at(0); grad = supported_dst_dtype.at(0);
} else { } else {
megdnn_assert( megdnn_assert(
vec_contains(supported_dst_dtype, grad), vec_contains(supported_dst_dtype, grad),
"unsupported ConvBwd(%s, %s) -> %s", filter.name(), diff.name(),
grad.name());
"runtime does not support ConvBwd(%s, %s) -> %s"
"now support case list: ConvBwd(FLOAT x FLOAT) -> FLOAT\n"
" ConvBwd(Int8 x Int8) -> Int32\n"
" ConvBwd(QuantizedS8 x QuantizedS8) -> "
"QuantizedS32\n"
" ConvBwd(Quantized8Asymm x Quantized8Asymm) -> "
"Quantized32Asymm\n",
filter.name(), diff.name(), grad.name());
} }
megdnn_assert( megdnn_assert(
param().compute_mode != Param::ComputeMode::FLOAT32 param().compute_mode != Param::ComputeMode::FLOAT32


+ 15
- 6
imperative/src/impl/ops/convolution.cpp View File

@@ -95,8 +95,11 @@ TensorLayout do_shape_infer(
dilated_spatial[i] = dilated_spatial[i] =
(filter[i + flt_start + flt_spatial_start] - 1) * dilation[i] + 1; (filter[i + flt_start + flt_spatial_start] - 1) * dilation[i] + 1;
} }
mgb_assert(icpg * group == src[src_or_dst_c_pos], "group conv invalid");

mgb_assert(
icpg * group == src[src_or_dst_c_pos],
"group conv invalid: input channel of Conv expect %zu, but got %zu\n"
"hint: weight may be changed by mistake\n",
icpg * group, src[src_or_dst_c_pos]);
TensorLayout dst{src.dtype}; TensorLayout dst{src.dtype};
dst.ndim = src_ndim; dst.ndim = src_ndim;
dst[0] = src[0]; dst[0] = src[0];
@@ -310,8 +313,11 @@ TensorLayout convbwd_do_shape_infer(
dilated_spatial[i] = dilated_spatial[i] =
(filter[i + flt_start + flt_spatial_start] - 1) * dilation[i] + 1; (filter[i + flt_start + flt_spatial_start] - 1) * dilation[i] + 1;
} }
mgb_assert(ocpg * group == diff[src_or_dst_c_pos], "group conv invalid");

mgb_assert(
ocpg * group == diff[src_or_dst_c_pos],
"group conv invalid: input channel of Conv expect %zu, but got %zu\n"
"hint: weight may be changed by mistake\n",
ocpg * group, diff[src_or_dst_c_pos]);
auto deduce = [](size_t out, size_t filter, size_t stride, size_t pad) { auto deduce = [](size_t out, size_t filter, size_t stride, size_t pad) {
auto i = (out - 1) * stride + filter; auto i = (out - 1) * stride + filter;
mgb_assert(i > pad * 2); mgb_assert(i > pad * 2);
@@ -479,8 +485,11 @@ TensorLayout do_shape_infer(
dilated_spatial[i] = dilated_spatial[i] =
(filter[i + flt_start + flt_spatial_start] - 1) * dilation[i] + 1; (filter[i + flt_start + flt_spatial_start] - 1) * dilation[i] + 1;
} }
mgb_assert(icpg * group == src[src_or_dst_c_pos], "group conv invalid");

mgb_assert(
icpg * group == src[src_or_dst_c_pos],
"group conv invalid: input channel of Conv expect %zu, but got %zu\n"
"hint: weight may be changed by mistake\n",
icpg * group, src[src_or_dst_c_pos]);
TensorLayout dst{src.dtype}; TensorLayout dst{src.dtype};
dst.ndim = src_ndim; dst.ndim = src_ndim;
dst[0] = src[0]; dst[0] = src[0];


Loading…
Cancel
Save