|
|
@@ -777,8 +777,10 @@ typename ConvolutionBase<Parameter>::CanonizedFilterMeta ConvolutionBase<Paramet |
|
|
|
src_or_dst_spatial_start = 1; |
|
|
|
} |
|
|
|
megdnn_assert( |
|
|
|
cflt.icpg * cflt.group == src[src_or_dst_c_pos], "%s", |
|
|
|
errmsg().c_str()); |
|
|
|
cflt.icpg * cflt.group == src[src_or_dst_c_pos], |
|
|
|
"group conv channel mismatch : input channel got %zu, and " |
|
|
|
"filter channel got %u. More details for src, filter and dst : \n%s", |
|
|
|
src[src_or_dst_c_pos], cflt.icpg * cflt.group, errmsg().c_str()); |
|
|
|
dst.ndim = src.ndim; |
|
|
|
dst[0] = src[0]; |
|
|
|
dst[src_or_dst_c_pos] = cflt.ocpg * cflt.group; |
|
|
@@ -792,8 +794,10 @@ typename ConvolutionBase<Parameter>::CanonizedFilterMeta ConvolutionBase<Paramet |
|
|
|
src.ndim == 5, "invalid src ndim for NCHW4, expected=5, got=%zu", |
|
|
|
src.ndim); |
|
|
|
megdnn_assert( |
|
|
|
cflt.icpg * cflt.group == src[1] * 4, "%s icpg=%u group=%u", |
|
|
|
errmsg().c_str(), cflt.icpg, cflt.group); |
|
|
|
cflt.icpg * cflt.group == src[1] * 4, |
|
|
|
"group conv channel mismatch : input channel got %zu, and " |
|
|
|
"filter channel got %u. More details for src, filter and dst : \n%s", |
|
|
|
src[1] * 4, cflt.icpg * cflt.group, errmsg().c_str()); |
|
|
|
dst.ndim = src.ndim; |
|
|
|
dst[0] = src[0]; |
|
|
|
auto oc = cflt.ocpg * cflt.group; |
|
|
@@ -809,8 +813,10 @@ typename ConvolutionBase<Parameter>::CanonizedFilterMeta ConvolutionBase<Paramet |
|
|
|
src.ndim == 5, "invalid src ndim for NCHW8, expected=5, got=%zu", |
|
|
|
src.ndim); |
|
|
|
megdnn_assert( |
|
|
|
cflt.icpg * cflt.group == src[1] * 8, "%s icpg=%u group=%u", |
|
|
|
errmsg().c_str(), cflt.icpg, cflt.group); |
|
|
|
cflt.icpg * cflt.group == src[1] * 8, |
|
|
|
"group conv channel mismatch : input channel got %zu, and " |
|
|
|
"filter channel got %u. More details for src, filter and dst : \n%s", |
|
|
|
src[1] * 8, cflt.icpg * cflt.group, errmsg().c_str()); |
|
|
|
dst.ndim = src.ndim; |
|
|
|
dst[0] = src[0]; |
|
|
|
auto oc = cflt.ocpg * cflt.group; |
|
|
@@ -826,8 +832,10 @@ typename ConvolutionBase<Parameter>::CanonizedFilterMeta ConvolutionBase<Paramet |
|
|
|
src.ndim == 5, "invalid src ndim for NCHW32, expected=5, got=%zu", |
|
|
|
src.ndim); |
|
|
|
megdnn_assert( |
|
|
|
cflt.icpg * cflt.group == src[1] * 32, "%s icpg=%u group=%u", |
|
|
|
errmsg().c_str(), cflt.icpg, cflt.group); |
|
|
|
cflt.icpg * cflt.group == src[1] * 32, |
|
|
|
"group conv channel mismatch : input channel got %zu, and " |
|
|
|
"filter channel got %u. More details for src, filter and dst : \n%s", |
|
|
|
src[1] * 32, cflt.icpg * cflt.group, errmsg().c_str()); |
|
|
|
dst.ndim = src.ndim; |
|
|
|
dst[0] = src[0]; |
|
|
|
auto oc = cflt.ocpg * cflt.group; |
|
|
@@ -856,7 +864,11 @@ typename ConvolutionBase<Parameter>::CanonizedFilterMeta ConvolutionBase<Paramet |
|
|
|
megdnn_assert( |
|
|
|
cflt.icpg * cflt.group == src[1] * 8 || |
|
|
|
(cflt.icpg * cflt.group == src[1]), |
|
|
|
"%s icpg=%u group=%u", errmsg().c_str(), cflt.icpg, cflt.group); |
|
|
|
"group conv channel mismatch : input channel got %zu, and " |
|
|
|
"filter channel got %u. More details about src, filter and dst : " |
|
|
|
"\n%s", |
|
|
|
src.ndim == 5 ? src[1] * 8 : src[1], cflt.icpg * cflt.group, |
|
|
|
errmsg().c_str()); |
|
|
|
} |
|
|
|
|
|
|
|
} else if ( |
|
|
@@ -879,15 +891,21 @@ typename ConvolutionBase<Parameter>::CanonizedFilterMeta ConvolutionBase<Paramet |
|
|
|
megdnn_assert( |
|
|
|
cflt.icpg * cflt.group == src[1] * 4 || |
|
|
|
(cflt.icpg * cflt.group == src[1]), |
|
|
|
"%s icpg=%u group=%u", errmsg().c_str(), cflt.icpg, cflt.group); |
|
|
|
"group conv channel mismatch : input channel got %zu, and " |
|
|
|
"filter channel got %u. More details about src, filter and dst : " |
|
|
|
"\n%s", |
|
|
|
src.ndim == 5 ? src[1] * 4 : src[1], cflt.icpg * cflt.group, |
|
|
|
errmsg().c_str()); |
|
|
|
} |
|
|
|
} else if (param().format == Param::Format::CHWN4) { |
|
|
|
megdnn_assert( |
|
|
|
src.ndim == 5, "invalid src ndim for CHWN4, expected=5, got=%zu", |
|
|
|
src.ndim); |
|
|
|
megdnn_assert( |
|
|
|
cflt.icpg * cflt.group == src[0] * 4, "%s icpg=%u group=%u", |
|
|
|
errmsg().c_str(), cflt.icpg, cflt.group); |
|
|
|
cflt.icpg * cflt.group == src[0] * 4, |
|
|
|
"group conv channel mismatch : input channel got %zu, and " |
|
|
|
"filter channel got %u. More details for src, filter and dst : \n%s", |
|
|
|
src[0] * 4, cflt.icpg * cflt.group, errmsg().c_str()); |
|
|
|
dst.ndim = src.ndim; |
|
|
|
dst[3] = src[3]; |
|
|
|
auto oc = cflt.ocpg * cflt.group; |
|
|
@@ -903,8 +921,10 @@ typename ConvolutionBase<Parameter>::CanonizedFilterMeta ConvolutionBase<Paramet |
|
|
|
src.ndim == 5, "invalid src ndim for NCHW4_NCHW, expected=5, got=%zu", |
|
|
|
src.ndim); |
|
|
|
megdnn_assert( |
|
|
|
cflt.icpg * cflt.group == src[1] * 4, "%s icpg=%u group=%u", |
|
|
|
errmsg().c_str(), cflt.icpg, cflt.group); |
|
|
|
cflt.icpg * cflt.group == src[1] * 4, |
|
|
|
"group conv channel mismatch : input channel got %zu, and " |
|
|
|
"filter channel got %u. More details for src, filter and dst : \n%s", |
|
|
|
src[1] * 4, cflt.icpg * cflt.group, errmsg().c_str()); |
|
|
|
dst.ndim = 4; |
|
|
|
dst[0] = src[0]; |
|
|
|
auto oc = cflt.ocpg * cflt.group; |
|
|
@@ -918,8 +938,10 @@ typename ConvolutionBase<Parameter>::CanonizedFilterMeta ConvolutionBase<Paramet |
|
|
|
src.ndim == 5, "invalid src ndim for NCHW4_NHWC, expected=5, got=%zu", |
|
|
|
src.ndim); |
|
|
|
megdnn_assert( |
|
|
|
cflt.icpg * cflt.group == src[1] * 4, "%s icpg=%u group=%u", |
|
|
|
errmsg().c_str(), cflt.icpg, cflt.group); |
|
|
|
cflt.icpg * cflt.group == src[1] * 4, |
|
|
|
"group conv channel mismatch : input channel got %zu, and " |
|
|
|
"filter channel got %u. More details for src, filter and dst : \n%s", |
|
|
|
src[1] * 4, cflt.icpg * cflt.group, errmsg().c_str()); |
|
|
|
dst.ndim = 4; |
|
|
|
dst[0] = src[0]; |
|
|
|
dst[1] = infer_conv_shape( |
|
|
@@ -933,8 +955,10 @@ typename ConvolutionBase<Parameter>::CanonizedFilterMeta ConvolutionBase<Paramet |
|
|
|
src.ndim == 5, "invalid src ndim for NCHW4_NCHW32, expected=5, got=%zu", |
|
|
|
src.ndim); |
|
|
|
megdnn_assert( |
|
|
|
cflt.icpg * cflt.group == src[1] * 4, "%s icpg=%u group=%u", |
|
|
|
errmsg().c_str(), cflt.icpg, cflt.group); |
|
|
|
cflt.icpg * cflt.group == src[1] * 4, |
|
|
|
"group conv channel mismatch : input channel got %zu, and " |
|
|
|
"filter channel got %u. More details for src, filter and dst : \n%s", |
|
|
|
src[1] * 4, cflt.icpg * cflt.group, errmsg().c_str()); |
|
|
|
dst.ndim = src.ndim; |
|
|
|
dst[0] = src[0]; |
|
|
|
auto oc = cflt.ocpg * cflt.group; |
|
|
@@ -950,8 +974,10 @@ typename ConvolutionBase<Parameter>::CanonizedFilterMeta ConvolutionBase<Paramet |
|
|
|
src.ndim == 5, "invalid src ndim for NCHW32_NCHW4, expected=5, got=%zu", |
|
|
|
src.ndim); |
|
|
|
megdnn_assert( |
|
|
|
cflt.icpg * cflt.group == src[1] * 32, "%s icpg=%u group=%u", |
|
|
|
errmsg().c_str(), cflt.icpg, cflt.group); |
|
|
|
cflt.icpg * cflt.group == src[1] * 32, |
|
|
|
"group conv channel mismatch : input channel got %zu, and " |
|
|
|
"filter channel got %u. More details for src, filter and dst : \n%s", |
|
|
|
src[1] * 32, cflt.icpg * cflt.group, errmsg().c_str()); |
|
|
|
dst.ndim = src.ndim; |
|
|
|
dst[0] = src[0]; |
|
|
|
auto oc = cflt.ocpg * cflt.group; |
|
|
@@ -967,8 +993,10 @@ typename ConvolutionBase<Parameter>::CanonizedFilterMeta ConvolutionBase<Paramet |
|
|
|
src.ndim == 5, "invalid src ndim for NCHW64, expected=5, got=%zu", |
|
|
|
src.ndim); |
|
|
|
megdnn_assert( |
|
|
|
cflt.icpg * cflt.group == src[1] * 64, "%s icpg=%u group=%u", |
|
|
|
errmsg().c_str(), cflt.icpg, cflt.group); |
|
|
|
cflt.icpg * cflt.group == src[1] * 64, |
|
|
|
"group conv channel mismatch : input channel got %zu, and " |
|
|
|
"filter channel got %u. More details for src, filter and dst : \n%s", |
|
|
|
src[1] * 64, cflt.icpg * cflt.group, errmsg().c_str()); |
|
|
|
dst.ndim = src.ndim; |
|
|
|
dst[0] = src[0]; |
|
|
|
auto oc = cflt.ocpg * cflt.group; |
|
|
@@ -985,8 +1013,10 @@ typename ConvolutionBase<Parameter>::CanonizedFilterMeta ConvolutionBase<Paramet |
|
|
|
src.ndim == 5, "invalid src ndim for NHWCD4, expected=5, got=%zu", |
|
|
|
src.ndim); |
|
|
|
megdnn_assert( |
|
|
|
cflt.icpg * cflt.group == src[2] * 4, "%s icpg=%u group=%u", |
|
|
|
errmsg().c_str(), cflt.icpg, cflt.group); |
|
|
|
cflt.icpg * cflt.group == src[2] * 4, |
|
|
|
"group conv channel mismatch : input channel got %zu, and " |
|
|
|
"filter channel got %u. More details for src, filter and dst : \n%s", |
|
|
|
src[2] * 4, cflt.icpg * cflt.group, errmsg().c_str()); |
|
|
|
dst.ndim = src.ndim; |
|
|
|
dst[0] = src[0]; |
|
|
|
auto oc = cflt.ocpg * cflt.group; |
|
|
|