diff --git a/dnn/src/common/unroll_macro.h b/dnn/src/common/unroll_macro.h index 41f88ff1..37975b0c 100644 --- a/dnn/src/common/unroll_macro.h +++ b/dnn/src/common/unroll_macro.h @@ -177,6 +177,15 @@ UNROLL_RAW_5x2(cb, v0, ##a) \ cb(5, 0, ##a) cb(5, 1, ##a) +#define UNROLL_RAW_4x6(cb, v0, a...) \ + cb(0, 0, ##a) cb(0, 1, ##a) cb(0, 2, ##a) cb(0, 3, ##a) cb(0, 4, ##a) cb(0, 5, ##a) \ + cb(1, 0, ##a) cb(1, 1, ##a) cb(1, 2, ##a) cb(1, 3, ##a) cb(1, 4, ##a) cb(1, 5, ##a) \ + cb(2, 0, ##a) cb(2, 1, ##a) cb(2, 2, ##a) cb(2, 3, ##a) cb(2, 4, ##a) cb(2, 5, ##a) \ + cb(3, 0, ##a) cb(3, 1, ##a) cb(3, 2, ##a) cb(3, 3, ##a) cb(3, 4, ##a) cb(3, 5, ##a) +#define UNROLL_RAW_5x6(cb, v0, a...) \ + UNROLL_RAW_4x6(cb, v0, ##a) \ + cb(4, 0, ##a) cb(4, 1, ##a) cb(4, 2, ##a) cb(4, 3, ##a) cb(4, 4, ##a) cb(4, 5, ##a) + #define UNROLL_CALL0_D2(step, step2, cb, v...) \ UNROLL_RAW_##step##x##step2(cb, 0, ##v) #define UNROLL_CALL1_D2(step, step2, cb, v...) \ diff --git a/dnn/src/fallback/conv_bias/gi/fp32/algos.cpp b/dnn/src/fallback/conv_bias/gi/fp32/algos.cpp index 5ae18365..64e7897c 100644 --- a/dnn/src/fallback/conv_bias/gi/fp32/algos.cpp +++ b/dnn/src/fallback/conv_bias/gi/fp32/algos.cpp @@ -218,6 +218,44 @@ MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL( AlgoFP32WinogradF63_4x4, winograd::winograd_6x3_4x4_f, megdnn_fallback_winograd_fp32, param::MatrixMul::Format::MK4); +/* ======================= AlgoFP32WinogradF43_4x4 ======================== */ + +bool ConvBiasImpl::AlgoFP32WinogradF43_4x4::usable( + const NCBKernSizeParam& param, + AlgoSelectionStrategy /*algo_selection_strategy*/) const { + MEGDNN_MARK_USED_VAR(param); + MIDOUT_BEGIN(megdnn_fallback_winograd_fp32, 6, 0) { + if (param.filter_meta.icpg % 4 != 0 || param.filter_meta.ocpg % 4 != 0) + return false; + using Strategy = winograd::winograd_4x3_4x4_f; + using PackMode = fallback::MatrixMulImpl::AlgoBase::PackMode; + Strategy strategy(param.src_type, param.filter_type, param.dst_type); + auto&& matmul_param = + megdnn::winograd::ConvBias( + strategy, m_tile_size, param) + .get_matmul_kern_param(param); + return m_matmul_algo->usable(matmul_param) && + m_matmul_algo->packmode() == PackMode::NO_PACK && + param.filter_meta.format == param::ConvBias::Format::NCHW && + !param.filter_meta.should_flip && + (param.filter_meta.spatial[0] == param.filter_meta.spatial[1] && + param.filter_meta.spatial[0] == 3) && + (param.filter_meta.stride[0] == param.filter_meta.stride[1] && + param.filter_meta.stride[0] == 1) && + (param.filter_meta.dilation[0] == param.filter_meta.dilation[1] && + param.filter_meta.dilation[0] == 1) && + param.compute_mode == param::ConvBias::ComputeMode::DEFAULT && + param.src_type.enumv() == DTypeEnum::Float32 && + param.filter_meta.icpg % 4 == 0 && param.filter_meta.ocpg % 4 == 0; + } + MIDOUT_END(); + return false; +} + +MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL( + AlgoFP32WinogradF43_4x4, winograd::winograd_4x3_4x4_f, + megdnn_fallback_winograd_fp32, param::MatrixMul::Format::MK4); + /* =================== AlgoFP32WinogradF23_4x4_NCHW44 =================== */ bool ConvBiasImpl::AlgoFP32WinogradF23_4x4_NCHW44::usable( @@ -297,6 +335,46 @@ MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL( AlgoFP32WinogradF63_4x4_NCHW44, winograd::winograd_F63_mk4_f_nchw44, megdnn_fallback_winograd_fp32, param::MatrixMul::Format::MK4); +/* =================== AlgoFP32WinogradF43_4x4_NCHW44 ===================== */ + +bool ConvBiasImpl::AlgoFP32WinogradF43_4x4_NCHW44::usable( + const NCBKernSizeParam& param, + AlgoSelectionStrategy /*algo_selection_strategy*/) const { + MEGDNN_MARK_USED_VAR(param); + MIDOUT_BEGIN( + megdnn_fallback_winograd_fp32, + midout_iv("AlgoFP32WinogradF43_4x4_NCHW44"_hash)) { + if (param.filter_meta.icpg % 4 != 0 || param.filter_meta.ocpg % 4 != 0) + return false; + using Strategy = winograd::winograd_F43_mk4_f_nchw44; + Strategy strategy(param.src_type, param.filter_type, param.dst_type); + auto&& matmul_param = + megdnn::winograd::ConvBias( + strategy, m_tile_size, param) + .get_matmul_kern_param(param); + return m_matmul_algo->usable(matmul_param) && + m_matmul_algo->packmode() == + fallback::MatrixMulImpl::AlgoBase::PackMode::NO_PACK && + param.filter_meta.format == param::ConvBias::Format::NCHW44 && + !param.filter_meta.should_flip && + (param.filter_meta.spatial[0] == param.filter_meta.spatial[1] && + param.filter_meta.spatial[0] == 3) && + (param.filter_meta.stride[0] == param.filter_meta.stride[1] && + param.filter_meta.stride[0] == 1) && + (param.filter_meta.dilation[0] == param.filter_meta.dilation[1] && + param.filter_meta.dilation[0] == 1) && + param.compute_mode == param::ConvBias::ComputeMode::DEFAULT && + param.src_type.enumv() == DTypeEnum::Float32 && + param.filter_meta.icpg % 4 == 0 && param.filter_meta.ocpg % 4 == 0; + } + MIDOUT_END(); + return false; +} + +MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL( + AlgoFP32WinogradF43_4x4_NCHW44, winograd::winograd_F43_mk4_f_nchw44, + megdnn_fallback_winograd_fp32, param::MatrixMul::Format::MK4); + /* =================== AlgoFP32WinogradF73_4x4_NCHW44 ===================== */ bool ConvBiasImpl::AlgoFP32WinogradF73_4x4_NCHW44::usable( diff --git a/dnn/src/fallback/conv_bias/gi/fp32/algos.h b/dnn/src/fallback/conv_bias/gi/fp32/algos.h index cbba0fad..3b8f1aad 100644 --- a/dnn/src/fallback/conv_bias/gi/fp32/algos.h +++ b/dnn/src/fallback/conv_bias/gi/fp32/algos.h @@ -81,6 +81,23 @@ public: MEGDNN_DECL_ALGO_TYPE(GI_COMMON_WINOGRAD_F63_4X4_FP32) }; +class ConvBiasImpl::AlgoFP32WinogradF43_4x4 final : public AlgoBase { +public: + AlgoFP32WinogradF43_4x4( + fallback::MatrixMulImpl::AlgoBase* matmul_algo, uint32_t tile_size) + : m_matmul_algo{matmul_algo}, m_tile_size{tile_size} {} + const char* name() const override { + if (m_name.empty()) { + m_name = ConvBiasImpl::algo_name( + m_matmul_algo->name(), {4, 4, m_tile_size, 3}); + } + return m_name.c_str(); + } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } + MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); + MEGDNN_DECL_ALGO_TYPE(GI_COMMON_WINOGRAD_F43_4X4_FP32) +}; + class ConvBiasImpl::AlgoFP32WinogradF54 final : public AlgoBase { public: AlgoFP32WinogradF54( @@ -156,6 +173,24 @@ public: MEGDNN_DECL_ALGO_TYPE(GI_COMMON_WINOGRAD_F63_4X4_NCHW44_F32) }; +class ConvBiasImpl::AlgoFP32WinogradF43_4x4_NCHW44 final : public AlgoBase { +public: + AlgoFP32WinogradF43_4x4_NCHW44( + fallback::MatrixMulImpl::AlgoBase* matmul_algo, uint32_t tile_size) + : m_matmul_algo{matmul_algo}, m_tile_size{tile_size} {} + const char* name() const override { + if (m_name.empty()) { + m_name = ConvBiasImpl::algo_name( + m_matmul_algo->name(), {4, 4, m_tile_size, 3}, + param::ConvBias::Format::NCHW44); + } + return m_name.c_str(); + } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } + MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); + MEGDNN_DECL_ALGO_TYPE(GI_COMMON_WINOGRAD_F43_4X4_NCHW44_F32) +}; + class ConvBiasImpl::AlgoFP32WinogradF73_4x4_NCHW44 final : public AlgoBase { public: AlgoFP32WinogradF73_4x4_NCHW44( diff --git a/dnn/src/fallback/conv_bias/gi/fp32/strategy.h b/dnn/src/fallback/conv_bias/gi/fp32/strategy.h index fe09b269..964239be 100644 --- a/dnn/src/fallback/conv_bias/gi/fp32/strategy.h +++ b/dnn/src/fallback/conv_bias/gi/fp32/strategy.h @@ -16,6 +16,8 @@ MEGDNN_REG_WINOGRAD_STRATEGY(float, float, float, float, 4, 3, 1, 1, winograd_4x MEGDNN_REG_WINOGRAD_STRATEGY(float, float, float, float, 6, 3, 4, 4, winograd_6x3_4x4_f) +MEGDNN_REG_WINOGRAD_STRATEGY(float, float, float, float, 4, 3, 4, 4, winograd_4x3_4x4_f) + MEGDNN_REG_WINOGRAD_STRATEGY(float, float, float, float, 5, 4, 1, 1, winograd_5x4_1x1_f) MEGDNN_REG_WINOGRAD_STRATEGY(float, float, float, float, 4, 5, 1, 1, winograd_4x5_1x1_f) @@ -27,6 +29,9 @@ MEGDNN_REG_WINOGRAD_STRATEGY( float, float, float, float, 6, 3, 4, 4, winograd_F63_mk4_f_nchw44) MEGDNN_REG_WINOGRAD_STRATEGY( + float, float, float, float, 4, 3, 4, 4, winograd_F43_mk4_f_nchw44) + +MEGDNN_REG_WINOGRAD_STRATEGY( float, float, float, float, 7, 3, 4, 4, winograd_F73_mk4_f_nchw44) } // namespace winograd } // namespace fallback diff --git a/dnn/src/fallback/conv_bias/gi/fp32/strategy_4x3_4x4.cpp b/dnn/src/fallback/conv_bias/gi/fp32/strategy_4x3_4x4.cpp new file mode 100644 index 00000000..326a62e0 --- /dev/null +++ b/dnn/src/fallback/conv_bias/gi/fp32/strategy_4x3_4x4.cpp @@ -0,0 +1,340 @@ +#include "src/common/unroll_macro.h" +#include "src/common/utils.h" +#include "src/common/winograd/winograd_helper.h" +#include "src/fallback/conv_bias/gi/fp32/filter_transform.h" +#include "src/fallback/conv_bias/gi/fp32/helper.h" +#include "src/fallback/conv_bias/gi/fp32/strategy.h" +#include "src/fallback/conv_bias/winograd/winograd.h" +#include "src/fallback/elemwise_helper/op_unary.h" + +#include "midout.h" +MIDOUT_DECL(megdnn_fallback_winograd_fp32_F43_4x4) + +using namespace megdnn; +using namespace fallback; + +namespace { +#define MLAF GiMultiplyAddScalarFloat32 +#define MLSF GiMultiplySubScalarFloat32 +struct InputTransform4X3 { + /** + * @brief Convert layout from NCHW to NCHW44(i.e. NC4HW4) + * + * @tparam inner Whether all data in [[ih_start, ih_start+6), [iw_start, + * iw_start+6)] is in @input + * @param input Pointer which points to all input data(CHW, exclude dim N) + * @param patch Buffer which size is sizeof(float) * 4 * 6 * 6. Continuous storage + * of data for the current block, order by C, H, W. + * @param patchT RETURN + * @param ih_start The start index of dim H of current block + * @param iw_start The start index of dim W of current block + * @param IH Dim H of input + * @param IW Dim W of input + * @param ic The index of dim C of input + * @param IC Dim C of input + */ + template + static void transpose( + const float* input, float* patch, float* patchT, int ih_start, int iw_start, + size_t IH, size_t IW, size_t ic, size_t IC) { + constexpr size_t alpha = 4 + 3 - 1; + if (!inner || ic + 4 > IC) { + memset(patch, 0, sizeof(float) * 4 * alpha * alpha); + } + + if (inner) { + const float* input_ptr = input + ic * IH * IW + ih_start * IW + iw_start; + for (size_t ico = 0; ico < 4; ++ico) { + if (ic + ico < IC) { +#define cb(i) \ + auto v##i##0 = GiLoadFloat32(input_ptr + i * IW); \ + GiStoreFloat32(patch + ico * alpha * alpha + i * alpha, v##i##0); \ + auto v##i##1 = GiLoadFloat32LowHalf(input_ptr + i * IW + 4); \ + GiStoreFloat32(patch + ico * alpha * alpha + i * alpha + 4, v##i##1); + UNROLL_CALL_NOWRAPPER(6, cb); +#undef cb + input_ptr += IH * IW; + } + } + } else { + size_t ih0 = std::max(0, ih_start), ih1 = std::min(ih_start + alpha, IH), + iw0 = std::max(0, iw_start), iw1 = std::min(iw_start + alpha, IW); + for (size_t ico = 0; ico < 4 && ic + ico < IC; ++ico) { + for (size_t ih = ih0; ih < ih1; ++ih) { + for (size_t iw = iw0; iw < iw1; ++iw) { + patch[ico * alpha * alpha + (ih - ih_start) * alpha + + (iw - iw_start)] = + input[(ic + ico) * IH * IW + ih * IW + iw]; + } + } + } + } + +#define cb(i) transpose_4x4(patch + i * 4, patchT + i * 16, 36, 4); + UNROLL_CALL_NOWRAPPER(9, cb); +#undef cb + } + + static void transform( + const float* patchT, float* input_transform_buf, size_t unit_idx, + size_t nr_units_in_tile, size_t ic, size_t IC) { + constexpr size_t alpha = 4 + 3 - 1; +#define cb(m, n) \ + GI_FLOAT32_t d##m##n = GiLoadFloat32(patchT + m * alpha * 4 + n * 4), wd##m##n; + UNROLL_CALL_NOWRAPPER_D2(6, 6, cb); +#undef cb + //! BT + //! 4 0 -5 0 1 0 + //! 0 -4 -4 1 1 0 + //! 0 4 -4 -1 1 0 + //! 0 -2 -1 2 1 0 + //! 0 2 -1 -2 1 0 + //! 0 4 0 -5 0 1 + + //! wd0n = 4 * (d0n - d2n) + (d4n - d2n) + //! wd1n = (d3n + d4n) - 4 * (d1n + d2n) + //! wd2n = 4 * (d1n - d2n) + (d4n - d3n) + //! wd3n = (d4n - d2n) - 2 * (d1n - d3n) + //! wd4n = 2 * (d1n - d3n) + (d4n - d2n) + //! wd5n = 4 * (d1n - d3n) + (d5n - d3n) +#define cb(n) \ + { \ + auto&& d4subd2 = SUBF(d4##n, d2##n); \ + auto&& d1subd3 = SUBF(d1##n, d3##n); \ + wd0##n = MLAF(d4subd2, SUBF(d0##n, d2##n), 4.0f); \ + wd1##n = MLSF(ADDF(d3##n, d4##n), ADDF(d1##n, d2##n), 4.0f); \ + wd2##n = MLAF(SUBF(d4##n, d3##n), SUBF(d1##n, d2##n), 4.0f); \ + auto&& double_d1subd3 = MULSF(d1subd3, 2.0f); \ + wd3##n = SUBF(d4subd2, double_d1subd3); \ + wd4##n = ADDF(double_d1subd3, d4subd2); \ + wd5##n = MLAF(SUBF(d5##n, d3##n), d1subd3, 4.0f); \ + } + UNROLL_CALL_NOWRAPPER(6, cb); +#undef cb + + //! B + //! 4 0 0 0 0 0 + //! 0 -4 4 -2 2 4 + //! -5 -4 -4 -1 -1 0 + //! 0 1 -1 2 -2 -5 + //! 1 1 1 1 1 0 + //! 0 0 0 0 0 1 + + //! dm0 = 4 * (wdm0 - wdm2) + (wdm4 - wdm2) + //! dm1 = (wdm3 + wdm4) - 4 * (wdm1 + wdm2) + //! dm2 = 4 * (wdm1 - wdm2) + (wdm4 - wdm3) + //! dm3 = (wdm4 - wdm2) - 2 * (wdm1 - wdm3) + //! dm4 = 2 * (wdm1 - wdm3) + (wdm4 - wdm2) + //! dm5 = 4 * (wdm1 - wdm3) + (wdm5 - wdm3) +#define cb(m) \ + { \ + auto&& wd4subwd2 = SUBF(wd##m##4, wd##m##2); \ + auto&& wd1subwd3 = SUBF(wd##m##1, wd##m##3); \ + d##m##0 = MLAF(wd4subwd2, SUBF(wd##m##0, wd##m##2), 4.0f); \ + d##m##1 = MLSF(ADDF(wd##m##3, wd##m##4), ADDF(wd##m##1, wd##m##2), 4.0f); \ + d##m##2 = MLAF(SUBF(wd##m##4, wd##m##3), SUBF(wd##m##1, wd##m##2), 4.0f); \ + auto&& double_wd1subwd3 = MULSF(wd1subwd3, 2.0f); \ + d##m##3 = SUBF(wd4subwd2, double_wd1subwd3); \ + d##m##4 = ADDF(double_wd1subwd3, wd4subwd2); \ + d##m##5 = MLAF(SUBF(wd##m##5, wd##m##3), wd1subwd3, 4.0f); \ + } + UNROLL_CALL_NOWRAPPER(6, cb); +#undef cb + + size_t ICB = IC / 4; + size_t icb = ic / 4; +#define cb(m, n) \ + GiStoreFloat32( \ + input_transform_buf + (m * alpha + n) * ICB * 4 * nr_units_in_tile + \ + icb * nr_units_in_tile * 4 + unit_idx * 4, \ + d##m##n); + UNROLL_CALL_NOWRAPPER_D2(6, 6, cb); +#undef cb + } +}; // InputTransform4X3 + +template +struct OutputTransform4X3 { + static void transform( + const float* output_transform_buf, const float* bias, float* output, + float* transform_mid_buf, size_t oh_start, size_t ow_start, size_t OH, + size_t OW, size_t oc_start, size_t oc_end, size_t oc_index, size_t unit_idx, + size_t nr_units_in_tile, const DType& src_dtype, const DType& dst_dtype) { + Op op(src_dtype, dst_dtype); + constexpr size_t alpha = 4 + 3 - 1; + size_t oc = oc_start + oc_index; + size_t OCB = (oc_end - oc_start) / 4; + size_t ocb = oc_index / 4; + +#define cb(m, n) \ + auto v##m##n = GiLoadFloat32( \ + output_transform_buf + (m * alpha + n) * OCB * nr_units_in_tile * 4 + \ + ocb * nr_units_in_tile * 4 + unit_idx * 4); + UNROLL_CALL_NOWRAPPER_D2(6, 6, cb); +#undef cb + + //! AT + //! 1 1 1 1 1 0 + //! 0 1 -1 2 -2 0 + //! 0 1 1 4 4 0 + //! 0 1 -1 8 -8 1 + + //! t0n = v0n + (v1n + v2n) + (v3n + v4n) + //! t1n = (v1n - v2n) + 2 * (v3n - v4n) + //! t2n = (v1n + v2n) + 4 * (v3n + v4n) + //! t3n = (v1n - v2n) + 8 * (v3n - v4n) + v5n +#define cb(m, n) GI_FLOAT32_t t##m##n; + UNROLL_CALL_NOWRAPPER_D2(4, 6, cb); +#undef cb + +#define cb(n) \ + { \ + auto&& v1addv2 = ADDF(v1##n, v2##n); \ + auto&& v1subv2 = SUBF(v1##n, v2##n); \ + auto&& v3addv4 = ADDF(v3##n, v4##n); \ + auto&& v3subv4 = SUBF(v3##n, v4##n); \ + \ + t0##n = ADDF(ADDF(v0##n, v1addv2), v3addv4); \ + t1##n = MLAF(v1subv2, v3subv4, 2.0f); \ + t2##n = MLAF(v1addv2, v3addv4, 4.0f); \ + t3##n = ADDF(MLAF(v1subv2, v3subv4, 8.0f), v5##n); \ + } + UNROLL_CALL_NOWRAPPER(6, cb); +#undef cb + +//! A +//! 1 0 0 0 +//! 1 1 1 1 +//! 1 -1 1 -1 +//! 1 2 4 8 +//! 1 -2 4 -8 +//! 0 0 0 1 + +// vm0 = tm0 + (tm1 + tm2) + (tm3 + tm4) +// vm1 = (tm1 - tm2) + 2 * (tm3 - tm4) +// vm2 = (tm1 + tm2) + 4 * (tm3 + tm4) +// vm3 = (tm1 - tm2) + 8 * (tm3 - tm4) + tm5 +#define cb(m) \ + { \ + auto&& t1addt2 = ADDF(t##m##1, t##m##2); \ + auto&& t1subt2 = SUBF(t##m##1, t##m##2); \ + auto&& t3addt4 = ADDF(t##m##3, t##m##4); \ + auto&& t3subt4 = SUBF(t##m##3, t##m##4); \ + v##m##0 = ADDF(ADDF(t##m##0, t1addt2), t3addt4); \ + v##m##1 = MLAF(t1subt2, t3subt4, 2.0f); \ + v##m##2 = MLAF(t1addt2, t3addt4, 4.0f); \ + v##m##3 = ADDF(MLAF(t1subt2, t3subt4, 8.0f), t##m##5); \ + } + UNROLL_CALL_NOWRAPPER(4, cb); +#undef cb + + GI_FLOAT32_t vbias; + if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { + vbias = GiLoadFloat32(bias + oc); +#define cb(m, n) v##m##n = GiAddFloat32(v##m##n, vbias); + UNROLL_CALL_NOWRAPPER_D2(4, 4, cb); +#undef cb + } + + if (bmode != BiasMode::BIAS) { +#define cb(m, n) v##m##n = op(v##m##n); + UNROLL_CALL_NOWRAPPER_D2(4, 4, cb); +#undef cb + } + +#define cb(m, n) GiStoreFloat32(transform_mid_buf + (4 * m + n) * 4, v##m##n); + UNROLL_CALL_NOWRAPPER_D2(4, 4, cb); +#undef cb + + for (size_t oho = 0; oho < 4 && oh_start + oho < OH; ++oho) { + for (size_t owo = 0; owo < 4 && ow_start + owo < OW; ++owo) { + for (size_t oco = 0; oco < 4 && oc + oco < oc_end; ++oco) { + float res = transform_mid_buf[oho * 4 * 4 + owo * 4 + oco]; + size_t oh = oh_start + oho; + size_t ow = ow_start + owo; + if (bmode == BiasMode::BIAS) { + res += bias[(oc + oco) * OH * OW + oh * OW + ow]; + res = op(res); + } + output[(oc + oco) * OH * OW + oh * OW + ow] = res; + } + } + } + } +}; // OutputTransform4X3 + +#undef MLSF +#undef MLAF +} // namespace + +namespace megdnn { +namespace fallback { +namespace winograd { +MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_4x3_4x4_f) + +void winograd_4x3_4x4_f::filter( + const float* filter, float* filter_transform_buf, float* transform_mid_buf, + size_t OC, size_t IC, size_t oc_start, size_t oc_end) { + FilterTransform4X3::transform( + filter, filter_transform_buf, transform_mid_buf, OC, IC, oc_start, oc_end); +} + +void winograd_4x3_4x4_f::input( + const float* input, float* input_transform_buf, float* transform_mid_buf, + size_t IH, size_t IW, size_t IC, size_t PH, size_t PW, size_t unit_start_idx, + size_t nr_units_in_tile) { + megdnn_assert(IC % 4 == 0); + auto unit_w = div_ceil(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE); + float* patch = transform_mid_buf; + float* patchT = transform_mid_buf + 4 * ALPHA * ALPHA; + for (size_t ic = 0; ic < IC; ic += 4) { + for (size_t unit_idx = 0; unit_idx < nr_units_in_tile; ++unit_idx) { + size_t index = unit_start_idx + unit_idx; + size_t oht = index / unit_w; + size_t owt = index % unit_w; + int ih_start = static_cast(oht * OUTPUT_BLOCK_SIZE - PH); + int iw_start = static_cast(owt * OUTPUT_BLOCK_SIZE - PW); + if (ih_start >= 0 && ih_start + 6 <= static_cast(IH) && + iw_start >= 0 && iw_start + 6 <= static_cast(IW)) { + InputTransform4X3::transpose( + input, patch, patchT, ih_start, iw_start, IH, IW, ic, IC); + } else { + InputTransform4X3::transpose( + input, patch, patchT, ih_start, iw_start, IH, IW, ic, IC); + } + InputTransform4X3::transform( + patchT, input_transform_buf, unit_idx, nr_units_in_tile, ic, IC); + } + } +} + +void winograd_4x3_4x4_f::output( + const float* output_transform_buf, const float* bias, float* output, + float* transform_mid_buf, BiasMode bmode, NonlineMode nonline_mode, size_t OH, + size_t OW, size_t oc_start, size_t oc_end, size_t unit_start_idx, + size_t nr_units_in_tile) { +#define cb(_bmode, _nonline_mode, ...) \ + OutputTransform4X3<_bmode, _nonline_mode>::transform(__VA_ARGS__); + auto unit_w = div_ceil(OW, OUTPUT_BLOCK_SIZE); + + for (size_t oc = oc_start; oc < oc_end; oc += 4) { + size_t oc_index = oc - oc_start; + for (size_t unit_idx = 0; unit_idx < nr_units_in_tile; ++unit_idx) { + size_t index = unit_idx + unit_start_idx; + size_t oht = index / unit_w; + size_t owt = index % unit_w; + size_t oh_start = oht * OUTPUT_BLOCK_SIZE; + size_t ow_start = owt * OUTPUT_BLOCK_SIZE; + GI_DISPATCH_CONV_WINOGRAD_BIAS( + megdnn_fallback_winograd_fp32_F43_4x4, cb, float, float, bmode, + nonline_mode, output_transform_buf, bias, output, transform_mid_buf, + oh_start, ow_start, OH, OW, oc_start, oc_end, oc_index, unit_idx, + nr_units_in_tile, src_dtype, dst_dtype); + } + } +#undef cb +} + +} // namespace winograd +} // namespace fallback +} // namespace megdnn \ No newline at end of file diff --git a/dnn/src/fallback/conv_bias/gi/fp32/strategy_f43_mk4_nchw44.cpp b/dnn/src/fallback/conv_bias/gi/fp32/strategy_f43_mk4_nchw44.cpp new file mode 100644 index 00000000..65757546 --- /dev/null +++ b/dnn/src/fallback/conv_bias/gi/fp32/strategy_f43_mk4_nchw44.cpp @@ -0,0 +1,1181 @@ +#include "src/common/unroll_macro.h" +#include "src/common/utils.h" +#include "src/common/winograd/winograd_helper.h" +#include "src/fallback/conv_bias/gi/fp32/filter_transform.h" +#include "src/fallback/conv_bias/gi/fp32/helper.h" +#include "src/fallback/conv_bias/gi/fp32/strategy.h" +#include "src/fallback/conv_bias/winograd/winograd.h" +#include "src/fallback/elemwise_helper/op_unary.h" + +#include "midout.h" +MIDOUT_DECL(megdnn_fallback_winograd_fp32_F43_mk4) + +using namespace megdnn; +using namespace fallback; + +namespace { + +constexpr size_t alpha = 4 + 3 - 1; +constexpr size_t pack_size = 4; +constexpr float input_parameters[4] = {4.0f, 5.0f, 2.0f, 0.0f}; +constexpr float output_parameters[4] = {1.0f, 2.0f, 4.0f, 8.0f}; + +struct InputTransformF43_NCHW44 { + template + static void transform( + const float* input, float* input_transform_buf, size_t unit_idx, + size_t nr_units_in_tile, size_t ic, size_t IC, int ih_start, int iw_start, + size_t IH, size_t IW, const bool* ih_valid, const bool* iw_valid) { + // BT * d * B + size_t ICB = IC / pack_size; + size_t icb = ic / pack_size; + +#if defined(GI_TARGET_X86) || defined(GI_RVV_INTRINSICS) +//! x86 and rvv GiSimdFmaLane API is slowly, as an alternate, use +//! GiMultiplyAddScalarFloat32 +#define MADD(a, b, c, d) GiMultiplyAddScalarFloat32(a, b, *(c + d)) +#define MSUB(a, b, c, d) GiMultiplySubScalarFloat32(a, b, *(c + d)) + const float* v0 = input_parameters; +#else +#define MADD(a, b, c, d) GiSimdFmaLane(a, b, c, d) +#define MSUB(a, b, c, d) GiFmsqLaneQFloat32(a, b, c, d) + GI_FLOAT32_t v0 = GiLoadFloat32(input_parameters); +#endif + // B + // 4 0 0 0 0 0 + // 0 -4 4 -2 2 4 + // -5 -4 -4 -1 -1 0 + // 0 1 -1 2 -2 -5 + // 1 1 1 1 1 0 + // 0 0 0 0 0 1 + + const float* input_ptr = + input + ic * IH * IW + ih_start * IW * 4 + iw_start * 4; + GI_FLOAT32_t zero = GiZeroFloat32(); +#define cb(i, j) GI_FLOAT32_t d##i##j; + UNROLL_CALL_NOWRAPPER_D2(4, 6, cb); +#undef cb +#define cb(i) GI_FLOAT32_t t##i; + UNROLL_CALL_NOWRAPPER(6, cb); +#undef cb + + // load line 0 -> d00 ... d05 + const float* line_ptr = input_ptr; + if (inner) { +#define cb(i) d0##i = GiLoadFloat32(line_ptr + i * pack_size); + UNROLL_CALL_NOWRAPPER(6, cb); +#undef cb + } else { + if (ih_valid[0]) { +#define cb(i) d0##i = iw_valid[i] ? GiLoadFloat32(line_ptr + i * pack_size) : zero; + UNROLL_CALL_NOWRAPPER(6, cb); +#undef cb + } else { +#define cb(i) d0##i = zero; + UNROLL_CALL_NOWRAPPER(6, cb); +#undef cb + } + } + + // load line 4 -> d30 ... t35 + line_ptr = input_ptr + 4 * IW * 4; + if (inner) { +#define cb(i) \ + d3##i = GiLoadFloat32(line_ptr + i * pack_size); \ + t##i = MADD(d3##i, d0##i, v0, 0); + UNROLL_CALL_NOWRAPPER(6, cb); +#undef cb + } else { + if (ih_valid[4]) { +#define cb(i) \ + d3##i = iw_valid[i] ? GiLoadFloat32(line_ptr + i * pack_size) : zero; \ + t##i = MADD(d3##i, d0##i, v0, 0); + UNROLL_CALL_NOWRAPPER(6, cb); +#undef cb + } else { +#define cb(i) \ + d3##i = zero; \ + t##i = MADD(d3##i, d0##i, v0, 0); + UNROLL_CALL_NOWRAPPER(6, cb); +#undef cb + } + } + + // load line 2 -> d20 ... d25 + line_ptr = input_ptr + 2 * IW * 4; + if (inner) { +#define cb(i) \ + d2##i = GiLoadFloat32(line_ptr + i * pack_size); \ + t##i = MSUB(t##i, d2##i, v0, 1); + UNROLL_CALL_NOWRAPPER(6, cb); +#undef cb + } else { + if (ih_valid[2]) { +#define cb(i) \ + d2##i = iw_valid[i] ? GiLoadFloat32(line_ptr + i * pack_size) : zero; \ + t##i = MSUB(t##i, d2##i, v0, 1); + UNROLL_CALL_NOWRAPPER(6, cb); +#undef cb + } else { +#define cb(i) \ + d2##i = zero; \ + t##i = MSUB(t##i, d2##i, v0, 1); + UNROLL_CALL_NOWRAPPER(6, cb); +#undef cb + } + } + + // load line 3 -> d10 ... d15 + line_ptr = input_ptr + 3 * IW * 4; + if (inner) { +#define cb(i) d1##i = GiLoadFloat32(line_ptr + i * pack_size); + UNROLL_CALL_NOWRAPPER(6, cb); +#undef cb + } else { + if (ih_valid[3]) { +#define cb(i) d1##i = iw_valid[i] ? GiLoadFloat32(line_ptr + i * pack_size) : zero; + UNROLL_CALL_NOWRAPPER(6, cb); +#undef cb + } else { +#define cb(i) d1##i = zero; + UNROLL_CALL_NOWRAPPER(6, cb); +#undef cb + } + } + + float* buf_ptr = input_transform_buf + icb * nr_units_in_tile * pack_size + + unit_idx * pack_size; + + d00 = MADD(t4, t0, v0, 0); + d00 = MSUB(d00, t2, v0, 1); + GiStoreFloat32(buf_ptr, d00); + d00 = MSUB(t3, t1, v0, 0); + d01 = MSUB(t4, t2, v0, 0); + d02 = ADDF(d00, d01); + GiStoreFloat32(buf_ptr + 1 * ICB * nr_units_in_tile * pack_size, d02); + d02 = SUBF(d01, d00); + GiStoreFloat32(buf_ptr + 2 * ICB * nr_units_in_tile * pack_size, d02); + d00 = SUBF(t3, t1); + d01 = SUBF(t4, t2); + d02 = MADD(d01, d00, v0, 2); + GiStoreFloat32(buf_ptr + 3 * ICB * nr_units_in_tile * pack_size, d02); + d02 = MSUB(d01, d00, v0, 2); + GiStoreFloat32(buf_ptr + 4 * ICB * nr_units_in_tile * pack_size, d02); + d01 = SUBF(t5, t3); + d02 = MSUB(d01, d00, v0, 0); + GiStoreFloat32(buf_ptr + 5 * ICB * nr_units_in_tile * pack_size, d02); + +// ln4 - ln2 -> t +#define cb(i) t##i = SUBF(d3##i, d2##i); + UNROLL_CALL_NOWRAPPER(6, cb); +#undef cb + + // load line 1 -> d00 ... d05 + line_ptr = input_ptr + IW * 4; + if (inner) { +#define cb(i) d0##i = GiLoadFloat32(line_ptr + i * pack_size); + UNROLL_CALL_NOWRAPPER(6, cb); +#undef cb + } else { + if (ih_valid[1]) { +#define cb(i) d0##i = iw_valid[i] ? GiLoadFloat32(line_ptr + i * pack_size) : zero; + UNROLL_CALL_NOWRAPPER(6, cb); +#undef cb + } else { +#define cb(i) d0##i = zero; + UNROLL_CALL_NOWRAPPER(6, cb); +#undef cb + } + } + +// ln4 - 4 * ln2 -> ln4 +#define cb(i) d3##i = MSUB(d3##i, d2##i, v0, 0); + UNROLL_CALL_NOWRAPPER(6, cb); +#undef cb + +// ln3 - 4 * ln1 -> ln2 +#define cb(i) d2##i = MSUB(d1##i, d0##i, v0, 0); + UNROLL_CALL_NOWRAPPER(6, cb); +#undef cb + +// ln3 - ln1 -> ln3 +#define cb(i) d1##i = SUBF(d1##i, d0##i); + UNROLL_CALL_NOWRAPPER(6, cb); +#undef cb + +// (ln4 - 4 * ln2)[ln4] + (ln3 - 4 * ln1)[ln2] -> ln1 +#define cb(i) d0##i = ADDF(d3##i, d2##i); + UNROLL_CALL_NOWRAPPER(6, cb); +#undef cb + +// (ln4 - 4 * ln2)[ln4] - (ln3 - 4 * ln1)[ln2] -> ln2 +#define cb(i) d2##i = SUBF(d3##i, d2##i); + UNROLL_CALL_NOWRAPPER(6, cb); +#undef cb + + // ln4(d30 ... d35) is free until now + buf_ptr = input_transform_buf + 1 * alpha * ICB * nr_units_in_tile * pack_size + + icb * nr_units_in_tile * pack_size + unit_idx * pack_size; + d30 = MADD(d04, d00, v0, 0); + d30 = MSUB(d30, d02, v0, 1); + GiStoreFloat32(buf_ptr, d30); + d30 = MSUB(d03, d01, v0, 0); + d32 = MSUB(d04, d02, v0, 0); + d31 = ADDF(d30, d32); + GiStoreFloat32(buf_ptr + ICB * nr_units_in_tile * pack_size, d31); + d31 = SUBF(d32, d30); + GiStoreFloat32(buf_ptr + 2 * ICB * nr_units_in_tile * pack_size, d31); + d30 = SUBF(d03, d01); + d31 = SUBF(d04, d02); + d32 = MADD(d31, d30, v0, 2); + GiStoreFloat32(buf_ptr + 3 * ICB * nr_units_in_tile * pack_size, d32); + d32 = MSUB(d31, d30, v0, 2); + GiStoreFloat32(buf_ptr + 4 * ICB * nr_units_in_tile * pack_size, d32); + d31 = SUBF(d05, d03); + d32 = MSUB(d31, d30, v0, 0); + GiStoreFloat32(buf_ptr + 5 * ICB * nr_units_in_tile * pack_size, d32); + + buf_ptr = input_transform_buf + 2 * alpha * ICB * nr_units_in_tile * pack_size + + icb * nr_units_in_tile * pack_size + unit_idx * pack_size; + d33 = MADD(d24, d20, v0, 0); + d33 = MSUB(d33, d22, v0, 1); + GiStoreFloat32(buf_ptr, d33); + d33 = MSUB(d23, d21, v0, 0); + d35 = MSUB(d24, d22, v0, 0); + d34 = ADDF(d33, d35); + GiStoreFloat32(buf_ptr + ICB * nr_units_in_tile * pack_size, d34); + d34 = SUBF(d35, d33); + GiStoreFloat32(buf_ptr + 2 * ICB * nr_units_in_tile * pack_size, d34); + d33 = SUBF(d23, d21); + d34 = SUBF(d24, d22); + d35 = MADD(d34, d33, v0, 2); + GiStoreFloat32(buf_ptr + 3 * ICB * nr_units_in_tile * pack_size, d35); + d35 = MSUB(d34, d33, v0, 2); + GiStoreFloat32(buf_ptr + 4 * ICB * nr_units_in_tile * pack_size, d35); + d34 = SUBF(d25, d23); + d35 = MSUB(d34, d33, v0, 0); + GiStoreFloat32(buf_ptr + 5 * ICB * nr_units_in_tile * pack_size, d35); + +// (ln4 - ln2)[t] + (ln3 - ln1)[ln3] * 2 -> ln4 +#define cb(i) d3##i = MADD(t##i, d1##i, v0, 2); + UNROLL_CALL_NOWRAPPER(6, cb); +#undef cb + +// (ln4 - ln2)[t] - (ln3 - ln1)[ln3] * 2 -> ln3 +#define cb(i) d1##i = MSUB(t##i, d1##i, v0, 2); + UNROLL_CALL_NOWRAPPER(6, cb); +#undef cb + // t is free + buf_ptr = input_transform_buf + 3 * alpha * ICB * nr_units_in_tile * pack_size + + icb * nr_units_in_tile * pack_size + unit_idx * pack_size; + t0 = MADD(d34, d30, v0, 0); + t0 = MSUB(t0, d32, v0, 1); + GiStoreFloat32(buf_ptr, t0); + t0 = MSUB(d33, d31, v0, 0); + t2 = MSUB(d34, d32, v0, 0); + t1 = ADDF(t0, t2); + GiStoreFloat32(buf_ptr + ICB * nr_units_in_tile * pack_size, t1); + t1 = SUBF(t2, t0); + GiStoreFloat32(buf_ptr + 2 * ICB * nr_units_in_tile * pack_size, t1); + t0 = SUBF(d33, d31); + t1 = SUBF(d34, d32); + t2 = MADD(t1, t0, v0, 2); + GiStoreFloat32(buf_ptr + 3 * ICB * nr_units_in_tile * pack_size, t2); + t2 = MSUB(t1, t0, v0, 2); + GiStoreFloat32(buf_ptr + 4 * ICB * nr_units_in_tile * pack_size, t2); + t1 = SUBF(d35, d33); + t2 = MSUB(t1, t0, v0, 0); + GiStoreFloat32(buf_ptr + 5 * ICB * nr_units_in_tile * pack_size, t2); + + buf_ptr = input_transform_buf + 4 * alpha * ICB * nr_units_in_tile * pack_size + + icb * nr_units_in_tile * pack_size + unit_idx * pack_size; + t3 = MADD(d14, d10, v0, 0); + t3 = MSUB(t3, d12, v0, 1); + GiStoreFloat32(buf_ptr, t3); + t3 = MSUB(d13, d11, v0, 0); + t5 = MSUB(d14, d12, v0, 0); + t4 = ADDF(t3, t5); + GiStoreFloat32(buf_ptr + ICB * nr_units_in_tile * pack_size, t4); + t4 = SUBF(t5, t3); + GiStoreFloat32(buf_ptr + 2 * ICB * nr_units_in_tile * pack_size, t4); + t3 = SUBF(d13, d11); + t4 = SUBF(d14, d12); + t5 = MADD(t4, t3, v0, 2); + GiStoreFloat32(buf_ptr + 3 * ICB * nr_units_in_tile * pack_size, t5); + t5 = MSUB(t4, t3, v0, 2); + GiStoreFloat32(buf_ptr + 4 * ICB * nr_units_in_tile * pack_size, t5); + t4 = SUBF(d15, d13); + t5 = MSUB(t4, t3, v0, 0); + GiStoreFloat32(buf_ptr + 5 * ICB * nr_units_in_tile * pack_size, t5); + + // load line 5 -> d30 ... d35 + line_ptr = input_ptr + 5 * IW * 4; + if (inner) { +#define cb(i) d3##i = GiLoadFloat32(line_ptr + i * pack_size); + UNROLL_CALL_NOWRAPPER(6, cb); +#undef cb + } else { + if (ih_valid[5]) { +#define cb(i) d3##i = iw_valid[i] ? GiLoadFloat32(line_ptr + i * pack_size) : zero; + UNROLL_CALL_NOWRAPPER(6, cb); +#undef cb + } else { +#define cb(i) d3##i = zero; + UNROLL_CALL_NOWRAPPER(6, cb); +#undef cb + } + } + + // load line 1 -> d0 ... d5 + line_ptr = input_ptr + IW * 4; + if (inner) { +#define cb(i) \ + d0##i = GiLoadFloat32(line_ptr + i * pack_size); \ + d3##i = MADD(d3##i, d0##i, v0, 0); + UNROLL_CALL_NOWRAPPER(6, cb); +#undef cb + } else { + if (ih_valid[1]) { +#define cb(i) \ + d0##i = iw_valid[i] ? GiLoadFloat32(line_ptr + i * pack_size) : zero; \ + d3##i = MADD(d3##i, d0##i, v0, 0); + UNROLL_CALL_NOWRAPPER(6, cb); +#undef cb + } else { +#define cb(i) \ + d0##i = zero; \ + d3##i = MADD(d3##i, d0##i, v0, 0); + UNROLL_CALL_NOWRAPPER(6, cb); +#undef cb + } + } + + // load line 3 -> d10 ... d15 + line_ptr = input_ptr + 3 * IW * 4; + if (inner) { +#define cb(i) \ + d1##i = GiLoadFloat32(line_ptr + i * pack_size); \ + d3##i = MSUB(d3##i, d1##i, v0, 1); + UNROLL_CALL_NOWRAPPER(6, cb); +#undef cb + } else { + if (ih_valid[3]) { +#define cb(i) \ + d1##i = iw_valid[i] ? GiLoadFloat32(line_ptr + i * pack_size) : zero; \ + d3##i = MSUB(d3##i, d1##i, v0, 1); + UNROLL_CALL_NOWRAPPER(6, cb); +#undef cb + } else { +#define cb(i) \ + d1##i = zero; \ + d3##i = MSUB(d3##i, d1##i, v0, 1); + UNROLL_CALL_NOWRAPPER(6, cb); +#undef cb + } + } + + buf_ptr = input_transform_buf + 5 * alpha * ICB * nr_units_in_tile * pack_size + + icb * nr_units_in_tile * pack_size + unit_idx * pack_size; + t0 = MADD(d34, d30, v0, 0); + t0 = MSUB(t0, d32, v0, 1); + GiStoreFloat32(buf_ptr, t0); + t0 = MSUB(d33, d31, v0, 0); + t2 = MSUB(d34, d32, v0, 0); + t1 = ADDF(t0, t2); + GiStoreFloat32(buf_ptr + ICB * nr_units_in_tile * pack_size, t1); + t1 = SUBF(t2, t0); + GiStoreFloat32(buf_ptr + 2 * ICB * nr_units_in_tile * pack_size, t1); + t0 = SUBF(d33, d31); + t1 = SUBF(d34, d32); + t2 = MADD(t1, t0, v0, 2); + GiStoreFloat32(buf_ptr + 3 * ICB * nr_units_in_tile * pack_size, t2); + t2 = MSUB(t1, t0, v0, 2); + GiStoreFloat32(buf_ptr + 4 * ICB * nr_units_in_tile * pack_size, t2); + t1 = SUBF(d35, d33); + t2 = MSUB(t1, t0, v0, 0); + GiStoreFloat32(buf_ptr + 5 * ICB * nr_units_in_tile * pack_size, t2); + +#undef MSUB +#undef MADD + } +}; // InputTransformF43_NCHW44 + +template +struct OutputTransformF43_NCHW44 { + static inline void transform( + const float* output_transform_buf, const float* bias, float* output, + const size_t oh_start, const size_t ow_start, const size_t OH, + const size_t OW, const size_t oc_start, const size_t oc_end, + const size_t oc_index, const size_t unit_idx, const size_t nr_units_in_tile, + const DType& src_dtype, const DType& dst_dtype) { + Op op(src_dtype, dst_dtype); + //! AT * m * A + + size_t oc = oc_start + oc_index; + size_t OCB = (oc_end - oc_start) / pack_size; + size_t ocb = oc_index / pack_size; + size_t col_step = OCB * nr_units_in_tile * 4; + size_t row_step = alpha * col_step; + +#if defined(GI_TARGET_X86) || defined(GI_RVV_INTRINSICS) +//! x86 and rvv GiSimdFmaLane API is slowly, as an alternate, use +//! GiMultiplyAddScalarFloat32 +#define MADD(a, b, c, d) GiMultiplyAddScalarFloat32(a, b, *(c + d)) +#define MSUB(a, b, c, d) GiMultiplySubScalarFloat32(a, b, *(c + d)) + const float* v0 = output_parameters; +#else +#define MADD(a, b, c, d) GiSimdFmaLane(a, b, c, d) +#define MSUB(a, b, c, d) GiFmsqLaneQFloat32(a, b, c, d) + GI_FLOAT32_t v0 = GiLoadFloat32(output_parameters); +#endif + + GI_FLOAT32_t vbias = GiZeroFloat32(); +#define cb(i, j) GI_FLOAT32_t v##i##j; + UNROLL_CALL_NOWRAPPER_D2(5, 6, cb); +#undef cb + + const float* buf_base = + output_transform_buf + ocb * nr_units_in_tile * 4 + unit_idx * 4; + const float* buf_ptr = nullptr; + + // load line 1 -> v10 ... v15 + buf_ptr = buf_base + row_step; +#define cb(i) v1##i = GiLoadFloat32(buf_ptr + i * col_step); + UNROLL_CALL_NOWRAPPER(6, cb); +#undef cb + + // load line 2 -> v20 ... v25 + buf_ptr = buf_base + 2 * row_step; +#define cb(i) \ + v2##i = GiLoadFloat32(buf_ptr + i * col_step); \ + v0##i = ADDF(v1##i, v2##i); \ + v1##i = SUBF(v1##i, v2##i); + UNROLL_CALL_NOWRAPPER(6, cb); +#undef cb + + // load line 3 -> v30 ... v35 + buf_ptr = buf_base + 3 * row_step; +#define cb(i) v3##i = GiLoadFloat32(buf_ptr + i * col_step); + UNROLL_CALL_NOWRAPPER(6, cb); +#undef cb + + // load line 4 -> v40 ... v45 + buf_ptr = buf_base + 4 * row_step; +#define cb(i) \ + v4##i = GiLoadFloat32(buf_ptr + i * col_step); \ + v2##i = ADDF(v3##i, v4##i); \ + v3##i = SUBF(v3##i, v4##i); \ + v4##i = MADD(v0##i, v2##i, v0, 2); \ + v2##i = ADDF(v2##i, v0##i); + UNROLL_CALL_NOWRAPPER(6, cb); +#undef cb + + float* output_base = output + oc * OH * OW + oh_start * OW * pack_size + + ow_start * pack_size; + float* output_ptr = output_base + 2 * OW * pack_size; + const float* bias_base = nullptr; + const float* bias_ptr = nullptr; + if (bmode == BiasMode::BIAS) { + bias_base = bias + oc * OH * OW + oh_start * OW * pack_size + + ow_start * pack_size; + bias_ptr = bias_base + 2 * OW * pack_size; + } + if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { + vbias = GiLoadFloat32(bias + oc); + } + v00 = ADDF(v41, v42); + v01 = ADDF(v43, v44); + v02 = ADDF(v40, v00); + v02 = ADDF(v02, v01); + if (bmode == BiasMode::BIAS) { + vbias = GiLoadFloat32(bias_ptr); + } + v02 = ADDF(v02, vbias); + v02 = op(v02); + GiStoreFloat32(output_ptr, v02); + + v03 = SUBF(v41, v42); + v04 = SUBF(v43, v44); + v05 = MADD(v03, v04, v0, 1); + if (bmode == BiasMode::BIAS) { + vbias = GiLoadFloat32(bias_ptr + pack_size); + } + v05 = ADDF(v05, vbias); + v05 = op(v05); + GiStoreFloat32(output_ptr + pack_size, v05); + + v02 = MADD(v00, v01, v0, 2); + if (bmode == BiasMode::BIAS) { + vbias = GiLoadFloat32(bias_ptr + 2 * pack_size); + } + v02 = ADDF(v02, vbias); + v02 = op(v02); + GiStoreFloat32(output_ptr + 2 * pack_size, v02); + + v05 = MADD(v03, v04, v0, 3); + v05 = ADDF(v05, v45); + if (bmode == BiasMode::BIAS) { + vbias = GiLoadFloat32(bias_ptr + 3 * pack_size); + } + v05 = ADDF(v05, vbias); + v05 = op(v05); + GiStoreFloat32(output_ptr + 3 * pack_size, v05); + + buf_ptr = buf_base; +#define cb(i) \ + v4##i = GiLoadFloat32(buf_ptr + i * col_step); \ + v4##i = ADDF(v4##i, v2##i); + UNROLL_CALL_NOWRAPPER(6, cb); +#undef cb + + output_ptr = output_base; + if (bmode == BiasMode::BIAS) { + bias_ptr = bias_base; + } + v00 = ADDF(v41, v42); + v01 = ADDF(v43, v44); + v02 = ADDF(v40, v00); + v02 = ADDF(v02, v01); + if (bmode == BiasMode::BIAS) { + vbias = GiLoadFloat32(bias_ptr); + } + v02 = ADDF(v02, vbias); + v02 = op(v02); + GiStoreFloat32(output_ptr, v02); + + v03 = SUBF(v41, v42); + v04 = SUBF(v43, v44); + v05 = MADD(v03, v04, v0, 1); + if (bmode == BiasMode::BIAS) { + vbias = GiLoadFloat32(bias_ptr + pack_size); + } + v05 = ADDF(v05, vbias); + v05 = op(v05); + GiStoreFloat32(output_ptr + pack_size, v05); + + v02 = MADD(v00, v01, v0, 2); + if (bmode == BiasMode::BIAS) { + vbias = GiLoadFloat32(bias_ptr + 2 * pack_size); + } + v02 = ADDF(v02, vbias); + v02 = op(v02); + GiStoreFloat32(output_ptr + 2 * pack_size, v02); + + v05 = MADD(v03, v04, v0, 3); + v05 = ADDF(v05, v45); + if (bmode == BiasMode::BIAS) { + vbias = GiLoadFloat32(bias_ptr + 3 * pack_size); + } + v05 = ADDF(v05, vbias); + v05 = op(v05); + GiStoreFloat32(output_ptr + 3 * pack_size, v05); + +#define cb(i) v4##i = MADD(v1##i, v3##i, v0, 1); + UNROLL_CALL_NOWRAPPER(6, cb); +#undef cb + + output_ptr = output_base + OW * pack_size; + if (bmode == BiasMode::BIAS) { + bias_ptr = bias_base + OW * pack_size; + } + v00 = ADDF(v41, v42); + v01 = ADDF(v43, v44); + v02 = ADDF(v40, v00); + v02 = ADDF(v02, v01); + if (bmode == BiasMode::BIAS) { + vbias = GiLoadFloat32(bias_ptr); + } + v02 = ADDF(v02, vbias); + v02 = op(v02); + GiStoreFloat32(output_ptr, v02); + + v03 = SUBF(v41, v42); + v04 = SUBF(v43, v44); + v05 = MADD(v03, v04, v0, 1); + if (bmode == BiasMode::BIAS) { + vbias = GiLoadFloat32(bias_ptr + pack_size); + } + v05 = ADDF(v05, vbias); + v05 = op(v05); + GiStoreFloat32(output_ptr + pack_size, v05); + + v02 = MADD(v00, v01, v0, 2); + if (bmode == BiasMode::BIAS) { + vbias = GiLoadFloat32(bias_ptr + 2 * pack_size); + } + v02 = ADDF(v02, vbias); + v02 = op(v02); + GiStoreFloat32(output_ptr + 2 * pack_size, v02); + + v05 = MADD(v03, v04, v0, 3); + v05 = ADDF(v05, v45); + if (bmode == BiasMode::BIAS) { + vbias = GiLoadFloat32(bias_ptr + 3 * pack_size); + } + v05 = ADDF(v05, vbias); + v05 = op(v05); + GiStoreFloat32(output_ptr + 3 * pack_size, v05); + + buf_ptr = buf_base + 5 * row_step; +#define cb(i) \ + v2##i = GiLoadFloat32(buf_ptr + i * col_step); \ + v1##i = MADD(v1##i, v3##i, v0, 3); \ + v2##i = ADDF(v1##i, v2##i); + UNROLL_CALL_NOWRAPPER(6, cb); +#undef cb + + output_ptr = output_base + 3 * OW * pack_size; + if (bmode == BiasMode::BIAS) { + bias_ptr = bias_base + 3 * OW * pack_size; + } + v00 = ADDF(v21, v22); + v01 = ADDF(v23, v24); + v02 = ADDF(v20, v00); + v02 = ADDF(v02, v01); + if (bmode == BiasMode::BIAS) { + vbias = GiLoadFloat32(bias_ptr); + } + v02 = ADDF(v02, vbias); + v02 = op(v02); + GiStoreFloat32(output_ptr, v02); + + v03 = SUBF(v21, v22); + v04 = SUBF(v23, v24); + v05 = MADD(v03, v04, v0, 1); + if (bmode == BiasMode::BIAS) { + vbias = GiLoadFloat32(bias_ptr + pack_size); + } + v05 = ADDF(v05, vbias); + v05 = op(v05); + GiStoreFloat32(output_ptr + pack_size, v05); + + v02 = MADD(v00, v01, v0, 2); + if (bmode == BiasMode::BIAS) { + vbias = GiLoadFloat32(bias_ptr + 2 * pack_size); + } + v02 = ADDF(v02, vbias); + v02 = op(v02); + GiStoreFloat32(output_ptr + 2 * pack_size, v02); + + v05 = MADD(v03, v04, v0, 3); + v05 = ADDF(v05, v25); + if (bmode == BiasMode::BIAS) { + vbias = GiLoadFloat32(bias_ptr + 3 * pack_size); + } + v05 = ADDF(v05, vbias); + v05 = op(v05); + GiStoreFloat32(output_ptr + 3 * pack_size, v05); + +#undef MSUB +#undef MADD + } + + static inline void transform( + const float* output_transform_buf, const float* bias, float* output, + const size_t oh_start, const size_t ow_start, const size_t OH, + const size_t OW, const size_t oc_start, const size_t oc_end, + const size_t oc_index, const size_t unit_idx, const size_t nr_units_in_tile, + const DType& src_dtype, const DType& dst_dtype, const size_t num_oh_valid, + const size_t num_ow_valid) { + Op op(src_dtype, dst_dtype); + //! AT * m * A + + size_t oc = oc_start + oc_index; + size_t OCB = (oc_end - oc_start) / pack_size; + size_t ocb = oc_index / pack_size; + size_t col_step = OCB * nr_units_in_tile * 4; + size_t row_step = alpha * col_step; + +#if defined(GI_TARGET_X86) || defined(GI_RVV_INTRINSICS) +//! x86 and rvv GiSimdFmaLane API is slowly, as an alternate, use +//! GiMultiplyAddScalarFloat32 +#define MADD(a, b, c, d) GiMultiplyAddScalarFloat32(a, b, *(c + d)) +#define MSUB(a, b, c, d) GiMultiplySubScalarFloat32(a, b, *(c + d)) + const float* v0 = output_parameters; +#else +#define MADD(a, b, c, d) GiSimdFmaLane(a, b, c, d) +#define MSUB(a, b, c, d) GiFmsqLaneQFloat32(a, b, c, d) + GI_FLOAT32_t v0 = GiLoadFloat32(output_parameters); +#endif + + GI_FLOAT32_t vbias = GiZeroFloat32(); +#define cb(i, j) GI_FLOAT32_t v##i##j; + UNROLL_CALL_NOWRAPPER_D2(5, 6, cb); +#undef cb + + const float* buf_base = + output_transform_buf + ocb * nr_units_in_tile * 4 + unit_idx * 4; + const float* buf_ptr = nullptr; + + // load line 1 -> v10 ... v15 + buf_ptr = buf_base + row_step; +#define cb(i) v1##i = GiLoadFloat32(buf_ptr + i * col_step); + UNROLL_CALL_NOWRAPPER(6, cb); +#undef cb + + // load line 2 -> v20 ... v25 + buf_ptr = buf_base + 2 * row_step; +#define cb(i) \ + v2##i = GiLoadFloat32(buf_ptr + i * col_step); \ + v0##i = ADDF(v1##i, v2##i); \ + v1##i = SUBF(v1##i, v2##i); + UNROLL_CALL_NOWRAPPER(6, cb); +#undef cb + + // load line 3 -> v30 ... v35 + buf_ptr = buf_base + 3 * row_step; +#define cb(i) v3##i = GiLoadFloat32(buf_ptr + i * col_step); + UNROLL_CALL_NOWRAPPER(6, cb); +#undef cb + + // load line 4 -> v40 ... v45 + buf_ptr = buf_base + 4 * row_step; +#define cb(i) \ + v4##i = GiLoadFloat32(buf_ptr + i * col_step); \ + v2##i = ADDF(v3##i, v4##i); \ + v3##i = SUBF(v3##i, v4##i); \ + v4##i = MADD(v0##i, v2##i, v0, 2); \ + v2##i = ADDF(v2##i, v0##i); + UNROLL_CALL_NOWRAPPER(6, cb); +#undef cb + + // result line 2, v40 ... v45 -> v02 ... v05 + // v40 ... v45 is free. + v00 = ADDF(v41, v42); + v01 = ADDF(v43, v44); + v02 = ADDF(v40, v00); + v02 = ADDF(v02, v01); + + v04 = MADD(v00, v01, v0, 2); + + v00 = SUBF(v41, v42); + v01 = SUBF(v43, v44); + v03 = MADD(v00, v01, v0, 1); + + v05 = MADD(v00, v01, v0, 3); + v05 = ADDF(v05, v45); + + buf_ptr = buf_base; +#define cb(i) \ + v4##i = GiLoadFloat32(buf_ptr + i * col_step); \ + v4##i = ADDF(v4##i, v2##i); + UNROLL_CALL_NOWRAPPER(6, cb); +#undef cb + + // result line 0 + // v40 ... v45 -> v22 ... v25 + v20 = ADDF(v41, v42); + v21 = ADDF(v43, v44); + v22 = ADDF(v40, v20); + v22 = ADDF(v22, v21); + + v24 = MADD(v20, v21, v0, 2); + + v20 = SUBF(v41, v42); + v21 = SUBF(v43, v44); + v23 = MADD(v20, v21, v0, 1); + + v25 = MADD(v20, v21, v0, 3); + v25 = ADDF(v25, v45); + +#define cb(i) \ + v4##i = MADD(v1##i, v3##i, v0, 1); \ + v3##i = MADD(v1##i, v3##i, v0, 3); + UNROLL_CALL_NOWRAPPER(6, cb); +#undef cb + + // result line 1 + // v40 ... v45 -> v12 ... v15 + v10 = ADDF(v41, v42); + v11 = ADDF(v43, v44); + v12 = ADDF(v40, v10); + v12 = ADDF(v12, v11); + + v14 = MADD(v10, v11, v0, 2); + + v10 = SUBF(v41, v42); + v11 = SUBF(v43, v44); + v13 = MADD(v10, v11, v0, 1); + + v15 = MADD(v10, v11, v0, 3); + v15 = ADDF(v15, v45); + + buf_ptr = buf_base + 5 * row_step; +#define cb(i) \ + v4##i = GiLoadFloat32(buf_ptr + i * col_step); \ + v4##i = ADDF(v3##i, v4##i); + UNROLL_CALL_NOWRAPPER(6, cb); +#undef cb + + // result line 3 + // v40 ... v45 -> v32 ... v35 + v30 = ADDF(v41, v42); + v31 = ADDF(v43, v44); + v32 = ADDF(v40, v30); + v32 = ADDF(v32, v31); + + v34 = MADD(v30, v31, v0, 2); + + v30 = SUBF(v41, v42); + v31 = SUBF(v43, v44); + v33 = MADD(v30, v31, v0, 1); + + v35 = MADD(v30, v31, v0, 3); + v35 = ADDF(v35, v45); + + float* output_base = output + oc * OH * OW + oh_start * OW * pack_size + + ow_start * pack_size; + float* output_ptr = nullptr; + const float* bias_base = nullptr; + const float* bias_ptr = nullptr; + if (bmode == BiasMode::BIAS) { + bias_base = bias + oc * OH * OW + oh_start * OW * pack_size + + ow_start * pack_size; + } + if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { + vbias = GiLoadFloat32(bias + oc); + } + + switch (num_oh_valid) { + case 4: { + output_ptr = output_base + 3 * OW * pack_size; + if (bmode == BiasMode::BIAS) { + bias_ptr = bias_base + 3 * OW * pack_size; + } + switch (num_ow_valid) { + case 4: + if (bmode == BiasMode::BIAS) { + vbias = GiLoadFloat32(bias_ptr + 3 * pack_size); + } + v35 = ADDF(v35, vbias); + v35 = op(v35); + GiStoreFloat32(output_ptr + 3 * pack_size, v35); + MEGDNN_FALLTHRU; + case 3: + if (bmode == BiasMode::BIAS) { + vbias = GiLoadFloat32(bias_ptr + 2 * pack_size); + } + v34 = ADDF(v34, vbias); + v34 = op(v34); + GiStoreFloat32(output_ptr + 2 * pack_size, v34); + MEGDNN_FALLTHRU; + case 2: + if (bmode == BiasMode::BIAS) { + vbias = GiLoadFloat32(bias_ptr + pack_size); + } + v33 = ADDF(v33, vbias); + v33 = op(v33); + GiStoreFloat32(output_ptr + pack_size, v33); + MEGDNN_FALLTHRU; + case 1: + if (bmode == BiasMode::BIAS) { + vbias = GiLoadFloat32(bias_ptr); + } + v32 = ADDF(v32, vbias); + v32 = op(v32); + GiStoreFloat32(output_ptr, v32); + } + MEGDNN_FALLTHRU; + } + case 3: { + output_ptr = output_base + 2 * OW * pack_size; + if (bmode == BiasMode::BIAS) { + bias_ptr = bias_base + 2 * OW * pack_size; + } + switch (num_ow_valid) { + case 4: + if (bmode == BiasMode::BIAS) { + vbias = GiLoadFloat32(bias_ptr + 3 * pack_size); + } + v05 = ADDF(v05, vbias); + v05 = op(v05); + GiStoreFloat32(output_ptr + 3 * pack_size, v05); + MEGDNN_FALLTHRU; + case 3: + if (bmode == BiasMode::BIAS) { + vbias = GiLoadFloat32(bias_ptr + 2 * pack_size); + } + v04 = ADDF(v04, vbias); + v04 = op(v04); + GiStoreFloat32(output_ptr + 2 * pack_size, v04); + MEGDNN_FALLTHRU; + case 2: + if (bmode == BiasMode::BIAS) { + vbias = GiLoadFloat32(bias_ptr + pack_size); + } + v03 = ADDF(v03, vbias); + v03 = op(v03); + GiStoreFloat32(output_ptr + pack_size, v03); + MEGDNN_FALLTHRU; + case 1: + if (bmode == BiasMode::BIAS) { + vbias = GiLoadFloat32(bias_ptr); + } + v02 = ADDF(v02, vbias); + v02 = op(v02); + GiStoreFloat32(output_ptr, v02); + } + MEGDNN_FALLTHRU; + } + case 2: { + output_ptr = output_base + OW * pack_size; + if (bmode == BiasMode::BIAS) { + bias_ptr = bias_base + OW * pack_size; + } + switch (num_ow_valid) { + case 4: + if (bmode == BiasMode::BIAS) { + vbias = GiLoadFloat32(bias_ptr + 3 * pack_size); + } + v15 = ADDF(v15, vbias); + v15 = op(v15); + GiStoreFloat32(output_ptr + 3 * pack_size, v15); + MEGDNN_FALLTHRU; + case 3: + if (bmode == BiasMode::BIAS) { + vbias = GiLoadFloat32(bias_ptr + 2 * pack_size); + } + v14 = ADDF(v14, vbias); + v14 = op(v14); + GiStoreFloat32(output_ptr + 2 * pack_size, v14); + MEGDNN_FALLTHRU; + case 2: + if (bmode == BiasMode::BIAS) { + vbias = GiLoadFloat32(bias_ptr + pack_size); + } + v13 = ADDF(v13, vbias); + v13 = op(v13); + GiStoreFloat32(output_ptr + pack_size, v13); + MEGDNN_FALLTHRU; + case 1: + if (bmode == BiasMode::BIAS) { + vbias = GiLoadFloat32(bias_ptr); + } + v12 = ADDF(v12, vbias); + v12 = op(v12); + GiStoreFloat32(output_ptr, v12); + } + MEGDNN_FALLTHRU; + } + case 1: { + output_ptr = output_base; + if (bmode == BiasMode::BIAS) { + bias_ptr = bias_base; + } + switch (num_ow_valid) { + case 4: + if (bmode == BiasMode::BIAS) { + vbias = GiLoadFloat32(bias_ptr + 3 * pack_size); + } + v25 = ADDF(v25, vbias); + v25 = op(v25); + GiStoreFloat32(output_ptr + 3 * pack_size, v25); + MEGDNN_FALLTHRU; + case 3: + if (bmode == BiasMode::BIAS) { + vbias = GiLoadFloat32(bias_ptr + 2 * pack_size); + } + v24 = ADDF(v24, vbias); + v24 = op(v24); + GiStoreFloat32(output_ptr + 2 * pack_size, v24); + MEGDNN_FALLTHRU; + case 2: + if (bmode == BiasMode::BIAS) { + vbias = GiLoadFloat32(bias_ptr + pack_size); + } + v23 = ADDF(v23, vbias); + v23 = op(v23); + GiStoreFloat32(output_ptr + pack_size, v23); + MEGDNN_FALLTHRU; + case 1: + if (bmode == BiasMode::BIAS) { + vbias = GiLoadFloat32(bias_ptr); + } + v22 = ADDF(v22, vbias); + v22 = op(v22); + GiStoreFloat32(output_ptr, v22); + } + } + } + +#undef MSUB +#undef MADD + } + + static void wrapper( + const float* output_transform_buf, const float* bias, float* output, + const size_t OH, const size_t OW, const size_t oc_start, + const size_t oc_end, const size_t unit_start_idx, + const size_t nr_units_in_tile, const DType& src_dtype, + const DType& dst_dtype) { + auto units_w = div_ceil(OW, 4); + for (size_t oc = oc_start; oc < oc_end; oc += pack_size) { + size_t oc_index = oc - oc_start; + rep(unit_idx, nr_units_in_tile) { + size_t index = unit_start_idx + unit_idx; + auto nh = index / units_w; + auto nw = index % units_w; + size_t oh_start = nh * 4; + size_t ow_start = nw * 4; + megdnn_assert(oh_start < OH); + megdnn_assert(ow_start < OW); + size_t num_valid_oh = std::min(static_cast(4), OH - oh_start), + num_valid_ow = std::min(static_cast(4), OW - ow_start); + if (num_valid_oh == num_valid_ow && num_valid_oh == 4) { + transform( + output_transform_buf, bias, output, oh_start, ow_start, OH, + OW, oc_start, oc_end, oc_index, unit_idx, nr_units_in_tile, + src_dtype, dst_dtype); + } else { + transform( + output_transform_buf, bias, output, oh_start, ow_start, OH, + OW, oc_start, oc_end, oc_index, unit_idx, nr_units_in_tile, + src_dtype, dst_dtype, num_valid_oh, num_valid_ow); + } + } + } + } +}; // OutputTransformF43_NCHW44 +} // namespace + +namespace megdnn { +namespace fallback { +namespace winograd { + +MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_F43_mk4_f_nchw44) + +void winograd_F43_mk4_f_nchw44::filter( + const float* filter, float* filter_transform_buf, float* transform_mid_buf, + size_t OC, size_t IC, size_t oc_start, size_t oc_end) { + constexpr size_t pack_size = 4; + MEGDNN_MARK_USED_VAR(transform_mid_buf); + megdnn_assert( + (oc_end - oc_start) % pack_size == 0 && oc_start % pack_size == 0 && + oc_end % pack_size == 0 && IC % pack_size == 0 && + OC % pack_size == 0, + "NCHW44 Winograd filter transform requires both OC and IC " + "are times of 4"); + + size_t ICB = IC / pack_size; + for (size_t ocb = oc_start / pack_size; ocb < oc_end / pack_size; ocb++) { + for (size_t icb = 0; icb < ICB; icb++) { + for (size_t ic_inner = 0; ic_inner < pack_size; ic_inner++) { + const float* fptr = filter + + (ocb * ICB + icb) * KERNEL_SIZE * KERNEL_SIZE * + pack_size * pack_size + + ic_inner * pack_size; + +#define cb(m, n) \ + GI_FLOAT32_t g##m##n = \ + GiLoadFloat32(fptr + (m * KERNEL_SIZE + n) * pack_size * pack_size); + UNROLL_CALL_NOWRAPPER_D2(3, 3, cb) +#undef cb + + //! G + // 1/4 0 0 + // -1/6 -1/6 -1/6 + // -1/6 1/6 -1/6 + // 1/24 1/12 1/6 + // 1/24 -1/12 1/6 + // 0 0 1 + +#define FILTER_TRANSFORM(n, wd, g) \ + auto wd##n##0 = MULSF(g##0##n, 0.25f); \ + tmp0 = MULSF(ADDF(g##0##n, g##2##n), -0.1666667f); \ + tmp1 = MULSF(g##1##n, -0.1666667f); \ + auto wd##n##1 = ADDF(tmp0, tmp1); \ + auto wd##n##2 = SUBF(tmp0, tmp1); \ + tmp0 = ADDF(MULSF(g##0##n, 0.0416667f), MULSF(g##2##n, 0.1666667f)); \ + tmp1 = MULSF(g##1##n, 0.0833333f); \ + auto wd##n##3 = ADDF(tmp0, tmp1); \ + auto wd##n##4 = SUBF(tmp0, tmp1); \ + auto wd##n##5 = g##2##n; + GI_FLOAT32_t tmp0, tmp1; + UNROLL_CALL_RAW(3, FILTER_TRANSFORM, wd, g); + UNROLL_CALL_RAW(6, FILTER_TRANSFORM, ret, wd); +#undef FILTER_TRANSFORM +#define cb_save(m, n) \ + GiStoreFloat32( \ + filter_transform_buf + (m * alpha + n) * OC * IC + ocb * IC * pack_size + \ + icb * pack_size * pack_size + ic_inner * pack_size, \ + ret##m##n); + UNROLL_CALL_NOWRAPPER_D2(6, 6, cb_save) +#undef cb_save + } + } + } +} + +void winograd_F43_mk4_f_nchw44::input( + const float* input, float* input_transform_buf, float* transform_mid_buf, + size_t IH, size_t IW, size_t IC, size_t PH, size_t PW, size_t unit_start_idx, + size_t nr_units_in_tile) { + MEGDNN_MARK_USED_VAR(transform_mid_buf); + constexpr size_t pack_size = 4; + megdnn_assert(IC % pack_size == 0); + constexpr int alpha = 3 + 4 - 1; + + auto units_w = div_ceil(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE); + + bool ih_valid[6], iw_valid[6]; + + for (size_t ic = 0; ic < IC; ic += pack_size) { + rep(unit_idx, nr_units_in_tile) { + size_t index = unit_start_idx + unit_idx; + size_t nh = index / units_w; + size_t nw = index % units_w; + int ih_start = nh * OUTPUT_BLOCK_SIZE - PH; + int iw_start = nw * OUTPUT_BLOCK_SIZE - PW; + + if (ih_start >= 0 && ih_start + alpha <= static_cast(IH) && + iw_start >= 0 && iw_start + alpha <= static_cast(IW)) { + InputTransformF43_NCHW44::transform( + input, input_transform_buf, unit_idx, nr_units_in_tile, ic, IC, + ih_start, iw_start, IH, IW, ih_valid, iw_valid); + } else { + for (int iho = 0; iho < alpha; ++iho) { + ih_valid[iho] = + (iho + ih_start >= 0 && + iho + ih_start < static_cast(IH)); + } + for (int iwo = 0; iwo < alpha; ++iwo) { + iw_valid[iwo] = + (iwo + iw_start >= 0 && + iwo + iw_start < static_cast(IW)); + } + InputTransformF43_NCHW44::transform( + input, input_transform_buf, unit_idx, nr_units_in_tile, ic, IC, + ih_start, iw_start, IH, IW, ih_valid, iw_valid); + } + } + } +} + +void winograd_F43_mk4_f_nchw44::output( + const float* output_transform_buf, const float* bias, float* output, + float* transform_mid_buf, BiasMode bmode, NonlineMode nonline_mode, size_t OH, + size_t OW, size_t oc_start, size_t oc_end, size_t unit_start_idx, + size_t nr_units_in_tile) { + MEGDNN_MARK_USED_VAR(transform_mid_buf); +#define cb(_bmode, _nonline_op, ...) \ + OutputTransformF43_NCHW44<_bmode MEGDNN_COMMA _nonline_op>::wrapper( \ + output_transform_buf, bias, output, OH, OW, oc_start, oc_end, \ + unit_start_idx, nr_units_in_tile, src_dtype, dst_dtype); + + constexpr size_t pack_size = 4; + size_t OC = oc_end - oc_start; + megdnn_assert( + OC % pack_size == 0 && oc_start % pack_size == 0 && oc_end % pack_size == 0, + "NCHW44 Winograd filter transform requires OC is times of 4"); + + GI_DISPATCH_CONV_WINOGRAD_BIAS( + megdnn_fallback_winograd_fp32_F43_mk4, cb, float, float, bmode, + nonline_mode); +#undef cb +} + +} // namespace winograd +} // namespace fallback +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/fallback/conv_bias/opr_impl.cpp b/dnn/src/fallback/conv_bias/opr_impl.cpp index 18db4d0e..87b49b00 100644 --- a/dnn/src/fallback/conv_bias/opr_impl.cpp +++ b/dnn/src/fallback/conv_bias/opr_impl.cpp @@ -121,7 +121,7 @@ public: for (auto&& algo : matmul_algos) { if (is_naive(algo)) continue; - for (uint32_t tile_size : {16, 8, 24, 32}) { + for (uint32_t tile_size : {16, 8, 24, 32, 48, 68}) { refhold.emplace_back(new AlgoFP32WinogradF23_4x4( static_cast(algo), tile_size)); @@ -130,10 +130,18 @@ public: static_cast(algo), tile_size)); m_gi_winograd_algos.emplace_back(refhold.back().get()); + refhold.emplace_back(new AlgoFP32WinogradF43_4x4( + static_cast(algo), + tile_size)); + m_gi_winograd_algos.emplace_back(refhold.back().get()); refhold.emplace_back(new AlgoFP32WinogradF63_4x4_NCHW44( static_cast(algo), tile_size)); m_gi_winograd_algos.emplace_back(refhold.back().get()); + refhold.emplace_back(new AlgoFP32WinogradF43_4x4_NCHW44( + static_cast(algo), + tile_size)); + m_gi_winograd_algos.emplace_back(refhold.back().get()); refhold.emplace_back(new AlgoFP32WinogradF23_4x4_NCHW44( static_cast(algo), tile_size)); diff --git a/dnn/src/fallback/conv_bias/opr_impl.h b/dnn/src/fallback/conv_bias/opr_impl.h index 5a0fe75a..a2222b51 100644 --- a/dnn/src/fallback/conv_bias/opr_impl.h +++ b/dnn/src/fallback/conv_bias/opr_impl.h @@ -219,9 +219,11 @@ public: GI_COMMON_WINOGRAD_F63_FP32, GI_COMMON_WINOGRAD_F43_FP32, GI_COMMON_WINOGRAD_F63_4X4_FP32, + GI_COMMON_WINOGRAD_F43_4X4_FP32, GI_COMMON_WINOGRAD_F54_FP32, GI_COMMON_WINOGRAD_F45_FP32, GI_COMMON_WINOGRAD_F23_4X4_NCHW44_F32, + GI_COMMON_WINOGRAD_F43_4X4_NCHW44_F32, GI_COMMON_WINOGRAD_F63_4X4_NCHW44_F32, GI_COMMON_WINOGRAD_F73_4X4_NCHW44_F32, GI_COMMON_DIRECT_FP32, @@ -382,9 +384,11 @@ private: class AlgoFP32WinogradF63; class AlgoFP32WinogradF43; class AlgoFP32WinogradF63_4x4; + class AlgoFP32WinogradF43_4x4; class AlgoFP32WinogradF54; class AlgoFP32WinogradF45; class AlgoFP32WinogradF23_4x4_NCHW44; + class AlgoFP32WinogradF43_4x4_NCHW44; class AlgoFP32WinogradF63_4x4_NCHW44; class AlgoFP32WinogradF73_4x4_NCHW44; diff --git a/dnn/test/arm_common/conv_bias.cpp b/dnn/test/arm_common/conv_bias.cpp index 3bdfa2c7..f54f8849 100644 --- a/dnn/test/arm_common/conv_bias.cpp +++ b/dnn/test/arm_common/conv_bias.cpp @@ -1013,6 +1013,27 @@ TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_WINOGRAD_F43_F63) { handle(), 3); #endif } + +TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_WINOGRAD_44_F43_F23) { +#if MEGDNN_AARCH64 + benchmark_winograd_compare( + "WINOGRAD:.*:4:4:.*:3", "WINOGRAD:.*:4:2", handle(), 3, 4); +#endif +} + +TEST_F(ARM_COMMON, BENCHMARK_WINOGRAD_F43_44) { +#if MEGDNN_AARCH64 + benchmark_winograd_weight_preprocess("WINOGRAD:.*:4:4:.*:3", handle(), 3, 4); +#endif +} + +TEST_F(ARM_COMMON, BENCHMARK_WINOGRAD_F43_NCHW44) { +#if MEGDNN_AARCH64 + benchmark_winograd_weight_preprocess( + "WINOGRAD_NCHW44:.*:4:4:.*:3", handle(), 3, 4, 4); +#endif +} + TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_WINOGRAD_F63) { #if MEGDNN_AARCH64 benchmark_winograd("WINOGRAD:AARCH64_F32K8X12X1:1:6", handle(), 3); diff --git a/dnn/test/common/conv_bias.cpp b/dnn/test/common/conv_bias.cpp index 6476176f..0130db00 100644 --- a/dnn/test/common/conv_bias.cpp +++ b/dnn/test/common/conv_bias.cpp @@ -902,7 +902,8 @@ void check_conv_bias( } #if MEGDNN_WITH_BENCHMARK std::vector get_winograd_benchmark_args( - size_t kernel, size_t pack_size) { + size_t kernel, size_t pack_size, size_t io_pack_size) { + megdnn_assert(io_pack_size == 1 || io_pack_size == 4); std::vector args; auto pack = [&](size_t oc, size_t ic, size_t w, size_t h, size_t kernel, size_t p) { if (ic % pack_size != 0 || oc % pack_size != 0) @@ -915,11 +916,20 @@ std::vector get_winograd_benchmark_args( param.pad_h = p; param.pad_w = p; - args.push_back(conv_bias::TestArg{ - param, - TensorShape{1, ic, h, w}, - TensorShape{oc, ic, kernel, kernel}, - {1, oc, 1, 1}}); + if (io_pack_size == 4) { + param.format = param::ConvBias::Format::NCHW44; + args.push_back(conv_bias::TestArg{ + param, + TensorShape{1, ic / 4, h, w, 4}, + TensorShape{oc / 4, ic / 4, kernel, kernel, 4, 4}, + {1, oc / 4, 1, 1, 4}}); + } else { + args.push_back(conv_bias::TestArg{ + param, + TensorShape{1, ic, h, w}, + TensorShape{oc, ic, kernel, kernel}, + {1, oc, 1, 1}}); + } }; for (size_t ic : {8, 16, 32, 64}) { @@ -950,8 +960,9 @@ std::vector get_winograd_benchmark_args( } void benchmark_winograd( - const char* algo_name, Handle* handle, size_t kernel, size_t pack_size) { - auto&& args = get_winograd_benchmark_args(kernel, pack_size); + const char* algo_name, Handle* handle, size_t kernel, size_t pack_size, + size_t io_pack_size) { + auto&& args = get_winograd_benchmark_args(kernel, pack_size, io_pack_size); using namespace conv_bias; constexpr size_t RUN = 10; Benchmarker benchmark(handle); @@ -969,10 +980,17 @@ void benchmark_winograd( opr->deduce_layout( {arg.src, dtype::Float32()}, {arg.filter, dtype::Float32()}, {arg.bias, dtype::Float32()}, {}, dst_layout); - //! dst.nr_elems * IC * FH * FW * 2 - float computations = dst_layout.total_nr_elems() * arg.filter[1] * - arg.filter[2] * arg.filter[3] * 2.0 / - (1024 * 1024 * 1024) * 1e3; + float computations = 0.0; + if (io_pack_size == 1) { + //! dst.nr_elems * IC * FH * FW * 2 + computations = dst_layout.total_nr_elems() * arg.filter[1] * arg.filter[2] * + arg.filter[3] * 2.0 / (1024 * 1024 * 1024) * 1e3; + } else { + //! dst.nr_elems * IC/4 * FH * FW * 4 * 2 + computations = dst_layout.total_nr_elems() * arg.filter[1] * arg.filter[2] * + arg.filter[3] * arg.filter[4] * 2.0 / (1024 * 1024 * 1024) * + 1e3; + } param::Convolution conv_param; conv_param.pad_h = arg.param.pad_h; @@ -999,9 +1017,9 @@ void benchmark_winograd( // usage of weight pre-processing for winograd benchmark void benchmark_winograd_weight_preprocess( - const char* algo_name, megdnn::Handle* handle, size_t kernel, - size_t pack_size) { - auto&& args = get_winograd_benchmark_args(kernel, pack_size); + const char* algo_name, megdnn::Handle* handle, size_t kernel, size_t pack_size, + size_t io_pack_size) { + auto&& args = get_winograd_benchmark_args(kernel, pack_size, io_pack_size); using namespace conv_bias; constexpr size_t RUN = 10; @@ -1018,16 +1036,17 @@ void benchmark_winograd_weight_preprocess( opr->deduce_layout( {arg.src, dtype::Float32()}, {arg.filter, dtype::Float32()}, {arg.bias, dtype::Float32()}, {}, dst_layout); - //! dst.nr_elems * IC * FH * FW * 2 - float computations = dst_layout.total_nr_elems() * arg.filter[1] * - arg.filter[2] * arg.filter[3] * 2.0 / - (1024 * 1024 * 1024) * 1e3; - - param::Convolution conv_param; - conv_param.pad_h = arg.param.pad_h; - conv_param.pad_w = arg.param.pad_w; - conv_param.stride_h = arg.param.stride_h; - conv_param.stride_w = arg.param.stride_w; + float computations = 0.0; + if (io_pack_size == 1) { + //! dst.nr_elems * IC * FH * FW * 2 + computations = dst_layout.total_nr_elems() * arg.filter[1] * arg.filter[2] * + arg.filter[3] * 2.0 / (1024 * 1024 * 1024) * 1e3; + } else { + //! dst.nr_elems * IC/4 * FH * FW * 4 * 2 + computations = dst_layout.total_nr_elems() * arg.filter[1] * arg.filter[2] * + arg.filter[3] * arg.filter[4] * 2.0 / (1024 * 1024 * 1024) * + 1e3; + } benchmark_winograd.set_param(arg.param); auto used_winograd = @@ -1045,8 +1064,8 @@ void benchmark_winograd_weight_preprocess( void benchmark_winograd_compare( const char* algoA_name, const char* algoB_name, megdnn::Handle* handle, - size_t kernel, size_t pack_size) { - auto&& args = get_winograd_benchmark_args(kernel, pack_size); + size_t kernel, size_t pack_size, size_t io_pack_size) { + auto&& args = get_winograd_benchmark_args(kernel, pack_size, io_pack_size); using namespace conv_bias; constexpr size_t RUN = 10; @@ -1062,16 +1081,17 @@ void benchmark_winograd_compare( opr->deduce_layout( {arg.src, dtype::Float32()}, {arg.filter, dtype::Float32()}, {arg.bias, dtype::Float32()}, {}, dst_layout); - //! dst.nr_elems * IC * FH * FW * 2 - float computations = dst_layout.total_nr_elems() * arg.filter[1] * - arg.filter[2] * arg.filter[3] * 2.0 / - (1024 * 1024 * 1024) * 1e3; - - param::Convolution conv_param; - conv_param.pad_h = arg.param.pad_h; - conv_param.pad_w = arg.param.pad_w; - conv_param.stride_h = arg.param.stride_h; - conv_param.stride_w = arg.param.stride_w; + float computations = 0.0; + if (io_pack_size == 1) { + //! dst.nr_elems * IC * FH * FW * 2 + computations = dst_layout.total_nr_elems() * arg.filter[1] * arg.filter[2] * + arg.filter[3] * 2.0 / (1024 * 1024 * 1024) * 1e3; + } else { + //! dst.nr_elems * IC/4 * FH * FW * 4 * 2 + computations = dst_layout.total_nr_elems() * arg.filter[1] * arg.filter[2] * + arg.filter[3] * arg.filter[4] * 2.0 / (1024 * 1024 * 1024) * + 1e3; + } benchmark_winograd.set_param(arg.param); auto used_winograd1 = diff --git a/dnn/test/common/conv_bias.h b/dnn/test/common/conv_bias.h index 0d0961ee..b44d0293 100644 --- a/dnn/test/common/conv_bias.h +++ b/dnn/test/common/conv_bias.h @@ -62,16 +62,16 @@ void check_conv_bias( #if MEGDNN_WITH_BENCHMARK std::vector get_winograd_benchmark_args( - size_t kernel, size_t pack_size = 1); + size_t kernel, size_t pack_size = 1, size_t io_pack_size = 1); void benchmark_winograd( const char* algo_name, megdnn::Handle* handle, size_t kernel, - size_t pack_size = 1); + size_t pack_size = 1, size_t io_pack_size = 1); void benchmark_winograd_weight_preprocess( const char* algo_name, megdnn::Handle* handle, size_t kernel, - size_t pack_size = 1); + size_t pack_size = 1, size_t io_pack_size = 1); void benchmark_winograd_compare( const char* algoA_name, const char* algoB_name, megdnn::Handle* handle, - size_t kernel, size_t pack_size = 1); + size_t kernel, size_t pack_size = 1, size_t io_pack_size = 1); #endif // MEGDNN_WITH_BENCHMARK template void check_winograd( diff --git a/dnn/test/fallback/conv_bias.cpp b/dnn/test/fallback/conv_bias.cpp index 3f6eab53..d7c2ee2e 100644 --- a/dnn/test/fallback/conv_bias.cpp +++ b/dnn/test/fallback/conv_bias.cpp @@ -597,6 +597,25 @@ TEST_F(FALLBACK_MULTI_THREADS, CONVBIAS_GI_WINOGRAD_F63_4_NCHW44) { param::ConvBias::Format::NCHW44); } +TEST_F(FALLBACK_MULTI_THREADS, CONVBIAS_GI_WINOGRAD_F43_4_NCHW44) { + using namespace conv_bias; + std::vector args = + get_nchw44_conv_bias_args({3}, QUAN_NLMODE, BR_AND_NO_BIASMODE, 1); + Checker checker(handle()); + check_winograd( + "4:4:16", checker, args, param::MatrixMul::Format::MK4, + param::ConvBias::Format::NCHW44); +} + +TEST_F(FALLBACK_MULTI_THREADS, CONVBIAS_GI_WINOGRAD_F43_4_WEIGHT_PREPROCESS) { + using namespace conv_bias; + std::vector args = get_winograd_mk_packed_args(); + Checker> checker( + handle()); + + check_winograd("4:4:16", checker, args, param::MatrixMul::Format::MK4); +} + TEST_F(FALLBACK_MULTI_THREADS, CONVBIAS_GI_WINOGRAD_F54) { using namespace conv_bias; std::vector args = get_winograd_args(4);