Browse Source

fix(mgb): make error infomation of group conv input channel mismatch more readable

GitOrigin-RevId: d249408c26
release-1.11.1
Megvii Engine Team 2 years ago
parent
commit
f0291883b6
2 changed files with 58 additions and 25 deletions
  1. +54
    -24
      dnn/src/common/convolution.cpp
  2. +4
    -1
      dnn/src/common/convolution3d.cpp

+ 54
- 24
dnn/src/common/convolution.cpp View File

@@ -777,8 +777,10 @@ typename ConvolutionBase<Parameter>::CanonizedFilterMeta ConvolutionBase<Paramet
src_or_dst_spatial_start = 1;
}
megdnn_assert(
cflt.icpg * cflt.group == src[src_or_dst_c_pos], "%s",
errmsg().c_str());
cflt.icpg * cflt.group == src[src_or_dst_c_pos],
"group conv channel mismatch : input channel got %zu, and "
"filter channel got %u. More details for src, filter and dst : \n%s",
src[src_or_dst_c_pos], cflt.icpg * cflt.group, errmsg().c_str());
dst.ndim = src.ndim;
dst[0] = src[0];
dst[src_or_dst_c_pos] = cflt.ocpg * cflt.group;
@@ -792,8 +794,10 @@ typename ConvolutionBase<Parameter>::CanonizedFilterMeta ConvolutionBase<Paramet
src.ndim == 5, "invalid src ndim for NCHW4, expected=5, got=%zu",
src.ndim);
megdnn_assert(
cflt.icpg * cflt.group == src[1] * 4, "%s icpg=%u group=%u",
errmsg().c_str(), cflt.icpg, cflt.group);
cflt.icpg * cflt.group == src[1] * 4,
"group conv channel mismatch : input channel got %zu, and "
"filter channel got %u. More details for src, filter and dst : \n%s",
src[1] * 4, cflt.icpg * cflt.group, errmsg().c_str());
dst.ndim = src.ndim;
dst[0] = src[0];
auto oc = cflt.ocpg * cflt.group;
@@ -809,8 +813,10 @@ typename ConvolutionBase<Parameter>::CanonizedFilterMeta ConvolutionBase<Paramet
src.ndim == 5, "invalid src ndim for NCHW8, expected=5, got=%zu",
src.ndim);
megdnn_assert(
cflt.icpg * cflt.group == src[1] * 8, "%s icpg=%u group=%u",
errmsg().c_str(), cflt.icpg, cflt.group);
cflt.icpg * cflt.group == src[1] * 8,
"group conv channel mismatch : input channel got %zu, and "
"filter channel got %u. More details for src, filter and dst : \n%s",
src[1] * 8, cflt.icpg * cflt.group, errmsg().c_str());
dst.ndim = src.ndim;
dst[0] = src[0];
auto oc = cflt.ocpg * cflt.group;
@@ -826,8 +832,10 @@ typename ConvolutionBase<Parameter>::CanonizedFilterMeta ConvolutionBase<Paramet
src.ndim == 5, "invalid src ndim for NCHW32, expected=5, got=%zu",
src.ndim);
megdnn_assert(
cflt.icpg * cflt.group == src[1] * 32, "%s icpg=%u group=%u",
errmsg().c_str(), cflt.icpg, cflt.group);
cflt.icpg * cflt.group == src[1] * 32,
"group conv channel mismatch : input channel got %zu, and "
"filter channel got %u. More details for src, filter and dst : \n%s",
src[1] * 32, cflt.icpg * cflt.group, errmsg().c_str());
dst.ndim = src.ndim;
dst[0] = src[0];
auto oc = cflt.ocpg * cflt.group;
@@ -856,7 +864,11 @@ typename ConvolutionBase<Parameter>::CanonizedFilterMeta ConvolutionBase<Paramet
megdnn_assert(
cflt.icpg * cflt.group == src[1] * 8 ||
(cflt.icpg * cflt.group == src[1]),
"%s icpg=%u group=%u", errmsg().c_str(), cflt.icpg, cflt.group);
"group conv channel mismatch : input channel got %zu, and "
"filter channel got %u. More details about src, filter and dst : "
"\n%s",
src.ndim == 5 ? src[1] * 8 : src[1], cflt.icpg * cflt.group,
errmsg().c_str());
}

} else if (
@@ -879,15 +891,21 @@ typename ConvolutionBase<Parameter>::CanonizedFilterMeta ConvolutionBase<Paramet
megdnn_assert(
cflt.icpg * cflt.group == src[1] * 4 ||
(cflt.icpg * cflt.group == src[1]),
"%s icpg=%u group=%u", errmsg().c_str(), cflt.icpg, cflt.group);
"group conv channel mismatch : input channel got %zu, and "
"filter channel got %u. More details about src, filter and dst : "
"\n%s",
src.ndim == 5 ? src[1] * 4 : src[1], cflt.icpg * cflt.group,
errmsg().c_str());
}
} else if (param().format == Param::Format::CHWN4) {
megdnn_assert(
src.ndim == 5, "invalid src ndim for CHWN4, expected=5, got=%zu",
src.ndim);
megdnn_assert(
cflt.icpg * cflt.group == src[0] * 4, "%s icpg=%u group=%u",
errmsg().c_str(), cflt.icpg, cflt.group);
cflt.icpg * cflt.group == src[0] * 4,
"group conv channel mismatch : input channel got %zu, and "
"filter channel got %u. More details for src, filter and dst : \n%s",
src[0] * 4, cflt.icpg * cflt.group, errmsg().c_str());
dst.ndim = src.ndim;
dst[3] = src[3];
auto oc = cflt.ocpg * cflt.group;
@@ -903,8 +921,10 @@ typename ConvolutionBase<Parameter>::CanonizedFilterMeta ConvolutionBase<Paramet
src.ndim == 5, "invalid src ndim for NCHW4_NCHW, expected=5, got=%zu",
src.ndim);
megdnn_assert(
cflt.icpg * cflt.group == src[1] * 4, "%s icpg=%u group=%u",
errmsg().c_str(), cflt.icpg, cflt.group);
cflt.icpg * cflt.group == src[1] * 4,
"group conv channel mismatch : input channel got %zu, and "
"filter channel got %u. More details for src, filter and dst : \n%s",
src[1] * 4, cflt.icpg * cflt.group, errmsg().c_str());
dst.ndim = 4;
dst[0] = src[0];
auto oc = cflt.ocpg * cflt.group;
@@ -918,8 +938,10 @@ typename ConvolutionBase<Parameter>::CanonizedFilterMeta ConvolutionBase<Paramet
src.ndim == 5, "invalid src ndim for NCHW4_NHWC, expected=5, got=%zu",
src.ndim);
megdnn_assert(
cflt.icpg * cflt.group == src[1] * 4, "%s icpg=%u group=%u",
errmsg().c_str(), cflt.icpg, cflt.group);
cflt.icpg * cflt.group == src[1] * 4,
"group conv channel mismatch : input channel got %zu, and "
"filter channel got %u. More details for src, filter and dst : \n%s",
src[1] * 4, cflt.icpg * cflt.group, errmsg().c_str());
dst.ndim = 4;
dst[0] = src[0];
dst[1] = infer_conv_shape(
@@ -933,8 +955,10 @@ typename ConvolutionBase<Parameter>::CanonizedFilterMeta ConvolutionBase<Paramet
src.ndim == 5, "invalid src ndim for NCHW4_NCHW32, expected=5, got=%zu",
src.ndim);
megdnn_assert(
cflt.icpg * cflt.group == src[1] * 4, "%s icpg=%u group=%u",
errmsg().c_str(), cflt.icpg, cflt.group);
cflt.icpg * cflt.group == src[1] * 4,
"group conv channel mismatch : input channel got %zu, and "
"filter channel got %u. More details for src, filter and dst : \n%s",
src[1] * 4, cflt.icpg * cflt.group, errmsg().c_str());
dst.ndim = src.ndim;
dst[0] = src[0];
auto oc = cflt.ocpg * cflt.group;
@@ -950,8 +974,10 @@ typename ConvolutionBase<Parameter>::CanonizedFilterMeta ConvolutionBase<Paramet
src.ndim == 5, "invalid src ndim for NCHW32_NCHW4, expected=5, got=%zu",
src.ndim);
megdnn_assert(
cflt.icpg * cflt.group == src[1] * 32, "%s icpg=%u group=%u",
errmsg().c_str(), cflt.icpg, cflt.group);
cflt.icpg * cflt.group == src[1] * 32,
"group conv channel mismatch : input channel got %zu, and "
"filter channel got %u. More details for src, filter and dst : \n%s",
src[1] * 32, cflt.icpg * cflt.group, errmsg().c_str());
dst.ndim = src.ndim;
dst[0] = src[0];
auto oc = cflt.ocpg * cflt.group;
@@ -967,8 +993,10 @@ typename ConvolutionBase<Parameter>::CanonizedFilterMeta ConvolutionBase<Paramet
src.ndim == 5, "invalid src ndim for NCHW64, expected=5, got=%zu",
src.ndim);
megdnn_assert(
cflt.icpg * cflt.group == src[1] * 64, "%s icpg=%u group=%u",
errmsg().c_str(), cflt.icpg, cflt.group);
cflt.icpg * cflt.group == src[1] * 64,
"group conv channel mismatch : input channel got %zu, and "
"filter channel got %u. More details for src, filter and dst : \n%s",
src[1] * 64, cflt.icpg * cflt.group, errmsg().c_str());
dst.ndim = src.ndim;
dst[0] = src[0];
auto oc = cflt.ocpg * cflt.group;
@@ -985,8 +1013,10 @@ typename ConvolutionBase<Parameter>::CanonizedFilterMeta ConvolutionBase<Paramet
src.ndim == 5, "invalid src ndim for NHWCD4, expected=5, got=%zu",
src.ndim);
megdnn_assert(
cflt.icpg * cflt.group == src[2] * 4, "%s icpg=%u group=%u",
errmsg().c_str(), cflt.icpg, cflt.group);
cflt.icpg * cflt.group == src[2] * 4,
"group conv channel mismatch : input channel got %zu, and "
"filter channel got %u. More details for src, filter and dst : \n%s",
src[2] * 4, cflt.icpg * cflt.group, errmsg().c_str());
dst.ndim = src.ndim;
dst[0] = src[0];
auto oc = cflt.ocpg * cflt.group;


+ 4
- 1
dnn/src/common/convolution3d.cpp View File

@@ -148,7 +148,10 @@ Convolution3DBase::CanonizedFilterMeta Convolution3DBase::deduce_layout_fwd(
src_or_dst_spatial_start = 1;
}
megdnn_assert(
cflt.icpg * cflt.group == src[src_or_dst_c_pos], "%s", errmsg().c_str());
cflt.icpg * cflt.group == src[src_or_dst_c_pos],
"group conv channel mismatch : input channel got %zu, and "
"filter channel got %u. More details about src, filter and dst : \n%s",
src[src_or_dst_c_pos], cflt.icpg * cflt.group, errmsg().c_str());
dst.ndim = src.ndim;
dst[0] = src[0];
dst[src_or_dst_c_pos] = cflt.ocpg * cflt.group;


Loading…
Cancel
Save