diff --git a/dnn/src/common/utils.h b/dnn/src/common/utils.h index d688166c..a12f2a53 100644 --- a/dnn/src/common/utils.h +++ b/dnn/src/common/utils.h @@ -41,6 +41,7 @@ #endif #define rep(i, n) for (auto i = decltype(n){0}; i < (n); ++i) +#define rep_step(i, n, step) for (auto i = decltype(n){0}; i < (n); i += (step)) #define megdnn_assert_contiguous(layout) \ do { \ diff --git a/dnn/src/fallback/conv_bias/opr_impl.cpp b/dnn/src/fallback/conv_bias/opr_impl.cpp index c4e5a7ac..4619941d 100644 --- a/dnn/src/fallback/conv_bias/opr_impl.cpp +++ b/dnn/src/fallback/conv_bias/opr_impl.cpp @@ -375,6 +375,25 @@ const T* ConvBiasImpl::NCBKernParam::filter(size_t group_pack_id, break; } + case Param::Format::NCHW44: { + size_t group = filter_meta.group; + size_t icpg = filter_meta.icpg; + size_t ocpg = filter_meta.ocpg; + //! four format of weight layout + //! 1. {oc/4, ic/4, fh, fw, 4, 4}, + //! 2. {g, oc/4, ic/4, fh, fw, 4, 4}, + //! 3. {g/4, fh, fw, 1, 1, 4}, 4. {oc/4, fh, fw, ic, 4} + megdnn_assert((icpg % 4 == 0 && ocpg % 4 == 0) || + (group % 4 == 0 && icpg == 1 && ocpg == 1 && + pack_group_size > 1) || + (group == 1 && ocpg % 4 == 0), + "The filter shepe is not right of nchw44"); + group_offset = pack_group_size * group_pack_id * filter_meta.icpg * + filter_meta.ocpg * filter_meta.spatial[0] * + filter_meta.spatial[1] * filter_type.size(); + + break; + } case ConvBiasImpl::Param::Format::NCHW_WINOGRAD: case ConvBiasImpl::Param::Format::NCHW88_WINOGRAD: { //! four format of weight layout