diff --git a/dnn/src/arm_common/conv_bias/f16/algos.h b/dnn/src/arm_common/conv_bias/f16/algos.h index 787c6dfe..23a5d9cc 100644 --- a/dnn/src/arm_common/conv_bias/f16/algos.h +++ b/dnn/src/arm_common/conv_bias/f16/algos.h @@ -153,6 +153,27 @@ public: MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_CHWNWISE_NCHW88_F16) }; +class ConvBiasImpl::AlgoF16DirectNCHW88 final : public AlgoBase { + SmallVector get_kimpls(const NCBKernSizeParam& param) const; + +public: + AlgoF16DirectNCHW88() {} + AlgoAttribute attribute() const override { + return AlgoAttribute::REPRODUCIBLE; + } + const char* name() const override { return "F16_CONV_NCHW88_DIRECT"; } + bool usable(const NCBKernSizeParam& param, + AlgoSelectionStrategy algo_selection_strategy) const override; + + size_t get_workspace(const NCBKernSizeParam& param) const override; + virtual SmallVector dispatch_kerns( + const NCBKernSizeParam& param) const override; + ConvAlgoTypePack get_algo_type() const override { + return {AlgoDataType::FLOAT16, AlgoCategory::DIRECT}; + } + MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_NCHW88_FP16) +}; + } // namespace arm_common } // namespace megdnn #endif diff --git a/dnn/src/arm_common/conv_bias/f16/direct_nchw88_algo.cpp b/dnn/src/arm_common/conv_bias/f16/direct_nchw88_algo.cpp new file mode 100644 index 00000000..06aaf80b --- /dev/null +++ b/dnn/src/arm_common/conv_bias/f16/direct_nchw88_algo.cpp @@ -0,0 +1,296 @@ +/** + * \file dnn/src/arm_common/conv_bias/f16/direct_nchw88_algo.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. + */ + +#include "megdnn/oprs.h" +#include "src/arm_common/conv_bias/block_helper.h" +#include "src/arm_common/conv_bias/f16/algos.h" +#include "src/arm_common/conv_bias/f16/direct_nchw88_kern.h" + +#include "src/arm_common/elemwise_op.h" + +#include "midout.h" + +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + +using namespace megdnn; +using namespace arm_common; +using conv_fun = + std::function; +MIDOUT_DECL(megdnn_arm_common_conv_bias_fp16_nchw88) +namespace { + +static WorkspaceBundle get_bundle(const ConvBiasImpl::NCBKernSizeParam& param) { + auto&& fm = param.filter_meta; + size_t nr_threads = param.nr_threads; + size_t IC = fm.icpg / 8; + size_t PH = fm.padding[0]; + size_t PW = fm.padding[1]; + size_t IH2 = param.isz[0] + 2 * PH; + size_t IW2 = param.isz[1] + 2 * PW; + if (PH == 0 && PW == 0) { + return {nullptr, {}}; + } + + size_t s = (nr_threads * IC * IH2 * IW2 * 8) * sizeof(dt_float16); + return {nullptr, {s}}; +} + +void copy_padding_kern(const WorkspaceBundle& bundle, + const ConvBiasImpl::NCBKernParam& kern_param, + const ConvBiasImpl::NCBKernIndex& ncb_index, + const CpuNDRange& workspace_ids) { + auto fm = kern_param.filter_meta; + size_t group = fm.group; + size_t IH = kern_param.isz[0]; + size_t IW = kern_param.isz[1]; + size_t IC = fm.icpg / 8; + size_t PH = fm.padding[0]; + size_t PW = fm.padding[1]; + size_t IH2 = IH + 2 * PH; + size_t IW2 = IW + 2 * PW; + + if (PH == 0 && PW == 0) { + return; + } + + //! Used for get the workspace offset + size_t workspace_group_id = workspace_ids[0]; + size_t workspace_batch_id = workspace_ids[1]; + size_t channel_id = workspace_ids[2]; + size_t group_id = ncb_index.ndrange_id[0]; + size_t batch_id = ncb_index.ndrange_id[1]; + + const dt_float16* sptr = + kern_param.src(batch_id, group_id, channel_id, 1, 8); + + //! copy to sptr_base to eliminate padding effect + dt_float16* sptr_base = static_cast(bundle.get(0)) + + workspace_batch_id * group * IC * IH2 * IW2 * 8 + + workspace_group_id * IC * IH2 * IW2 * 8 + + channel_id * IH2 * IW2 * 8; + std::memset(sptr_base, 0, IH2 * IW2 * 8 * sizeof(dt_float16)); + rep(ih, IH) { + std::memcpy(sptr_base + (ih + PH) * IW2 * 8 + PW * 8, + sptr + ih * IW * 8, IW * 8 * sizeof(dt_float16)); + } +}; + +template +static void do_conv_kern(const WorkspaceBundle& bundle, + const ConvBiasImpl::NCBKernParam& kern_param, + const ConvBiasImpl::NCBKernIndex& ncb_index, + const CpuNDRange& workspace_ids) { + auto fm = kern_param.filter_meta; + size_t group = fm.group; + size_t OH = kern_param.osz[0]; + size_t OW = kern_param.osz[1]; + size_t FW = FH; + size_t IC = fm.icpg / 8; + size_t PH = fm.padding[0]; + size_t PW = fm.padding[1]; + size_t IH2 = kern_param.isz[0] + 2 * PH; + size_t IW2 = kern_param.isz[1] + 2 * PW; + + size_t group_id = ncb_index.ndrange_id[0]; + size_t batch_id = ncb_index.ndrange_id[1]; + size_t channel_id = workspace_ids[2]; + + //! Used for get the workspace offset + size_t workspace_batch_id = workspace_ids[1]; + size_t workspace_group_id = workspace_ids[0]; + + const __fp16* sptr = nullptr; + if (PH == 0 && PW == 0) { + sptr = reinterpret_cast( + kern_param.src(batch_id, group_id)); + } else { + sptr = reinterpret_cast( + static_cast(bundle.get(0))) + + workspace_batch_id * group * IC * IH2 * IW2 * 8 + + workspace_group_id * IC * IH2 * IW2 * 8; + } + const __fp16* filter = reinterpret_cast( + kern_param.filter(group_id, 1)) + + channel_id * IC * FH * FW * 8 * 8; + const __fp16* bias_ptr = reinterpret_cast( + kern_param.bias(batch_id, group_id, channel_id, 1, 8)); + __fp16* dptr = reinterpret_cast<__fp16*>( + kern_param.dst(batch_id, group_id, channel_id, 1, 8)); + + conv_bias::conv_direct_fp16_nchw88( + sptr, filter, bias_ptr, dptr, IC, IH2, IW2, OH, OW); +} + +} // namespace + +/* ===================== stride1 algo ===================== */ +bool ConvBiasImpl::AlgoF16DirectNCHW88::usable(const NCBKernSizeParam& param, + AlgoSelectionStrategy) const { + auto&& fm = param.filter_meta; + auto fh = fm.spatial[0]; + int oc = fm.ocpg; + int ic = fm.icpg; + bool ok_type = ((param.src_type.enumv() == DTypeEnum::Float16 && + param.filter_type.enumv() == DTypeEnum::Float16 && + (param.dst_type.enumv() == DTypeEnum::Float16))) && + (fm.format == param::Convolution::Format::NCHW88); + bool ok_src_dst = (oc % 8 == 0 && oc >= 8 && ic % 8 == 0 && ic >= 8); + bool ok_filter = fm.spatial_ndim == 2 && fh == fm.spatial[1] && + (fh == 1 || fh == 2 || fh == 3 || fh == 5 || fh == 7); + bool ok_slide = fm.dilation[0] == 1 && fm.dilation[1] == 1 && + ((fm.stride[0] == 1 && fm.stride[1] == 1) || + (fm.stride[0] == 2 && fm.stride[1] == 2)); + bool ok_conv = !fm.should_flip; + bool ok_comp = param.compute_mode == Param::ComputeMode::DEFAULT; + return ok_type && ok_src_dst && ok_filter && ok_slide && ok_conv && ok_comp; +} + +size_t ConvBiasImpl::AlgoF16DirectNCHW88::get_workspace( + const NCBKernSizeParam& param) const { + MIDOUT_BEGIN(megdnn_arm_common_conv_bias_fp16_nchw88_stride1, + midout_iv("AlgoF16DirectNCHW88::get_workspace"_hash)) { + return get_bundle(param).total_size_in_bytes(); + } + MIDOUT_END(); + return 0; +} + +SmallVector +ConvBiasImpl::AlgoF16DirectNCHW88::dispatch_kerns( + const NCBKernSizeParam& param) const { + auto fm = param.filter_meta; + size_t batch = param.n; + size_t group = fm.group; + + WorkspaceBundle wbundle = get_bundle(param); + conv_fun do_conv_fun = nullptr; + // NOTE: remain_w is not used to gen hash of midout for compatible with +// shape runtime +#define DO_CONV_KERN_FUN(filter, bias_mode, op, stride) \ + MIDOUT_BEGIN(megdnn_arm_common_conv_bias_fp16_nchw88, \ + midout_iv(#filter #bias_mode #stride #op##_hash)) { \ + do_conv_fun = do_conv_kern; \ + } \ + MIDOUT_END(); + +#define GET_STRIDE_PARAM(filter, bias_mode, op) \ + switch (fm.stride[0]) { \ + case 1: \ + DO_CONV_KERN_FUN(filter, bias_mode, op, 1); \ + break; \ + case 2: \ + DO_CONV_KERN_FUN(filter, bias_mode, op, 2); \ + break; \ + \ + default: \ + megdnn_assert(0, "stride not supported"); \ + } + +#define GET_OP_PARAM(filter, bias_mode) \ + switch (param.nonlineMode) { \ + case param::ConvBias::NonlineMode::IDENTITY: \ + GET_STRIDE_PARAM(filter, bias_mode, NoneOp<__fp16>) \ + break; \ + case param::ConvBias::NonlineMode::RELU: \ + GET_STRIDE_PARAM(filter, bias_mode, ReluOp<__fp16>) \ + break; \ + case param::ConvBias::NonlineMode::H_SWISH: \ + GET_STRIDE_PARAM(filter, bias_mode, HSwishOp<__fp16>) \ + break; \ + case param::ConvBias::NonlineMode::SIGMOID: \ + GET_STRIDE_PARAM(filter, bias_mode, SigmoidOp<__fp16>) \ + break; \ + default: \ + megdnn_assert(0, "nonline not supported"); \ + break; \ + } + +#define GET_BIAS_MODE_PARAM(filter) \ + switch (param.bias_mode) { \ + case BiasMode::NO_BIAS: \ + GET_OP_PARAM(filter, BiasMode::NO_BIAS) \ + break; \ + case BiasMode::BROADCAST_CHANNEL_BIAS: \ + GET_OP_PARAM(filter, BiasMode::BROADCAST_CHANNEL_BIAS) \ + break; \ + case BiasMode::BIAS: \ + GET_OP_PARAM(filter, BiasMode::BIAS) \ + break; \ + default: \ + megdnn_assert(0, "bias_mode not supported"); \ + break; \ + } + +#define DISPATCH_CONV_KERN() \ + switch (param.filter_meta.spatial[0]) { \ + case 1: \ + GET_BIAS_MODE_PARAM(1) \ + break; \ + case 2: \ + GET_BIAS_MODE_PARAM(2) \ + break; \ + case 3: \ + GET_BIAS_MODE_PARAM(3) \ + break; \ + case 5: \ + GET_BIAS_MODE_PARAM(5) \ + break; \ + case 7: \ + GET_BIAS_MODE_PARAM(7) \ + break; \ + default: \ + megdnn_assert(0, "filter not supported"); \ + break; \ + } + + DISPATCH_CONV_KERN(); + +#undef DO_CONV_KERN_FUN +#undef GET_REMAIN_W_PARAM +#undef GET_OP_PARAM +#undef GET_BIAS_MODE_PARAM +#undef DISPATCH_CONV_KERN + + megdnn_assert(do_conv_fun); + + WorkspaceBundle bundle = get_bundle(param); + + SmallVector ret_kerns; + + auto exec_one_group = [bundle, do_conv_fun]( + const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) mutable { + auto fm = kern_param.filter_meta; + size_t IC = fm.icpg / 8; + size_t OC = fm.ocpg / 8; + bundle.set(kern_param.workspace_ptr); + for (size_t ic = 0; ic < IC; ic++) { + copy_padding_kern(bundle, kern_param, ncb_index, + {ncb_index.thread_id, 0, ic}); + } + for (size_t oc = 0; oc < OC; oc++) { + do_conv_fun(bundle, kern_param, ncb_index, + {ncb_index.thread_id, 0, oc}); + } + }; + // TODO: large group only, further multithread optimization required + ret_kerns.push_back({exec_one_group, {group, batch, 1_z}}); + + return ret_kerns; +} + +#endif + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/f16/direct_nchw88_kern.cpp b/dnn/src/arm_common/conv_bias/f16/direct_nchw88_kern.cpp new file mode 100644 index 00000000..354ce3b1 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/f16/direct_nchw88_kern.cpp @@ -0,0 +1,307 @@ +/** + * \file + * dnn/src/arm_common/conv_bias/f16/direct_nchw88_kern.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "src/arm_common/conv_bias/f16/direct_nchw88_kern.h" +#include "src/arm_common/conv_bias/opr_impl.h" +#include "src/arm_common/elemwise_op.h" +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/fallback/conv_bias/common.h" + +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + +using namespace megdnn; +using namespace arm_common; + +template +struct compute_fma { + static inline void call(const float16x8_t* ri, const float16x8_t* rf, + float16x8_t* rdst) { +#if defined(__aarch64__) + rdst[bw] = vfmaq_laneq_f16(rdst[bw], rf[pc], ri[bw], pc); +#else + rdst[bw] = vfmaq_f16(rdst[bw], rf[pc], + vdupq_n_f16(vgetq_lane_f16(ri[bw], pc))); +#endif + compute_fma::call(ri, rf, rdst); + } +}; + +template +struct compute_fma { + static inline void call(const float16x8_t* ri, const float16x8_t* rf, + float16x8_t* rdst) { + compute_fma::call(ri, rf, rdst); + } +}; + +template +struct compute_fma { + static inline void call(const float16x8_t* ri, const float16x8_t* rf, + float16x8_t* rdst) {} +}; + +template +struct load_dst { + static inline void call(float16x8_t* rdst, const float16_t* dst_ptr) { + rdst[bw] = vld1q_f16(dst_ptr + bw * PC); + load_dst::call(rdst, dst_ptr); + } +}; + +template +struct load_dst { + static inline void call(float16x8_t* rdst, const float16_t* dst_ptr) {} +}; + +template +struct load_src { + static inline void call(float16x8_t* ri, const float16_t* src_ptr) { + ri[bw] = vld1q_f16(src_ptr + bw * SW * PC); + load_src::call(ri, src_ptr); + } +}; + +template +struct load_src { + static inline void call(float16x8_t* ri, const float16_t* src_ptr) {} +}; + +template +struct load_filter { + static inline void call(float16x8_t* rf, const float16_t* filter_ptr) { + rf[pc] = vld1q_f16(filter_ptr + pc * PC); + load_filter::call(rf, filter_ptr); + } +}; + +template +struct load_filter { + static inline void call(float16x8_t* rf, const float16_t* filter_ptr) {} +}; + +template +struct store_dst { + static inline void call(const float16x8_t* rdst, float16_t* dst_ptr) { + vst1q_f16(dst_ptr + bw * PC, rdst[bw]); + store_dst::call(rdst, dst_ptr); + } +}; + +template +struct store_dst { + static inline void call(const float16x8_t* rdst, float16_t* dst_ptr) {} +}; + +template +static inline void do_conv_kern_1xBW(const float16_t*& src, float16_t*& dst, + const float16_t* filter, int IW, int OW, + int& ow) { + constexpr int PC = 8; + constexpr int FW = FH; + constexpr int SW = SH; + + float16x8_t rf[PC]; + if (FH == 1 && FW == 1) { + load_filter::call(rf, filter); + } + + for (; ow + BW - 1 < OW; ow += BW) { + float16x8_t rdst[BW]; + load_dst::call(rdst, dst); + + for (int fh = 0; fh < FH; ++fh) { + for (int fw = 0; fw < FW; ++fw) { + float16x8_t ri[BW]; + load_src::call(ri, src + (fh * IW + fw) * PC); + + if (FH > 1 || FW > 1) { + load_filter::call(rf, + filter + (fh * FW + fw) * PC * PC); + } + + compute_fma::call(ri, rf, rdst); + } + } + + store_dst::call(rdst, dst); + + src += SW * BW * PC; + dst += BW * PC; + } +} + +template +static void do_load_bias_kern(float16_t* dst, const float16_t* bias, int OH, + int OW) { + constexpr int PC = 8; + + if (bias_mode == BiasMode::NO_BIAS) { + memset(dst, 0, OH * OW * PC * sizeof(float16_t)); + } else if (bias_mode == BiasMode::BIAS) { + memcpy(dst, bias, OH * OW * PC * sizeof(float16_t)); + } else if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { + float16x8_t bias_v = vld1q_f16(bias); + int i = 0; + for (; i + 3 < OH * OW; i += 4) { + vst1q_f16(dst + PC * 0, bias_v); + vst1q_f16(dst + PC * 1, bias_v); + vst1q_f16(dst + PC * 2, bias_v); + vst1q_f16(dst + PC * 3, bias_v); + dst += PC * 4; + } + for (; i < OH * OW; i += 1) { + vst1q_f16(dst, bias_v); + dst += PC; + } + } +} + +template +static void do_op_kern(float16_t* dst, int OH, int OW) { + constexpr int PC = 8; + + Op op; + + int i = 0; + for (; i + 3 < OH * OW; i += 4) { + float16x8_t dst0 = vld1q_f16(dst + PC * 0); + float16x8_t dst1 = vld1q_f16(dst + PC * 1); + float16x8_t dst2 = vld1q_f16(dst + PC * 2); + float16x8_t dst3 = vld1q_f16(dst + PC * 3); + + dst0 = op(dst0); + dst1 = op(dst1); + dst2 = op(dst2); + dst3 = op(dst3); + + vst1q_f16(dst + PC * 0, dst0); + vst1q_f16(dst + PC * 1, dst1); + vst1q_f16(dst + PC * 2, dst2); + vst1q_f16(dst + PC * 3, dst3); + + dst += PC * 4; + } + for (; i < OH * OW; i += 1) { + vst1q_f16(dst, op(vld1q_f16(dst))); + dst += PC; + } +} + +template +static void do_conv_kern(const float16_t* src, float16_t* dst, + const float16_t* filter, int IC, int IH, int IW, + int OH, int OW) { + constexpr int PC = 8; + constexpr int FW = FH; + + for (int ic = 0; ic < IC; ic += 1) { + const float16_t* src_ptr_h = src; + float16_t* dst_ptr_h = dst; + + for (int oh = 0; oh < OH; oh += 1) { + const float16_t* src_ptr_w = src_ptr_h; + float16_t* dst_ptr_w = dst_ptr_h; + + int ow = 0; + do_conv_kern_1xBW(src_ptr_w, dst_ptr_w, filter, IW, OW, + ow); + if (OW & 3) { + do_conv_kern_1xBW(src_ptr_w, dst_ptr_w, filter, IW, + OW, ow); + do_conv_kern_1xBW(src_ptr_w, dst_ptr_w, filter, IW, + OW, ow); + } + src_ptr_h += SH * IW * PC; + dst_ptr_h += OW * PC; + } + src += IH * IW * PC; + filter += FH * FW * PC * PC; + } +} + +static void do_conv_kern_1x1(const float16_t* src, float16_t* dst, + const float16_t* filter, int IC, int OH, int OW) { + constexpr int PC = 8; + const int IH = OH; + const int IW = OW; + const int IHW = IH * IW; + const int OHW = OH * OW; + + for (int ic = 0; ic < IC; ic += 1) { + const float16_t* src_ptr_hw = src; + float16_t* dst_ptr_hw = dst; + + int ohw = 0; + do_conv_kern_1xBW<1, 1, 8>(src_ptr_hw, dst_ptr_hw, filter, IHW, OHW, + ohw); + do_conv_kern_1xBW<1, 1, 4>(src_ptr_hw, dst_ptr_hw, filter, IHW, OHW, + ohw); + do_conv_kern_1xBW<1, 1, 1>(src_ptr_hw, dst_ptr_hw, filter, IHW, OHW, + ohw); + src += IHW * PC; + filter += PC * PC; + } +} + +template +void conv_bias::conv_direct_fp16_nchw88(const __fp16* src, const __fp16* filter, + const __fp16* bias, __fp16* dst, int IC, + int IH, int IW, int OH, int OW) { + do_load_bias_kern(dst, bias, OH, OW); + if (FH == 1 && SH == 1 && IH == OH && IW == OW) { + do_conv_kern_1x1(src, dst, filter, IC, OH, OW); + } else { + do_conv_kern(src, dst, filter, IC, IH, IW, OH, OW); + } + do_op_kern(dst, OH, OW); +} + +#define INSTANTIATION(stride, filter, bias, Op) \ + template void \ + conv_bias::conv_direct_fp16_nchw88( \ + const __fp16*, const __fp16*, const __fp16*, __fp16*, int, int, \ + int, int, int); + +#define FOR_OP(stride, filter, bias) \ + INSTANTIATION(stride, filter, bias, SigmoidOp<__fp16>) \ + INSTANTIATION(stride, filter, bias, ReluOp<__fp16>) \ + INSTANTIATION(stride, filter, bias, HSwishOp<__fp16>) \ + INSTANTIATION(stride, filter, bias, NoneOp<__fp16>) + +#define FOR_BIAS(stride, filter) \ + FOR_OP(stride, filter, BiasMode::NO_BIAS) \ + FOR_OP(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS) \ + FOR_OP(stride, filter, BiasMode::BIAS) + +#define FOR_FILTER(stride) \ + FOR_BIAS(stride, 1) \ + FOR_BIAS(stride, 2) \ + FOR_BIAS(stride, 3) \ + FOR_BIAS(stride, 5) \ + FOR_BIAS(stride, 7) + +#define FOR_STRIDE \ + FOR_FILTER(1) \ + FOR_FILTER(2) + +FOR_STRIDE + +#undef FOR_STRIDE +#undef FOR_FILTER +#undef FOR_BIAS +#undef FOR_OP +#undef INSTANTIATION + +#endif + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/f16/direct_nchw88_kern.h b/dnn/src/arm_common/conv_bias/f16/direct_nchw88_kern.h new file mode 100644 index 00000000..dafedc3e --- /dev/null +++ b/dnn/src/arm_common/conv_bias/f16/direct_nchw88_kern.h @@ -0,0 +1,32 @@ +/** + * \file dnn/src/arm_common/conv_bias/f16/direct_nchw88_kern.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#pragma once + +#include "src/arm_common/conv_bias/opr_impl.h" +#include "src/fallback/conv_bias/common.h" + +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + +namespace megdnn { +namespace arm_common { +namespace conv_bias { + +template +void conv_direct_fp16_nchw88(const __fp16* src, const __fp16* filter, + const __fp16* bias, __fp16* dst, int IC, int IH, + int IW, int OH, int OW); + +} // namespace conv_bias +} // namespace arm_common +} // namespace megdnn + +#endif diff --git a/dnn/src/arm_common/conv_bias/opr_impl.cpp b/dnn/src/arm_common/conv_bias/opr_impl.cpp index e3bb59c9..19ca06b3 100644 --- a/dnn/src/arm_common/conv_bias/opr_impl.cpp +++ b/dnn/src/arm_common/conv_bias/opr_impl.cpp @@ -86,6 +86,7 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj { AlgoF16Direct f16_direct; AlgoF16DirectStride1 f16_direct_stride1; AlgoF16ChannelWiseNCHW88 f16_channel_wise_nchw88; + AlgoF16DirectNCHW88 f16_direct_nchw88; #endif SmallVector> refhold; @@ -121,6 +122,7 @@ public: m_direct_algos.emplace_back(&f16_direct_stride1); m_direct_algos.emplace_back(&f16_direct); m_direct_algos.emplace_back(&f16_channel_wise_nchw88); + m_direct_algos.emplace_back(&f16_direct_nchw88); #endif m_direct_algos.emplace_back(&i8x8x16_direct); m_direct_algos.emplace_back(&i8x8x16_stride2_filter2); @@ -252,7 +254,6 @@ public: } } - for (auto&& algo : m_direct_algos) { m_all_algos_map.emplace(algo->info().desc, algo); } @@ -261,8 +262,7 @@ public: } } - const SmallVector& direct_algos() - const { + const SmallVector& direct_algos() const { return m_direct_algos; } const SmallVector& winograd_algos() diff --git a/dnn/src/arm_common/conv_bias/opr_impl.h b/dnn/src/arm_common/conv_bias/opr_impl.h index 225288e3..38106a77 100644 --- a/dnn/src/arm_common/conv_bias/opr_impl.h +++ b/dnn/src/arm_common/conv_bias/opr_impl.h @@ -10,9 +10,9 @@ * implied. */ #pragma once +#include "src/common/algo_base.h" #include "src/common/utils.h" #include "src/fallback/conv_bias/opr_impl.h" -#include "src/common/algo_base.h" namespace megdnn { namespace arm_common { @@ -28,7 +28,8 @@ public: } }; - SmallVector get_all_packed_algo() override; + SmallVector get_all_packed_algo() + override; bool is_matmul_quantized_prefer( const fallback::ConvBiasImpl::NCBKernSizeParam& ncb_param) @@ -97,6 +98,7 @@ private: class AlgoF16Direct; class AlgoF16DirectStride1; class AlgoF16ChannelWiseNCHW88; + class AlgoF16DirectNCHW88; #endif class AlgoPack; diff --git a/dnn/src/fallback/conv_bias/opr_impl.h b/dnn/src/fallback/conv_bias/opr_impl.h index fe5393e8..2f875284 100644 --- a/dnn/src/fallback/conv_bias/opr_impl.h +++ b/dnn/src/fallback/conv_bias/opr_impl.h @@ -56,8 +56,7 @@ public: bool is_thread_safe() const override { return true; } void exec_preprocess(const TensorLayout& src_layout, - _megdnn_tensor_in filter, - _megdnn_tensor_in bias, + _megdnn_tensor_in filter, _megdnn_tensor_in bias, const TensorLayout& z_layout, const TensorLayout& dst_layout, PreprocessedFilter* preprocessed_filter, @@ -243,6 +242,7 @@ public: ARM_COMMON_DIRECT_FP16, ARM_COMMON_DIRECT_STRD1_FP16, ARM_COMMON_CHWNWISE_NCHW88_F16, + ARM_COMMON_DIRECT_NCHW88_FP16, ARM_COMMON_WINOGRAD_F23_4X4_FP32, ARM_COMMON_WINOGRAD_F63_FP32, ARM_COMMON_WINOGRAD_F63_4X4_FP32, @@ -288,7 +288,7 @@ public: #else ARMV7_MATMUL_S8, ARMV7_MATMUL_QU8, -#endif // MEGDNN_AARCH64 +#endif // MEGDNN_AARCH64 #endif }; diff --git a/dnn/test/arm_common/conv_bias_multi_thread.cpp b/dnn/test/arm_common/conv_bias_multi_thread.cpp index 1c697a18..57477fda 100644 --- a/dnn/test/arm_common/conv_bias_multi_thread.cpp +++ b/dnn/test/arm_common/conv_bias_multi_thread.cpp @@ -124,8 +124,8 @@ std::vector get_nchw44_channel_wise_args( for (size_t n : {1, 2}) { for (auto nlmode : nonlinemode) { for (bool pad : {true}) { - for (size_t group : {1, 2, 4, 7, 128}) { - for (size_t size : {4, 6, 7, 9, 15, 40}) { + for (size_t group : {1, 2, 4, 7, 16}) { + for (size_t size : {4, 6, 7, 9, 20}) { for (size_t kern : kernel) { pack(n, group, size, size, kern, stride, nlmode, pad); @@ -134,8 +134,8 @@ std::vector get_nchw44_channel_wise_args( } } for (bool pad : {false}) { - for (size_t group : {1, 2, 7, 128}) { - for (size_t size : {7, 9, 15, 40}) { + for (size_t group : {1, 2, 7, 16}) { + for (size_t size : {7, 9, 20}) { for (size_t kern : kernel) { pack(n, group, size, size, kern, stride, nlmode, pad); @@ -199,8 +199,8 @@ std::vector get_nchw88_channel_wise_args( for (size_t n : {1, 2}) { for (auto nlmode : nonlinemode) { for (bool pad : {true}) { - for (size_t group : {1, 2, 4, 7, 8, 128}) { - for (size_t size : {4, 6, 7, 9, 15, 40}) { + for (size_t group : {1, 2, 4, 7, 8, 16}) { + for (size_t size : {4, 6, 7, 9, 20}) { for (size_t kern : kernel) { pack(n, group, size, size, kern, stride, nlmode, pad); @@ -209,8 +209,8 @@ std::vector get_nchw88_channel_wise_args( } } for (bool pad : {false}) { - for (size_t group : {1, 2, 7, 128}) { - for (size_t size : {7, 9, 15, 40}) { + for (size_t group : {1, 2, 7, 16}) { + for (size_t size : {7, 9, 20}) { for (size_t kern : kernel) { pack(n, group, size, size, kern, stride, nlmode, pad); @@ -412,6 +412,23 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_CHANNEL_WISE_STRIDE2_FP16_NCHW88) { get_nchw88_channel_wise_args({2, 3, 5}, 2, false, false, false), handle(), rng, "F16_CHANNEL_WISE_NCHW88", 0.03); } + +TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP16_NCHW88_S1) { + NormalRNG rng(1); + checker_conv_bias_f16( + get_nchw88_conv_bias_args({1, 2, 3, 5, 7}, FULL_NLMODE, + ALL_BIASMODE, 1), + handle(), rng, "F16_CONV_NCHW88_DIRECT", 0.03); +} + +TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP16_NCHW88_S2) { + NormalRNG rng(1); + checker_conv_bias_f16( + get_nchw88_conv_bias_args({1, 2, 3, 5, 7}, FULL_NLMODE, + ALL_BIASMODE, 2), + handle(), rng, "F16_CONV_NCHW88_DIRECT", 0.03); +} + #endif /**********************************algo 8816 direct************************/ @@ -794,8 +811,6 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F63) { check_winograd("1:6:32", checker, args); } - - TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F63_4) { using namespace conv_bias; std::vector args = get_winograd_mk_packed_args(); @@ -804,19 +819,15 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F63_4) { check_winograd("4:6:16", checker, args, param::MatrixMul::Format::MK4); } - - TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F63_4_NCHW44) { using namespace conv_bias; std::vector args = - get_nchw44_conv_bias_args({3},QUAN_NLMODE,BR_AND_NO_BIASMODE,1); + get_nchw44_conv_bias_args({3}, QUAN_NLMODE, BR_AND_NO_BIASMODE, 1); Checker checker(handle()); check_winograd("4:6:16", checker, args, param::MatrixMul::Format::MK4, param::ConvBias::Format::NCHW44); } - - //! uncomment it when low precision mode is ok #if 0 TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F73_4_NCHW44) { @@ -847,8 +858,6 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F54) { check_winograd("1:5:32", checker, args); } - - TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F45) { using namespace conv_bias; std::vector args = get_winograd_args(5); @@ -971,18 +980,17 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8) { using namespace conv_bias; Checker checker(handle()); - auto run = [&checker](const std::vector& args, - DType A_dtype, + auto run = [&checker](const std::vector& args, DType A_dtype, DType B_dtype, DType C_dtype, DType D_dtype, float eps) { for (auto&& arg : args) { - checker.set_dtype(0, A_dtype) - .set_dtype(1, B_dtype) - .set_dtype(2, C_dtype) - .set_dtype(4, D_dtype) - .set_epsilon(eps) - .set_param(arg.param) - .execs({arg.src, arg.filter, arg.bias, {}, {}}); + checker.set_dtype(0, A_dtype) + .set_dtype(1, B_dtype) + .set_dtype(2, C_dtype) + .set_dtype(4, D_dtype) + .set_epsilon(eps) + .set_param(arg.param) + .execs({arg.src, arg.filter, arg.bias, {}, {}}); } }; @@ -997,9 +1005,8 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8) { std::vector quantized_args = get_int8_nchw44_args(3, 4); UniformIntRNG int_rng{-50, 50}; checker.set_rng(0, &int_rng).set_rng(1, &int_rng).set_rng(2, &int_rng); - run(quantized_args, dtype::QuantizedS8(2.5f), - dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), - dtype::QuantizedS8(60.25f),1e-3); + run(quantized_args, dtype::QuantizedS8(2.5f), dtype::QuantizedS8(2.5f), + dtype::QuantizedS32(6.25f), dtype::QuantizedS8(60.25f), 1e-3); } TEST_F(ARM_COMMON_MULTI_THREADS, diff --git a/dnn/test/arm_common/conv_bias_multi_thread_benchmark.cpp b/dnn/test/arm_common/conv_bias_multi_thread_benchmark.cpp index 18884324..fa061cb2 100644 --- a/dnn/test/arm_common/conv_bias_multi_thread_benchmark.cpp +++ b/dnn/test/arm_common/conv_bias_multi_thread_benchmark.cpp @@ -400,7 +400,8 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_DIRECTF16_STR1) { benchmark_impl(param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, {1, {4}}, data_type); } -TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_F16_NCHW88) { + +TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CHANNEL_WISE_FP16_NCHW88) { constexpr size_t RUNS = 50; std::string algo_name = "F16_CHANNEL_WISE_NCHW88"; @@ -462,6 +463,64 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_F16_NCHW88) { bench_case(1, 64, 28, 28, 2, 0, 2); } +TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_FP16_NCHW88) { + constexpr size_t RUNS = 40; + std::vector data_type = {dtype::Float16(), dtype::Float16(), + dtype::Float16(), dtype::Float16()}; + auto bench_case = [&](size_t N, size_t IC, size_t OC, size_t H, size_t W, + size_t FS, size_t group, size_t P, size_t S) { + param::ConvBias param; + param.nonlineMode = param::ConvBias::NonlineMode::RELU; + param.pad_h = P; + param.pad_w = P; + param.stride_h = S; + param.stride_w = S; + param.sparse = param::ConvBias::Sparse::DENSE; + param.format = param::ConvBias::Format::NCHW88; + auto OH = (H + 2 * P - FS) / static_cast(S) + 1; + auto OW = (W + 2 * P - FS) / static_cast(S) + 1; + TensorShape src = {N, IC / 8, H, W, 8}; + TensorShape filter = {OC / 8, IC / 8, FS, FS, 8, 8}; + if (group > 1) { + filter = {group, OC / group / 8, IC / group / 8, FS, FS, 8, 8}; + param.sparse = param::ConvBias::Sparse::GROUP; + } + TensorShape bias = {1, OC / 8, 1, 1, 8}; + TensorShape dst = {N, OC / 8, OH, OW, 8}; + + SmallVector shapes{src, filter, bias, {}, dst}; + float computations = + (((IC / group) * FS * FS + 1) * dst.total_nr_elems() * 2 + + dst.total_nr_elems()) * + 1e-6; + std::vector, float>> shape_arg = { + std::make_pair(shapes, computations)}; + benchmark_impl(param, shape_arg, ".+", RUNS, {4, {4, 5, 6, 7}}, + {1, {7}}, data_type); + }; + bench_case(1, 64, 64, 28, 28, 3, 1, 1, 1); + bench_case(1, 64, 64, 28, 28, 5, 1, 2, 1); + bench_case(1, 64, 64, 28, 28, 7, 1, 3, 1); + + bench_case(1, 64, 64, 28, 28, 3, 1, 1, 2); + bench_case(1, 64, 64, 28, 28, 5, 1, 2, 2); + bench_case(1, 64, 64, 28, 28, 7, 1, 3, 2); + + bench_case(1, 64, 64, 28, 28, 3, 2, 1, 1); + bench_case(1, 64, 64, 28, 28, 3, 4, 1, 1); + bench_case(1, 64, 64, 28, 28, 3, 8, 1, 1); + + bench_case(1, 16, 16, 28, 28, 3, 1, 1, 1); + bench_case(1, 32, 32, 28, 28, 3, 1, 1, 1); + bench_case(1, 128, 128, 28, 28, 3, 1, 1, 1); + bench_case(1, 256, 256, 28, 28, 3, 1, 1, 1); + + bench_case(1, 64, 64, 7, 7, 3, 1, 1, 1); + bench_case(1, 64, 64, 14, 14, 3, 1, 1, 1); + bench_case(1, 64, 64, 56, 56, 3, 1, 1, 1); + bench_case(1, 64, 64, 112, 112, 3, 1, 1, 1); +} + #endif TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_DIRECT_INT8x8x16) { @@ -769,10 +828,10 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_INT8_NCHW44_DOT) { bench_case(1, 128, 128, 28, 28, 3, 4, 1, 1); bench_case(1, 256, 256, 14, 14, 3, 4, 1, 1); bench_case(1, 512, 512, 7, 7, 3, 4, 1, 1); - } -TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_INT8_NCHW44_DOT_S2) { +TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, + BENCHMARK_CONVBIAS_INT8_NCHW44_DOT_S2) { constexpr size_t RUNS = 40; std::vector data_type = { dtype::QuantizedS8(2.5f), dtype::QuantizedS8(2.5f), @@ -825,16 +884,13 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_INT8_NCHW44_DOT_S2 bench_case(1, 128, 128, 28, 28, 3, 4, 1, 2); bench_case(1, 256, 256, 14, 14, 3, 4, 1, 2); bench_case(1, 512, 512, 7, 7, 3, 4, 1, 2); - } - #endif TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_FLOAT_NCHW44) { constexpr size_t RUNS = 40; - std::vector data_type = { - dtype::Float32(), dtype::Float32(), - dtype::Float32(), dtype::Float32()}; + std::vector data_type = {dtype::Float32(), dtype::Float32(), + dtype::Float32(), dtype::Float32()}; auto bench_case = [&](size_t N, size_t IC, size_t OC, size_t H, size_t W, size_t FS, size_t group, size_t P, size_t S, bool is_nchw = false) { @@ -880,15 +936,12 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_FLOAT_NCHW44) { bench_case(1, 128, 128, 28, 28, 3, 4, 1, 2); bench_case(1, 256, 256, 14, 14, 3, 4, 1, 2); bench_case(1, 512, 512, 7, 7, 3, 4, 1, 2); - - bench_case(1, 64, 64, 56*2, 56*2, 3, 4, 1, 2); - bench_case(1, 128, 128, 28*2, 28*2, 3, 4, 1, 2); - bench_case(1, 256, 256, 14*2, 14*2, 3, 4, 1, 2); - bench_case(1, 512, 512, 7*2, 7*2, 3, 4, 1, 2); -} - - + bench_case(1, 64, 64, 56 * 2, 56 * 2, 3, 4, 1, 2); + bench_case(1, 128, 128, 28 * 2, 28 * 2, 3, 4, 1, 2); + bench_case(1, 256, 256, 14 * 2, 14 * 2, 3, 4, 1, 2); + bench_case(1, 512, 512, 7 * 2, 7 * 2, 3, 4, 1, 2); +} TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_INT8_INT8_INT8_STRIDE2) { @@ -1473,9 +1526,9 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_WINOGRAD_INT8) { algo_name = "WINOGRAD:ARMV7_INT16X16X32_MK8_4X8:8:2:32"; #endif - - std::vector data_type = {dtype::QuantizedS8(2.5f), dtype::QuantizedS8(2.5f), - dtype::QuantizedS32(6.25f) ,dtype::QuantizedS8(60.25f) }; + std::vector data_type = { + dtype::QuantizedS8(2.5f), dtype::QuantizedS8(2.5f), + dtype::QuantizedS32(6.25f), dtype::QuantizedS8(60.25f)}; printf("Benchmark WINOGRAD_IN8_MK8 algo\n"); benchmark_impl(param, shapes_and_computation, algo_name, RUNS, {4, {4, 5, 6, 7}}, {1, {4}}, data_type); @@ -1839,7 +1892,6 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, {1, {4}}, data_type); } - TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_IM2COL_NCHW44_INT8x8x32_STRIDE1) { constexpr size_t RUNS = 50; @@ -1852,18 +1904,17 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, param.stride_w = 1; param.sparse = param::ConvBias::Sparse::DENSE; param.format = param::ConvBias::Format::NCHW44; - std::vector, float>> shapes_and_computation; auto bench_case = [&](size_t N, size_t IC, size_t OC, size_t H, size_t W, - size_t FS, size_t group=1) { - SmallVector shapes{{N, IC, H, W,4}, - {OC, IC / group, FS, FS,4,4}, + size_t FS, size_t group = 1) { + SmallVector shapes{{N, IC, H, W, 4}, + {OC, IC / group, FS, FS, 4, 4}, {/*1, OC, 1, 1*/}, {}, - {N, OC, H, W,4}}; - TensorShape dst{N, OC, H, W,4}; + {N, OC, H, W, 4}}; + TensorShape dst{N, OC, H, W, 4}; float computations = ((4 * IC / group) * FS * FS * dst.total_nr_elems() * 2 + dst.total_nr_elems()) * @@ -1907,9 +1958,10 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, #endif std::string algo_name = "IM2COLMATMUL:AARCH64_INT8X8X32_MK4_4X4X16:96"; printf("Benchmarker IM2COLMATMUL:AARCH64_INT8X8X32_MK4_4X4X16:96 algo\n"); - std::vector data_type = { - dtype::QuantizedS8(2.5f), dtype::QuantizedS8(2.5f), - dtype::QuantizedS32(6.25f), {}}; + std::vector data_type = {dtype::QuantizedS8(2.5f), + dtype::QuantizedS8(2.5f), + dtype::QuantizedS32(6.25f), + {}}; benchmark_impl(param, shapes_and_computation, algo_name, RUNS, {4, {4, 5, 6, 7}}, {1, {4}}, data_type); benchmark_impl(param, shapes_and_computation, algo_name, RUNS, @@ -1917,10 +1969,9 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, benchmark_impl(param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, {1, {4}}, data_type); - - algo_name = "IM2COLMATMUL:AARCH64_INT8X8X32_MK4_4X4X16:192"; - printf("Benchmarker IM2COLMATMUL:AARCH64_INT8X8X32_MK4_4X4X16:192 algo\n"); + printf("Benchmarker IM2COLMATMUL:AARCH64_INT8X8X32_MK4_4X4X16:192 " + "algo\n"); benchmark_impl(param, shapes_and_computation, algo_name, RUNS, {4, {4, 5, 6, 7}}, {1, {4}}, data_type); benchmark_impl(param, shapes_and_computation, algo_name, RUNS, @@ -1929,14 +1980,14 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, {1, {4}}, data_type); algo_name = "IM2COLMATMUL:AARCH64_INT8X8X32_MK4_4X4X16:384"; - printf("Benchmarker IM2COLMATMUL:AARCH64_INT8X8X32_MK4_4X4X16:384 algo\n"); + printf("Benchmarker IM2COLMATMUL:AARCH64_INT8X8X32_MK4_4X4X16:384 " + "algo\n"); benchmark_impl(param, shapes_and_computation, algo_name, RUNS, {4, {4, 5, 6, 7}}, {1, {4}}, data_type); benchmark_impl(param, shapes_and_computation, algo_name, RUNS, {4, {4, 5, 6, 7}}, {1, {7}}, data_type); benchmark_impl(param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, {1, {4}}, data_type); - } #endif diff --git a/dnn/test/common/conv_bias.cpp b/dnn/test/common/conv_bias.cpp index fe202235..aacb62b0 100644 --- a/dnn/test/common/conv_bias.cpp +++ b/dnn/test/common/conv_bias.cpp @@ -1185,9 +1185,10 @@ void check_conv_bias_preprocess(std::vector args, } } -void checker_conv_bias_common(std::vector args, Handle* handle, - RNG* rng, float epsilon, DType type0, DType type1, - DType type2, DType type3, const char* algo_name) { +void checker_conv_bias_common(std::vector args, + Handle* handle, RNG* rng, float epsilon, + DType type0, DType type1, DType type2, + DType type3, const char* algo_name) { using namespace conv_bias; Checker checker(handle); @@ -1377,6 +1378,88 @@ std::vector get_nchw44_conv_bias_args( } return args; } + +std::vector get_nchw88_conv_bias_args( + std::vector kernel_vec, + std::vector nlmode_vec, + std::vector biasmode_vec, size_t stride) { + using namespace conv_bias; + using NLMode = param::ConvBias::NonlineMode; + + std::vector args; + + auto pack = [&](size_t n, size_t oc, size_t ic, size_t h, size_t w, + size_t kernel, size_t stride, size_t group, NLMode nlmode, + megdnn::BiasMode bias_mode) { + constexpr int pack_c = 8; + const size_t pad = kernel / 2; + auto oc_per_group = oc / group; + auto ic_per_group = ic / group; + + megdnn_assert(oc_per_group % pack_c == 0 && ic_per_group % pack_c == 0, + "ocpg/icpg not divided by 8"); + + size_t kernel_h = kernel; + size_t kernel_w = kernel; + param::ConvBias param; + param.format = param::ConvBias::Format::NCHW88; + + param.stride_h = stride; + param.stride_w = stride; + param.pad_h = pad; + param.pad_w = pad; + param.nonlineMode = nlmode; + + auto src_tensor_shape = TensorShape{n, ic / pack_c, h, w, pack_c}; + auto weight_tensor_shape = TensorShape{ + oc / pack_c, ic / pack_c, kernel_h, kernel_w, pack_c, pack_c}; + auto bias_tensor_shape = TensorShape{}; + if (bias_mode == megdnn::BiasMode::BROADCAST_CHANNEL_BIAS) { + bias_tensor_shape = {1, oc / pack_c, 1, 1, pack_c}; + } else if (bias_mode == megdnn::BiasMode::BIAS) { + bias_tensor_shape = {n, oc / pack_c, + (h + 2 * pad - kernel) / stride + 1, + (w + 2 * pad - kernel) / stride + 1, pack_c}; + } + if (group == 1) { + param.sparse = param::ConvBias::Sparse::DENSE; + } else { + param.sparse = param::ConvBias::Sparse::GROUP; + weight_tensor_shape = TensorShape{group, + oc_per_group / pack_c, + ic_per_group / pack_c, + kernel_h, + kernel_w, + pack_c, + pack_c}; + } + args.emplace_back(param, src_tensor_shape, weight_tensor_shape, + bias_tensor_shape); + }; + + for (auto bias : biasmode_vec) + for (auto nlmode : nlmode_vec) + for (size_t n : {1, 2}) + for (size_t kernel : kernel_vec) + for (size_t oc : {8, 16}) + for (size_t ic : {8, 16, 24}) + for (size_t h : {1, 3, 12}) + for (size_t w : {1, 8, 13}) { + for (size_t group = 1; group < oc / 8; + ++group) { + if (ic % (group * 8) || + oc % (group * 8)) { + continue; + } + if (kernel < h || kernel < w) { + continue; + } + pack(n, oc, ic, h, w, kernel, stride, + group, nlmode, bias); + } + } + return args; +} } // namespace conv_bias } // namespace test } // namespace megdnn