diff --git a/dnn/src/common/images2neibs.cpp b/dnn/src/common/images2neibs.cpp index 66cd3e3a..8278f1af 100644 --- a/dnn/src/common/images2neibs.cpp +++ b/dnn/src/common/images2neibs.cpp @@ -6,7 +6,8 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #include "megdnn/oprs.h" @@ -28,22 +29,40 @@ void Images2NeibsBase::deduce_layout_fwd(const TensorLayout& src, TensorLayout& }; MEGDNN_MARK_USED_VAR(errmsg); megdnn_assert_contiguous(src); - megdnn_assert(src.ndim == 4_z, "%s", errmsg().c_str()); - size_t n = src[0], ic = src[1], ih = src[2], iw = src[3]; - size_t ph = this->param().pad_h; - size_t pw = this->param().pad_w; - size_t sh = this->param().stride_h; - size_t sw = this->param().stride_w; - size_t dh = this->param().dilate_h; - size_t dw = this->param().dilate_w; - size_t wh = this->param().window_h; - size_t ww = this->param().window_w; - size_t oh, ow; + megdnn_assert(src.ndim == 4_z || src.ndim == 5_z, "%s", errmsg().c_str()); - infer_conv_shape2d( - ih, iw, wh + (wh - 1) * (dh - 1), ww + (ww - 1) * (dw - 1), sh, sw, ph, pw, - oh, ow); - dst = TensorLayout(TensorShape({n, ic, oh, ow, wh, ww}), src.dtype); + if (src.ndim == 4_z) { + size_t n = src[0], ic = src[1], ih = src[2], iw = src[3]; + size_t ph = this->param().pad_h; + size_t pw = this->param().pad_w; + size_t sh = this->param().stride_h; + size_t sw = this->param().stride_w; + size_t dh = this->param().dilate_h; + size_t dw = this->param().dilate_w; + size_t wh = this->param().window_h; + size_t ww = this->param().window_w; + size_t oh, ow; + infer_conv_shape2d( + ih, iw, wh + (wh - 1) * (dh - 1), ww + (ww - 1) * (dw - 1), sh, sw, ph, + pw, oh, ow); + dst = TensorLayout(TensorShape({n, ic, oh, ow, wh, ww}), src.dtype, src.format); + } else if (src.ndim == 5_z) { + size_t n = src[0], ih = src[1], iw = src[3], ic = src[2]; + size_t ph = this->param().pad_h; + size_t pw = this->param().pad_w; + size_t sh = this->param().stride_h; + size_t sw = this->param().stride_w; + size_t dh = this->param().dilate_h; + size_t dw = this->param().dilate_w; + size_t wh = this->param().window_h; + size_t ww = this->param().window_w; + size_t oh, ow; + infer_conv_shape2d( + ih, iw, wh + (wh - 1) * (dh - 1), ww + (ww - 1) * (dw - 1), sh, sw, ph, + pw, oh, ow); + dst = TensorLayout( + TensorShape({n, oh, ic, ow, wh, ww, 4}), src.dtype, src.format); + } } void Images2NeibsBase::check_layout_fwd( diff --git a/dnn/src/naive/images2neibs/opr_impl.cpp b/dnn/src/naive/images2neibs/opr_impl.cpp index 99fc3519..c56a1723 100644 --- a/dnn/src/naive/images2neibs/opr_impl.cpp +++ b/dnn/src/naive/images2neibs/opr_impl.cpp @@ -21,40 +21,100 @@ namespace naive { template void Images2NeibsForwardImpl::exec_internal( _megdnn_tensor_in src, _megdnn_tensor_out dst) { - int N = src.layout.shape[0], C = src.layout.shape[1], IH = src.layout.shape[2], - IW = src.layout.shape[3]; - auto sptr = src.ptr(); - auto dptr = dst.ptr(); - size_t idx = 0; - int window_h = static_cast(param().window_h); - int window_w = static_cast(param().window_w); - int pad_h = static_cast(param().pad_h); - int pad_w = static_cast(param().pad_w); - int stride_h = static_cast(param().stride_h); - int stride_w = static_cast(param().stride_w); - int dilate_h = static_cast(param().dilate_h); - int dilate_w = static_cast(param().dilate_w); - int equ_window_h = dilate_h * (window_h - 1) + 1; - int equ_window_w = dilate_w * (window_w - 1) + 1; - for (int n = 0; n < N; ++n) - for (int c = 0; c < C; ++c) { - int ih = -pad_h; - for (; ih + equ_window_h <= IH + pad_h; ih += stride_h) { - int iw = -pad_w; - for (; iw + equ_window_w <= IW + pad_w; iw += stride_w) { - for (int kh = 0; kh < window_h; ++kh) - for (int kw = 0; kw < window_w; ++kw) { - int ih2 = ih + dilate_h * kh, iw2 = iw + dilate_w * kw; - dptr[idx * window_h * window_w + kh * window_w + kw] = - ih2 >= 0 && ih2 < IH && iw2 >= 0 && iw2 < IW - ? sptr[n * C * IH * IW + c * IH * IW + - ih2 * IW + iw2] - : 0.0f; - } - ++idx; + megdnn_assert(src.layout.ndim == 5 || src.layout.ndim == 4); + if (src.layout.ndim == 5) { + int N = src.layout.shape[0], C = src.layout.shape[2], IH = src.layout.shape[1], + IW = src.layout.shape[3]; + auto sptr = src.ptr(); + auto dptr = dst.ptr(); + size_t idx = 0; + int window_h = static_cast(param().window_h); + int window_w = static_cast(param().window_w); + int pad_h = static_cast(param().pad_h); + int pad_w = static_cast(param().pad_w); + int stride_h = static_cast(param().stride_h); + int stride_w = static_cast(param().stride_w); + int dilate_h = static_cast(param().dilate_h); + int dilate_w = static_cast(param().dilate_w); + int equ_window_h = dilate_h * (window_h - 1) + 1; + int equ_window_w = dilate_w * (window_w - 1) + 1; + + auto src_stride = src.layout.stride; + auto dst_stride = dst.layout.stride; + + for (int n = 0; n < N; ++n) + for (int c = 0; c < C; ++c) { + int ih = -pad_h; + int hc = 0; + for (; ih <= IH + pad_h - equ_window_h; ih += stride_h, hc++) { + int iw = -pad_w; + int wc = 0; + for (; iw <= IW + pad_w - equ_window_w; iw += stride_w, wc++) { + for (int kh = 0; kh < window_h; ++kh) + for (int kw = 0; kw < window_w; ++kw) { + for (int cn = 0; cn < 4; cn++) { + int ih2 = ih + dilate_h * kh, + iw2 = iw + dilate_w * kw; + int dst_pos = + n * dst_stride[0] + hc * dst_stride[1] + + c * dst_stride[2] + wc * dst_stride[3] + + kh * dst_stride[4] + kw * dst_stride[5] + + cn * dst_stride[6]; + int src_pos = + n * src_stride[0] + ih2 * src_stride[1] + + c * src_stride[2] + iw2 * src_stride[3] + + cn * src_stride[4]; + if (ih2 >= 0 && ih2 < IH && iw2 >= 0 && iw2 < IW) { + dptr[dst_pos] = sptr[src_pos]; + } else { + dptr[dst_pos] = 0.0f; + } + } + } + ++idx; + } } } - } + } else { + int N = src.layout.shape[0], C = src.layout.shape[1], IH = src.layout.shape[2], + IW = src.layout.shape[3]; + auto sptr = src.ptr(); + auto dptr = dst.ptr(); + size_t idx = 0; + int window_h = static_cast(param().window_h); + int window_w = static_cast(param().window_w); + int pad_h = static_cast(param().pad_h); + int pad_w = static_cast(param().pad_w); + int stride_h = static_cast(param().stride_h); + int stride_w = static_cast(param().stride_w); + int dilate_h = static_cast(param().dilate_h); + int dilate_w = static_cast(param().dilate_w); + int equ_window_h = dilate_h * (window_h - 1) + 1; + int equ_window_w = dilate_w * (window_w - 1) + 1; + for (int n = 0; n < N; ++n) + for (int c = 0; c < C; ++c) { + int ih = -pad_h; + for (; ih + equ_window_h <= IH + pad_h; ih += stride_h) { + int iw = -pad_w; + for (; iw + equ_window_w <= IW + pad_w; iw += stride_w) { + for (int kh = 0; kh < window_h; ++kh) + for (int kw = 0; kw < window_w; ++kw) { + int ih2 = ih + dilate_h * kh, iw2 = iw + dilate_w * kw; + int src_pos = + n * C * IH * IW + c * IH * IW + ih2 * IW + iw2; + int dst_pos = + idx * window_h * window_w + kh * window_w + kw; + if (ih2 >= 0 && ih2 < IH && iw2 >= 0 && iw2 < IW) { + dptr[dst_pos] = sptr[src_pos]; + } else { + dptr[dst_pos] = 0.0f; + } + } + ++idx; + } + } + } + } } void Images2NeibsForwardImpl::exec( diff --git a/dnn/test/common/images2neibs.h b/dnn/test/common/images2neibs.h index 257509cc..1963d7c5 100644 --- a/dnn/test/common/images2neibs.h +++ b/dnn/test/common/images2neibs.h @@ -6,7 +6,8 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #pragma once #include @@ -26,6 +27,32 @@ struct TestArg { inline std::vector get_args() { std::vector args; + + // clang-format off + for (uint32_t ph : {0, 1}) + for (uint32_t pw : {0, 1}) + for (uint32_t sh : {1, 2}) + for (uint32_t sw : {1, 2}) + for (uint32_t dh : {1, 2, 3}) + for (uint32_t dw : {1, 2, 3}) + for (uint32_t wh : {3, 4}) + for (uint32_t ww : {3, 4}) { + args.emplace_back(param::Images2Neibs{ph, pw, sh, sw, dh, dw, wh, ww}, + TensorShape{2, 3, 19, 20}); + } + // clang-format on + // large window case + args.emplace_back( + param::Images2Neibs{0, 0, 1, 1, 1, 1, 32, 64}, TensorShape{2, 3, 96, 128}); + // large size + args.emplace_back( + param::Images2Neibs{0, 0, 1, 1, 1, 1, 1, 1}, TensorShape{128, 128, 28, 24}); + + return args; +} + +inline std::vector get_cd4_args() { + std::vector args; // clang-format off for (uint32_t ph : {0, 1}) for (uint32_t pw : {0, 1}) @@ -33,18 +60,21 @@ inline std::vector get_args() { for (uint32_t sw : {1, 2}) for (uint32_t dh : {1, 2, 3}) for (uint32_t dw : {1, 2, 3}) - for (uint32_t wh : {3, 4}) - for (uint32_t ww : {3, 4}) { - args.emplace_back(param::Images2Neibs{ph, pw, sh, sw, dh, dw, wh, ww}, - TensorShape{2, 3, 19, 20}); + for (uint32_t wh : {2, 3}) + for (uint32_t ww : {2, 3}) { + args.emplace_back(param::Images2Neibs{ph, pw, sh, sw, dh, dw, wh, + ww}, + TensorShape{2, 13, 1, 14, 4}); } + // clang-format on // large window case args.emplace_back( - param::Images2Neibs{0, 0, 1, 1, 1, 1, 32, 64}, TensorShape{2, 3, 96, 128}); + param::Images2Neibs{0, 0, 1, 1, 1, 1, 8, 14}, TensorShape{2, 16, 1, 16, 4}); // large size args.emplace_back( - param::Images2Neibs{0, 0, 1, 1, 1, 1, 1, 1}, TensorShape{128, 128, 28, 24}); + param::Images2Neibs{0, 0, 1, 1, 1, 1, 1, 1}, + TensorShape{256, 16, 64, 16, 4}); return args; } @@ -75,6 +105,33 @@ inline std::vector get_benchmark_args() { return args; } +inline std::vector get_benchmark_args_cd4() { + std::vector args; + // clang-format off + for (uint32_t ph : {0, 1}) + for (uint32_t pw : {0, 1}) + for (uint32_t sh : {1, 2}) + for (uint32_t sw : {1, 2}) + for (uint32_t dh : {1, 2}) + for (uint32_t dw : {1, 2}) + for (uint32_t wh : {3, 4}) + for (uint32_t ww : {3, 4}) + for (uint32_t b : {1, 32}) + for (uint32_t c : {16, 32}) + for (uint32_t hw : {16, 32}) { + args.emplace_back(param::Images2Neibs{ph, pw, sh, sw, dh, dw, wh, ww}, + TensorShape{b, hw, (c + 3) / 4, hw, 4}); + } + + // clang-format on + // large size + args.emplace_back( + param::Images2Neibs{0, 0, 1, 1, 1, 1, 1, 1}, + TensorShape{256, 28, 32, 24, 4}); + + return args; +} + } // namespace images2neibs } // namespace test } // namespace megdnn diff --git a/dnn/test/naive/images2neibs.cpp b/dnn/test/naive/images2neibs.cpp index 4fae6875..83cc0123 100644 --- a/dnn/test/naive/images2neibs.cpp +++ b/dnn/test/naive/images2neibs.cpp @@ -56,3 +56,68 @@ TEST_F(NAIVE, IMAGES2NEIBS_FORWARD) { 8, 10, 0, 22, 24, 0, 36, 38, 8, 10, 12, 22, 24, 26, 36, 38, 40, 10, 12, 0, 24, 26, 0, 38, 40, 0})}); } + +TEST_F(NAIVE, IMAGES2NEIBS_FORWARD_CD4) { + Checker checker(handle(), /* check_dispatch */ false); + + Images2Neibs::Param param(0, 0, 1, 1, 1, 1, 2, 2); + + checker.set_param(param).exect( + Testcase{ + TensorValue( + {1, 3, 1, 3, 4}, dtype::Uint8(), + {0, 0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0, 4, 0, + 0, 0, 5, 0, 0, 0, 6, 0, 0, 0, 7, 0, 0, 0, 8, 0, 0, 0}), + {}}, + Testcase{ + {}, + TensorValue( + {1, 2, 1, 2, 2, 2, 4}, dtype::Uint8(), + {0, 0, 0, 0, 1, 0, 0, 0, 3, 0, 0, 0, 4, 0, 0, 0, + 1, 0, 0, 0, 2, 0, 0, 0, 4, 0, 0, 0, 5, 0, 0, 0, + 3, 0, 0, 0, 4, 0, 0, 0, 6, 0, 0, 0, 7, 0, 0, 0, + 4, 0, 0, 0, 5, 0, 0, 0, 7, 0, 0, 0, 8, 0, 0, 0})}); + + param.pad_h = 1; + param.pad_w = 1; + param.stride_h = 2; + param.stride_w = 2; + param.dilate_h = 2; + param.dilate_w = 2; + param.window_h = 3; + param.window_w = 3; + checker.set_param(param).exect( + Testcase{ + TensorValue( + {1, 6, 1, 7, 4}, dtype::Uint8(), + {0, 0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0, + 4, 0, 0, 0, 5, 0, 0, 0, 6, 0, 0, 0, 7, 0, 0, 0, + 8, 0, 0, 0, 9, 0, 0, 0, 10, 0, 0, 0, 11, 0, 0, 0, + 12, 0, 0, 0, 13, 0, 0, 0, 14, 0, 0, 0, 15, 0, 0, 0, + 16, 0, 0, 0, 17, 0, 0, 0, 18, 0, 0, 0, 19, 0, 0, 0, + 20, 0, 0, 0, 21, 0, 0, 0, 22, 0, 0, 0, 23, 0, 0, 0, + 24, 0, 0, 0, 25, 0, 0, 0, 26, 0, 0, 0, 27, 0, 0, 0, + 28, 0, 0, 0, 29, 0, 0, 0, 30, 0, 0, 0, 31, 0, 0, 0, + 32, 0, 0, 0, 33, 0, 0, 0, 34, 0, 0, 0, 35, 0, 0, 0, + 36, 0, 0, 0, 37, 0, 0, 0, 38, 0, 0, 0, 39, 0, 0, 0, + 40, 0, 0, 0, 41, 0, 0, 0}), + {}}, + Testcase{ + {}, + TensorValue( + {1, 2, 1, 3, 3, 3, 4}, dtype::Uint8(), + {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 8, 0, 0, 0, 10, 0, 0, 0, 0, 0, 0, 0, 22, 0, 0, 0, + 24, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 8, 0, 0, 0, 10, 0, 0, 0, 12, 0, 0, 0, 22, 0, 0, 0, + 24, 0, 0, 0, 26, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 10, 0, 0, 0, 12, 0, 0, 0, 0, 0, 0, 0, + 24, 0, 0, 0, 26, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 8, 0, 0, 0, 10, 0, 0, 0, 0, 0, 0, 0, 22, 0, 0, 0, + 24, 0, 0, 0, 0, 0, 0, 0, 36, 0, 0, 0, 38, 0, 0, 0, + 8, 0, 0, 0, 10, 0, 0, 0, 12, 0, 0, 0, 22, 0, 0, 0, + 24, 0, 0, 0, 26, 0, 0, 0, 36, 0, 0, 0, 38, 0, 0, 0, + 40, 0, 0, 0, 10, 0, 0, 0, 12, 0, 0, 0, 0, 0, 0, 0, + 24, 0, 0, 0, 26, 0, 0, 0, 0, 0, 0, 0, 38, 0, 0, 0, + 40, 0, 0, 0, 0, 0, 0, 0})}); +} diff --git a/src/gopt/impl/inference.cpp b/src/gopt/impl/inference.cpp index 96f1e830..42a1a14d 100644 --- a/src/gopt/impl/inference.cpp +++ b/src/gopt/impl/inference.cpp @@ -17,6 +17,7 @@ #include "megbrain/opr/blas.h" #include "megbrain/opr/dnn/batch_norm.h" #include "megbrain/opr/dnn/convolution.h" +#include "megbrain/opr/dnn/images2neibs.h" #include "megbrain/opr/dnn/local.h" #include "megbrain/opr/dnn/pooling.h" #include "megbrain/opr/imgproc.h" @@ -1651,6 +1652,7 @@ std::unique_ptr ConvertFormatPass::make_nhwcd4_converter() { replace_func[opr::Concat::typeinfo()] = replace_concat_opr; replace_func[opr::Reshape::typeinfo()] = relayout_inp_to_chw; replace_func[opr::GetVarShape::typeinfo()] = relayout_inp_to_chw; + replace_func[opr::Images2NeibsBackward::typeinfo()] = relayout_inp_to_chw; replace_func[opr::Dimshuffle::typeinfo()] = relayout_inp_to_chw; replace_func[opr::Reduce::typeinfo()] = relayout_inp_to_chw; replace_func[opr::AssertEqual::typeinfo()] = relayout_inp_to_chw;