diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern.cpp b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern.cpp new file mode 100644 index 00000000..fcf8801e --- /dev/null +++ b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern.cpp @@ -0,0 +1,173 @@ +/** + * \file + * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "src/arm_common/conv_bias/fp32/f32_direct_nchw44_kern.h" +#include "src/arm_common/conv_bias/opr_impl.h" +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/fallback/conv_bias/common.h" +namespace megdnn { +namespace arm_common { +namespace conv_bias { +template <> +void pack_src_fp32_nchw44<1>(float* sptr_base, const float* sptr_origin, + const int, const int pw, const int pad_right, + const int ih, const int iw, const int iw2, + const int pad_top, const int pad_bottom, + const int ic, const int ic_stride) { + constexpr int ic_step = 4; + rep_step(ic_idx, ic, ic_step) { + const float* sptr = sptr_origin + ic_idx * ic_stride; + memset(sptr_base, 0, sizeof(float) * iw2 * pad_top * ic_step); + sptr_base += iw2 * pad_top * ic_step; + rep(ih_idx, ih) { + memset(sptr_base, 0, sizeof(float) * pw * ic_step); + sptr_base += pw * ic_step; + memcpy(sptr_base, sptr, sizeof(float) * iw * ic_step); + sptr_base += iw * ic_step; + sptr += iw * ic_step; + memset(sptr_base, 0, sizeof(float) * pad_right * ic_step); + sptr_base += pad_right * ic_step; + } + memset(sptr_base, 0, sizeof(float) * iw2 * pad_bottom * ic_step); + sptr_base += iw2 * pad_bottom * ic_step; + } +} + +namespace { + +static inline void odd_even_split_iw8_even(float* sptr_base, const float* sptr, + const int odd_start, + const int src_idx, + const int iw_idx) { + constexpr int ic_step = 4; + const int src_offset = src_idx * ic_step; + const int even_offset = iw_idx / 2 * ic_step; + const int odd_offset = (odd_start + iw_idx / 2) * ic_step; + float32x4_t temp[8]; + temp[0] = vld1q_f32(sptr + src_offset + 0 * ic_step); + temp[1] = vld1q_f32(sptr + src_offset + 1 * ic_step); + temp[2] = vld1q_f32(sptr + src_offset + 2 * ic_step); + temp[3] = vld1q_f32(sptr + src_offset + 3 * ic_step); + temp[4] = vld1q_f32(sptr + src_offset + 4 * ic_step); + temp[5] = vld1q_f32(sptr + src_offset + 5 * ic_step); + temp[6] = vld1q_f32(sptr + src_offset + 6 * ic_step); + temp[7] = vld1q_f32(sptr + src_offset + 7 * ic_step); + vst1q_f32(sptr_base + even_offset + 0 * ic_step, temp[0]); + vst1q_f32(sptr_base + even_offset + 1 * ic_step, temp[2]); + vst1q_f32(sptr_base + even_offset + 2 * ic_step, temp[4]); + vst1q_f32(sptr_base + even_offset + 3 * ic_step, temp[6]); + vst1q_f32(sptr_base + odd_offset + 0 * ic_step, temp[1]); + vst1q_f32(sptr_base + odd_offset + 1 * ic_step, temp[3]); + vst1q_f32(sptr_base + odd_offset + 2 * ic_step, temp[5]); + vst1q_f32(sptr_base + odd_offset + 3 * ic_step, temp[7]); +} + +static inline void odd_even_split_iw8_odd(float* sptr_base, const float* sptr, + const int odd_start, + const int src_idx, const int iw_idx) { + constexpr int ic_step = 4; + const int src_offset = src_idx * ic_step; + const int even_offset = (iw_idx + 1) / 2 * ic_step; + const int odd_offset = (odd_start + iw_idx / 2) * ic_step; + float32x4_t temp[8]; + temp[0] = vld1q_f32(sptr + src_offset + 0 * ic_step); + temp[1] = vld1q_f32(sptr + src_offset + 1 * ic_step); + temp[2] = vld1q_f32(sptr + src_offset + 2 * ic_step); + temp[3] = vld1q_f32(sptr + src_offset + 3 * ic_step); + temp[4] = vld1q_f32(sptr + src_offset + 4 * ic_step); + temp[5] = vld1q_f32(sptr + src_offset + 5 * ic_step); + temp[6] = vld1q_f32(sptr + src_offset + 6 * ic_step); + temp[7] = vld1q_f32(sptr + src_offset + 7 * ic_step); + vst1q_f32(sptr_base + odd_offset + 0 * ic_step, temp[0]); + vst1q_f32(sptr_base + odd_offset + 1 * ic_step, temp[2]); + vst1q_f32(sptr_base + odd_offset + 2 * ic_step, temp[4]); + vst1q_f32(sptr_base + odd_offset + 3 * ic_step, temp[6]); + vst1q_f32(sptr_base + even_offset + 0 * ic_step, temp[1]); + vst1q_f32(sptr_base + even_offset + 1 * ic_step, temp[3]); + vst1q_f32(sptr_base + even_offset + 2 * ic_step, temp[5]); + vst1q_f32(sptr_base + even_offset + 3 * ic_step, temp[7]); +} +} // namespace + +template <> +void pack_src_fp32_nchw44<2>(float* sptr_base, const float* sptr_origin, + const int ph, const int pw, const int pad_right, + const int ih, const int iw, const int iw2, + const int pad_top, const int pad_bottom, + const int ic, const int ic_stride) { + constexpr int ic_step = 4; + int odd_start = megdnn::div_ceil(iw2, 2); + float32x4_t zero_v = vdupq_n_f32(0.f); + MEGDNN_MARK_USED_VAR(ph); + bool even_start = pw % 2 == 0; + rep_step(ic_idx, ic, ic_step) { + const float* sptr = sptr_origin + ic_idx * ic_stride; + memset(sptr_base, 0, sizeof(float) * iw2 * pad_top * ic_step); + sptr_base += iw2 * pad_top * ic_step; + rep(ih_idx, ih) { + int iw_idx = 0; + rep(idx, pw) { + if (iw_idx % 2 == 0) { + vst1q_f32(sptr_base + iw_idx / 2 * ic_step, zero_v); + } else { + vst1q_f32(sptr_base + (odd_start + iw_idx / 2) * ic_step, + zero_v); + } + ++iw_idx; + } + int src_idx = 0; + if (even_start) { + for (; src_idx + 7 < iw; src_idx += 8) { + odd_even_split_iw8_even(sptr_base, sptr, odd_start, src_idx, + iw_idx); + iw_idx += 8; + } + } else { + for (; src_idx + 7 < iw; src_idx += 8) { + odd_even_split_iw8_odd(sptr_base, sptr, odd_start, src_idx, + iw_idx); + iw_idx += 8; + } + } + for (; src_idx < iw; ++src_idx) { + if (iw_idx % 2 == 0) { + vst1q_f32(sptr_base + iw_idx / 2 * ic_step, + vld1q_f32(sptr + src_idx * ic_step)); + } else { + vst1q_f32(sptr_base + (odd_start + iw_idx / 2) * ic_step, + vld1q_f32(sptr + src_idx * ic_step)); + } + ++iw_idx; + } + rep(idx, pad_right) { + if (iw_idx % 2 == 0) { + vst1q_f32(sptr_base + iw_idx / 2 * ic_step, zero_v); + } else { + vst1q_f32(sptr_base + (odd_start + iw_idx / 2) * ic_step, + zero_v); + } + ++iw_idx; + } + sptr_base += iw2 * ic_step; + sptr += iw * ic_step; + } + memset(sptr_base, 0, sizeof(float) * iw2 * pad_bottom * ic_step); + sptr_base += iw2 * pad_bottom * ic_step; + } +} + +} // namespace conv_bias +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s1.cpp b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s1.cpp new file mode 100644 index 00000000..dba54fd5 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s1.cpp @@ -0,0 +1,14 @@ +/** + * \file + * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s1.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" +INSTANTIATION_CONV_S1(2); \ No newline at end of file diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s2.cpp b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s2.cpp new file mode 100644 index 00000000..58430539 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s2.cpp @@ -0,0 +1,14 @@ +/** + * \file + * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s2.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" +INSTANTIATION_CONV_S2(2); \ No newline at end of file diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s1.cpp b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s1.cpp new file mode 100644 index 00000000..f3d4fdfe --- /dev/null +++ b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s1.cpp @@ -0,0 +1,14 @@ +/** + * \file + * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s1.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" +INSTANTIATION_CONV_S1(3); \ No newline at end of file diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s2.cpp b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s2.cpp new file mode 100644 index 00000000..76d3bae8 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s2.cpp @@ -0,0 +1,14 @@ +/** + * \file + * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s2.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" +INSTANTIATION_CONV_S2(3); \ No newline at end of file diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s1.cpp b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s1.cpp new file mode 100644 index 00000000..3f703967 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s1.cpp @@ -0,0 +1,14 @@ +/** + * \file + * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s1.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" +INSTANTIATION_CONV_S1(5); \ No newline at end of file diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s2.cpp b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s2.cpp new file mode 100644 index 00000000..bfe92661 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s2.cpp @@ -0,0 +1,14 @@ +/** + * \file + * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s2.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" +INSTANTIATION_CONV_S2(5); \ No newline at end of file diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s1.cpp b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s1.cpp new file mode 100644 index 00000000..f77cdeab --- /dev/null +++ b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s1.cpp @@ -0,0 +1,14 @@ +/** + * \file + * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s1.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" +INSTANTIATION_CONV_S1(7); \ No newline at end of file diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s2.cpp b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s2.cpp new file mode 100644 index 00000000..975e4389 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s2.cpp @@ -0,0 +1,14 @@ +/** + * \file + * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s2.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" +INSTANTIATION_CONV_S2(7); diff --git a/dnn/src/arm_common/conv_bias/fp32/f32_direct_stride1_nchw44_kern.cpp b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h similarity index 68% rename from dnn/src/arm_common/conv_bias/fp32/f32_direct_stride1_nchw44_kern.cpp rename to dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h index ba8af1ef..b40e1688 100644 --- a/dnn/src/arm_common/conv_bias/fp32/f32_direct_stride1_nchw44_kern.cpp +++ b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h @@ -1,6 +1,6 @@ /** * \file - * dnn/src/arm_common/conv_bias/fp32/f32_direct_stride1_nchw44_kern.cpp + * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. @@ -12,7 +12,7 @@ */ #include "megdnn/arch.h" -#include "src/arm_common/conv_bias/fp32/f32_direct_stride1_nchw44_kern.h" +#include "src/arm_common/conv_bias/fp32/f32_direct_nchw44_kern.h" #include "src/arm_common/conv_bias/intrinsic_helper.h" #include "src/arm_common/elemwise_op.h" #include "src/arm_common/simd_macro/marm_neon.h" @@ -24,21 +24,21 @@ using namespace megdnn; using namespace arm_common; namespace { -template +template struct ShiftCalHelper { static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight); }; -template -struct ShiftCalHelper { +template +struct ShiftCalHelper { static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) { -#define cb(step, lane) \ - c[0][step] = Func::template impl(c[0][step], weight[0][lane], \ - src[(step + src_idx) % 8]); \ - c[1][step] = Func::template impl(c[1][step], weight[1][lane], \ - src[(step + src_idx) % 8]); +#define cb(step, lane) \ + c[0][step] = vfmaq_laneq_f32(c[0][step], weight[0][lane], \ + src[(step + src_idx) % 8], lane); \ + c[1][step] = vfmaq_laneq_f32(c[1][step], weight[1][lane], \ + src[(step + src_idx) % 8], lane); UNROLL_CALL_RAW(8, cb, 0); UNROLL_CALL_RAW(8, cb, 1); @@ -47,15 +47,15 @@ struct ShiftCalHelper { #undef cb } }; -template -struct ShiftCalHelper { +template +struct ShiftCalHelper { static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) { -#define cb(step, lane) \ - c[0][step] = Func::template impl(c[0][step], weight[0][lane], \ - src[(step + src_idx) % 4]); \ - c[1][step] = Func::template impl(c[1][step], weight[1][lane], \ - src[(step + src_idx) % 4]); +#define cb(step, lane) \ + c[0][step] = vfmaq_laneq_f32(c[0][step], weight[0][lane], \ + src[(step + src_idx) % 4], lane); \ + c[1][step] = vfmaq_laneq_f32(c[1][step], weight[1][lane], \ + src[(step + src_idx) % 4], lane); UNROLL_CALL_RAW(4, cb, 0); UNROLL_CALL_RAW(4, cb, 1); @@ -64,13 +64,13 @@ struct ShiftCalHelper { #undef cb } }; -template -struct ShiftCalHelper { +template +struct ShiftCalHelper { static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) { -#define cb(step, lane) \ - c[0][step] = Func::template impl(c[0][step], weight[0][lane], \ - src[(step + src_idx) % 8]); +#define cb(step, lane) \ + c[0][step] = vfmaq_laneq_f32(c[0][step], weight[0][lane], \ + src[(step + src_idx) % 8], lane); UNROLL_CALL_RAW(8, cb, 0); UNROLL_CALL_RAW(8, cb, 1); @@ -79,13 +79,13 @@ struct ShiftCalHelper { #undef cb } }; -template -struct ShiftCalHelper { +template +struct ShiftCalHelper { static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) { -#define cb(step, lane) \ - c[0][step] = Func::template impl(c[0][step], weight[0][lane], \ - src[(step + src_idx) % 4]); +#define cb(step, lane) \ + c[0][step] = vfmaq_laneq_f32(c[0][step], weight[0][lane], \ + src[(step + src_idx) % 4], lane); UNROLL_CALL_RAW(4, cb, 0); UNROLL_CALL_RAW(4, cb, 1); @@ -95,11 +95,11 @@ struct ShiftCalHelper { } }; -template +template MEGDNN_ALWAYS_INLINE void cal_helper(T& c, T2& src, T3& weight) { - ShiftCalHelper::impl(c, src, weight); + ShiftCalHelper::impl( + c, src, weight); }; template struct OCHelper { @@ -162,13 +162,11 @@ struct KerNeonXXs1Nchw44FP32 { 0); load_helper( weight, weight_ptr, ld_weight_oc); - cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, - weight); + cal_helper<0, 0, c_dim, ow_block>(c, src, weight); src[0] = vld1q_f32(src_ptr + (ow_block)*ic_step); load_helper( weight, weight_ptr, ld_weight_oc); - cal_helper<1, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, - weight); + cal_helper<1, 0, c_dim, ow_block>(c, src, weight); src_ptr += ld_src_iw; weight_ptr += ld_weight_fh; } @@ -209,18 +207,15 @@ struct KerNeonXXs1Nchw44FP32 { 0); load_helper( weight, weight_ptr, ld_weight_oc); - cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, - weight); + cal_helper<0, 0, c_dim, ow_block>(c, src, weight); src[0] = vld1q_f32(src_ptr + (ow_block)*ic_step); load_helper( weight, weight_ptr, ld_weight_oc); - cal_helper<1, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, - weight); + cal_helper<1, 0, c_dim, ow_block>(c, src, weight); src[1] = vld1q_f32(src_ptr + (ow_block + 1) * ic_step); load_helper( weight, weight_ptr, ld_weight_oc); - cal_helper<2, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, - weight); + cal_helper<2, 0, c_dim, ow_block>(c, src, weight); src_ptr += ld_src_iw; weight_ptr += ld_weight_fh; } @@ -260,32 +255,27 @@ struct KerNeonXXs1Nchw44FP32 { 0); load_helper( weight, weight_ptr, ld_weight_oc); - cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, - weight); + cal_helper<0, 0, c_dim, ow_block>(c, src, weight); src[0] = vld1q_f32(src_ptr + (ow_block)*ic_step); load_helper( weight, weight_ptr, ld_weight_oc); - cal_helper<1, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, - weight); + cal_helper<1, 0, c_dim, ow_block>(c, src, weight); src[1] = vld1q_f32(src_ptr + (ow_block + 1) * ic_step); load_helper( weight, weight_ptr, ld_weight_oc); - cal_helper<2, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, - weight); + cal_helper<2, 0, c_dim, ow_block>(c, src, weight); src[2] = vld1q_f32(src_ptr + (ow_block + 2) * ic_step); load_helper( weight, weight_ptr, ld_weight_oc); - cal_helper<3, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, - weight); + cal_helper<3, 0, c_dim, ow_block>(c, src, weight); src[3] = vld1q_f32(src_ptr + (ow_block + 3) * ic_step); load_helper( weight, weight_ptr, ld_weight_oc); - cal_helper<4, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, - weight); + cal_helper<4, 0, c_dim, ow_block>(c, src, weight); src_ptr += ld_src_iw; weight_ptr += ld_weight_fh; } @@ -326,44 +316,37 @@ struct KerNeonXXs1Nchw44FP32 { 0); load_helper( weight, weight_ptr, ld_weight_oc); - cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, - weight); + cal_helper<0, 0, c_dim, ow_block>(c, src, weight); src[0] = vld1q_f32(src_ptr + (ow_block)*ic_step); load_helper( weight, weight_ptr, ld_weight_oc); - cal_helper<1, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, - weight); + cal_helper<1, 0, c_dim, ow_block>(c, src, weight); src[1] = vld1q_f32(src_ptr + (ow_block + 1) * ic_step); load_helper( weight, weight_ptr, ld_weight_oc); - cal_helper<2, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, - weight); + cal_helper<2, 0, c_dim, ow_block>(c, src, weight); src[2] = vld1q_f32(src_ptr + (ow_block + 2) * ic_step); load_helper( weight, weight_ptr, ld_weight_oc); - cal_helper<3, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, - weight); + cal_helper<3, 0, c_dim, ow_block>(c, src, weight); src[3] = vld1q_f32(src_ptr + (ow_block + 3) * ic_step); load_helper( weight, weight_ptr, ld_weight_oc); - cal_helper<4, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, - weight); + cal_helper<4, 0, c_dim, ow_block>(c, src, weight); src[4] = vld1q_f32(src_ptr + (ow_block + 4) * ic_step); load_helper( weight, weight_ptr, ld_weight_oc); - cal_helper<5, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, - weight); + cal_helper<5, 0, c_dim, ow_block>(c, src, weight); src[5] = vld1q_f32(src_ptr + (ow_block + 5) * ic_step); load_helper( weight, weight_ptr, ld_weight_oc); - cal_helper<6, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, - weight); + cal_helper<6, 0, c_dim, ow_block>(c, src, weight); src_ptr += ld_src_iw; weight_ptr += ld_weight_fh; } @@ -375,36 +358,14 @@ struct KerNeonXXs1Nchw44FP32 { } // namespace -void conv_bias::pack_src_fp32_nchw44_stride1( - float* sptr_base, const float* sptr_origin, const int, const int pw, - const int pad_right, const int ih, const int iw, const int iw2, - const int pad_top, const int pad_bottom, const int ic, - const int ic_stride) { - constexpr int ic_step = 4; - rep_step(ic_idx, ic, ic_step) { - const float* sptr = sptr_origin + ic_idx * ic_stride; - memset(sptr_base, 0, sizeof(float) * iw2 * pad_top * ic_step); - sptr_base += iw2 * pad_top * ic_step; - rep(ih_idx, ih) { - memset(sptr_base, 0, sizeof(float) * pw * ic_step); - sptr_base += pw * ic_step; - memcpy(sptr_base, sptr, sizeof(float) * iw * ic_step); - sptr_base += iw * ic_step; - sptr += iw * ic_step; - memset(sptr_base, 0, sizeof(float) * pad_right * ic_step); - sptr_base += pad_right * ic_step; - } - memset(sptr_base, 0, sizeof(float) * iw2 * pad_bottom * ic_step); - sptr_base += iw2 * pad_bottom * ic_step; - } -} - -template -static void conv_direct_stride1_fp32_nchw44( - const float32_t* src, const float32_t* filter, const float32_t* bias, - float32_t*, float32_t* dst, const int oc, const int ic, const int ih, - const int iw, const int oh, const int oh_block, const int ow, - const Op& op, const int, const int) { +template +void conv_bias::conv_direct_fp32_nchw44(const float* src, const float* filter, + const float* bias, float*, float* dst, + const int oc, const int ic, + const int ih, const int iw, + const int oh, const int oh_block, + const int ow, const Op& op, const int, + const int) { constexpr int fh = filter_size; constexpr int fw = filter_size; constexpr int ic_step = 4; @@ -518,55 +479,23 @@ static void conv_direct_stride1_fp32_nchw44( } } -#define CONSTRUCT_FUNC(filter_size) \ - template \ - void conv_bias:: \ - conv_direct_stride1_##filter_size##x##filter_size##_fp32_nchw44( \ - const float32_t* src, const float32_t* filter, \ - const float32_t* bias, float32_t* temp, float32_t* dst, \ - const int oc, const int ic, const int ih, const int iw, \ - const int oh, const int oh_block, const int ow, \ - const Op& op, const int ph, const int pw) { \ - conv_direct_stride1_fp32_nchw44( \ - src, filter, bias, temp, dst, oc, ic, ih, iw, oh, oh_block, \ - ow, op, ph, pw); \ - } -CONSTRUCT_FUNC(2); -CONSTRUCT_FUNC(3); -CONSTRUCT_FUNC(5); -CONSTRUCT_FUNC(7); -#undef CONSTRUCT_FUNC - -#define INSTANTIATION(stride, i, bias, Op) \ - template void conv_bias::conv_direct_##stride##_##i##x##i##_fp32_nchw44< \ - bias, Op>(const float32_t*, const float32_t*, const float32_t*, \ - float32_t*, float32_t*, const int, const int, const int, \ - const int, const int, const int, const int, const Op&, \ - const int, const int); - -#define FOR_OP(stride, i, bias) \ - INSTANTIATION(stride, i, bias, NoneOp) \ - INSTANTIATION(stride, i, bias, ReluOp) \ - INSTANTIATION(stride, i, bias, HSwishOp) \ - INSTANTIATION(stride, i, bias, SigmoidOp) - -#define FOR_BIAS(stride, i) \ - FOR_OP(stride, i, BiasMode::NO_BIAS) \ - FOR_OP(stride, i, BiasMode::BROADCAST_CHANNEL_BIAS) \ - FOR_OP(stride, i, BiasMode::BIAS) - -#define FOR_FILTER(stride) \ - FOR_BIAS(stride, 2) \ - FOR_BIAS(stride, 3) \ - FOR_BIAS(stride, 5) \ - FOR_BIAS(stride, 7) - -FOR_FILTER(stride1) - -#undef FOR_STRIDE -#undef FOR_FILTER -#undef FOR_IC -#undef FOR_BIAS -#undef FOR_NONLINEAR -#undef FOR_REMAIN -#undef INSTANTIATION +#define INSTANTIATION(filter_size, bias_mode, Op) \ + template void \ + conv_bias::conv_direct_fp32_nchw44( \ + const float* src, const float* filter, const float* bias, float*, \ + float* dst, const int oc, const int ic, const int ih, \ + const int iw, const int oh, const int oh_block, const int ow, \ + const Op& op, const int, const int); + +#define FOR_OP(filter_size, bias) \ + INSTANTIATION(filter_size, bias, NoneOp) \ + INSTANTIATION(filter_size, bias, ReluOp) \ + INSTANTIATION(filter_size, bias, HSwishOp) \ + INSTANTIATION(filter_size, bias, SigmoidOp) + +#define INSTANTIATION_CONV_S1(filter_size) \ + FOR_OP(filter_size, BiasMode::NO_BIAS) \ + FOR_OP(filter_size, BiasMode::BROADCAST_CHANNEL_BIAS) \ + FOR_OP(filter_size, BiasMode::BIAS) + +// vim: syntax=cpp.doxygen \ No newline at end of file diff --git a/dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw44_kern.cpp b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h similarity index 63% rename from dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw44_kern.cpp rename to dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h index 78019dd6..73141077 100644 --- a/dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw44_kern.cpp +++ b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h @@ -1,6 +1,6 @@ /** * \file - * dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw44_kern.cpp + * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. @@ -12,7 +12,7 @@ */ #include "megdnn/arch.h" -#include "src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw44_kern.h" +#include "src/arm_common/conv_bias/fp32/f32_direct_nchw44_kern.h" #include "src/arm_common/conv_bias/intrinsic_helper.h" #include "src/arm_common/elemwise_op.h" #include "src/arm_common/simd_macro/marm_neon.h" @@ -24,21 +24,21 @@ using namespace megdnn; using namespace arm_common; namespace { -template +template struct ShiftCalHelper { static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight); }; -template -struct ShiftCalHelper { +template +struct ShiftCalHelper { static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) { -#define cb(step, lane) \ - c[0][step] = Func::template impl(c[0][step], weight[0][lane], \ - src[(step + src_idx) % 8]); \ - c[1][step] = Func::template impl(c[1][step], weight[1][lane], \ - src[(step + src_idx) % 8]); +#define cb(step, lane) \ + c[0][step] = vfmaq_laneq_f32(c[0][step], weight[0][lane], \ + src[(step + src_idx) % 8], lane); \ + c[1][step] = vfmaq_laneq_f32(c[1][step], weight[1][lane], \ + src[(step + src_idx) % 8], lane); UNROLL_CALL_RAW(8, cb, 0); UNROLL_CALL_RAW(8, cb, 1); @@ -47,15 +47,15 @@ struct ShiftCalHelper { #undef cb } }; -template -struct ShiftCalHelper { +template +struct ShiftCalHelper { static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) { -#define cb(step, lane) \ - c[0][step] = Func::template impl(c[0][step], weight[0][lane], \ - src[(step + src_idx) % 4]); \ - c[1][step] = Func::template impl(c[1][step], weight[1][lane], \ - src[(step + src_idx) % 4]); +#define cb(step, lane) \ + c[0][step] = vfmaq_laneq_f32(c[0][step], weight[0][lane], \ + src[(step + src_idx) % 4], lane); \ + c[1][step] = vfmaq_laneq_f32(c[1][step], weight[1][lane], \ + src[(step + src_idx) % 4], lane); UNROLL_CALL_RAW(4, cb, 0); UNROLL_CALL_RAW(4, cb, 1); @@ -64,13 +64,13 @@ struct ShiftCalHelper { #undef cb } }; -template -struct ShiftCalHelper { +template +struct ShiftCalHelper { static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) { -#define cb(step, lane) \ - c[0][step] = Func::template impl(c[0][step], weight[0][lane], \ - src[(step + src_idx) % 8]); +#define cb(step, lane) \ + c[0][step] = vfmaq_laneq_f32(c[0][step], weight[0][lane], \ + src[(step + src_idx) % 8], lane); UNROLL_CALL_RAW(8, cb, 0); UNROLL_CALL_RAW(8, cb, 1); @@ -79,13 +79,13 @@ struct ShiftCalHelper { #undef cb } }; -template -struct ShiftCalHelper { +template +struct ShiftCalHelper { static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) { -#define cb(step, lane) \ - c[0][step] = Func::template impl(c[0][step], weight[0][lane], \ - src[(step + src_idx) % 4]); +#define cb(step, lane) \ + c[0][step] = vfmaq_laneq_f32(c[0][step], weight[0][lane], \ + src[(step + src_idx) % 4], lane); UNROLL_CALL_RAW(4, cb, 0); UNROLL_CALL_RAW(4, cb, 1); @@ -95,11 +95,11 @@ struct ShiftCalHelper { } }; -template +template MEGDNN_ALWAYS_INLINE void cal_helper(T& c, T2& src, T3& weight) { - ShiftCalHelper::impl(c, src, weight); + ShiftCalHelper::impl( + c, src, weight); }; template struct OCHelper { @@ -163,13 +163,13 @@ struct KerNeonXXs2Nchw44FP32 { load_helper(src, src_ptr, 0); load_helper<4, 0, oc_step, c_dim, Vld1q_f32>(weight, weight_ptr, ld_weight_oc); - cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, weight); + cal_helper<0, 0, c_dim, ow_block>(c, src, weight); load_helper(src, src_ptr_odd, 0); load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( weight, weight_ptr, ld_weight_oc); - cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, weight); + cal_helper<0, 0, c_dim, ow_block>(c, src, weight); src_ptr += ld_src_iw; src_ptr_odd += ld_src_iw; weight_ptr += ld_weight_fh; @@ -177,13 +177,13 @@ struct KerNeonXXs2Nchw44FP32 { load_helper(src, src_ptr, 0); load_helper<4, 0, oc_step, c_dim, Vld1q_f32>(weight, weight_ptr, ld_weight_oc); - cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, weight); + cal_helper<0, 0, c_dim, ow_block>(c, src, weight); load_helper(src, src_ptr_odd, 0); load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( weight, weight_ptr, ld_weight_oc); - cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, weight); + cal_helper<0, 0, c_dim, ow_block>(c, src, weight); src_ptr += ld_src_iw; src_ptr_odd += ld_src_iw; weight_ptr += ld_weight_fh; @@ -224,18 +224,18 @@ struct KerNeonXXs2Nchw44FP32 { load_helper(src, src_ptr, 0); load_helper<4, 0, oc_step, c_dim, Vld1q_f32>(weight, weight_ptr, ld_weight_oc); - cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, weight); + cal_helper<0, 0, c_dim, ow_block>(c, src, weight); src[0] = vld1q_f32(src_ptr + ow_block * simd_len); load_helper<4, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>( weight, weight_ptr, ld_weight_oc); - cal_helper<1, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, weight); + cal_helper<1, 0, c_dim, ow_block>(c, src, weight); load_helper(src, src_ptr_odd, 0); load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( weight, weight_ptr, ld_weight_oc); - cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, weight); + cal_helper<0, 0, c_dim, ow_block>(c, src, weight); src_ptr += ld_src_iw; src_ptr_odd += ld_src_iw; weight_ptr += ld_weight_fh; @@ -243,17 +243,17 @@ struct KerNeonXXs2Nchw44FP32 { load_helper(src, src_ptr, 0); load_helper<4, 0, oc_step, c_dim, Vld1q_f32>(weight, weight_ptr, ld_weight_oc); - cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, weight); + cal_helper<0, 0, c_dim, ow_block>(c, src, weight); src[0] = vld1q_f32(src_ptr + ow_block * simd_len); load_helper<4, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>( weight, weight_ptr, ld_weight_oc); - cal_helper<1, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, weight); + cal_helper<1, 0, c_dim, ow_block>(c, src, weight); load_helper(src, src_ptr_odd, 0); load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( weight, weight_ptr, ld_weight_oc); - cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, weight); + cal_helper<0, 0, c_dim, ow_block>(c, src, weight); src_ptr += ld_src_iw; src_ptr_odd += ld_src_iw; weight_ptr += ld_weight_fh; @@ -261,18 +261,18 @@ struct KerNeonXXs2Nchw44FP32 { load_helper(src, src_ptr, 0); load_helper<4, 0, oc_step, c_dim, Vld1q_f32>(weight, weight_ptr, ld_weight_oc); - cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, weight); + cal_helper<0, 0, c_dim, ow_block>(c, src, weight); src[0] = vld1q_f32(src_ptr + ow_block * simd_len); load_helper<4, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>( weight, weight_ptr, ld_weight_oc); - cal_helper<1, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, weight); + cal_helper<1, 0, c_dim, ow_block>(c, src, weight); load_helper(src, src_ptr_odd, 0); load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( weight, weight_ptr, ld_weight_oc); - cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, weight); + cal_helper<0, 0, c_dim, ow_block>(c, src, weight); src_ptr += ld_src_iw; src_ptr_odd += ld_src_iw; weight_ptr += ld_weight_fh; @@ -316,30 +316,25 @@ struct KerNeonXXs2Nchw44FP32 { 0); load_helper<4, 0, oc_step, c_dim, Vld1q_f32>(weight, weight_ptr, ld_weight_oc); - cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, - weight); + cal_helper<0, 0, c_dim, ow_block>(c, src, weight); src[0] = vld1q_f32(src_ptr + ow_block * simd_len); load_helper<4, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>( weight, weight_ptr, ld_weight_oc); - cal_helper<1, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, - weight); + cal_helper<1, 0, c_dim, ow_block>(c, src, weight); src[1] = vld1q_f32(src_ptr + (ow_block + 1) * simd_len); load_helper<4, 4 * ld_weight, oc_step, c_dim, Vld1q_f32>( weight, weight_ptr, ld_weight_oc); - cal_helper<2, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, - weight); + cal_helper<2, 0, c_dim, ow_block>(c, src, weight); // odd element load_helper( src, src_ptr_odd, 0); load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( weight, weight_ptr, ld_weight_oc); - cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, - weight); + cal_helper<0, 0, c_dim, ow_block>(c, src, weight); src[0] = vld1q_f32(src_ptr_odd + ow_block * simd_len); load_helper<4, 3 * ld_weight, oc_step, c_dim, Vld1q_f32>( weight, weight_ptr, ld_weight_oc); - cal_helper<1, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, - weight); + cal_helper<1, 0, c_dim, ow_block>(c, src, weight); src_ptr += ld_src_iw; src_ptr_odd += ld_src_iw; @@ -390,40 +385,33 @@ struct KerNeonXXs2Nchw44FP32 { 0); load_helper<4, 0, oc_step, c_dim, Vld1q_f32>(weight, weight_ptr, ld_weight_oc); - cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, - weight); + cal_helper<0, 0, c_dim, ow_block>(c, src, weight); src[0] = vld1q_f32(src_ptr + ow_block * simd_len); load_helper<4, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>( weight, weight_ptr, ld_weight_oc); - cal_helper<1, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, - weight); + cal_helper<1, 0, c_dim, ow_block>(c, src, weight); src[1] = vld1q_f32(src_ptr + (ow_block + 1) * simd_len); load_helper<4, 4 * ld_weight, oc_step, c_dim, Vld1q_f32>( weight, weight_ptr, ld_weight_oc); - cal_helper<2, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, - weight); + cal_helper<2, 0, c_dim, ow_block>(c, src, weight); src[2] = vld1q_f32(src_ptr + (ow_block + 2) * simd_len); load_helper<4, 6 * ld_weight, oc_step, c_dim, Vld1q_f32>( weight, weight_ptr, ld_weight_oc); - cal_helper<3, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, - weight); + cal_helper<3, 0, c_dim, ow_block>(c, src, weight); // odd element load_helper( src, src_ptr_odd, 0); load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( weight, weight_ptr, ld_weight_oc); - cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, - weight); + cal_helper<0, 0, c_dim, ow_block>(c, src, weight); src[0] = vld1q_f32(src_ptr_odd + ow_block * simd_len); load_helper<4, 3 * ld_weight, oc_step, c_dim, Vld1q_f32>( weight, weight_ptr, ld_weight_oc); - cal_helper<1, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, - weight); + cal_helper<1, 0, c_dim, ow_block>(c, src, weight); src[1] = vld1q_f32(src_ptr_odd + (ow_block + 1) * simd_len); load_helper<4, 5 * ld_weight, oc_step, c_dim, Vld1q_f32>( weight, weight_ptr, ld_weight_oc); - cal_helper<2, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, - weight); + cal_helper<2, 0, c_dim, ow_block>(c, src, weight); src_ptr += ld_src_iw; src_ptr_odd += ld_src_iw; @@ -436,133 +424,15 @@ struct KerNeonXXs2Nchw44FP32 { }; } // namespace -namespace { - -inline void odd_even_split_iw8_even(float* sptr_base, const float* sptr, - const int odd_start, const int src_idx, - const int iw_idx) { - constexpr int ic_step = 4; - const int src_offset = src_idx * ic_step; - const int even_offset = iw_idx / 2 * ic_step; - const int odd_offset = (odd_start + iw_idx / 2) * ic_step; - float32x4_t temp[8]; - temp[0] = vld1q_f32(sptr + src_offset + 0 * ic_step); - temp[1] = vld1q_f32(sptr + src_offset + 1 * ic_step); - temp[2] = vld1q_f32(sptr + src_offset + 2 * ic_step); - temp[3] = vld1q_f32(sptr + src_offset + 3 * ic_step); - temp[4] = vld1q_f32(sptr + src_offset + 4 * ic_step); - temp[5] = vld1q_f32(sptr + src_offset + 5 * ic_step); - temp[6] = vld1q_f32(sptr + src_offset + 6 * ic_step); - temp[7] = vld1q_f32(sptr + src_offset + 7 * ic_step); - vst1q_f32(sptr_base + even_offset + 0 * ic_step, temp[0]); - vst1q_f32(sptr_base + even_offset + 1 * ic_step, temp[2]); - vst1q_f32(sptr_base + even_offset + 2 * ic_step, temp[4]); - vst1q_f32(sptr_base + even_offset + 3 * ic_step, temp[6]); - vst1q_f32(sptr_base + odd_offset + 0 * ic_step, temp[1]); - vst1q_f32(sptr_base + odd_offset + 1 * ic_step, temp[3]); - vst1q_f32(sptr_base + odd_offset + 2 * ic_step, temp[5]); - vst1q_f32(sptr_base + odd_offset + 3 * ic_step, temp[7]); -} - -inline void odd_even_split_iw8_odd(float* sptr_base, const float* sptr, - const int odd_start, const int src_idx, - const int iw_idx) { - constexpr int ic_step = 4; - const int src_offset = src_idx * ic_step; - const int even_offset = (iw_idx + 1) / 2 * ic_step; - const int odd_offset = (odd_start + iw_idx / 2) * ic_step; - float32x4_t temp[8]; - temp[0] = vld1q_f32(sptr + src_offset + 0 * ic_step); - temp[1] = vld1q_f32(sptr + src_offset + 1 * ic_step); - temp[2] = vld1q_f32(sptr + src_offset + 2 * ic_step); - temp[3] = vld1q_f32(sptr + src_offset + 3 * ic_step); - temp[4] = vld1q_f32(sptr + src_offset + 4 * ic_step); - temp[5] = vld1q_f32(sptr + src_offset + 5 * ic_step); - temp[6] = vld1q_f32(sptr + src_offset + 6 * ic_step); - temp[7] = vld1q_f32(sptr + src_offset + 7 * ic_step); - vst1q_f32(sptr_base + odd_offset + 0 * ic_step, temp[0]); - vst1q_f32(sptr_base + odd_offset + 1 * ic_step, temp[2]); - vst1q_f32(sptr_base + odd_offset + 2 * ic_step, temp[4]); - vst1q_f32(sptr_base + odd_offset + 3 * ic_step, temp[6]); - vst1q_f32(sptr_base + even_offset + 0 * ic_step, temp[1]); - vst1q_f32(sptr_base + even_offset + 1 * ic_step, temp[3]); - vst1q_f32(sptr_base + even_offset + 2 * ic_step, temp[5]); - vst1q_f32(sptr_base + even_offset + 3 * ic_step, temp[7]); -} -} // namespace - -void conv_bias::pack_src_fp32_nchw44_stride2( - float* sptr_base, const float* sptr_origin, const int ph, const int pw, - const int pad_right, const int ih, const int iw, const int iw2, - const int pad_top, const int pad_bottom, const int ic, - const int ic_stride) { - constexpr int ic_step = 4; - int odd_start = megdnn::div_ceil(iw2, 2); - float32x4_t zero_v = vdupq_n_f32(0.f); - MEGDNN_MARK_USED_VAR(ph); - bool even_start = pw % 2 == 0; - rep_step(ic_idx, ic, ic_step) { - const float* sptr = sptr_origin + ic_idx * ic_stride; - memset(sptr_base, 0, sizeof(float) * iw2 * pad_top * ic_step); - sptr_base += iw2 * pad_top * ic_step; - rep(ih_idx, ih) { - int iw_idx = 0; - rep(idx, pw) { - if (iw_idx % 2 == 0) { - vst1q_f32(sptr_base + iw_idx / 2 * ic_step, zero_v); - } else { - vst1q_f32(sptr_base + (odd_start + iw_idx / 2) * ic_step, - zero_v); - } - ++iw_idx; - } - int src_idx = 0; - if (even_start) { - for (; src_idx + 7 < iw; src_idx += 8) { - odd_even_split_iw8_even(sptr_base, sptr, odd_start, src_idx, - iw_idx); - iw_idx += 8; - } - } else { - for (; src_idx + 7 < iw; src_idx += 8) { - odd_even_split_iw8_odd(sptr_base, sptr, odd_start, src_idx, - iw_idx); - iw_idx += 8; - } - } - for (; src_idx < iw; ++src_idx) { - if (iw_idx % 2 == 0) { - vst1q_f32(sptr_base + iw_idx / 2 * ic_step, - vld1q_f32(sptr + src_idx * ic_step)); - } else { - vst1q_f32(sptr_base + (odd_start + iw_idx / 2) * ic_step, - vld1q_f32(sptr + src_idx * ic_step)); - } - ++iw_idx; - } - rep(idx, pad_right) { - if (iw_idx % 2 == 0) { - vst1q_f32(sptr_base + iw_idx / 2 * ic_step, zero_v); - } else { - vst1q_f32(sptr_base + (odd_start + iw_idx / 2) * ic_step, - zero_v); - } - ++iw_idx; - } - sptr_base += iw2 * ic_step; - sptr += iw * ic_step; - } - memset(sptr_base, 0, sizeof(float) * iw2 * pad_bottom * ic_step); - sptr_base += iw2 * pad_bottom * ic_step; - } -} -template -static void conv_direct_stride2_fp32_nchw44( - const float32_t* src, const float32_t* filter, const float32_t* bias, - float32_t*, float32_t* dst, const int oc, const int ic, const int ih, - const int iw, const int oh, const int oh_block, const int ow, - const Op& op, const int, const int) { +template +void conv_bias::conv_direct_fp32_nchw44(const float* src, const float* filter, + const float* bias, float*, float* dst, + const int oc, const int ic, + const int ih, const int iw, + const int oh, const int oh_block, + const int ow, const Op& op, const int, + const int) { constexpr int fh = filter_size; constexpr int fw = filter_size; constexpr int ic_step = 4; @@ -697,55 +567,23 @@ static void conv_direct_stride2_fp32_nchw44( } } -#define CONSTRUCT_FUNC(filter_size) \ - template \ - void conv_bias:: \ - conv_direct_stride2_##filter_size##x##filter_size##_fp32_nchw44( \ - const float32_t* src, const float32_t* filter, \ - const float32_t* bias, float32_t* temp, float32_t* dst, \ - const int oc, const int ic, const int ih, const int iw, \ - const int oh, const int oh_block, const int ow, \ - const Op& op, const int ph, const int pw) { \ - conv_direct_stride2_fp32_nchw44( \ - src, filter, bias, temp, dst, oc, ic, ih, iw, oh, oh_block, \ - ow, op, ph, pw); \ - } -CONSTRUCT_FUNC(2); -CONSTRUCT_FUNC(3); -CONSTRUCT_FUNC(5); -CONSTRUCT_FUNC(7); -#undef CONSTRUCT_FUNC - -#define INSTANTIATION(stride, i, bias, Op) \ - template void conv_bias::conv_direct_##stride##_##i##x##i##_fp32_nchw44< \ - bias, Op>(const float32_t*, const float32_t*, const float32_t*, \ - float32_t*, float32_t*, const int, const int, const int, \ - const int, const int, const int, const int, const Op&, \ - const int, const int); - -#define FOR_OP(stride, i, bias) \ - INSTANTIATION(stride, i, bias, NoneOp) \ - INSTANTIATION(stride, i, bias, ReluOp) \ - INSTANTIATION(stride, i, bias, HSwishOp) \ - INSTANTIATION(stride, i, bias, SigmoidOp) - -#define FOR_BIAS(stride, i) \ - FOR_OP(stride, i, BiasMode::NO_BIAS) \ - FOR_OP(stride, i, BiasMode::BROADCAST_CHANNEL_BIAS) \ - FOR_OP(stride, i, BiasMode::BIAS) - -#define FOR_FILTER(stride) \ - FOR_BIAS(stride, 2) \ - FOR_BIAS(stride, 3) \ - FOR_BIAS(stride, 5) \ - FOR_BIAS(stride, 7) - -FOR_FILTER(stride2) - -#undef FOR_STRIDE -#undef FOR_FILTER -#undef FOR_IC -#undef FOR_BIAS -#undef FOR_NONLINEAR -#undef FOR_REMAIN -#undef INSTANTIATION +#define INSTANTIATION(filter_size, bias_mode, Op) \ + template void \ + conv_bias::conv_direct_fp32_nchw44( \ + const float* src, const float* filter, const float* bias, float*, \ + float* dst, const int oc, const int ic, const int ih, \ + const int iw, const int oh, const int oh_block, const int ow, \ + const Op& op, const int, const int); + +#define FOR_OP(filter_size, bias) \ + INSTANTIATION(filter_size, bias, NoneOp) \ + INSTANTIATION(filter_size, bias, ReluOp) \ + INSTANTIATION(filter_size, bias, HSwishOp) \ + INSTANTIATION(filter_size, bias, SigmoidOp) + +#define INSTANTIATION_CONV_S2(filter_size) \ + FOR_OP(filter_size, BiasMode::NO_BIAS) \ + FOR_OP(filter_size, BiasMode::BROADCAST_CHANNEL_BIAS) \ + FOR_OP(filter_size, BiasMode::BIAS) + +// vim: syntax=cpp.doxygen \ No newline at end of file diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s1.cpp b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s1.cpp new file mode 100644 index 00000000..0e4b75dd --- /dev/null +++ b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s1.cpp @@ -0,0 +1,14 @@ +/** + * \file + * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s1.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" +INSTANCE_CONV(2, 1); \ No newline at end of file diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s2.cpp b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s2.cpp new file mode 100644 index 00000000..231fb564 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s2.cpp @@ -0,0 +1,14 @@ +/** + * \file + * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s2.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" +INSTANCE_CONV(2, 2); diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s1.cpp b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s1.cpp new file mode 100644 index 00000000..7d9a6291 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s1.cpp @@ -0,0 +1,14 @@ +/** + * \file + * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s1.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" +INSTANCE_CONV(3, 1); diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s2.cpp b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s2.cpp new file mode 100644 index 00000000..e7d2e841 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s2.cpp @@ -0,0 +1,14 @@ +/** + * \file + * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s2.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" +INSTANCE_CONV(3, 2); diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s1.cpp b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s1.cpp new file mode 100644 index 00000000..21e5a14c --- /dev/null +++ b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s1.cpp @@ -0,0 +1,14 @@ +/** + * \file + * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s1.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" +INSTANCE_CONV(5, 1); diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s2.cpp b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s2.cpp new file mode 100644 index 00000000..82a45e18 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s2.cpp @@ -0,0 +1,14 @@ +/** + * \file + * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s2.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" +INSTANCE_CONV(5, 2); diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s1.cpp b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s1.cpp new file mode 100644 index 00000000..09e1bc48 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s1.cpp @@ -0,0 +1,14 @@ +/** + * \file + * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s1.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" +INSTANCE_CONV(7, 1); diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s2.cpp b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s2.cpp new file mode 100644 index 00000000..ad8205f6 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s2.cpp @@ -0,0 +1,14 @@ +/** + * \file + * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s2.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" +INSTANCE_CONV(7, 2); diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h new file mode 100644 index 00000000..70b9f128 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h @@ -0,0 +1,443 @@ +/** + * \file dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_kern.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#pragma once +#include "megdnn/arch.h" +#include "src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_kern.h" +#include "src/arm_common/conv_bias/intrinsic_helper.h" +#include "src/arm_common/conv_bias/opr_impl.h" +#include "src/arm_common/elemwise_op.h" +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/common/unroll_macro.h" +#include "src/common/utils.h" +#include "src/fallback/conv_bias/common.h" + +using namespace megdnn; +using namespace arm_common; + +namespace { +/** + *\brief ShiftCalHelper is core calculate code + *\tparam src_idx is offset for src regs + *\tparam weight_idx is offset for weight regs + *\tparam T is type of output regs + *\tparam T2 is type of src regs + *\tparam T3 is type of weight regs + */ +template +struct ShiftCalHelper { + static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight); +}; + +template +struct ShiftCalHelper { + static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) { +#define cb(step) \ + c[0][step] = vfmaq_laneq_f32(c[0][step], weight[0][weight_idx], \ + src[(step * stride + src_idx) / 4], \ + (step * stride + src_idx) % 4); \ + c[1][step] = vfmaq_laneq_f32(c[1][step], weight[1][weight_idx], \ + src[(step * stride + src_idx) / 4], \ + (step * stride + src_idx) % 4); + + UNROLL_CALL_RAW(8, cb); +#undef cb + } +}; +template +struct ShiftCalHelper { + static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) { +#define cb(step) \ + c[0][step] = vfmaq_laneq_f32(c[0][step], weight[0][weight_idx], \ + src[(step * stride + src_idx) / 4], \ + (step * stride + src_idx) % 4); + + UNROLL_CALL_RAW(8, cb); +#undef cb + } +}; + +template +MEGDNN_ALWAYS_INLINE void cal_helper(T& c, T2& src, T3& weight) { + ShiftCalHelper::impl(c, src, + weight); +}; +template +struct OCHelper { +public: + static const int val = -1; +}; + +template <> +struct OCHelper<4> { +public: + static const int val = 1; +}; + +template <> +struct OCHelper<8> { +public: + static const int val = 2; +}; +/** + * oc8_ow8(m = 8, n = 8) and oc4_ow8(m = 4, n = 8) gemm like kernel + **/ +template +struct KerNeonXXs2NchwNchw44FP32 { + static void impl(const float32_t* src_ptr, const float32_t* weight_ptr, + const float32_t* bias_ptr, float32_t* dst_ptr, int ic, + int ih, int iw, int ld_dst_oc, const Op& op); +}; +template +struct KerNeonXXs2NchwNchw44FP32 { + static void impl(const float32_t* src_ptr, const float32_t* weight_ptr, + const float32_t* bias_ptr, float32_t* dst_ptr, int ic, + int ih, int iw, int ld_dst_oc, const Op& op) { + constexpr int loop_ic_step = 1; + constexpr int filter_size = 7; + constexpr int oc_step = 4; + constexpr int simd_len = 4; + constexpr int src_reg_size = + (ow_block * stride + filter_size - stride + simd_len - 1) / + simd_len; + + constexpr int ld_weight_fw = oc_step * filter_size; + const int ld_weight_oc = oc_step * filter_size * filter_size * ic; + const int ld_weight_ic = oc_step * filter_size * filter_size; + const int ld_src_ic = ih * iw; + constexpr int c_dim = OCHelper::val; + float32x4_t c[c_dim][8]; + init_ocx_ow8(c, bias_ptr, oc_step); + + for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { + float32x4_t src[src_reg_size]; + float32x4_t weight[c_dim][filter_size]; + +#define KERNEL_CB(step) \ + load_helper( \ + src, src_ptr + step * iw, 0); \ + load_helper( \ + weight, weight_ptr + step * ld_weight_fw, ld_weight_oc); \ + cal_helper<0, 0, c_dim, stride>(c, src, weight); \ + cal_helper<1, 1, c_dim, stride>(c, src, weight); \ + cal_helper<2, 2, c_dim, stride>(c, src, weight); \ + cal_helper<3, 3, c_dim, stride>(c, src, weight); \ + cal_helper<4, 4, c_dim, stride>(c, src, weight); \ + cal_helper<5, 5, c_dim, stride>(c, src, weight); \ + cal_helper<6, 6, c_dim, stride>(c, src, weight); + + UNROLL_CALL_RAW(7, KERNEL_CB) +#undef KERNEL_CB + + src_ptr += ld_src_ic; + weight_ptr += ld_weight_ic; + } + store_ocx_ow8_remain_static(c, op, dst_ptr, + ld_dst_oc); + } +}; + +template +struct KerNeonXXs2NchwNchw44FP32 { + static void impl(const float32_t* src_ptr, const float32_t* weight_ptr, + const float32_t* bias_ptr, float32_t* dst_ptr, int ic, + int ih, int iw, int ld_dst_oc, const Op& op) { + constexpr int loop_ic_step = 1; + constexpr int filter_size = 5; + constexpr int oc_step = 4; + constexpr int simd_len = 4; + constexpr int src_reg_size = + (ow_block * stride + filter_size - stride + simd_len - 1) / + simd_len; + + constexpr int ld_weight_fw = oc_step * filter_size; + const int ld_weight_oc = oc_step * filter_size * filter_size * ic; + const int ld_weight_ic = oc_step * filter_size * filter_size; + const int ld_src_ic = ih * iw; + constexpr int c_dim = OCHelper::val; + float32x4_t c[c_dim][8]; + init_ocx_ow8(c, bias_ptr, oc_step); + + for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { + float32x4_t src[src_reg_size]; + float32x4_t weight[c_dim][filter_size]; + +#define KERNEL_CB(step) \ + load_helper( \ + src, src_ptr + step * iw, 0); \ + load_helper( \ + weight, weight_ptr + step * ld_weight_fw, ld_weight_oc); \ + cal_helper<0, 0, c_dim, stride>(c, src, weight); \ + cal_helper<1, 1, c_dim, stride>(c, src, weight); \ + cal_helper<2, 2, c_dim, stride>(c, src, weight); \ + cal_helper<3, 3, c_dim, stride>(c, src, weight); \ + cal_helper<4, 4, c_dim, stride>(c, src, weight); + UNROLL_CALL_RAW(5, KERNEL_CB) +#undef KERNEL_CB + + src_ptr += ld_src_ic; + weight_ptr += ld_weight_ic; + } + store_ocx_ow8_remain_static(c, op, dst_ptr, + ld_dst_oc); + } +}; + +template +struct KerNeonXXs2NchwNchw44FP32 { + static void impl(const float32_t* src_ptr, const float32_t* weight_ptr, + const float32_t* bias_ptr, float32_t* dst_ptr, int ic, + int ih, int iw, int ld_dst_oc, const Op& op) { + constexpr int loop_ic_step = 1; + constexpr int filter_size = 3; + constexpr int oc_step = 4; + constexpr int simd_len = 4; + constexpr int src_reg_size = + (ow_block * stride + filter_size - stride + simd_len - 1) / + simd_len; + + constexpr int ld_weight_fw = oc_step * filter_size; + const int ld_weight_oc = oc_step * filter_size * filter_size * ic; + const int ld_weight_ic = oc_step * filter_size * filter_size; + const int ld_src_ic = ih * iw; + constexpr int c_dim = OCHelper::val; + float32x4_t c[c_dim][8]; + init_ocx_ow8(c, bias_ptr, oc_step); + + for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { + float32x4_t src[src_reg_size]; + float32x4_t weight[c_dim][filter_size]; + // row 0 + load_helper(src, src_ptr, + 0); + load_helper( + weight, weight_ptr, ld_weight_oc); + cal_helper<0, 0, c_dim, stride>(c, src, weight); + cal_helper<1, 1, c_dim, stride>(c, src, weight); + cal_helper<2, 2, c_dim, stride>(c, src, weight); + + // row 1 + load_helper( + src, src_ptr + iw, 0); + load_helper( + weight, weight_ptr + 1 * ld_weight_fw, ld_weight_oc); + cal_helper<0, 0, c_dim, stride>(c, src, weight); + cal_helper<1, 1, c_dim, stride>(c, src, weight); + cal_helper<2, 2, c_dim, stride>(c, src, weight); + + // row 2 + load_helper( + src, src_ptr + 2 * iw, 0); + load_helper( + weight, weight_ptr + 2 * ld_weight_fw, ld_weight_oc); + cal_helper<0, 0, c_dim, stride>(c, src, weight); + cal_helper<1, 1, c_dim, stride>(c, src, weight); + cal_helper<2, 2, c_dim, stride>(c, src, weight); + + src_ptr += ld_src_ic; + weight_ptr += ld_weight_ic; + } + store_ocx_ow8_remain_static(c, op, dst_ptr, + ld_dst_oc); + } +}; + +template +struct KerNeonXXs2NchwNchw44FP32 { + static void impl(const float32_t* src_ptr, const float32_t* weight_ptr, + const float32_t* bias_ptr, float32_t* dst_ptr, int ic, + int ih, int iw, int ld_dst_oc, const Op& op) { + constexpr int loop_ic_step = 1; + constexpr int filter_size = 2; + constexpr int oc_step = 4; + constexpr int simd_len = 4; + constexpr int src_reg_size = + (ow_block * stride + filter_size - stride + simd_len - 1) / + simd_len; + + constexpr int ld_weight_fw = oc_step * filter_size; + const int ld_weight_oc = oc_step * filter_size * filter_size * ic; + const int ld_weight_ic = oc_step * filter_size * filter_size; + const int ld_src_ic = ih * iw; + constexpr int c_dim = OCHelper::val; + float32x4_t c[c_dim][8]; + init_ocx_ow8(c, bias_ptr, oc_step); + + for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { + float32x4_t src[src_reg_size]; + float32x4_t weight[c_dim][filter_size]; + // row 0 + load_helper(src, src_ptr, + 0); + load_helper( + weight, weight_ptr, ld_weight_oc); + cal_helper<0, 0, c_dim, stride>(c, src, weight); + cal_helper<1, 1, c_dim, stride>(c, src, weight); + + // row 1 + load_helper( + src, src_ptr + iw, 0); + load_helper( + weight, weight_ptr + 1 * ld_weight_fw, ld_weight_oc); + cal_helper<0, 0, c_dim, stride>(c, src, weight); + cal_helper<1, 1, c_dim, stride>(c, src, weight); + + src_ptr += ld_src_ic; + weight_ptr += ld_weight_ic; + } + store_ocx_ow8_remain_static(c, op, dst_ptr, + ld_dst_oc); + } +}; + +} // namespace + +template +void fp32_direct_nchw_nchw44::conv_direct_fp32_nchw_nchw44( + const float32_t* src, const float32_t* filter, const float32_t* bias, + float32_t*, float32_t* dst, const int oc, const int ic, const int ih, + const int iw, const int oh, const int oh_block, const int ow, + const Op& op, const int, const int) { + constexpr int fh = filter_size; + constexpr int fw = filter_size; + constexpr int ic_step = 1; + constexpr int big_oc_step = 8; + constexpr int oc_step = 4; + constexpr int ih_step = 1; + constexpr int oh_step = 1; + constexpr int ow_step = 8; + constexpr int stride_h = stride; + constexpr int stride_w = stride; + constexpr int pack_iw_len = 1; + + const int img_stride = oh * ow; + const int ow_end = ow / ow_step * ow_step; + const int ow_remain = ow - ow_end; + const int oc_end = oc / big_oc_step * big_oc_step; + const int oc_remain = oc - oc_end; + const int ld_dst_oc = oc_step * img_stride; + + using remain_fun = std::function; + remain_fun kern_big_oc_remain = nullptr; + remain_fun kern_small_oc_remain = nullptr; + + switch (ow_remain) { +#define cb(step) \ + case step: \ + kern_big_oc_remain = \ + KerNeonXXs2NchwNchw44FP32::impl; \ + kern_small_oc_remain = \ + KerNeonXXs2NchwNchw44FP32::impl; \ + break; + + UNROLL_CALL_RAW(8, cb); + default: + megdnn_assert(0, "no remain %d for kern", ow_remain); + } + for (int oc_idx = 0; oc_idx < oc_end; oc_idx += big_oc_step) { + const int weight_offset = oc_idx * ic * fh * fw; + for (int oh_idx = 0; oh_idx < oh_block; oh_idx += oh_step) { + for (int ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { + const int src_offset = + (oh_idx * stride_h * iw + ow_idx * stride_w * ih_step) * + ic_step * pack_iw_len; + const int dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; + KerNeonXXs2NchwNchw44FP32::impl(src + src_offset, + filter + weight_offset, + bias + oc_idx, + dst + dst_offset, ic, + ih, iw, ld_dst_oc, op); + } + if (ow_remain > 0) { + const int src_offset = + (oh_idx * stride_h * iw + ow_end * stride_w * ih_step) * + ic_step * pack_iw_len; + const int dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; + kern_big_oc_remain(src + src_offset, filter + weight_offset, + bias + oc_idx, dst + dst_offset, ic, ih, iw, + ld_dst_oc, op); + } + } + } + if (oc_remain > 0) { + int oc_idx = oc_end; + const int weight_offset = oc_idx * ic * fh * fw; + for (int oh_idx = 0; oh_idx < oh_block; oh_idx += oh_step) { + for (int ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { + const int src_offset = + (oh_idx * stride_h * iw + ow_idx * stride_w * ih_step) * + ic_step * pack_iw_len; + const int dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; + KerNeonXXs2NchwNchw44FP32::impl(src + src_offset, + filter + weight_offset, + bias + oc_idx, + dst + dst_offset, ic, + ih, iw, ld_dst_oc, op); + } + if (ow_remain > 0) { + const int src_offset = + (oh_idx * stride_h * iw + ow_end * stride_w * ih_step) * + ic_step * pack_iw_len; + const int dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; + kern_small_oc_remain(src + src_offset, filter + weight_offset, + bias + oc_idx, dst + dst_offset, ic, ih, + iw, ld_dst_oc, op); + } + } + } +} + +#define INSTANTIATION(stride, filter_size, bias_mode, Op) \ + template void fp32_direct_nchw_nchw44::conv_direct_fp32_nchw_nchw44< \ + bias_mode, Op, filter_size, stride>( \ + const float32_t* src, const float32_t* filter, \ + const float32_t* bias, float32_t*, float32_t* dst, const int oc, \ + const int ic, const int ih, const int iw, const int oh, \ + const int oh_block, const int ow, const Op& op, const int, \ + const int); + +#define FOR_OP(stride, filter, bias) \ + INSTANTIATION(stride, filter, bias, NoneOp) \ + INSTANTIATION(stride, filter, bias, ReluOp) \ + INSTANTIATION(stride, filter, bias, HSwishOp) + +#define INSTANCE_CONV(filter, stride) \ + FOR_OP(stride, filter, BiasMode::NO_BIAS) \ + FOR_OP(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS) \ + FOR_OP(stride, filter, BiasMode::BIAS) + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw44_algo.cpp b/dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw44_algo.cpp index 920aa183..49a62400 100644 --- a/dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw44_algo.cpp +++ b/dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw44_algo.cpp @@ -13,8 +13,8 @@ #include "megdnn/oprs.h" #include "src/arm_common/conv_bias/block_helper.h" #include "src/arm_common/conv_bias/fp32/algos.h" -#include "src/arm_common/conv_bias/fp32/f32_direct_stride1_nchw44_kern.h" -#include "src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw44_kern.h" +#include "src/arm_common/conv_bias/fp32/f32_direct_nchw44_kern.h" + #include "src/arm_common/elemwise_op.h" #include "midout.h" @@ -112,17 +112,11 @@ static void do_conv_kern(const WorkspaceBundle& bundle, const size_t src_size = get_perthread_cache_bytes(ic, ih2, iw2); float* sptr = reinterpret_cast((int8_t*)bundle.get(0) + ncb_index.thread_id * src_size); - if (stride == 1) { - conv_bias::pack_src_fp32_nchw44_stride1( - sptr, origin_sptr, ph, pw, remain_right_pad, - ih_real - src_top_pad - src_bottom_pad, iw, iw2, src_top_pad, - src_bottom_pad, ic, ih * iw); - } else { - conv_bias::pack_src_fp32_nchw44_stride2( - sptr, origin_sptr, ph, pw, remain_right_pad, - ih_real - src_top_pad - src_bottom_pad, iw, iw2, src_top_pad, - src_bottom_pad, ic, ih * iw); - } + + conv_bias::pack_src_fp32_nchw44( + sptr, origin_sptr, ph, pw, remain_right_pad, + ih_real - src_top_pad - src_bottom_pad, iw, iw2, src_top_pad, + src_bottom_pad, ic, ih * iw); const float* fptr = kern_param.filter(group_id) + oc_idx * fh * fw * ic; @@ -135,25 +129,9 @@ static void do_conv_kern(const WorkspaceBundle& bundle, kern_param.bias(batch_id, group_id) + bias_offset; Op op; - if (stride == 1) { -#define KERN1_NCHW44_CONV(filter) \ - conv_bias::conv_direct_stride1_##filter##x##filter##_fp32_nchw44< \ - \ - bias_mode, Op>(sptr, fptr, bptr, nullptr, dst, oc_block, ic, \ - ih_real, iw2, oh, oh_block_real, ow, op, ph, pw) - - DISPATCH_FILTER(filter, KERN1_NCHW44_CONV); -#undef KERN1_NCHW44_CONV - } else { -#define KERN1_NCHW44_CONV(filter) \ - conv_bias::conv_direct_stride2_##filter##x##filter##_fp32_nchw44< \ - \ - bias_mode, Op>(sptr, fptr, bptr, nullptr, dst, oc_block, ic, \ - ih_real, iw2, oh, oh_block_real, ow, op, ph, pw) - - DISPATCH_FILTER(filter, KERN1_NCHW44_CONV); -#undef KERN1_NCHW44_CONV - } + conv_bias::conv_direct_fp32_nchw44( + sptr, fptr, bptr, nullptr, dst, oc_block, ic, ih_real, iw2, oh, + oh_block_real, ow, op, ph, pw); } } // namespace diff --git a/dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw44_kern.h b/dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw44_kern.h new file mode 100644 index 00000000..2d5c7d9e --- /dev/null +++ b/dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw44_kern.h @@ -0,0 +1,34 @@ +/** + * \file dnn/src/arm_common/conv_bias/fp32/f32_direct_stride1_nchw44_kern.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "src/arm_common/conv_bias/opr_impl.h" +#include "src/fallback/conv_bias/common.h" +namespace megdnn { +namespace arm_common { +namespace conv_bias { + +template +void conv_direct_fp32_nchw44(const float* src, const float* filter, + const float* bias, float*, float* dst, + const int oc, const int ic, const int ih, + const int iw, const int oh, const int oh_block, + const int ow, const Op& op, const int, const int); +template +void pack_src_fp32_nchw44(float* sptr_base, const float* sptr_origin, const int, + const int pw, const int pad_right, const int ih, + const int iw, const int iw2, const int pad_top, + const int pad_bottom, const int ic, + const int ic_stride); + +} // namespace conv_bias +} // namespace arm_common +} // namespace megdnn diff --git a/dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_algo.cpp b/dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_algo.cpp index 82dd3231..41d7dad9 100644 --- a/dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_algo.cpp +++ b/dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_algo.cpp @@ -120,7 +120,8 @@ static void pack_weight(const WorkspaceBundle& bundle, kern_param.filter(group_id) + oc_idx * fh * fw * ic; auto packed_weight = reinterpret_cast(bundle.get(1)) + group_id * oc * ic * fh * fw + oc_idx * ic * fh * fw; - pack_weight_fp32_nchw_nchw44(fptr, packed_weight, oc_block, fh, fw, ic); + fp32_direct_nchw_nchw44::pack_weight_fp32_nchw_nchw44(fptr, packed_weight, + oc_block, fh, fw, ic); } template @@ -180,7 +181,8 @@ static void do_conv_kern(const WorkspaceBundle& bundle, kern_param.bias(batch_id, group_id) + oc_idx; Op op; - conv_direct_fp32_nchw_nchw44( + fp32_direct_nchw_nchw44::conv_direct_fp32_nchw_nchw44( sptr, packed_weight, bptr, nullptr, dst, oc_block, ic, ih_real, iw2, oh, oh_block_real, ow, op, ph, pw); } diff --git a/dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_kern.h b/dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_kern.h index 70b7a1de..1617eedd 100644 --- a/dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_kern.h +++ b/dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_kern.h @@ -20,295 +20,12 @@ #include "src/fallback/conv_bias/common.h" namespace megdnn { namespace arm_common { -namespace { -/** - *\brief ShiftCalHelper is core calculate code - *\tparam src_idx is offset for src regs - *\tparam weight_idx is offset for weight regs - *\tparam T is type of output regs - *\tparam T2 is type of src regs - *\tparam T3 is type of weight regs - */ -template -struct ShiftCalHelper { - static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight); -}; - -template -struct ShiftCalHelper { - static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) { -#define cb(step) \ - c[0][step] = Func::template impl<(step * stride + src_idx) % 4>( \ - c[0][step], weight[0][weight_idx], \ - src[(step * stride + src_idx) / 4]); \ - c[1][step] = Func::template impl<(step * stride + src_idx) % 4>( \ - c[1][step], weight[1][weight_idx], \ - src[(step * stride + src_idx) / 4]); - - UNROLL_CALL_RAW(8, cb); -#undef cb - } -}; -template -struct ShiftCalHelper { - static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) { -#define cb(step) \ - c[0][step] = Func::template impl<(step * stride + src_idx) % 4>( \ - c[0][step], weight[0][weight_idx], \ - src[(step * stride + src_idx) / 4]); - - UNROLL_CALL_RAW(8, cb); -#undef cb - } -}; - -template -MEGDNN_ALWAYS_INLINE void cal_helper(T& c, T2& src, T3& weight) { - ShiftCalHelper::impl( - c, src, weight); -}; -template -struct OCHelper { -public: - static const int val = -1; -}; - -template <> -struct OCHelper<4> { -public: - static const int val = 1; -}; - -template <> -struct OCHelper<8> { -public: - static const int val = 2; -}; -/** - * oc8_ow8(m = 8, n = 8) and oc4_ow8(m = 4, n = 8) gemm like kernel - **/ -template -struct KerNeonXXs2NchwNchw44FP32 { - static void impl(const float32_t* src_ptr, const float32_t* weight_ptr, - const float32_t* bias_ptr, float32_t* dst_ptr, int ic, - int ih, int iw, int ld_dst_oc, const Op& op); -}; -template -struct KerNeonXXs2NchwNchw44FP32 { - static void impl(const float32_t* src_ptr, const float32_t* weight_ptr, - const float32_t* bias_ptr, float32_t* dst_ptr, int ic, - int ih, int iw, int ld_dst_oc, const Op& op) { - constexpr int loop_ic_step = 1; - constexpr int filter_size = 7; - constexpr int oc_step = 4; - constexpr int simd_len = 4; - constexpr int src_reg_size = - (ow_block * stride + filter_size - stride + simd_len - 1) / - simd_len; - - constexpr int ld_weight_fw = oc_step * filter_size; - const int ld_weight_oc = oc_step * filter_size * filter_size * ic; - const int ld_weight_ic = oc_step * filter_size * filter_size; - const int ld_src_ic = ih * iw; - constexpr int c_dim = OCHelper::val; - float32x4_t c[c_dim][8]; - init_ocx_ow8(c, bias_ptr, oc_step); - - for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { - float32x4_t src[src_reg_size]; - float32x4_t weight[c_dim][filter_size]; - -#define KERNEL_CB(step) \ - load_helper( \ - src, src_ptr + step * iw, 0); \ - load_helper( \ - weight, weight_ptr + step * ld_weight_fw, ld_weight_oc); \ - cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); \ - cal_helper<1, 1, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); \ - cal_helper<2, 2, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); \ - cal_helper<3, 3, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); \ - cal_helper<4, 4, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); \ - cal_helper<5, 5, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); \ - cal_helper<6, 6, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); - - UNROLL_CALL_RAW(7, KERNEL_CB) -#undef KERNEL_CB - - src_ptr += ld_src_ic; - weight_ptr += ld_weight_ic; - } - store_ocx_ow8_remain_static(c, op, dst_ptr, - ld_dst_oc); - } -}; -template -struct KerNeonXXs2NchwNchw44FP32 { - static void impl(const float32_t* src_ptr, const float32_t* weight_ptr, - const float32_t* bias_ptr, float32_t* dst_ptr, int ic, - int ih, int iw, int ld_dst_oc, const Op& op) { - constexpr int loop_ic_step = 1; - constexpr int filter_size = 5; - constexpr int oc_step = 4; - constexpr int simd_len = 4; - constexpr int src_reg_size = - (ow_block * stride + filter_size - stride + simd_len - 1) / - simd_len; - - constexpr int ld_weight_fw = oc_step * filter_size; - const int ld_weight_oc = oc_step * filter_size * filter_size * ic; - const int ld_weight_ic = oc_step * filter_size * filter_size; - const int ld_src_ic = ih * iw; - constexpr int c_dim = OCHelper::val; - float32x4_t c[c_dim][8]; - init_ocx_ow8(c, bias_ptr, oc_step); - - for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { - float32x4_t src[src_reg_size]; - float32x4_t weight[c_dim][filter_size]; - -#define KERNEL_CB(step) \ - load_helper( \ - src, src_ptr + step * iw, 0); \ - load_helper( \ - weight, weight_ptr + step * ld_weight_fw, ld_weight_oc); \ - cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); \ - cal_helper<1, 1, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); \ - cal_helper<2, 2, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); \ - cal_helper<3, 3, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); \ - cal_helper<4, 4, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); - UNROLL_CALL_RAW(5, KERNEL_CB) -#undef KERNEL_CB - - src_ptr += ld_src_ic; - weight_ptr += ld_weight_ic; - } - store_ocx_ow8_remain_static(c, op, dst_ptr, - ld_dst_oc); - } -}; - -template -struct KerNeonXXs2NchwNchw44FP32 { - static void impl(const float32_t* src_ptr, const float32_t* weight_ptr, - const float32_t* bias_ptr, float32_t* dst_ptr, int ic, - int ih, int iw, int ld_dst_oc, const Op& op) { - constexpr int loop_ic_step = 1; - constexpr int filter_size = 3; - constexpr int oc_step = 4; - constexpr int simd_len = 4; - constexpr int src_reg_size = - (ow_block * stride + filter_size - stride + simd_len - 1) / - simd_len; - - constexpr int ld_weight_fw = oc_step * filter_size; - const int ld_weight_oc = oc_step * filter_size * filter_size * ic; - const int ld_weight_ic = oc_step * filter_size * filter_size; - const int ld_src_ic = ih * iw; - constexpr int c_dim = OCHelper::val; - float32x4_t c[c_dim][8]; - init_ocx_ow8(c, bias_ptr, oc_step); - - for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { - float32x4_t src[src_reg_size]; - float32x4_t weight[c_dim][filter_size]; - // row 0 - load_helper(src, src_ptr, - 0); - load_helper( - weight, weight_ptr, ld_weight_oc); - cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); - cal_helper<1, 1, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); - cal_helper<2, 2, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); +namespace fp32_direct_nchw_nchw44 { - // row 1 - load_helper( - src, src_ptr + iw, 0); - load_helper( - weight, weight_ptr + 1 * ld_weight_fw, ld_weight_oc); - cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); - cal_helper<1, 1, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); - cal_helper<2, 2, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); - - // row 2 - load_helper( - src, src_ptr + 2 * iw, 0); - load_helper( - weight, weight_ptr + 2 * ld_weight_fw, ld_weight_oc); - cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); - cal_helper<1, 1, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); - cal_helper<2, 2, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); - - src_ptr += ld_src_ic; - weight_ptr += ld_weight_ic; - } - store_ocx_ow8_remain_static(c, op, dst_ptr, - ld_dst_oc); - } -}; - -template -struct KerNeonXXs2NchwNchw44FP32 { - static void impl(const float32_t* src_ptr, const float32_t* weight_ptr, - const float32_t* bias_ptr, float32_t* dst_ptr, int ic, - int ih, int iw, int ld_dst_oc, const Op& op) { - constexpr int loop_ic_step = 1; - constexpr int filter_size = 2; - constexpr int oc_step = 4; - constexpr int simd_len = 4; - constexpr int src_reg_size = - (ow_block * stride + filter_size - stride + simd_len - 1) / - simd_len; - - constexpr int ld_weight_fw = oc_step * filter_size; - const int ld_weight_oc = oc_step * filter_size * filter_size * ic; - const int ld_weight_ic = oc_step * filter_size * filter_size; - const int ld_src_ic = ih * iw; - constexpr int c_dim = OCHelper::val; - float32x4_t c[c_dim][8]; - init_ocx_ow8(c, bias_ptr, oc_step); - - for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { - float32x4_t src[src_reg_size]; - float32x4_t weight[c_dim][filter_size]; - // row 0 - load_helper(src, src_ptr, - 0); - load_helper( - weight, weight_ptr, ld_weight_oc); - cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); - cal_helper<1, 1, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); - - // row 1 - load_helper( - src, src_ptr + iw, 0); - load_helper( - weight, weight_ptr + 1 * ld_weight_fw, ld_weight_oc); - cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); - cal_helper<1, 1, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); - - src_ptr += ld_src_ic; - weight_ptr += ld_weight_ic; - } - store_ocx_ow8_remain_static(c, op, dst_ptr, - ld_dst_oc); - } -}; -void pack_weight_fp32_nchw_nchw44(const float32_t* in_ptr, float32_t* dst_ptr, - const int oc, const int kh, const int kw, - const int ic) { +static inline void pack_weight_fp32_nchw_nchw44(const float32_t* in_ptr, + float32_t* dst_ptr, + const int oc, const int kh, + const int kw, const int ic) { constexpr int oc_step = 4; const int filter_oc_stride = kh * kw * ic; const int filter_ic_stride = kh * kw * oc_step; @@ -327,115 +44,15 @@ void pack_weight_fp32_nchw_nchw44(const float32_t* in_ptr, float32_t* dst_ptr, } } } - template -static void conv_direct_fp32_nchw_nchw44( - const float32_t* src, const float32_t* filter, const float32_t* bias, - float32_t*, float32_t* dst, const int oc, const int ic, const int ih, - const int iw, const int oh, const int oh_block, const int ow, - const Op& op, const int, const int) { - constexpr int fh = filter_size; - constexpr int fw = filter_size; - constexpr int ic_step = 1; - constexpr int big_oc_step = 8; - constexpr int oc_step = 4; - constexpr int ih_step = 1; - constexpr int oh_step = 1; - constexpr int ow_step = 8; - constexpr int stride_h = stride; - constexpr int stride_w = stride; - constexpr int pack_iw_len = 1; +void conv_direct_fp32_nchw_nchw44(const float32_t* src, const float32_t* filter, + const float32_t* bias, float32_t*, + float32_t* dst, const int oc, const int ic, + const int ih, const int iw, const int oh, + const int oh_block, const int ow, + const Op& op, const int, const int); +} // namespace fp32_direct_nchw_nchw44 - const int img_stride = oh * ow; - const int ow_end = ow / ow_step * ow_step; - const int ow_remain = ow - ow_end; - const int oc_end = oc / big_oc_step * big_oc_step; - const int oc_remain = oc - oc_end; - const int ld_dst_oc = oc_step * img_stride; - - using remain_fun = std::function; - remain_fun kern_big_oc_remain = nullptr; - remain_fun kern_small_oc_remain = nullptr; - - switch (ow_remain) { -#define cb(step) \ - case step: \ - kern_big_oc_remain = \ - KerNeonXXs2NchwNchw44FP32::impl; \ - kern_small_oc_remain = \ - KerNeonXXs2NchwNchw44FP32::impl; \ - break; - - UNROLL_CALL_RAW(8, cb); - default: - megdnn_assert(0, "no remain %d for kern", ow_remain); - } - for (int oc_idx = 0; oc_idx < oc_end; oc_idx += big_oc_step) { - const int weight_offset = oc_idx * ic * fh * fw; - for (int oh_idx = 0; oh_idx < oh_block; oh_idx += oh_step) { - for (int ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { - const int src_offset = - (oh_idx * stride_h * iw + ow_idx * stride_w * ih_step) * - ic_step * pack_iw_len; - const int dst_offset = - oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; - KerNeonXXs2NchwNchw44FP32::impl(src + src_offset, - filter + weight_offset, - bias + oc_idx, - dst + dst_offset, ic, - ih, iw, ld_dst_oc, op); - } - if (ow_remain > 0) { - const int src_offset = - (oh_idx * stride_h * iw + ow_end * stride_w * ih_step) * - ic_step * pack_iw_len; - const int dst_offset = - oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; - kern_big_oc_remain(src + src_offset, filter + weight_offset, - bias + oc_idx, dst + dst_offset, ic, ih, iw, - ld_dst_oc, op); - } - } - } - if (oc_remain > 0) { - int oc_idx = oc_end; - const int weight_offset = oc_idx * ic * fh * fw; - for (int oh_idx = 0; oh_idx < oh_block; oh_idx += oh_step) { - for (int ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { - const int src_offset = - (oh_idx * stride_h * iw + ow_idx * stride_w * ih_step) * - ic_step * pack_iw_len; - const int dst_offset = - oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; - KerNeonXXs2NchwNchw44FP32::impl(src + src_offset, - filter + weight_offset, - bias + oc_idx, - dst + dst_offset, ic, - ih, iw, ld_dst_oc, op); - } - if (ow_remain > 0) { - const int src_offset = - (oh_idx * stride_h * iw + ow_end * stride_w * ih_step) * - ic_step * pack_iw_len; - const int dst_offset = - oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; - kern_small_oc_remain(src + src_offset, filter + weight_offset, - bias + oc_idx, dst + dst_offset, ic, ih, - iw, ld_dst_oc, op); - } - } - } -} -} // namespace } // namespace arm_common } // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/f32_direct_stride1_nchw44_kern.h b/dnn/src/arm_common/conv_bias/fp32/f32_direct_stride1_nchw44_kern.h deleted file mode 100644 index c58d3d91..00000000 --- a/dnn/src/arm_common/conv_bias/fp32/f32_direct_stride1_nchw44_kern.h +++ /dev/null @@ -1,40 +0,0 @@ -/** - * \file dnn/src/arm_common/conv_bias/fp32/f32_direct_stride1_nchw44_kern.h - * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") - * - * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or - * implied. - */ - -#include "src/arm_common/conv_bias/opr_impl.h" -#include "src/fallback/conv_bias/common.h" -namespace megdnn { -namespace arm_common { -namespace conv_bias { -#define KERN(stride, i, layout) \ - template \ - void conv_direct_##stride##_##i##x##i##_fp32_##layout( \ - const float* src, const float* filter, const float* bias, \ - float* temp, float* dst, const int oc, const int ic, const int ih, \ - const int iw, const int oh, const int oh_block, const int ow, \ - const Op& op, const int ph, const int pw); - -KERN(stride1, 2, nchw44) -KERN(stride1, 3, nchw44) -KERN(stride1, 5, nchw44) -KERN(stride1, 7, nchw44) -#undef KERN - -void pack_src_fp32_nchw44_stride1(float* sptr_base, const float* sptr_origin, - const int ph, const int pw, - const int pad_right, const int ih, - const int iw, const int iw2, - const int pad_top, const int pad_bottom, - const int ic, const int ic_stride); -} // namespace conv_bias -} // namespace arm_common -} // namespace megdnn diff --git a/dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw44_kern.h b/dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw44_kern.h deleted file mode 100644 index a0d852a2..00000000 --- a/dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw44_kern.h +++ /dev/null @@ -1,40 +0,0 @@ -/** - * \file dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw44_kern.h - * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") - * - * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or - * implied. - */ - -#include "src/arm_common/conv_bias/opr_impl.h" -#include "src/fallback/conv_bias/common.h" -namespace megdnn { -namespace arm_common { -namespace conv_bias { -#define KERN(stride, i, layout) \ - template \ - void conv_direct_##stride##_##i##x##i##_fp32_##layout( \ - const float* src, const float* filter, const float* bias, \ - float* temp, float* dst, const int oc, const int ic, const int ih, \ - const int iw, const int oh, const int oh_block, const int ow, \ - const Op& op, const int ph, const int pw); - -KERN(stride2, 2, nchw44) -KERN(stride2, 3, nchw44) -KERN(stride2, 5, nchw44) -KERN(stride2, 7, nchw44) -#undef KERN - -void pack_src_fp32_nchw44_stride2(float* sptr_base, const float* sptr_origin, - const int ph, const int pw, - const int pad_right, const int ih, - const int iw, const int iw2, - const int pad_top, const int pad_bottom, - const int ic, const int ic_stride); -} // namespace conv_bias -} // namespace arm_common -} // namespace megdnn diff --git a/dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44.cpp b/dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44.cpp index 1ebb1eef..fbe0e6c6 100644 --- a/dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44.cpp +++ b/dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44.cpp @@ -6,7 +6,8 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #ifdef __ARM_FEATURE_DOTPROD @@ -17,7 +18,7 @@ #include "src/fallback/conv_bias/common.h" #include "src/arm_common/conv_bias/int8/direct_dotprod_nchw44.h" -#include "src/arm_common/conv_bias/int8/direct_dotprod_nchw44_kern.h" + namespace megdnn { namespace arm_common { namespace direct_dotprod_nchw44 { @@ -139,234 +140,9 @@ void copy_packed_src_int8_nchw44<2>(int8_t* dst, const int dst_step, } } -template -void conv_direct_sdot_int8_nchw44(dst_type* dst, const int oh, const int ow, - const int8_t* src, const int ih, const int iw, - const int8_t* filter, const int32_t* bias, - const int oh_size, const int oc, const int ic, - const Op& op) { - constexpr int FH = filter_size; - constexpr int FW = filter_size; - constexpr int IC_PACK_SIZE = 4; - constexpr int OC_PACK_SIZE = 4; - -#if MEGDNN_AARCH64 - constexpr int OC_BIG_INTERVAL = 12; - constexpr int OC_MID_INTERVAL = 8; - constexpr int OC_SMA_INTERVAL = 4; -#else - constexpr int OC_BIG_INTERVAL = 4; - constexpr int OC_MID_INTERVAL = 4; - constexpr int OC_SMA_INTERVAL = 4; -#endif - - constexpr int OW_INTERVAL = 8; - constexpr int SH = stride; - - const int dst_numbers_per_channel = oh * ow; - const int ow_remain = ow % OW_INTERVAL; - const int ow_end_idx = ow - ow_remain; - const int oc_remain = - oc % OC_BIG_INTERVAL; //! NCHW44 means oc_remain = 4 or 8 - const int oc_end_idx = oc - oc_remain; - const int dst_numbers_4channel_packed = - dst_numbers_per_channel * OC_PACK_SIZE; - - using remain_fun = std::function; - - remain_fun kern_big_oc_remain = nullptr; - remain_fun kern_mid_oc_remain = nullptr; - remain_fun kern_sma_oc_remain = nullptr; - - switch (ow_remain) { -#define cb(step) \ - case step: \ - kern_big_oc_remain = \ - KernNeonSdotNCHW44::impl; \ - kern_mid_oc_remain = \ - KernNeonSdotNCHW44::impl; \ - kern_sma_oc_remain = \ - KernNeonSdotNCHW44::impl; \ - break; - UNROLL_CALL_RAW(8, cb); -#undef cb - default: - megdnn_assert(0, "no remain %d for kern", ow_remain); - } - - //! filter layout is [OC/4, IC/4, FH, FW, 4OC, 4IC] - //! cut [oc, oh, ow] into [oc/OC_INTERVAL, 1, ow/OW_INTERVAL, OW_INTERVAL, - //! oh, OC_INTERVAL] to calculate KernNeonSdotNCHW44 calculates - //! [OW_INTERVAL, 1, OC_INTERVAL] each time - for (int oc_idx = 0; oc_idx < oc_end_idx; oc_idx += OC_BIG_INTERVAL) { - const int filter_offset_in_element = oc_idx * ic * FH * FW; - for (int oh_idx = 0; oh_idx < oh_size; ++oh_idx) { - for (int ow_idx = 0; ow_idx < ow_end_idx; ow_idx += OW_INTERVAL) { - const int src_offset_in_element = - (oh_idx * SH * iw + ow_idx) * IC_PACK_SIZE; - const int dst_offset_in_element = - oc_idx * dst_numbers_per_channel + - (oh_idx * ow + ow_idx) * OC_PACK_SIZE; - const int bias_offset_in_element = oc_idx; - KernNeonSdotNCHW44:: - impl(dst + dst_offset_in_element, - dst_numbers_4channel_packed, - src + src_offset_in_element, ih, iw, - filter + filter_offset_in_element, - bias + bias_offset_in_element, ic, op); - } - if (ow_remain) { - const int src_offset_in_element = - (oh_idx * SH * iw + ow_end_idx) * IC_PACK_SIZE; - const int dst_offset_in_element = - oc_idx * dst_numbers_per_channel + - (oh_idx * ow + ow_end_idx) * OC_PACK_SIZE; - const int bias_offset_in_element = oc_idx; - kern_big_oc_remain(dst + dst_offset_in_element, - dst_numbers_4channel_packed, - src + src_offset_in_element, ih, iw, - filter + filter_offset_in_element, - bias + bias_offset_in_element, ic, op); - } - } - } - -#ifdef MEGDNN_AARCH64 - //! oc_remain must be 4 or 8 on aarch64 and must be 0 on aarch32 - if (oc_remain) { - int oc_idx = oc_end_idx; - const int filter_offset_in_element = oc_idx * ic * FH * FW; - for (int oh_idx = 0; oh_idx < oh_size; ++oh_idx) { - for (int ow_idx = 0; ow_idx < ow_end_idx; ow_idx += OW_INTERVAL) { - const int src_offset_in_element = - (oh_idx * SH * iw + ow_idx) * IC_PACK_SIZE; - const int dst_offset_in_element = - oc_idx * dst_numbers_per_channel + - (oh_idx * ow + ow_idx) * OC_PACK_SIZE; - const int bias_offset_in_element = oc_idx; - if (oc_remain == 8) { - KernNeonSdotNCHW44< - dst_type, stride, bias_mode, Op, OW_INTERVAL, - filter_size, OC_MID_INTERVAL, - OW_INTERVAL>::impl(dst + dst_offset_in_element, - dst_numbers_4channel_packed, - src + src_offset_in_element, ih, - iw, - filter + - filter_offset_in_element, - bias + bias_offset_in_element, - ic, op); - } else { - KernNeonSdotNCHW44< - dst_type, stride, bias_mode, Op, OW_INTERVAL, - filter_size, OC_SMA_INTERVAL, - OW_INTERVAL>::impl(dst + dst_offset_in_element, - dst_numbers_4channel_packed, - src + src_offset_in_element, ih, - iw, - filter + - filter_offset_in_element, - bias + bias_offset_in_element, - ic, op); - } - } - if (ow_remain) { - const int src_offset_in_element = - (oh_idx * SH * iw + ow_end_idx) * IC_PACK_SIZE; - const int dst_offset_in_element = - oc_idx * dst_numbers_per_channel + - (oh_idx * ow + ow_end_idx) * OC_PACK_SIZE; - const int bias_offset_in_element = oc_idx; - if (oc_remain == 8) { - kern_mid_oc_remain(dst + dst_offset_in_element, - dst_numbers_4channel_packed, - src + src_offset_in_element, ih, iw, - filter + filter_offset_in_element, - bias + bias_offset_in_element, ic, op); - } else { - kern_sma_oc_remain(dst + dst_offset_in_element, - dst_numbers_4channel_packed, - src + src_offset_in_element, ih, iw, - filter + filter_offset_in_element, - bias + bias_offset_in_element, ic, op); - } - } - } - } -#endif -} - -#define CONSTRUCT_FUNC(filter_size) \ - template \ - void conv_direct_##filter_size##x##filter_size##_int8_nchw44( \ - dst_type* dst, const int oh, const int ow, const int8_t* src, \ - const int ih, const int iw, const int8_t* weight, \ - const int32_t* bias, const int oh_size, const int oc, \ - const int ic, const Op& op) { \ - conv_direct_sdot_int8_nchw44( \ - dst, oh, ow, src, ih, iw, weight, bias, oh_size, oc, ic, op); \ - } - -CONSTRUCT_FUNC(2); -CONSTRUCT_FUNC(3); -CONSTRUCT_FUNC(5); -CONSTRUCT_FUNC(7); -#undef CONSTRUCT_FUNC - -#define INSTANTIATION(dst_type, stride, i, bias_mode, Op) \ - template void conv_direct_##i##x##i##_int8_nchw44( \ - dst_type * dst, const int oh, const int ow, const int8_t* src, \ - const int ih, const int iw, const int8_t* weight, \ - const int32_t* bias, const int oh_size, const int oc, \ - const int ic, const Op& op); - -#define FOR_OP(stride, i, bias_mode) \ - INSTANTIATION(dt_int8, stride, i, bias_mode, \ - TypeCvtOp) \ - INSTANTIATION(dt_int32, stride, i, bias_mode, \ - NoneOp) \ - INSTANTIATION(dt_int8, stride, i, bias_mode, \ - ReluOp) \ - INSTANTIATION(dt_int8, stride, i, bias_mode, \ - HSwishOp) - -#define FOR_BIAS(stride, i) \ - FOR_OP(stride, i, BiasMode::NO_BIAS) \ - FOR_OP(stride, i, BiasMode::BROADCAST_CHANNEL_BIAS) - -#define FOR_FILTER(stride) \ - FOR_BIAS(stride, 2) \ - FOR_BIAS(stride, 3) \ - FOR_BIAS(stride, 5) \ - FOR_BIAS(stride, 7) - -FOR_FILTER(1) -FOR_FILTER(2) - -#undef FOR_STRIDE -#undef FOR_FILTER -#undef FOR_IC -#undef FOR_BIAS -#undef FOR_NONLINEAR -#undef FOR_REMAIN -#undef INSTANTIATION - } // namespace direct_dotprod_nchw44 } // namespace arm_common } // namespace megdnn #endif -//vim: syntax=cpp.doxygen +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44.h b/dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44.h index 809befd0..f5ffac31 100644 --- a/dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44.h +++ b/dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44.h @@ -6,7 +6,8 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #if __ARM_FEATURE_DOTPROD @@ -42,20 +43,13 @@ using BiasMode = ConvBiasForward::BiasMode; * @return none */ -#define KERN(filter_size) \ - template \ - void conv_direct_##filter_size##x##filter_size##_int8_nchw44( \ - dst_type* dst, const int oh, const int ow, const int8_t* src, \ - const int ih, const int iw, const int8_t* weight, \ - const int32_t* bias, const int oh_size, const int oc, \ - const int ic, const Op& op) - -KERN(2); -KERN(3); -KERN(5); -KERN(7); - -#undef KERN +template +void conv_direct_sdot_int8_nchw44(dst_type* dst, const int oh, const int ow, + const int8_t* src, const int ih, const int iw, + const int8_t* filter, const int32_t* bias, + const int oh_size, const int oc, const int ic, + const Op& op); /** * @brief : copy data from src to dst for direct conv with no side effect * @param : [output ptr] dst @@ -84,4 +78,4 @@ void copy_packed_src_int8_nchw44(int8_t* dst, const int dst_step, #endif -//vim: syntax=cpp.doxygen \ No newline at end of file +// vim: syntax=cpp.doxygen \ No newline at end of file diff --git a/dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44_algo.cpp b/dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44_algo.cpp index f7b08374..33d91091 100644 --- a/dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44_algo.cpp +++ b/dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44_algo.cpp @@ -148,14 +148,10 @@ static void conv_kern(const WorkspaceBundle& bundle, float scale_dst = ncb_param.dst_type.param().scale; op = Op(scale_bias, scale_dst); } - -#define KERN1_NCHW44_CONV(filter) \ - direct_dotprod_nchw44::conv_direct_##filter##x##filter##_int8_nchw44< \ - dst_type, bias_mode, Op, stride>(dst, OH, OW, copy_dst, \ - ih_real_size, iw2, weights, bias, \ - oh_real_size, OC, IC, op); - DISPATCH_FILTER(filter_size, KERN1_NCHW44_CONV); -#undef KERN1_NCHW44_CONV + direct_dotprod_nchw44::conv_direct_sdot_int8_nchw44< + dst_type, stride, bias_mode, Op, filter_size>( + dst, OH, OW, copy_dst, ih_real_size, iw2, weights, bias, + oh_real_size, OC, IC, op); } } // namespace @@ -342,4 +338,4 @@ ConvBiasImpl::AlgoDotS8Direct_NCHW44::dispatch_kerns( #endif -//vim: syntax=cpp.doxygen +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44_kern.h b/dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44_kern.h deleted file mode 100644 index 3d9e9a08..00000000 --- a/dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44_kern.h +++ /dev/null @@ -1,435 +0,0 @@ -/** - * \file dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44_kern.h - * - * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or - * implied. - */ -#pragma once -#ifdef __ARM_FEATURE_DOTPROD - -#include "megdnn/arch.h" -#include "src/arm_common/conv_bias/intrinsic_helper.h" -#include "src/arm_common/elemwise_op.h" -#include "src/arm_common/intrinsic_helper.h" -#include "src/arm_common/neon_struct.h" -#include "src/common/unroll_macro.h" - -namespace megdnn { -namespace arm_common { -namespace direct_dotprod_nchw44 { - -constexpr int SIMD_LEN = 16; -constexpr int IC_PACK_SIZE = 4; -constexpr int OC_PACK_SIZE = 4; -constexpr int filter_next_col = - IC_PACK_SIZE * OC_PACK_SIZE; //! [OC/4, IC/4, FH, FW, 4OC, 4IC] - -template -MEGDNN_ALWAYS_INLINE void init_ocx_ow8(int32x4_t c[][8], - const int32_t* bias_ptr, int oc_step) { - static_assert(row == 1 || row == 2 || row == 3, "Invalid OC number."); - if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { -#define BIAS_INIT(step, i) c[i][step] = vld1q_s32(bias_ptr + i * oc_step); - switch (row) { - case 3: - UNROLL_CALL_RAW(8, BIAS_INIT, 2); - case 2: - UNROLL_CALL_RAW(8, BIAS_INIT, 1); - default: - UNROLL_CALL_RAW(8, BIAS_INIT, 0); - } -#undef BIAS_INIT - } else { -#define BIAS_INIT(step, i) c[i][step] = vdupq_n_s32(0); - switch (row) { - case 3: - UNROLL_CALL_RAW(8, BIAS_INIT, 2); - case 2: - UNROLL_CALL_RAW(8, BIAS_INIT, 1); - default: - UNROLL_CALL_RAW(8, BIAS_INIT, 0); - } -#undef BIAS_INIT - } -} - -#define cb11(col) \ - op(res[0][col], reinterpret_cast(dst_ptr + col / 2 * 8)); - -#define cb21(col) \ - op(res[0][col], reinterpret_cast(dst_ptr + col / 2 * 8)); \ - op(res[1][col], \ - reinterpret_cast(dst_ptr + ld_dst_oc + col / 2 * 8)); - -#define cb31(col) \ - op(res[0][col], reinterpret_cast(dst_ptr + col / 2 * 8)); \ - op(res[1][col], \ - reinterpret_cast(dst_ptr + ld_dst_oc + col / 2 * 8)); \ - op(res[2][col], reinterpret_cast(dst_ptr + ld_dst_oc + \ - ld_dst_oc + col / 2 * 8)); - -#define cb12(step) \ - op({{res[0][2 * step], res[0][2 * step + 1]}}, \ - reinterpret_cast(dst_ptr + step * 8)); - -#define cb22(step) \ - op({{res[0][2 * step], res[0][2 * step + 1]}}, \ - reinterpret_cast(dst_ptr + step * 8)); \ - op({{res[1][2 * step], res[1][2 * step + 1]}}, \ - reinterpret_cast(dst_ptr + ld_dst_oc + step * 8)); - -#define cb32(step) \ - op({{res[0][2 * step], res[0][2 * step + 1]}}, \ - reinterpret_cast(dst_ptr + step * 8)); \ - op({{res[1][2 * step], res[1][2 * step + 1]}}, \ - reinterpret_cast(dst_ptr + ld_dst_oc + step * 8)); \ - op({{res[2][2 * step], res[2][2 * step + 1]}}, \ - reinterpret_cast(dst_ptr + 2 * ld_dst_oc + step * 8)); - -template -struct StoreOCxOWx { - static MEGDNN_ALWAYS_INLINE void impl(int32x4_t res[][8], const Op& op, - T* dst_ptr, const int ld_dst_oc); -}; - -template -struct StoreOCxOWx<1, ow_remain, Op, T> { - - static void impl(int32x4_t res[][8], const Op& op, T* dst_ptr, - const int ld_dst_oc) { - MEGDNN_MARK_USED_VAR(ld_dst_oc); - switch (ow_remain) { - case 8: - UNROLL_CALL_RAW(4, cb12); - break; - case 7: - cb11(6); - case 6: - UNROLL_CALL_RAW(3, cb12); - break; - case 5: - cb11(4); - case 4: - UNROLL_CALL_RAW(2, cb12); - break; - case 3: - cb11(2); - case 2: - UNROLL_CALL_RAW(1, cb12); - break; - case 1: - cb11(0); - default: - break; - } - } -}; - -template -struct StoreOCxOWx<2, ow_remain, Op, T> { - static MEGDNN_ALWAYS_INLINE void impl(int32x4_t res[][8], const Op& op, - T* dst_ptr, const int ld_dst_oc) { - switch (ow_remain) { - case 8: - UNROLL_CALL_RAW(4, cb22); - break; - case 7: - cb21(6); - case 6: - UNROLL_CALL_RAW(3, cb22); - break; - case 5: - cb21(4); - case 4: - UNROLL_CALL_RAW(2, cb22); - break; - case 3: - cb21(2); - case 2: - UNROLL_CALL_RAW(1, cb22); - break; - case 1: - cb21(0); - default: - break; - } - } -}; - -template -struct StoreOCxOWx<3, ow_remain, Op, T> { - static MEGDNN_ALWAYS_INLINE void impl(int32x4_t res[][8], const Op& op, - T* dst_ptr, const int ld_dst_oc) { - switch (ow_remain) { - case 8: - UNROLL_CALL_RAW(4, cb32); - break; - case 7: - cb31(6); - case 6: - UNROLL_CALL_RAW(3, cb32); - break; - case 5: - cb31(4); - case 4: - UNROLL_CALL_RAW(2, cb32); - break; - case 3: - cb31(2); - case 2: - UNROLL_CALL_RAW(1, cb32); - break; - case 1: - cb31(0); - default: - break; - } - } -}; - -#undef cb11 -#undef cb21 -#undef cb31 -#undef cb12 -#undef cb22 -#undef cb32 - -template -MEGDNN_ALWAYS_INLINE void store_ocx_owx_remain_static(int32x4_t res[][8], - const Op& op, T* dst_ptr, - const int ld_dst_oc) { - StoreOCxOWx::impl(res, op, dst_ptr, ld_dst_oc); -} - -template -struct ShiftCalHelper { - static MEGDNN_ALWAYS_INLINE void impl(T& res, T2& src, T3& weight) { -#define cb(step) \ - res[res_row][step] = FUNC::template impl<((src_start_idx + step) % 4)>( \ - res[res_row][step], weight[weight_idx], \ - src[src_row][(src_start_idx + step) / 4]); - UNROLL_CALL_RAW(8, cb); -#undef cb - } -}; - -template -MEGDNN_ALWAYS_INLINE void cal_helper(T& res, T2& src, T3& weight) { - ShiftCalHelper::impl(res, src, weight); -}; - -/** - * oc12_owx(m = 12, n = x) and oc8_owx(m = 8, n = x) and oc4_owx(m = 4, n = x) - * gemm like kernel - * */ -template -struct KernNeonSdotNCHW44 { - static void impl(dst_type* dst, const int dst_step, const int8_t* src, - const int ih, const int iw, const int8_t* filter, - const int32_t* bias, const int ic, const Op& op); -}; - -template -struct KernNeonSdotNCHW44 { - static void impl(dst_type* dst, const int dst_step, const int8_t* src, - const int ih, const int iw, const int8_t* filter, - const int32_t* bias, const int ic, const Op& op) { - constexpr int FH = filter_size; - constexpr int FW = filter_size; - constexpr int filter_next_row = - FW * OC_PACK_SIZE * - IC_PACK_SIZE; //! [OC/4, IC/4, FH, FW, 4OC, 4IC] - - const int filter_next_4oc = - FH * FW * ic * OC_PACK_SIZE; //! [OC/4, IC/4, FH, FW, 4OC, 4IC] - const int src_next_ic = ih * iw; - const int src_next_row = iw * IC_PACK_SIZE; - - constexpr int NSRC = (ow_interval + filter_size - 1) / 4 + 1; - constexpr int LOOP = oc_interval / 4; - - int32x4_t res[3][ow_interval]; - init_ocx_ow8(res, bias, OC_PACK_SIZE); - - for (int ic_idx = 0; ic_idx < ic; ic_idx += IC_PACK_SIZE) { - const int8_t* i_src = src + ic_idx * src_next_ic; - const int8_t* i_filter = filter + ic_idx * FH * FW * OC_PACK_SIZE; - for (int fh_idx = 0; fh_idx < FH; ++fh_idx) { - int8x16_t src[1][4]; - int8x16_t weight[3]; - - load_helper(src, i_src, 0); - -//! do not use switch order 3,2,1 because it will slow the speed. -#define CALC_PART(step) \ - switch (LOOP) { \ - case 1: \ - weight[0] = vld1q_s8(i_filter + filter_next_4oc * 0 + \ - filter_next_col * step); \ - cal_helper<0, 0, step, 0, Vdotq_laneq_s32>(res, src, weight); \ - break; \ - case 2: \ - weight[0] = vld1q_s8(i_filter + filter_next_4oc * 0 + \ - filter_next_col * step); \ - cal_helper<0, 0, step, 0, Vdotq_laneq_s32>(res, src, weight); \ - weight[1] = vld1q_s8(i_filter + filter_next_4oc * 1 + \ - filter_next_col * step); \ - cal_helper<1, 0, step, 1, Vdotq_laneq_s32>(res, src, weight); \ - break; \ - case 3: \ - weight[0] = vld1q_s8(i_filter + filter_next_4oc * 0 + \ - filter_next_col * step); \ - cal_helper<0, 0, step, 0, Vdotq_laneq_s32>(res, src, weight); \ - weight[1] = vld1q_s8(i_filter + filter_next_4oc * 1 + \ - filter_next_col * step); \ - cal_helper<1, 0, step, 1, Vdotq_laneq_s32>(res, src, weight); \ - weight[2] = vld1q_s8(i_filter + filter_next_4oc * 2 + \ - filter_next_col * step); \ - cal_helper<2, 0, step, 2, Vdotq_laneq_s32>(res, src, weight); \ - break; \ - default: \ - break; \ - } - - switch (filter_size) { - case 2: - UNROLL_CALL_RAW(2, CALC_PART); - break; - case 3: - UNROLL_CALL_RAW(3, CALC_PART); - break; - case 5: - UNROLL_CALL_RAW(5, CALC_PART); - break; - case 7: - UNROLL_CALL_RAW(7, CALC_PART); - break; - default: - break; - } -#undef CALC_PART - - i_filter += filter_next_row; - i_src += src_next_row; - } - } - store_ocx_owx_remain_static(res, op, dst, - dst_step); - } -}; - -template -struct KernNeonSdotNCHW44 { - static void impl(dst_type* dst, const int dst_step, const int8_t* src, - const int ih, const int iw, const int8_t* filter, - const int32_t* bias, const int ic, const Op& op) { - constexpr int FH = filter_size; - constexpr int FW = filter_size; - constexpr int filter_next_row = - FW * OC_PACK_SIZE * - IC_PACK_SIZE; //! [OC/4, IC/4, FH, FW, 4OC, 4IC] - - const int filter_next_4oc = - FH * FW * ic * OC_PACK_SIZE; //! [OC/4, IC/4, FH, FW, 4OC, 4IC] - const int src_next_ic = ih * iw; - const int src_next_row = iw * IC_PACK_SIZE; - - constexpr int NSRC = (ow_interval * 2 + filter_size - 3) / 8 + 1; - constexpr int LOOP = oc_interval / 4; - - int32x4_t res[3][ow_interval]; - init_ocx_ow8(res, bias, OC_PACK_SIZE); - - for (int ic_idx = 0; ic_idx < ic; ic_idx += IC_PACK_SIZE) { - const int8_t* i_src = src + ic_idx * src_next_ic; - const int8_t* i_filter = filter + ic_idx * FH * FW * OC_PACK_SIZE; - for (int fh_idx = 0; fh_idx < FH; ++fh_idx) { - int8x16_t src[2][3]; - int8x16_t weight[3]; - const int offset = megdnn::div_ceil(iw, 2) * IC_PACK_SIZE; - - load_helper(src, i_src, offset); - -//! do not use switch order 3,2,1 because it will slow the speed. -#define CALC_PART(step) \ - switch (LOOP) { \ - case 1: \ - weight[0] = vld1q_s8(i_filter + filter_next_4oc * 0 + \ - filter_next_col * step); \ - cal_helper<0, step % 2, step / 2, 0, Vdotq_laneq_s32>(res, src, \ - weight); \ - break; \ - case 2: \ - weight[0] = vld1q_s8(i_filter + filter_next_4oc * 0 + \ - filter_next_col * step); \ - cal_helper<0, step % 2, step / 2, 0, Vdotq_laneq_s32>(res, src, \ - weight); \ - weight[1] = vld1q_s8(i_filter + filter_next_4oc * 1 + \ - filter_next_col * step); \ - cal_helper<1, step % 2, step / 2, 1, Vdotq_laneq_s32>(res, src, \ - weight); \ - break; \ - case 3: \ - weight[0] = vld1q_s8(i_filter + filter_next_4oc * 0 + \ - filter_next_col * step); \ - cal_helper<0, step % 2, step / 2, 0, Vdotq_laneq_s32>(res, src, \ - weight); \ - weight[1] = vld1q_s8(i_filter + filter_next_4oc * 1 + \ - filter_next_col * step); \ - cal_helper<1, step % 2, step / 2, 1, Vdotq_laneq_s32>(res, src, \ - weight); \ - weight[2] = vld1q_s8(i_filter + filter_next_4oc * 2 + \ - filter_next_col * step); \ - cal_helper<2, step % 2, step / 2, 2, Vdotq_laneq_s32>(res, src, \ - weight); \ - break; \ - default: \ - break; \ - } - - switch (filter_size) { - case 2: - UNROLL_CALL_RAW(2, CALC_PART); - break; - case 3: - UNROLL_CALL_RAW(3, CALC_PART); - break; - case 5: - UNROLL_CALL_RAW(5, CALC_PART); - break; - case 7: - UNROLL_CALL_RAW(7, CALC_PART); - break; - default: - break; - } -#undef CALC_PART - - i_filter += filter_next_row; - i_src += src_next_row; - } - } - store_ocx_owx_remain_static(res, op, dst, - dst_step); - } -}; - -} // namespace direct_dotprod_nchw44 -} // namespace arm_common -} // namespace megdnn - -#endif - -// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_common.h b/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_common.h new file mode 100644 index 00000000..6e4e1029 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_common.h @@ -0,0 +1,245 @@ +/** + * \file + * dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_common.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#pragma once +#if __ARM_FEATURE_DOTPROD +#include "megdnn/arch.h" +#include "src/arm_common/conv_bias/intrinsic_helper.h" +#include "src/arm_common/elemwise_op.h" +#include "src/arm_common/intrinsic_helper.h" +#include "src/arm_common/neon_struct.h" +#include "src/common/unroll_macro.h" + +namespace megdnn { +namespace arm_common { +namespace direct_dotprod_nchw44 { + +constexpr int SIMD_LEN = 16; +constexpr int IC_PACK_SIZE = 4; +constexpr int OC_PACK_SIZE = 4; +constexpr int filter_next_col = + IC_PACK_SIZE * OC_PACK_SIZE; //! [OC/4, IC/4, FH, FW, 4OC, 4IC] + +template +MEGDNN_ALWAYS_INLINE void init_ocx_ow8(int32x4_t c[][8], + const int32_t* bias_ptr, int oc_step) { + static_assert(row == 1 || row == 2 || row == 3, "Invalid OC number."); + if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { +#define BIAS_INIT(step, i) c[i][step] = vld1q_s32(bias_ptr + i * oc_step); + switch (row) { + case 3: + UNROLL_CALL_RAW(8, BIAS_INIT, 2); + case 2: + UNROLL_CALL_RAW(8, BIAS_INIT, 1); + default: + UNROLL_CALL_RAW(8, BIAS_INIT, 0); + } +#undef BIAS_INIT + } else { +#define BIAS_INIT(step, i) c[i][step] = vdupq_n_s32(0); + switch (row) { + case 3: + UNROLL_CALL_RAW(8, BIAS_INIT, 2); + case 2: + UNROLL_CALL_RAW(8, BIAS_INIT, 1); + default: + UNROLL_CALL_RAW(8, BIAS_INIT, 0); + } +#undef BIAS_INIT + } +} + +#define cb11(col) \ + op(res[0][col], reinterpret_cast(dst_ptr + col / 2 * 8)); + +#define cb21(col) \ + op(res[0][col], reinterpret_cast(dst_ptr + col / 2 * 8)); \ + op(res[1][col], \ + reinterpret_cast(dst_ptr + ld_dst_oc + col / 2 * 8)); + +#define cb31(col) \ + op(res[0][col], reinterpret_cast(dst_ptr + col / 2 * 8)); \ + op(res[1][col], \ + reinterpret_cast(dst_ptr + ld_dst_oc + col / 2 * 8)); \ + op(res[2][col], reinterpret_cast(dst_ptr + ld_dst_oc + \ + ld_dst_oc + col / 2 * 8)); + +#define cb12(step) \ + op({{res[0][2 * step], res[0][2 * step + 1]}}, \ + reinterpret_cast(dst_ptr + step * 8)); + +#define cb22(step) \ + op({{res[0][2 * step], res[0][2 * step + 1]}}, \ + reinterpret_cast(dst_ptr + step * 8)); \ + op({{res[1][2 * step], res[1][2 * step + 1]}}, \ + reinterpret_cast(dst_ptr + ld_dst_oc + step * 8)); + +#define cb32(step) \ + op({{res[0][2 * step], res[0][2 * step + 1]}}, \ + reinterpret_cast(dst_ptr + step * 8)); \ + op({{res[1][2 * step], res[1][2 * step + 1]}}, \ + reinterpret_cast(dst_ptr + ld_dst_oc + step * 8)); \ + op({{res[2][2 * step], res[2][2 * step + 1]}}, \ + reinterpret_cast(dst_ptr + 2 * ld_dst_oc + step * 8)); + +template +struct StoreOCxOWx { + static MEGDNN_ALWAYS_INLINE void impl(int32x4_t res[][8], const Op& op, + T* dst_ptr, const int ld_dst_oc); +}; + +template +struct StoreOCxOWx<1, ow_remain, Op, T> { + static void impl(int32x4_t res[][8], const Op& op, T* dst_ptr, + const int ld_dst_oc) { + MEGDNN_MARK_USED_VAR(ld_dst_oc); + switch (ow_remain) { + case 8: + UNROLL_CALL_RAW(4, cb12); + break; + case 7: + cb11(6); + case 6: + UNROLL_CALL_RAW(3, cb12); + break; + case 5: + cb11(4); + case 4: + UNROLL_CALL_RAW(2, cb12); + break; + case 3: + cb11(2); + case 2: + UNROLL_CALL_RAW(1, cb12); + break; + case 1: + cb11(0); + default: + break; + } + } +}; + +template +struct StoreOCxOWx<2, ow_remain, Op, T> { + static MEGDNN_ALWAYS_INLINE void impl(int32x4_t res[][8], const Op& op, + T* dst_ptr, const int ld_dst_oc) { + switch (ow_remain) { + case 8: + UNROLL_CALL_RAW(4, cb22); + break; + case 7: + cb21(6); + case 6: + UNROLL_CALL_RAW(3, cb22); + break; + case 5: + cb21(4); + case 4: + UNROLL_CALL_RAW(2, cb22); + break; + case 3: + cb21(2); + case 2: + UNROLL_CALL_RAW(1, cb22); + break; + case 1: + cb21(0); + default: + break; + } + } +}; + +template +struct StoreOCxOWx<3, ow_remain, Op, T> { + static MEGDNN_ALWAYS_INLINE void impl(int32x4_t res[][8], const Op& op, + T* dst_ptr, const int ld_dst_oc) { + switch (ow_remain) { + case 8: + UNROLL_CALL_RAW(4, cb32); + break; + case 7: + cb31(6); + case 6: + UNROLL_CALL_RAW(3, cb32); + break; + case 5: + cb31(4); + case 4: + UNROLL_CALL_RAW(2, cb32); + break; + case 3: + cb31(2); + case 2: + UNROLL_CALL_RAW(1, cb32); + break; + case 1: + cb31(0); + default: + break; + } + } +}; + +#undef cb11 +#undef cb21 +#undef cb31 +#undef cb12 +#undef cb22 +#undef cb32 + +template +MEGDNN_ALWAYS_INLINE void store_ocx_owx_remain_static(int32x4_t res[][8], + const Op& op, T* dst_ptr, + const int ld_dst_oc) { + StoreOCxOWx::impl(res, op, dst_ptr, ld_dst_oc); +} + +template +struct ShiftCalHelper { + static MEGDNN_ALWAYS_INLINE void impl(T& res, T2& src, T3& weight) { +#define cb(step) \ + res[res_row][step] = \ + vdotq_laneq_s32(res[res_row][step], weight[weight_idx], \ + src[src_row][(src_start_idx + step) / 4], \ + (src_start_idx + step) % 4); + UNROLL_CALL_RAW(8, cb); +#undef cb + } +}; + +template +MEGDNN_ALWAYS_INLINE void cal_helper(T& res, T2& src, T3& weight) { + ShiftCalHelper::impl(res, src, weight); +}; + +/** + * oc12_owx(m = 12, n = x) and oc8_owx(m = 8, n = x) and oc4_owx(m = 4, n = x) + * gemm like kernel + * */ +template +struct KernNeonSdotNCHW44 { + static void impl(dst_type* dst, const int dst_step, const int8_t* src, + const int ih, const int iw, const int8_t* filter, + const int32_t* bias, const int ic, const Op& op); +}; + +} // namespace direct_dotprod_nchw44 +} // namespace arm_common +} // namespace megdnn +#endif +// vim: syntax=cpp.doxygen \ No newline at end of file diff --git a/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s1.cpp b/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s1.cpp new file mode 100644 index 00000000..737ef33b --- /dev/null +++ b/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s1.cpp @@ -0,0 +1,320 @@ +/** + * \file + * dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s1.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#if __ARM_FEATURE_DOTPROD +#include "src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_common.h" + +namespace megdnn { +namespace arm_common { +namespace direct_dotprod_nchw44 { +template +struct KernNeonSdotNCHW44 { + static void impl(dst_type* dst, const int dst_step, const int8_t* src, + const int ih, const int iw, const int8_t* filter, + const int32_t* bias, const int ic, const Op& op) { + constexpr int FH = filter_size; + constexpr int FW = filter_size; + constexpr int filter_next_row = + FW * OC_PACK_SIZE * + IC_PACK_SIZE; //! [OC/4, IC/4, FH, FW, 4OC, 4IC] + + const int filter_next_4oc = + FH * FW * ic * OC_PACK_SIZE; //! [OC/4, IC/4, FH, FW, 4OC, 4IC] + const int src_next_ic = ih * iw; + const int src_next_row = iw * IC_PACK_SIZE; + + constexpr int NSRC = (ow_interval + filter_size - 1) / 4 + 1; + constexpr int LOOP = oc_interval / 4; + + int32x4_t res[3][ow_interval]; + init_ocx_ow8(res, bias, OC_PACK_SIZE); + + for (int ic_idx = 0; ic_idx < ic; ic_idx += IC_PACK_SIZE) { + const int8_t* i_src = src + ic_idx * src_next_ic; + const int8_t* i_filter = filter + ic_idx * FH * FW * OC_PACK_SIZE; + for (int fh_idx = 0; fh_idx < FH; ++fh_idx) { + int8x16_t src[1][4]; + int8x16_t weight[3]; + + load_helper(src, i_src, 0); + +//! do not use switch order 3,2,1 because it will slow the speed. +#define CALC_PART(step) \ + switch (LOOP) { \ + case 1: \ + weight[0] = vld1q_s8(i_filter + filter_next_4oc * 0 + \ + filter_next_col * step); \ + cal_helper<0, 0, step, 0>(res, src, weight); \ + break; \ + case 2: \ + weight[0] = vld1q_s8(i_filter + filter_next_4oc * 0 + \ + filter_next_col * step); \ + cal_helper<0, 0, step, 0>(res, src, weight); \ + weight[1] = vld1q_s8(i_filter + filter_next_4oc * 1 + \ + filter_next_col * step); \ + cal_helper<1, 0, step, 1>(res, src, weight); \ + break; \ + case 3: \ + weight[0] = vld1q_s8(i_filter + filter_next_4oc * 0 + \ + filter_next_col * step); \ + cal_helper<0, 0, step, 0>(res, src, weight); \ + weight[1] = vld1q_s8(i_filter + filter_next_4oc * 1 + \ + filter_next_col * step); \ + cal_helper<1, 0, step, 1>(res, src, weight); \ + weight[2] = vld1q_s8(i_filter + filter_next_4oc * 2 + \ + filter_next_col * step); \ + cal_helper<2, 0, step, 2>(res, src, weight); \ + break; \ + default: \ + break; \ + } + + switch (filter_size) { + case 2: + UNROLL_CALL_RAW(2, CALC_PART); + break; + case 3: + UNROLL_CALL_RAW(3, CALC_PART); + break; + case 5: + UNROLL_CALL_RAW(5, CALC_PART); + break; + case 7: + UNROLL_CALL_RAW(7, CALC_PART); + break; + default: + break; + } +#undef CALC_PART + + i_filter += filter_next_row; + i_src += src_next_row; + } + } + store_ocx_owx_remain_static(res, op, dst, + dst_step); + } +}; + +template +void conv_direct_sdot_int8_nchw44(dst_type* dst, const int oh, const int ow, + const int8_t* src, const int ih, const int iw, + const int8_t* filter, const int32_t* bias, + const int oh_size, const int oc, const int ic, + const Op& op) { + constexpr int FH = filter_size; + constexpr int FW = filter_size; + constexpr int IC_PACK_SIZE = 4; + constexpr int OC_PACK_SIZE = 4; + +#if MEGDNN_AARCH64 + constexpr int OC_BIG_INTERVAL = 12; + constexpr int OC_MID_INTERVAL = 8; + constexpr int OC_SMA_INTERVAL = 4; +#else + constexpr int OC_BIG_INTERVAL = 4; + constexpr int OC_MID_INTERVAL = 4; + constexpr int OC_SMA_INTERVAL = 4; +#endif + + constexpr int OW_INTERVAL = 8; + constexpr int SH = stride; + + const int dst_numbers_per_channel = oh * ow; + const int ow_remain = ow % OW_INTERVAL; + const int ow_end_idx = ow - ow_remain; + const int oc_remain = + oc % OC_BIG_INTERVAL; //! NCHW44 means oc_remain = 4 or 8 + const int oc_end_idx = oc - oc_remain; + const int dst_numbers_4channel_packed = + dst_numbers_per_channel * OC_PACK_SIZE; + + using remain_fun = std::function; + + remain_fun kern_big_oc_remain = nullptr; + remain_fun kern_mid_oc_remain = nullptr; + remain_fun kern_sma_oc_remain = nullptr; + + switch (ow_remain) { +#define cb(step) \ + case step: \ + kern_big_oc_remain = \ + KernNeonSdotNCHW44::impl; \ + kern_mid_oc_remain = \ + KernNeonSdotNCHW44::impl; \ + kern_sma_oc_remain = \ + KernNeonSdotNCHW44::impl; \ + break; + UNROLL_CALL_RAW(8, cb); +#undef cb + default: + megdnn_assert(0, "no remain %d for kern", ow_remain); + } + + //! filter layout is [OC/4, IC/4, FH, FW, 4OC, 4IC] + //! cut [oc, oh, ow] into [oc/OC_INTERVAL, 1, ow/OW_INTERVAL, OW_INTERVAL, + //! oh, OC_INTERVAL] to calculate KernNeonSdotNCHW44 calculates + //! [OW_INTERVAL, 1, OC_INTERVAL] each time + for (int oc_idx = 0; oc_idx < oc_end_idx; oc_idx += OC_BIG_INTERVAL) { + const int filter_offset_in_element = oc_idx * ic * FH * FW; + for (int oh_idx = 0; oh_idx < oh_size; ++oh_idx) { + for (int ow_idx = 0; ow_idx < ow_end_idx; ow_idx += OW_INTERVAL) { + const int src_offset_in_element = + (oh_idx * SH * iw + ow_idx) * IC_PACK_SIZE; + const int dst_offset_in_element = + oc_idx * dst_numbers_per_channel + + (oh_idx * ow + ow_idx) * OC_PACK_SIZE; + const int bias_offset_in_element = oc_idx; + KernNeonSdotNCHW44:: + impl(dst + dst_offset_in_element, + dst_numbers_4channel_packed, + src + src_offset_in_element, ih, iw, + filter + filter_offset_in_element, + bias + bias_offset_in_element, ic, op); + } + if (ow_remain) { + const int src_offset_in_element = + (oh_idx * SH * iw + ow_end_idx) * IC_PACK_SIZE; + const int dst_offset_in_element = + oc_idx * dst_numbers_per_channel + + (oh_idx * ow + ow_end_idx) * OC_PACK_SIZE; + const int bias_offset_in_element = oc_idx; + kern_big_oc_remain(dst + dst_offset_in_element, + dst_numbers_4channel_packed, + src + src_offset_in_element, ih, iw, + filter + filter_offset_in_element, + bias + bias_offset_in_element, ic, op); + } + } + } + +#ifdef MEGDNN_AARCH64 + //! oc_remain must be 4 or 8 on aarch64 and must be 0 on aarch32 + if (oc_remain) { + int oc_idx = oc_end_idx; + const int filter_offset_in_element = oc_idx * ic * FH * FW; + for (int oh_idx = 0; oh_idx < oh_size; ++oh_idx) { + for (int ow_idx = 0; ow_idx < ow_end_idx; ow_idx += OW_INTERVAL) { + const int src_offset_in_element = + (oh_idx * SH * iw + ow_idx) * IC_PACK_SIZE; + const int dst_offset_in_element = + oc_idx * dst_numbers_per_channel + + (oh_idx * ow + ow_idx) * OC_PACK_SIZE; + const int bias_offset_in_element = oc_idx; + if (oc_remain == 8) { + KernNeonSdotNCHW44< + dst_type, stride, bias_mode, Op, OW_INTERVAL, + filter_size, OC_MID_INTERVAL, + OW_INTERVAL>::impl(dst + dst_offset_in_element, + dst_numbers_4channel_packed, + src + src_offset_in_element, ih, + iw, + filter + + filter_offset_in_element, + bias + bias_offset_in_element, + ic, op); + } else { + KernNeonSdotNCHW44< + dst_type, stride, bias_mode, Op, OW_INTERVAL, + filter_size, OC_SMA_INTERVAL, + OW_INTERVAL>::impl(dst + dst_offset_in_element, + dst_numbers_4channel_packed, + src + src_offset_in_element, ih, + iw, + filter + + filter_offset_in_element, + bias + bias_offset_in_element, + ic, op); + } + } + if (ow_remain) { + const int src_offset_in_element = + (oh_idx * SH * iw + ow_end_idx) * IC_PACK_SIZE; + const int dst_offset_in_element = + oc_idx * dst_numbers_per_channel + + (oh_idx * ow + ow_end_idx) * OC_PACK_SIZE; + const int bias_offset_in_element = oc_idx; + if (oc_remain == 8) { + kern_mid_oc_remain(dst + dst_offset_in_element, + dst_numbers_4channel_packed, + src + src_offset_in_element, ih, iw, + filter + filter_offset_in_element, + bias + bias_offset_in_element, ic, op); + } else { + kern_sma_oc_remain(dst + dst_offset_in_element, + dst_numbers_4channel_packed, + src + src_offset_in_element, ih, iw, + filter + filter_offset_in_element, + bias + bias_offset_in_element, ic, op); + } + } + } + } +#endif +} + +#define INSTANTIATION(dst_type, stride, filter_size, bias_mode, Op) \ + template void conv_direct_sdot_int8_nchw44( \ + dst_type * dst, const int oh, const int ow, const int8_t* src, \ + const int ih, const int iw, const int8_t* weight, \ + const int32_t* bias, const int oh_size, const int oc, \ + const int ic, const Op& op); + +#define FOR_OP(stride, i, bias_mode) \ + INSTANTIATION(dt_int8, stride, i, bias_mode, \ + TypeCvtOp) \ + INSTANTIATION(dt_int32, stride, i, bias_mode, \ + NoneOp) \ + INSTANTIATION(dt_int8, stride, i, bias_mode, \ + ReluOp) \ + INSTANTIATION(dt_int8, stride, i, bias_mode, \ + HSwishOp) + +#define FOR_BIAS(stride, i) \ + FOR_OP(stride, i, BiasMode::NO_BIAS) \ + FOR_OP(stride, i, BiasMode::BROADCAST_CHANNEL_BIAS) + +#define FOR_FILTER(stride) \ + FOR_BIAS(stride, 2) \ + FOR_BIAS(stride, 3) \ + FOR_BIAS(stride, 5) \ + FOR_BIAS(stride, 7) + +FOR_FILTER(1) + +#undef FOR_STRIDE +#undef FOR_FILTER +#undef FOR_IC +#undef FOR_BIAS +#undef FOR_NONLINEAR +#undef FOR_REMAIN +#undef INSTANTIATION + +} // namespace direct_dotprod_nchw44 +} // namespace arm_common +} // namespace megdnn +#endif +// vim: syntax=cpp.doxygen \ No newline at end of file diff --git a/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s2.cpp b/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s2.cpp new file mode 100644 index 00000000..4f5b8ebf --- /dev/null +++ b/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s2.cpp @@ -0,0 +1,322 @@ +/** + * \file + * dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s2.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#if __ARM_FEATURE_DOTPROD + +#include "src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_common.h" +namespace megdnn { +namespace arm_common { +namespace direct_dotprod_nchw44 { +template +struct KernNeonSdotNCHW44 { + static void impl(dst_type* dst, const int dst_step, const int8_t* src, + const int ih, const int iw, const int8_t* filter, + const int32_t* bias, const int ic, const Op& op) { + constexpr int FH = filter_size; + constexpr int FW = filter_size; + constexpr int filter_next_row = + FW * OC_PACK_SIZE * + IC_PACK_SIZE; //! [OC/4, IC/4, FH, FW, 4OC, 4IC] + + const int filter_next_4oc = + FH * FW * ic * OC_PACK_SIZE; //! [OC/4, IC/4, FH, FW, 4OC, 4IC] + const int src_next_ic = ih * iw; + const int src_next_row = iw * IC_PACK_SIZE; + + constexpr int NSRC = (ow_interval * 2 + filter_size - 3) / 8 + 1; + constexpr int LOOP = oc_interval / 4; + + int32x4_t res[3][ow_interval]; + init_ocx_ow8(res, bias, OC_PACK_SIZE); + + for (int ic_idx = 0; ic_idx < ic; ic_idx += IC_PACK_SIZE) { + const int8_t* i_src = src + ic_idx * src_next_ic; + const int8_t* i_filter = filter + ic_idx * FH * FW * OC_PACK_SIZE; + for (int fh_idx = 0; fh_idx < FH; ++fh_idx) { + int8x16_t src[2][3]; + int8x16_t weight[3]; + const int offset = megdnn::div_ceil(iw, 2) * IC_PACK_SIZE; + + load_helper(src, i_src, offset); + +//! do not use switch order 3,2,1 because it will slow the speed. +#define CALC_PART(step) \ + switch (LOOP) { \ + case 1: \ + weight[0] = vld1q_s8(i_filter + filter_next_4oc * 0 + \ + filter_next_col * step); \ + cal_helper<0, step % 2, step / 2, 0>(res, src, weight); \ + break; \ + case 2: \ + weight[0] = vld1q_s8(i_filter + filter_next_4oc * 0 + \ + filter_next_col * step); \ + cal_helper<0, step % 2, step / 2, 0>(res, src, weight); \ + weight[1] = vld1q_s8(i_filter + filter_next_4oc * 1 + \ + filter_next_col * step); \ + cal_helper<1, step % 2, step / 2, 1>(res, src, weight); \ + break; \ + case 3: \ + weight[0] = vld1q_s8(i_filter + filter_next_4oc * 0 + \ + filter_next_col * step); \ + cal_helper<0, step % 2, step / 2, 0>(res, src, weight); \ + weight[1] = vld1q_s8(i_filter + filter_next_4oc * 1 + \ + filter_next_col * step); \ + cal_helper<1, step % 2, step / 2, 1>(res, src, weight); \ + weight[2] = vld1q_s8(i_filter + filter_next_4oc * 2 + \ + filter_next_col * step); \ + cal_helper<2, step % 2, step / 2, 2>(res, src, weight); \ + break; \ + default: \ + break; \ + } + + switch (filter_size) { + case 2: + UNROLL_CALL_RAW(2, CALC_PART); + break; + case 3: + UNROLL_CALL_RAW(3, CALC_PART); + break; + case 5: + UNROLL_CALL_RAW(5, CALC_PART); + break; + case 7: + UNROLL_CALL_RAW(7, CALC_PART); + break; + default: + break; + } +#undef CALC_PART + + i_filter += filter_next_row; + i_src += src_next_row; + } + } + store_ocx_owx_remain_static(res, op, dst, + dst_step); + } +}; + +template +void conv_direct_sdot_int8_nchw44(dst_type* dst, const int oh, const int ow, + const int8_t* src, const int ih, const int iw, + const int8_t* filter, const int32_t* bias, + const int oh_size, const int oc, const int ic, + const Op& op) { + constexpr int FH = filter_size; + constexpr int FW = filter_size; + constexpr int IC_PACK_SIZE = 4; + constexpr int OC_PACK_SIZE = 4; + +#if MEGDNN_AARCH64 + constexpr int OC_BIG_INTERVAL = 12; + constexpr int OC_MID_INTERVAL = 8; + constexpr int OC_SMA_INTERVAL = 4; +#else + constexpr int OC_BIG_INTERVAL = 4; + constexpr int OC_MID_INTERVAL = 4; + constexpr int OC_SMA_INTERVAL = 4; +#endif + + constexpr int OW_INTERVAL = 8; + constexpr int SH = stride; + + const int dst_numbers_per_channel = oh * ow; + const int ow_remain = ow % OW_INTERVAL; + const int ow_end_idx = ow - ow_remain; + const int oc_remain = + oc % OC_BIG_INTERVAL; //! NCHW44 means oc_remain = 4 or 8 + const int oc_end_idx = oc - oc_remain; + const int dst_numbers_4channel_packed = + dst_numbers_per_channel * OC_PACK_SIZE; + + using remain_fun = std::function; + + remain_fun kern_big_oc_remain = nullptr; + remain_fun kern_mid_oc_remain = nullptr; + remain_fun kern_sma_oc_remain = nullptr; + + switch (ow_remain) { +#define cb(step) \ + case step: \ + kern_big_oc_remain = \ + KernNeonSdotNCHW44::impl; \ + kern_mid_oc_remain = \ + KernNeonSdotNCHW44::impl; \ + kern_sma_oc_remain = \ + KernNeonSdotNCHW44::impl; \ + break; + UNROLL_CALL_RAW(8, cb); +#undef cb + default: + megdnn_assert(0, "no remain %d for kern", ow_remain); + } + + //! filter layout is [OC/4, IC/4, FH, FW, 4OC, 4IC] + //! cut [oc, oh, ow] into [oc/OC_INTERVAL, 1, ow/OW_INTERVAL, OW_INTERVAL, + //! oh, OC_INTERVAL] to calculate KernNeonSdotNCHW44 calculates + //! [OW_INTERVAL, 1, OC_INTERVAL] each time + for (int oc_idx = 0; oc_idx < oc_end_idx; oc_idx += OC_BIG_INTERVAL) { + const int filter_offset_in_element = oc_idx * ic * FH * FW; + for (int oh_idx = 0; oh_idx < oh_size; ++oh_idx) { + for (int ow_idx = 0; ow_idx < ow_end_idx; ow_idx += OW_INTERVAL) { + const int src_offset_in_element = + (oh_idx * SH * iw + ow_idx) * IC_PACK_SIZE; + const int dst_offset_in_element = + oc_idx * dst_numbers_per_channel + + (oh_idx * ow + ow_idx) * OC_PACK_SIZE; + const int bias_offset_in_element = oc_idx; + KernNeonSdotNCHW44:: + impl(dst + dst_offset_in_element, + dst_numbers_4channel_packed, + src + src_offset_in_element, ih, iw, + filter + filter_offset_in_element, + bias + bias_offset_in_element, ic, op); + } + if (ow_remain) { + const int src_offset_in_element = + (oh_idx * SH * iw + ow_end_idx) * IC_PACK_SIZE; + const int dst_offset_in_element = + oc_idx * dst_numbers_per_channel + + (oh_idx * ow + ow_end_idx) * OC_PACK_SIZE; + const int bias_offset_in_element = oc_idx; + kern_big_oc_remain(dst + dst_offset_in_element, + dst_numbers_4channel_packed, + src + src_offset_in_element, ih, iw, + filter + filter_offset_in_element, + bias + bias_offset_in_element, ic, op); + } + } + } + +#ifdef MEGDNN_AARCH64 + //! oc_remain must be 4 or 8 on aarch64 and must be 0 on aarch32 + if (oc_remain) { + int oc_idx = oc_end_idx; + const int filter_offset_in_element = oc_idx * ic * FH * FW; + for (int oh_idx = 0; oh_idx < oh_size; ++oh_idx) { + for (int ow_idx = 0; ow_idx < ow_end_idx; ow_idx += OW_INTERVAL) { + const int src_offset_in_element = + (oh_idx * SH * iw + ow_idx) * IC_PACK_SIZE; + const int dst_offset_in_element = + oc_idx * dst_numbers_per_channel + + (oh_idx * ow + ow_idx) * OC_PACK_SIZE; + const int bias_offset_in_element = oc_idx; + if (oc_remain == 8) { + KernNeonSdotNCHW44< + dst_type, stride, bias_mode, Op, OW_INTERVAL, + filter_size, OC_MID_INTERVAL, + OW_INTERVAL>::impl(dst + dst_offset_in_element, + dst_numbers_4channel_packed, + src + src_offset_in_element, ih, + iw, + filter + + filter_offset_in_element, + bias + bias_offset_in_element, + ic, op); + } else { + KernNeonSdotNCHW44< + dst_type, stride, bias_mode, Op, OW_INTERVAL, + filter_size, OC_SMA_INTERVAL, + OW_INTERVAL>::impl(dst + dst_offset_in_element, + dst_numbers_4channel_packed, + src + src_offset_in_element, ih, + iw, + filter + + filter_offset_in_element, + bias + bias_offset_in_element, + ic, op); + } + } + if (ow_remain) { + const int src_offset_in_element = + (oh_idx * SH * iw + ow_end_idx) * IC_PACK_SIZE; + const int dst_offset_in_element = + oc_idx * dst_numbers_per_channel + + (oh_idx * ow + ow_end_idx) * OC_PACK_SIZE; + const int bias_offset_in_element = oc_idx; + if (oc_remain == 8) { + kern_mid_oc_remain(dst + dst_offset_in_element, + dst_numbers_4channel_packed, + src + src_offset_in_element, ih, iw, + filter + filter_offset_in_element, + bias + bias_offset_in_element, ic, op); + } else { + kern_sma_oc_remain(dst + dst_offset_in_element, + dst_numbers_4channel_packed, + src + src_offset_in_element, ih, iw, + filter + filter_offset_in_element, + bias + bias_offset_in_element, ic, op); + } + } + } + } +#endif +} + +#define INSTANTIATION(dst_type, stride, filter_size, bias_mode, Op) \ + template void conv_direct_sdot_int8_nchw44( \ + dst_type * dst, const int oh, const int ow, const int8_t* src, \ + const int ih, const int iw, const int8_t* weight, \ + const int32_t* bias, const int oh_size, const int oc, \ + const int ic, const Op& op); + +#define FOR_OP(stride, i, bias_mode) \ + INSTANTIATION(dt_int8, stride, i, bias_mode, \ + TypeCvtOp) \ + INSTANTIATION(dt_int32, stride, i, bias_mode, \ + NoneOp) \ + INSTANTIATION(dt_int8, stride, i, bias_mode, \ + ReluOp) \ + INSTANTIATION(dt_int8, stride, i, bias_mode, \ + HSwishOp) + +#define FOR_BIAS(stride, i) \ + FOR_OP(stride, i, BiasMode::NO_BIAS) \ + FOR_OP(stride, i, BiasMode::BROADCAST_CHANNEL_BIAS) + +#define FOR_FILTER(stride) \ + FOR_BIAS(stride, 2) \ + FOR_BIAS(stride, 3) \ + FOR_BIAS(stride, 5) \ + FOR_BIAS(stride, 7) + +FOR_FILTER(2) + +#undef FOR_STRIDE +#undef FOR_FILTER +#undef FOR_IC +#undef FOR_BIAS +#undef FOR_NONLINEAR +#undef FOR_REMAIN +#undef INSTANTIATION + +} // namespace direct_dotprod_nchw44 +} // namespace arm_common +} // namespace megdnn + +#endif +// vim: syntax=cpp.doxygen \ No newline at end of file diff --git a/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_nchw44_s1.cpp b/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_nchw44_s1.cpp new file mode 100644 index 00000000..1c879d91 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_nchw44_s1.cpp @@ -0,0 +1,448 @@ +/** + * \file + * dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_nchw44_s1.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#if __ARM_FEATURE_DOTPROD +#include "src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_kern.h" +namespace megdnn { +namespace arm_common { +namespace dot_direct_nchw_nchw44 { + +template +struct ShiftCalHelper { + static void impl(T& c, T2& src, T3& weight) { +#define cb(step) \ + c[0][step] = Func::template impl<(src_idx + step) % 4>( \ + c[0][step], weight[0][weight_idx], src[(src_idx + step) / 4]); \ + c[1][step] = Func::template impl<(src_idx + step) % 4>( \ + c[1][step], weight[1][weight_idx], src[(src_idx + step) / 4]); + + UNROLL_CALL_RAW(8, cb); +#undef cb + } +}; + +template +struct ShiftCalHelper { + static void impl(T& c, T2& src, T3& weight) { +#define cb(step) \ + c[0][step] = Func::template impl<(src_idx + step) % 4>( \ + c[0][step], weight[0][weight_idx], src[(src_idx + step) / 4]); + + UNROLL_CALL_RAW(8, cb); +#undef cb + } +}; +////////////////////stride 1/////////////////// +template +struct KerNeonDotXXs2Nchw44Int8 { + static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, + const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, + int iw, int ld_dst_oc, const Op& op) { + constexpr int stride = 1; + constexpr int filter_hight = 2; + constexpr int filter_width = 4; + constexpr int weight_reg = 2; + constexpr int src_reg = 2; + + constexpr int oc_step = 4; + constexpr int ic_step = 1; + constexpr int pack_iw_len = 4; + constexpr int simd_len = 16; + + const int ld_bias = oc_step; + const int ic_stride = ih * iw * pack_iw_len; + const int ld_weight_oc = oc_step * filter_hight * filter_width * ic; + constexpr int c_dim = OCHelper::val; + + int32x4_t c[c_dim][8]; + init_ocx_ow8(c, bias_ptr, ld_bias); + for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { + int8x16_t src[src_reg]; + int8x16_t weight[c_dim][weight_reg]; + // row 0 + load_helper( + src, src_ptr + 0 * iw * pack_iw_len, 0); + load_helper( + weight, weight_ptr, ld_weight_oc); + cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, + weight); + // row 1 + load_helper( + src, src_ptr + 1 * iw * pack_iw_len, 0); + cal_helper<0, 1, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, + weight); + + src_ptr += ic_stride; + weight_ptr += filter_hight * filter_width * oc_step; + } + store_ocx_ow8_remain_static_dt( + c, op, dst_ptr, ld_dst_oc); + } +}; +template +struct KerNeonDotXXs2Nchw44Int8 { + static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, + const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, + int iw, int ld_dst_oc, const Op& op) { + constexpr int stride = 1; + constexpr int filter_hight = 3; + constexpr int filter_width = 4; + constexpr int weight_reg = 3; + constexpr int src_reg = 2; + + constexpr int oc_step = 4; + constexpr int ic_step = 1; + constexpr int pack_iw_len = 4; + constexpr int simd_len = 16; + + const int ld_bias = oc_step; + const int ic_stride = ih * iw * pack_iw_len; + const int ld_weight_oc = oc_step * filter_hight * filter_width * ic; + constexpr int c_dim = OCHelper::val; + + int32x4_t c[c_dim][8]; + init_ocx_ow8(c, bias_ptr, ld_bias); + for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { + int8x16_t src[src_reg]; + int8x16_t weight[c_dim][weight_reg]; + // row 0 + load_helper( + src, src_ptr + 0 * iw * pack_iw_len, 0); + load_helper( + weight, weight_ptr, ld_weight_oc); + cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, + weight); + // row 1 + load_helper( + src, src_ptr + 1 * iw * pack_iw_len, 0); + cal_helper<0, 1, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, + weight); + // row 2 + load_helper( + src, src_ptr + 2 * iw * pack_iw_len, 0); + cal_helper<0, 2, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, + weight); + + src_ptr += ic_stride; + weight_ptr += filter_hight * filter_width * oc_step; + } + store_ocx_ow8_remain_static_dt( + c, op, dst_ptr, ld_dst_oc); + } +}; + +template +struct KerNeonDotXXs2Nchw44Int8 { + static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, + const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, + int iw, int ld_dst_oc, const Op& op) { + constexpr int stride = 1; + constexpr int filter_hight = 5; + constexpr int filter_width = 8; + constexpr int src_reg = 3; + constexpr int weight_reg = 2; + + constexpr int oc_step = 4; + constexpr int ic_step = 1; + constexpr int pack_iw_len = 4; + constexpr int simd_len = 16; + + const int ld_bias = oc_step; + const int ic_stride = ih * iw * pack_iw_len; + const int ld_weight_oc = oc_step * filter_hight * filter_width * ic; + constexpr int c_dim = OCHelper::val; + + int32x4_t c[c_dim][8]; + init_ocx_ow8(c, bias_ptr, ld_bias); + + for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { + int8x16_t src[src_reg]; + int8x16_t weight[c_dim][weight_reg]; + +#define cb(step) \ + load_helper( \ + src, src_ptr + step * iw * pack_iw_len, 0); \ + load_helper( \ + weight, weight_ptr + step * 2 * simd_len, ld_weight_oc); \ + cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, \ + weight); \ + cal_helper<4, 1, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, weight); + + UNROLL_CALL_RAW(5, cb); +#undef cb + src_ptr += ic_stride; + weight_ptr += filter_hight * filter_width * oc_step; + } + store_ocx_ow8_remain_static_dt( + c, op, dst_ptr, ld_dst_oc); + } +}; + +template +struct KerNeonDotXXs2Nchw44Int8 { + static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, + const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, + int iw, int ld_dst_oc, const Op& op) { + constexpr int stride = 1; + constexpr int filter_hight = 7; + constexpr int filter_width = 8; + constexpr int src_reg = 3; + constexpr int weight_reg = 2; + + constexpr int oc_step = 4; + constexpr int ic_step = 1; + constexpr int pack_iw_len = 4; + constexpr int simd_len = 16; + + const int ld_bias = oc_step; + const int ic_stride = ih * iw * pack_iw_len; + const int ld_weight_oc = oc_step * filter_hight * filter_width * ic; + constexpr int c_dim = OCHelper::val; + + int32x4_t c[c_dim][8]; + init_ocx_ow8(c, bias_ptr, ld_bias); + + for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { + int8x16_t src[src_reg]; + int8x16_t weight[c_dim][weight_reg]; +#define cb(step) \ + load_helper( \ + src, src_ptr + step * iw * pack_iw_len, 0); \ + load_helper( \ + weight, weight_ptr + step * 2 * simd_len, ld_weight_oc); \ + cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, \ + weight); \ + cal_helper<4, 1, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, weight); + + UNROLL_CALL_RAW(7, cb); +#undef cb + src_ptr += ic_stride; + weight_ptr += filter_hight * filter_width * oc_step; + } + store_ocx_ow8_remain_static_dt( + c, op, dst_ptr, ld_dst_oc); + } +}; +template <> +void pack_src_int8_nchw_nchw44_dot<1>(int8_t* sptr_base, + const int8_t* sptr_origin, const int, + const int pw, const int, const int ih, + const int iw, const int iw2, + const int pad_top, const int pad_bottom, + const int ic, const int ic_stride, + int8_t* temp_ptr) { + static uint8_t reorder_idx[16] = {0, 1, 2, 3, 1, 2, 3, 4, + 2, 3, 4, 5, 3, 4, 5, 6}; + uint8x16_t tbl_idx = vld1q_u8(&reorder_idx[0]); + + constexpr int iw_step = 16; + constexpr int pack_iw_len = 4; + const int iw_with_pad = iw + 2 * pw; + const int iw_with_pad_end = iw_with_pad / iw_step * iw_step; + rep(ic_idx, ic) { + const int8_t* sptr = sptr_origin + ic_idx * ic_stride; + memset(sptr_base, 0, + sizeof(int8_t) * iw2 * (ih + pad_top + pad_bottom) * + pack_iw_len); + sptr_base += iw2 * pad_top * pack_iw_len; + rep(ih_idx, ih) { + memset(temp_ptr, 0, iw_with_pad * sizeof(int8_t)); + memcpy(temp_ptr + pw, sptr, sizeof(int8_t) * iw); + for (int iw_idx = 0; iw_idx < iw_with_pad_end; iw_idx += iw_step) { + int8x16_t src[4]; + int8x16_t dst[4]; + src[0] = vld1q_s8(temp_ptr + iw_idx); + src[1] = vld1q_s8(temp_ptr + iw_idx + 4); + src[2] = vld1q_s8(temp_ptr + iw_idx + 8); + src[3] = vld1q_s8(temp_ptr + iw_idx + 12); + dst[0] = vqtbl1q_s8(src[0], tbl_idx); + dst[1] = vqtbl1q_s8(src[1], tbl_idx); + dst[2] = vqtbl1q_s8(src[2], tbl_idx); + dst[3] = vqtbl1q_s8(src[3], tbl_idx); + vst1q_s8(sptr_base + iw_idx * pack_iw_len + 0, dst[0]); + vst1q_s8(sptr_base + iw_idx * pack_iw_len + 16, dst[1]); + vst1q_s8(sptr_base + iw_idx * pack_iw_len + 32, dst[2]); + vst1q_s8(sptr_base + iw_idx * pack_iw_len + 48, dst[3]); + } + for (int iw_idx = iw_with_pad_end; iw_idx < iw_with_pad; ++iw_idx) { + *(sptr_base + iw_idx * pack_iw_len + 0) = + *(temp_ptr + iw_idx + 0); + *(sptr_base + iw_idx * pack_iw_len + 1) = + *(temp_ptr + iw_idx + 1); + *(sptr_base + iw_idx * pack_iw_len + 2) = + *(temp_ptr + iw_idx + 2); + *(sptr_base + iw_idx * pack_iw_len + 3) = + *(temp_ptr + iw_idx + 3); + } + sptr_base += iw2 * pack_iw_len; + sptr += iw; + } + sptr_base += iw2 * pad_bottom * pack_iw_len; + } +} + +template +void conv_direct_int8_nchw_nchw44_dot(const int8_t* src, const int8_t* filter, + const int32_t* bias, int32_t* temp, + int8_t* dst, const int oc, const int ic, + const int ih, const int iw, const int oh, + const int oh_block, const int ow, + const Op& op) { + MEGDNN_MARK_USED_VAR(temp); + constexpr int fh = filter_size; + constexpr int fw = (filter_size + 3) / 4 * 4; +#if MEGDNN_AARCH64 + constexpr int big_oc_step = 8; +#else + constexpr int big_oc_step = 4; +#endif + constexpr int oc_step = 4; + constexpr int ih_step = 1; + constexpr int oh_step = 1; + constexpr int ow_step = 8; + constexpr int stride_h = stride; + constexpr int stride_w = stride; + constexpr int pack_iw_len = stride == 2 ? 1 : 4; + + const int img_stride = oh * ow; + const int ow_end = ow / ow_step * ow_step; + const int ow_remain = ow - ow_end; + const int oc_end = oc / big_oc_step * big_oc_step; + const int oc_remain = oc - oc_end; + const int ld_dst_oc = oc_step * img_stride; + + using remain_fun = + std::function; + remain_fun kern_big_oc_remain = nullptr; + remain_fun kern_small_oc_remain = nullptr; + switch (ow_remain) { +#define cb(step) \ + case step: \ + kern_big_oc_remain = \ + KerNeonDotXXs2Nchw44Int8::impl; \ + kern_small_oc_remain = \ + KerNeonDotXXs2Nchw44Int8::impl; \ + break; + + UNROLL_CALL_RAW(8, cb); + default: + megdnn_assert(0, "no remain %d for kern", ow_remain); + } + + for (int oc_idx = 0; oc_idx < oc_end; oc_idx += big_oc_step) { + const int weight_offset = oc_idx * ic * fh * fw; + for (int oh_idx = 0; oh_idx < oh_block; oh_idx += oh_step) { + for (int ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { + const int src_offset = + (oh_idx * stride_h * iw + ow_idx * stride_w * ih_step) * + pack_iw_len; + const int dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; + KerNeonDotXXs2Nchw44Int8::impl(src + src_offset, + filter + weight_offset, + bias + oc_idx, + dst + dst_offset, ic, ih, + iw, ld_dst_oc, op); + } + if (ow_remain > 0) { + const int src_offset = + (oh_idx * stride_h * iw + ow_end * stride_w * ih_step) * + pack_iw_len; + const int dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; + kern_big_oc_remain(src + src_offset, filter + weight_offset, + bias + oc_idx, dst + dst_offset, ic, ih, iw, + ld_dst_oc, op); + } + } + } + if (oc_remain > 0) { + int oc_idx = oc_end; + const int weight_offset = oc_idx * ic * fh * fw; + for (int oh_idx = 0; oh_idx < oh_block; oh_idx += oh_step) { + for (int ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { + const int src_offset = + (oh_idx * stride_h * iw + ow_idx * stride_w * ih_step) * + pack_iw_len; + const int dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; + KerNeonDotXXs2Nchw44Int8::impl(src + src_offset, + filter + weight_offset, + bias + oc_idx, + dst + dst_offset, ic, ih, + iw, ld_dst_oc, op); + } + if (ow_remain > 0) { + const int src_offset = + (oh_idx * stride_h * iw + ow_end * stride_w * ih_step) * + pack_iw_len; + const int dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; + kern_small_oc_remain(src + src_offset, filter + weight_offset, + bias + oc_idx, dst + dst_offset, ic, ih, + iw, ld_dst_oc, op); + } + } + } +} +#define DO_CONV_KERN_FUN(stride, filter_size, bias_mode, Op) \ + template void \ + conv_direct_int8_nchw_nchw44_dot( \ + const int8_t* src, const int8_t* filter, const int32_t* bias, \ + int32_t* temp, int8_t* dst, const int oc, const int ic, \ + const int ih, const int iw, const int oh, const int oh_block, \ + const int ow, const Op& op); + +#define GET_OP_PARAM(stride, filter, bias_mode) \ + DO_CONV_KERN_FUN(stride, filter, bias_mode, \ + TypeCvtOp) \ + DO_CONV_KERN_FUN(stride, filter, bias_mode, \ + ReluOp) \ + DO_CONV_KERN_FUN(stride, filter, bias_mode, \ + HSwishOp) + +#define GET_BIAS_MODE_PARAM(stride, filter) \ + GET_OP_PARAM(stride, filter, BiasMode::NO_BIAS) \ + GET_OP_PARAM(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS) + +#define DISPATCH_CONV_KERN(stride) \ + GET_BIAS_MODE_PARAM(stride, 2) \ + GET_BIAS_MODE_PARAM(stride, 3) \ + GET_BIAS_MODE_PARAM(stride, 5) \ + GET_BIAS_MODE_PARAM(stride, 7) + +DISPATCH_CONV_KERN(1); + +} // namespace dot_direct_nchw_nchw44 +} // namespace arm_common +} // namespace megdnn +#endif +// vim: syntax=cpp.doxygen \ No newline at end of file diff --git a/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_nchw44_s2.cpp b/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_nchw44_s2.cpp new file mode 100644 index 00000000..46e0177f --- /dev/null +++ b/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_nchw44_s2.cpp @@ -0,0 +1,437 @@ +/** + * \file + * dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_nchw44_s2.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#if __ARM_FEATURE_DOTPROD +#include "src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_kern.h" +namespace megdnn { +namespace arm_common { +namespace dot_direct_nchw_nchw44 { + +template +struct ShiftCalHelper { + static void impl(T& c, T2& src, T3& weight) { +#define cb(step) \ + c[0][step * 2] = Func::template impl<(src_idx + step) % 4>( \ + c[0][step * 2], weight[0][weight_idx], \ + src[0][(src_idx + step) / 4]); \ + c[1][step * 2] = Func::template impl<(src_idx + step) % 4>( \ + c[1][step * 2], weight[1][weight_idx], \ + src[0][(src_idx + step) / 4]); \ + c[0][step * 2 + 1] = Func::template impl<(src_idx + step) % 4>( \ + c[0][step * 2 + 1], weight[0][weight_idx], \ + src[1][(src_idx + step) / 4]); \ + c[1][step * 2 + 1] = Func::template impl<(src_idx + step) % 4>( \ + c[1][step * 2 + 1], weight[1][weight_idx], \ + src[1][(src_idx + step) / 4]); + + UNROLL_CALL_RAW(4, cb); +#undef cb + } +}; + +template +struct ShiftCalHelper { + static void impl(T& c, T2& src, T3& weight) { +#define cb(step) \ + c[0][step * 2] = Func::template impl<(src_idx + step) % 4>( \ + c[0][step * 2], weight[0][weight_idx], \ + src[0][(src_idx + step) / 4]); \ + c[0][step * 2 + 1] = Func::template impl<(src_idx + step) % 4>( \ + c[0][step * 2 + 1], weight[0][weight_idx], \ + src[1][(src_idx + step) / 4]); + + UNROLL_CALL_RAW(4, cb); +#undef cb + } +}; + +template +struct KerNeonDotXXs2Nchw44Int8 { + static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, + const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, + int iw, int ld_dst_oc, const Op& op) { + constexpr int stride = 2; + constexpr int filter_hight = 2; + constexpr int filter_width = 4; + constexpr int weight_reg = 1; + constexpr int src_reg = 1; + + constexpr int oc_step = 4; + constexpr int ic_step = 1; + constexpr int pack_iw_len = 1; + constexpr int simd_len = 16; + + const int ld_bias = oc_step; + const int ic_stride = ih * iw * pack_iw_len; + const int ld_weight_oc = oc_step * filter_hight * filter_width * ic; + constexpr int c_dim = OCHelper::val; + + int32x4_t c[c_dim][8]; + init_ocx_ow8(c, bias_ptr, ld_bias); + for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { + int8x16_t src[2][src_reg]; + int8x16_t weight[c_dim][weight_reg]; + // row 0 + load_helper( + src, src_ptr + 0 * iw, stride); + load_helper( + weight, weight_ptr, ld_weight_oc); + cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, + weight); + // row 1 + load_helper( + src, src_ptr + 1 * iw, stride); + load_helper( + weight, weight_ptr, ld_weight_oc); + cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, + weight); + + src_ptr += ic_stride; + weight_ptr += filter_hight * filter_width * oc_step; + } + store_ocx_ow8_remain_static_dt( + c, op, dst_ptr, ld_dst_oc); + } +}; + +template +struct KerNeonDotXXs2Nchw44Int8 { + static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, + const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, + int iw, int ld_dst_oc, const Op& op) { + constexpr int stride = 2; + constexpr int filter_hight = 3; + constexpr int filter_width = 4; + constexpr int weight_reg = 1; + constexpr int src_reg = 1; + + constexpr int oc_step = 4; + constexpr int ic_step = 1; + constexpr int pack_iw_len = 1; + constexpr int simd_len = 16; + + const int ld_bias = oc_step; + const int ic_stride = ih * iw * pack_iw_len; + const int ld_weight_oc = oc_step * filter_hight * filter_width * ic; + constexpr int c_dim = OCHelper::val; + + int32x4_t c[c_dim][8]; + init_ocx_ow8(c, bias_ptr, ld_bias); + for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { + int8x16_t src[2][src_reg]; + int8x16_t weight[c_dim][weight_reg]; + // row 0 + load_helper( + src, src_ptr + 0 * iw, stride); + load_helper( + weight, weight_ptr, ld_weight_oc); + cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, + weight); + // row 1 + load_helper( + src, src_ptr + 1 * iw, stride); + load_helper( + weight, weight_ptr, ld_weight_oc); + cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, + weight); + // row 2 + load_helper( + src, src_ptr + 2 * iw, stride); + load_helper( + weight, weight_ptr, ld_weight_oc); + cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, + weight); + + src_ptr += ic_stride; + weight_ptr += filter_hight * filter_width * oc_step; + } + store_ocx_ow8_remain_static_dt( + c, op, dst_ptr, ld_dst_oc); + } +}; + +template +struct KerNeonDotXXs2Nchw44Int8 { + static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, + const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, + int iw, int ld_dst_oc, const Op& op) { + constexpr int stride = 2; + constexpr int filter_hight = 5; + constexpr int filter_width = 8; + constexpr int src_reg = 2; + constexpr int weight_reg = 2; + + constexpr int oc_step = 4; + constexpr int ic_step = 1; + constexpr int pack_iw_len = 1; + constexpr int simd_len = 16; + + const int ld_bias = oc_step; + const int ic_stride = ih * iw * pack_iw_len; + const int ld_weight_oc = oc_step * filter_hight * filter_width * ic; + constexpr int c_dim = OCHelper::val; + + int32x4_t c[c_dim][8]; + init_ocx_ow8(c, bias_ptr, ld_bias); + + for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { + int8x16_t src[2][src_reg]; + int8x16_t weight[c_dim][weight_reg]; +#define cb(step) \ + load_helper(src, src_ptr + step * iw, \ + stride); \ + load_helper( \ + weight, weight_ptr + step * 2 * simd_len, ld_weight_oc); \ + cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, \ + weight); \ + cal_helper<1, 1, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, weight); + UNROLL_CALL_RAW(5, cb); +#undef cb + src_ptr += ic_stride; + weight_ptr += 5 * 32; + } + store_ocx_ow8_remain_static_dt( + c, op, dst_ptr, ld_dst_oc); + } +}; + +/** + * oc = 8, ow = 8 + * dot 4 element, pad last filter and do twice dot every row filter, filter like + * below + * -------------------------- + * |x, x, x, x,| x, x, x, 0 | + * -------------------------- + **/ +template +struct KerNeonDotXXs2Nchw44Int8 { + static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, + const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, + int iw, int ld_dst_oc, const Op& op) { + constexpr int stride = 2; + constexpr int filter_hight = 7; + constexpr int filter_width = 8; + constexpr int src_reg = 2; + constexpr int weight_reg = 2; + + constexpr int oc_step = 4; + constexpr int ic_step = 1; + constexpr int pack_iw_len = 1; + constexpr int simd_len = 16; + + const int ld_bias = oc_step; + const int ic_stride = ih * iw * pack_iw_len; + const int ld_weight_oc = oc_step * filter_hight * filter_width * ic; + constexpr int c_dim = OCHelper::val; + + int32x4_t c[c_dim][8]; + init_ocx_ow8(c, bias_ptr, ld_bias); + + for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { + int8x16_t src[2][src_reg]; + int8x16_t weight[c_dim][weight_reg]; +#define cb(step) \ + load_helper(src, src_ptr + step * iw, \ + stride); \ + load_helper( \ + weight, weight_ptr + step * 2 * simd_len, ld_weight_oc); \ + cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, \ + weight); \ + cal_helper<1, 1, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, weight); + UNROLL_CALL_RAW(7, cb); +#undef cb + src_ptr += ic_stride; + weight_ptr += 7 * 32; + } + store_ocx_ow8_remain_static_dt( + c, op, dst_ptr, ld_dst_oc); + } +}; + +template <> +void pack_src_int8_nchw_nchw44_dot<2>( + int8_t* sptr_base, const int8_t* sptr_origin, const int, const int pw, + const int, const int ih, const int iw, const int iw2, const int pad_top, + const int pad_bottom, const int ic, const int ic_stride, int8_t*) { + constexpr int ic_step = 1; + rep_step(ic_idx, ic, ic_step) { + const int8_t* sptr = sptr_origin + ic_idx * ic_stride; + memset(sptr_base, 0, + sizeof(int8_t) * ic_step * iw2 * (ih + pad_top + pad_bottom)); + sptr_base += iw2 * pad_top * ic_step; + rep(ih_idx, ih) { + memcpy(sptr_base + pw * ic_step, sptr, + sizeof(int8_t) * iw * ic_step); + sptr_base += iw2 * ic_step; + sptr += iw * ic_step; + } + sptr_base += iw2 * pad_bottom * ic_step; + } +} + +template +void conv_direct_int8_nchw_nchw44_dot(const int8_t* src, const int8_t* filter, + const int32_t* bias, int32_t* temp, + int8_t* dst, const int oc, const int ic, + const int ih, const int iw, const int oh, + const int oh_block, const int ow, + const Op& op) { + MEGDNN_MARK_USED_VAR(temp); + constexpr int fh = filter_size; + constexpr int fw = (filter_size + 3) / 4 * 4; +#if MEGDNN_AARCH64 + constexpr int big_oc_step = 8; +#else + constexpr int big_oc_step = 4; +#endif + constexpr int oc_step = 4; + constexpr int ih_step = 1; + constexpr int oh_step = 1; + constexpr int ow_step = 8; + constexpr int stride_h = stride; + constexpr int stride_w = stride; + constexpr int pack_iw_len = stride == 2 ? 1 : 4; + + const int img_stride = oh * ow; + const int ow_end = ow / ow_step * ow_step; + const int ow_remain = ow - ow_end; + const int oc_end = oc / big_oc_step * big_oc_step; + const int oc_remain = oc - oc_end; + const int ld_dst_oc = oc_step * img_stride; + + using remain_fun = + std::function; + remain_fun kern_big_oc_remain = nullptr; + remain_fun kern_small_oc_remain = nullptr; + switch (ow_remain) { +#define cb(step) \ + case step: \ + kern_big_oc_remain = \ + KerNeonDotXXs2Nchw44Int8::impl; \ + kern_small_oc_remain = \ + KerNeonDotXXs2Nchw44Int8::impl; \ + break; + + UNROLL_CALL_RAW(8, cb); + default: + megdnn_assert(0, "no remain %d for kern", ow_remain); + } + + for (int oc_idx = 0; oc_idx < oc_end; oc_idx += big_oc_step) { + const int weight_offset = oc_idx * ic * fh * fw; + for (int oh_idx = 0; oh_idx < oh_block; oh_idx += oh_step) { + for (int ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { + const int src_offset = + (oh_idx * stride_h * iw + ow_idx * stride_w * ih_step) * + pack_iw_len; + const int dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; + KerNeonDotXXs2Nchw44Int8::impl(src + src_offset, + filter + weight_offset, + bias + oc_idx, + dst + dst_offset, ic, ih, + iw, ld_dst_oc, op); + } + if (ow_remain > 0) { + const int src_offset = + (oh_idx * stride_h * iw + ow_end * stride_w * ih_step) * + pack_iw_len; + const int dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; + kern_big_oc_remain(src + src_offset, filter + weight_offset, + bias + oc_idx, dst + dst_offset, ic, ih, iw, + ld_dst_oc, op); + } + } + } + if (oc_remain > 0) { + int oc_idx = oc_end; + const int weight_offset = oc_idx * ic * fh * fw; + for (int oh_idx = 0; oh_idx < oh_block; oh_idx += oh_step) { + for (int ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { + const int src_offset = + (oh_idx * stride_h * iw + ow_idx * stride_w * ih_step) * + pack_iw_len; + const int dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; + KerNeonDotXXs2Nchw44Int8::impl(src + src_offset, + filter + weight_offset, + bias + oc_idx, + dst + dst_offset, ic, ih, + iw, ld_dst_oc, op); + } + if (ow_remain > 0) { + const int src_offset = + (oh_idx * stride_h * iw + ow_end * stride_w * ih_step) * + pack_iw_len; + const int dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; + kern_small_oc_remain(src + src_offset, filter + weight_offset, + bias + oc_idx, dst + dst_offset, ic, ih, + iw, ld_dst_oc, op); + } + } + } +} + +#define DO_CONV_KERN_FUN(stride, filter_size, bias_mode, Op) \ + template void \ + conv_direct_int8_nchw_nchw44_dot( \ + const int8_t* src, const int8_t* filter, const int32_t* bias, \ + int32_t* temp, int8_t* dst, const int oc, const int ic, \ + const int ih, const int iw, const int oh, const int oh_block, \ + const int ow, const Op& op); + +#define GET_OP_PARAM(stride, filter, bias_mode) \ + DO_CONV_KERN_FUN(stride, filter, bias_mode, \ + TypeCvtOp) \ + DO_CONV_KERN_FUN(stride, filter, bias_mode, \ + ReluOp) \ + DO_CONV_KERN_FUN(stride, filter, bias_mode, \ + HSwishOp) + +#define GET_BIAS_MODE_PARAM(stride, filter) \ + GET_OP_PARAM(stride, filter, BiasMode::NO_BIAS) \ + GET_OP_PARAM(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS) + +#define DISPATCH_CONV_KERN(stride) \ + GET_BIAS_MODE_PARAM(stride, 2) \ + GET_BIAS_MODE_PARAM(stride, 3) \ + GET_BIAS_MODE_PARAM(stride, 5) \ + GET_BIAS_MODE_PARAM(stride, 7) + +DISPATCH_CONV_KERN(2); + +} // namespace dot_direct_nchw_nchw44 +} // namespace arm_common +} // namespace megdnn + +#endif +// vim: syntax=cpp.doxygen \ No newline at end of file diff --git a/dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw44_s1.cpp b/dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw44_s1.cpp new file mode 100644 index 00000000..c7736149 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw44_s1.cpp @@ -0,0 +1,743 @@ +/** + * \file + * dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw44_s1.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "src/arm_common/conv_bias/int8/direct.h" +#include "src/arm_common/conv_bias/int8/direct_nchw44_kern.h" +#include "src/arm_common/conv_bias/intrinsic_helper.h" +#include "src/arm_common/elemwise_op.h" +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/common/utils.h" +#include "src/fallback/conv_bias/common.h" + +namespace megdnn { +namespace arm_common { +namespace { +template +static void ker_neon_dirctconv_2x2s1_oc8_ow8(const int8_t* src_ptr, + const int8_t* weight_ptr, + const int32_t* bias_ptr, + DstType* dst_ptr, int ic, int ih, + int iw, int ld_dst_oc, + const Op& op) { + constexpr int fh = filter_size; + constexpr int fw = filter_size; + constexpr int ic_step = 4; + constexpr int oc_step = 4; + constexpr int loop_ic_step = 4; + constexpr int ld_weight_ic4 = 16; + constexpr int pack_iw_len = 4; + + const int ic_stride = ih * iw * pack_iw_len; + const int ld_weight_oc4 = oc_step * fh * fw * ic; + + int32x4_t c[2][8]; + int8x16_t weight[2][2]; + int8x16_t src[8 + 1]; + int16x8_t temp_c[4]; + + init_ocx_ow8(c, bias_ptr, oc_step); + + for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { + for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { + const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + + fh_idx * iw * ic_step * pack_iw_len; + + src[0] = vld1q_s8(src_ic_0_3); + src[1] = vld1q_s8((src_ic_0_3 + 16)); + src[2] = vld1q_s8((src_ic_0_3 + 2 * 16)); + src[3] = vld1q_s8((src_ic_0_3 + 3 * 16)); + src[4] = vld1q_s8((src_ic_0_3 + 4 * 16)); + src[5] = vld1q_s8((src_ic_0_3 + 5 * 16)); + src[6] = vld1q_s8((src_ic_0_3 + 6 * 16)); + src[7] = vld1q_s8((src_ic_0_3 + 7 * 16)); + src[8] = vld1q_s8((src_ic_0_3 + 8 * 16)); + + // oc == 0 + const int8_t* read_weight_ptr = + weight_ptr + fh_idx * fw * ld_weight_ic4; + + weight[0][0] = vld1q_s8(read_weight_ptr); + weight[0][1] = vld1q_s8(read_weight_ptr + 16); + weight[1][0] = vld1q_s8(read_weight_ptr + ld_weight_oc4); + weight[1][1] = vld1q_s8(read_weight_ptr + ld_weight_oc4 + 16); + + c[0][0] = vdotq_s32_h(weight[0][0], src[0], c[0][0], temp_c[0]); + c[1][0] = vdotq_s32_h(weight[1][0], src[0], c[1][0], temp_c[1]); + c[0][1] = vdotq_s32_h(weight[0][0], src[1], c[0][1], temp_c[2]); + c[1][1] = vdotq_s32_h(weight[1][0], src[1], c[1][1], temp_c[3]); + c[0][0] = vdotq_s32_h(weight[0][1], src[1], c[0][0], temp_c[0]); + c[1][0] = vdotq_s32_h(weight[1][1], src[1], c[1][0], temp_c[1]); + c[0][1] = vdotq_s32_h(weight[0][1], src[2], c[0][1], temp_c[2]); + c[1][1] = vdotq_s32_h(weight[1][1], src[2], c[1][1], temp_c[3]); + + c[0][2] = vdotq_s32_h(weight[0][0], src[2], c[0][2], temp_c[0]); + c[1][2] = vdotq_s32_h(weight[1][0], src[2], c[1][2], temp_c[1]); + c[0][3] = vdotq_s32_h(weight[0][0], src[3], c[0][3], temp_c[2]); + c[1][3] = vdotq_s32_h(weight[1][0], src[3], c[1][3], temp_c[3]); + c[0][2] = vdotq_s32_h(weight[0][1], src[3], c[0][2], temp_c[0]); + c[1][2] = vdotq_s32_h(weight[1][1], src[3], c[1][2], temp_c[1]); + c[0][3] = vdotq_s32_h(weight[0][1], src[4], c[0][3], temp_c[2]); + c[1][3] = vdotq_s32_h(weight[1][1], src[4], c[1][3], temp_c[3]); + + c[0][4] = vdotq_s32_h(weight[0][0], src[4], c[0][4], temp_c[0]); + c[1][4] = vdotq_s32_h(weight[1][0], src[4], c[1][4], temp_c[1]); + c[0][5] = vdotq_s32_h(weight[0][0], src[5], c[0][5], temp_c[2]); + c[1][5] = vdotq_s32_h(weight[1][0], src[5], c[1][5], temp_c[3]); + c[0][4] = vdotq_s32_h(weight[0][1], src[5], c[0][4], temp_c[0]); + c[1][4] = vdotq_s32_h(weight[1][1], src[5], c[1][4], temp_c[1]); + c[0][5] = vdotq_s32_h(weight[0][1], src[6], c[0][5], temp_c[2]); + c[1][5] = vdotq_s32_h(weight[1][1], src[6], c[1][5], temp_c[3]); + + c[0][6] = vdotq_s32_h(weight[0][0], src[6], c[0][6], temp_c[0]); + c[1][6] = vdotq_s32_h(weight[1][0], src[6], c[1][6], temp_c[1]); + c[0][7] = vdotq_s32_h(weight[0][0], src[7], c[0][7], temp_c[2]); + c[1][7] = vdotq_s32_h(weight[1][0], src[7], c[1][7], temp_c[3]); + c[0][6] = vdotq_s32_h(weight[0][1], src[7], c[0][6], temp_c[0]); + c[1][6] = vdotq_s32_h(weight[1][1], src[7], c[1][6], temp_c[1]); + c[0][7] = vdotq_s32_h(weight[0][1], src[8], c[0][7], temp_c[2]); + c[1][7] = vdotq_s32_h(weight[1][1], src[8], c[1][7], temp_c[3]); + } + weight_ptr += fh * fw * ld_weight_ic4; + } + store_ocx_ow8_remain_static_dt( + c, op, dst_ptr, ld_dst_oc); +} + +template +static void ker_neon_dirctconv_2x2s1_oc4_ow8(const int8_t* src_ptr, + const int8_t* weight_ptr, + const int32_t* bias_ptr, + DstType* dst_ptr, int ic, int ih, + int iw, int ld_dst_oc, + const Op& op) { + constexpr int fh = filter_size; + constexpr int fw = filter_size; + constexpr int oc_step = 4; + constexpr int ic_step = 4; + constexpr int loop_ic_step = 4; + constexpr int ld_weight_ic4 = 16; + constexpr int pack_iw_len = 4; + + const int ic_stride = ih * iw * pack_iw_len; + + int32x4_t c[1][8]; + int8x16_t weight[1][2]; + int8x16_t src[8 + 1]; + int16x8_t temp_c[2]; + init_ocx_ow8(c, bias_ptr, oc_step); + + for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { + for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { + const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + + fh_idx * iw * ic_step * pack_iw_len; + + src[0] = vld1q_s8(src_ic_0_3); + src[1] = vld1q_s8((src_ic_0_3 + 16)); + src[2] = vld1q_s8((src_ic_0_3 + 2 * 16)); + src[3] = vld1q_s8((src_ic_0_3 + 3 * 16)); + src[4] = vld1q_s8((src_ic_0_3 + 4 * 16)); + src[5] = vld1q_s8((src_ic_0_3 + 5 * 16)); + src[6] = vld1q_s8((src_ic_0_3 + 6 * 16)); + src[7] = vld1q_s8((src_ic_0_3 + 7 * 16)); + src[8] = vld1q_s8((src_ic_0_3 + 8 * 16)); + + // oc == 0 + const int8_t* read_weight_ptr = + weight_ptr + fh_idx * fw * ld_weight_ic4; + + weight[0][0] = vld1q_s8(read_weight_ptr); + weight[0][1] = vld1q_s8(read_weight_ptr + 16); + + c[0][0] = vdotq_s32_h(weight[0][0], src[0], c[0][0], temp_c[0]); + c[0][1] = vdotq_s32_h(weight[0][0], src[1], c[0][1], temp_c[1]); + c[0][0] = vdotq_s32_h(weight[0][1], src[1], c[0][0], temp_c[0]); + c[0][1] = vdotq_s32_h(weight[0][1], src[2], c[0][1], temp_c[1]); + + c[0][2] = vdotq_s32_h(weight[0][0], src[2], c[0][2], temp_c[0]); + c[0][3] = vdotq_s32_h(weight[0][0], src[3], c[0][3], temp_c[1]); + c[0][2] = vdotq_s32_h(weight[0][1], src[3], c[0][2], temp_c[0]); + c[0][3] = vdotq_s32_h(weight[0][1], src[4], c[0][3], temp_c[1]); + + c[0][4] = vdotq_s32_h(weight[0][0], src[4], c[0][4], temp_c[0]); + c[0][5] = vdotq_s32_h(weight[0][0], src[5], c[0][5], temp_c[1]); + c[0][4] = vdotq_s32_h(weight[0][1], src[5], c[0][4], temp_c[0]); + c[0][5] = vdotq_s32_h(weight[0][1], src[6], c[0][5], temp_c[1]); + + c[0][6] = vdotq_s32_h(weight[0][0], src[6], c[0][6], temp_c[0]); + c[0][7] = vdotq_s32_h(weight[0][0], src[7], c[0][7], temp_c[1]); + c[0][6] = vdotq_s32_h(weight[0][1], src[7], c[0][6], temp_c[0]); + c[0][7] = vdotq_s32_h(weight[0][1], src[8], c[0][7], temp_c[1]); + } + weight_ptr += fh * fw * ld_weight_ic4; + } + + store_ocx_ow8_remain_static_dt( + c, op, dst_ptr, ld_dst_oc); +} + +template +struct KerNeonDirectStride1Int8 { + static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, + const int32_t* bias_ptr, DstType* dst_ptr, int ic, int ih, + int iw, const Op& op, int ld_dst_oc); +}; +/** +dot like impl. dot 4 ic to 1 oc, accumale to c +example: (format like weight) +packed weight +low 64 bit <0, 0> <0, 1> <1, 2> <1, 3> | <2, 0> <2, 1> <3, 2> <3, 3> +--------------------------------------------------------------------- +high 64 bit <0, 3> <0, 2> <1, 1> <1, 0> | <2, 3> <2, 2> <3, 1> <3, 0> +dot: (<0, 0> + <0, 3>) + (<0, 1> + <0, 2>) -> <0> +**/ +//! TODO: can try oh = 2 impl, oc = 8 impl +template +struct KerNeonDirectStride1Int8 { + static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, + const int32_t* bias_ptr, DstType* dst_ptr, int ic, int ih, + int iw, const Op& op, int ld_dst_oc) { + constexpr int filter_size = 3; + constexpr int fh = filter_size; + constexpr int fw = filter_size; + constexpr int oc_step = 4; + constexpr int ic_step = 4; + constexpr int loop_ic_step = 4; + constexpr int ld_weight_ic4 = 16; + constexpr int pack_iw_len = 4; + + const int ic_stride = ih * iw * pack_iw_len; + + int32x4_t c[c_dim][8]; + int8x16_t weight[3]; + int8x16_t src[8 + 2]; + int16x8_t temp_c[2]; + init_ocx_ow8(c, bias_ptr, oc_step); + + for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { + for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { + const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + + fh_idx * iw * ic_step * pack_iw_len; + + src[0] = vld1q_s8(src_ic_0_3); + src[1] = vld1q_s8((src_ic_0_3 + 16)); + src[2] = vld1q_s8((src_ic_0_3 + 2 * 16)); + src[3] = vld1q_s8((src_ic_0_3 + 3 * 16)); + src[4] = vld1q_s8((src_ic_0_3 + 4 * 16)); + src[5] = vld1q_s8((src_ic_0_3 + 5 * 16)); + src[6] = vld1q_s8((src_ic_0_3 + 6 * 16)); + src[7] = vld1q_s8((src_ic_0_3 + 7 * 16)); + src[8] = vld1q_s8((src_ic_0_3 + 8 * 16)); + src[9] = vld1q_s8((src_ic_0_3 + 9 * 16)); + + // oc == 0 + const int8_t* read_weight_ptr = + weight_ptr + fh_idx * fw * ld_weight_ic4; + + weight[0] = vld1q_s8(read_weight_ptr); + weight[1] = vld1q_s8(read_weight_ptr + 16); + weight[2] = vld1q_s8(read_weight_ptr + 2 * 16); + + c[0][0] = vdotq_s32_h(weight[0], src[0], c[0][0], temp_c[0]); + c[0][1] = vdotq_s32_h(weight[0], src[1], c[0][1], temp_c[1]); + c[0][0] = vdotq_s32_h(weight[1], src[1], c[0][0], temp_c[0]); + c[0][1] = vdotq_s32_h(weight[1], src[2], c[0][1], temp_c[1]); + c[0][0] = vdotq_s32_h(weight[2], src[2], c[0][0], temp_c[0]); + c[0][1] = vdotq_s32_h(weight[2], src[3], c[0][1], temp_c[1]); + + c[0][2] = vdotq_s32_h(weight[0], src[2], c[0][2], temp_c[0]); + c[0][3] = vdotq_s32_h(weight[0], src[3], c[0][3], temp_c[1]); + c[0][2] = vdotq_s32_h(weight[1], src[3], c[0][2], temp_c[0]); + c[0][3] = vdotq_s32_h(weight[1], src[4], c[0][3], temp_c[1]); + c[0][2] = vdotq_s32_h(weight[2], src[4], c[0][2], temp_c[0]); + c[0][3] = vdotq_s32_h(weight[2], src[5], c[0][3], temp_c[1]); + + c[0][4] = vdotq_s32_h(weight[0], src[4], c[0][4], temp_c[0]); + c[0][5] = vdotq_s32_h(weight[0], src[5], c[0][5], temp_c[1]); + c[0][4] = vdotq_s32_h(weight[1], src[5], c[0][4], temp_c[0]); + c[0][5] = vdotq_s32_h(weight[1], src[6], c[0][5], temp_c[1]); + c[0][4] = vdotq_s32_h(weight[2], src[6], c[0][4], temp_c[0]); + c[0][5] = vdotq_s32_h(weight[2], src[7], c[0][5], temp_c[1]); + + c[0][6] = vdotq_s32_h(weight[0], src[6], c[0][6], temp_c[0]); + c[0][7] = vdotq_s32_h(weight[0], src[7], c[0][7], temp_c[1]); + c[0][6] = vdotq_s32_h(weight[1], src[7], c[0][6], temp_c[0]); + c[0][7] = vdotq_s32_h(weight[1], src[8], c[0][7], temp_c[1]); + c[0][6] = vdotq_s32_h(weight[2], src[8], c[0][6], temp_c[0]); + c[0][7] = vdotq_s32_h(weight[2], src[9], c[0][7], temp_c[1]); + } + weight_ptr += fh * fw * ld_weight_ic4; + } + store_ocx_ow8_remain_static_dt<1, remain_w, Op, DstType*>( + c, op, dst_ptr, ld_dst_oc); + } +}; + +template +struct KerNeonDirectStride1Int8 { + static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, + const int32_t* bias_ptr, DstType* dst_ptr, int ic, int ih, + int iw, const Op& op, int ld_dst_oc) { + constexpr int filter_size = 5; + constexpr int fh = filter_size; + constexpr int fw = filter_size; + constexpr int oc_step = 4; + constexpr int ic_step = 4; + constexpr int loop_ic_step = 4; + constexpr int ld_weight_ic4 = 16; + constexpr int pack_iw_len = 4; + + const int ic_stride = ih * iw * pack_iw_len; + + int32x4_t c[c_dim][8]; + int8x16_t weight[5]; + int8x16_t src[8 + 2]; + int16x8_t temp_c[2]; + init_ocx_ow8(c, bias_ptr, oc_step); + + for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { + for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { + const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + + fh_idx * iw * ic_step * pack_iw_len; + + src[0] = vld1q_s8(src_ic_0_3); + src[1] = vld1q_s8((src_ic_0_3 + 16)); + src[2] = vld1q_s8((src_ic_0_3 + 2 * 16)); + src[3] = vld1q_s8((src_ic_0_3 + 3 * 16)); + src[4] = vld1q_s8((src_ic_0_3 + 4 * 16)); + src[5] = vld1q_s8((src_ic_0_3 + 5 * 16)); + src[6] = vld1q_s8((src_ic_0_3 + 6 * 16)); + src[7] = vld1q_s8((src_ic_0_3 + 7 * 16)); + src[8] = vld1q_s8((src_ic_0_3 + 8 * 16)); + src[9] = vld1q_s8((src_ic_0_3 + 9 * 16)); + + // oc == 0 + const int8_t* read_weight_ptr = + weight_ptr + fh_idx * fw * ld_weight_ic4; + + weight[0] = vld1q_s8(read_weight_ptr); + weight[1] = vld1q_s8(read_weight_ptr + 16); + weight[2] = vld1q_s8(read_weight_ptr + 2 * 16); + weight[3] = vld1q_s8(read_weight_ptr + 3 * 16); + weight[4] = vld1q_s8(read_weight_ptr + 4 * 16); + + c[0][0] = vdotq_s32_h(weight[0], src[0], c[0][0], temp_c[0]); + c[0][1] = vdotq_s32_h(weight[0], src[1], c[0][1], temp_c[1]); + c[0][0] = vdotq_s32_h(weight[1], src[1], c[0][0], temp_c[0]); + c[0][1] = vdotq_s32_h(weight[1], src[2], c[0][1], temp_c[1]); + c[0][0] = vdotq_s32_h(weight[2], src[2], c[0][0], temp_c[0]); + c[0][1] = vdotq_s32_h(weight[2], src[3], c[0][1], temp_c[1]); + c[0][0] = vdotq_s32_h(weight[3], src[3], c[0][0], temp_c[0]); + c[0][1] = vdotq_s32_h(weight[3], src[4], c[0][1], temp_c[1]); + c[0][0] = vdotq_s32_h(weight[4], src[4], c[0][0], temp_c[0]); + c[0][1] = vdotq_s32_h(weight[4], src[5], c[0][1], temp_c[1]); + + c[0][2] = vdotq_s32_h(weight[0], src[2], c[0][2], temp_c[0]); + c[0][3] = vdotq_s32_h(weight[0], src[3], c[0][3], temp_c[1]); + c[0][2] = vdotq_s32_h(weight[1], src[3], c[0][2], temp_c[0]); + c[0][3] = vdotq_s32_h(weight[1], src[4], c[0][3], temp_c[1]); + c[0][2] = vdotq_s32_h(weight[2], src[4], c[0][2], temp_c[0]); + c[0][3] = vdotq_s32_h(weight[2], src[5], c[0][3], temp_c[1]); + c[0][2] = vdotq_s32_h(weight[3], src[5], c[0][2], temp_c[0]); + c[0][3] = vdotq_s32_h(weight[3], src[6], c[0][3], temp_c[1]); + c[0][2] = vdotq_s32_h(weight[4], src[6], c[0][2], temp_c[0]); + c[0][3] = vdotq_s32_h(weight[4], src[7], c[0][3], temp_c[1]); + + c[0][4] = vdotq_s32_h(weight[0], src[4], c[0][4], temp_c[0]); + c[0][5] = vdotq_s32_h(weight[0], src[5], c[0][5], temp_c[1]); + c[0][4] = vdotq_s32_h(weight[1], src[5], c[0][4], temp_c[0]); + c[0][5] = vdotq_s32_h(weight[1], src[6], c[0][5], temp_c[1]); + c[0][4] = vdotq_s32_h(weight[2], src[6], c[0][4], temp_c[0]); + c[0][5] = vdotq_s32_h(weight[2], src[7], c[0][5], temp_c[1]); + c[0][4] = vdotq_s32_h(weight[3], src[7], c[0][4], temp_c[0]); + c[0][5] = vdotq_s32_h(weight[3], src[8], c[0][5], temp_c[1]); + c[0][4] = vdotq_s32_h(weight[4], src[8], c[0][4], temp_c[0]); + c[0][5] = vdotq_s32_h(weight[4], src[9], c[0][5], temp_c[1]); + + src[0] = vld1q_s8(src_ic_0_3 + 10 * 16); + src[1] = vld1q_s8((src_ic_0_3 + 11 * 16)); + + c[0][6] = vdotq_s32_h(weight[0], src[6], c[0][6], temp_c[0]); + c[0][7] = vdotq_s32_h(weight[0], src[7], c[0][7], temp_c[1]); + c[0][6] = vdotq_s32_h(weight[1], src[7], c[0][6], temp_c[0]); + c[0][7] = vdotq_s32_h(weight[1], src[8], c[0][7], temp_c[1]); + c[0][6] = vdotq_s32_h(weight[2], src[8], c[0][6], temp_c[0]); + c[0][7] = vdotq_s32_h(weight[2], src[9], c[0][7], temp_c[1]); + c[0][6] = vdotq_s32_h(weight[3], src[9], c[0][6], temp_c[0]); + c[0][7] = vdotq_s32_h(weight[3], src[0], c[0][7], temp_c[1]); + c[0][6] = vdotq_s32_h(weight[4], src[0], c[0][6], temp_c[0]); + c[0][7] = vdotq_s32_h(weight[4], src[1], c[0][7], temp_c[1]); + } + weight_ptr += fh * fw * ld_weight_ic4; + } + store_ocx_ow8_remain_static_dt<1, remain_w, Op, DstType*>( + c, op, dst_ptr, ld_dst_oc); + } +}; + +template +struct KerNeonDirectStride1Int8 { + static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, + const int32_t* bias_ptr, DstType* dst_ptr, int ic, int ih, + int iw, const Op& op, int ld_dst_oc) { + constexpr int filter_size = 7; + constexpr int fh = filter_size; + constexpr int fw = filter_size; + constexpr int oc_step = 4; + constexpr int ic_step = 4; + constexpr int loop_ic_step = 4; + constexpr int ld_weight_ic4 = 16; + constexpr int pack_iw_len = 4; + + const int ic_stride = ih * iw * pack_iw_len; + + int32x4_t c[c_dim][8]; + int8x16_t weight[7]; + int8x16_t src[8 + 2]; + int16x8_t temp_c[2]; + init_ocx_ow8(c, bias_ptr, oc_step); + + for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { + for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { + const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + + fh_idx * iw * ic_step * pack_iw_len; + + src[0] = vld1q_s8(src_ic_0_3); + src[1] = vld1q_s8((src_ic_0_3 + 16)); + src[2] = vld1q_s8((src_ic_0_3 + 2 * 16)); + src[3] = vld1q_s8((src_ic_0_3 + 3 * 16)); + src[4] = vld1q_s8((src_ic_0_3 + 4 * 16)); + src[5] = vld1q_s8((src_ic_0_3 + 5 * 16)); + src[6] = vld1q_s8((src_ic_0_3 + 6 * 16)); + src[7] = vld1q_s8((src_ic_0_3 + 7 * 16)); + src[8] = vld1q_s8((src_ic_0_3 + 8 * 16)); + src[9] = vld1q_s8((src_ic_0_3 + 9 * 16)); + + // oc == 0 + const int8_t* read_weight_ptr = + weight_ptr + fh_idx * fw * ld_weight_ic4; + + weight[0] = vld1q_s8(read_weight_ptr); + weight[1] = vld1q_s8(read_weight_ptr + 16); + weight[2] = vld1q_s8(read_weight_ptr + 2 * 16); + weight[3] = vld1q_s8(read_weight_ptr + 3 * 16); + weight[4] = vld1q_s8(read_weight_ptr + 4 * 16); + weight[5] = vld1q_s8(read_weight_ptr + 5 * 16); + weight[6] = vld1q_s8(read_weight_ptr + 6 * 16); + + c[0][0] = vdotq_s32_h(weight[0], src[0], c[0][0], temp_c[0]); + c[0][1] = vdotq_s32_h(weight[0], src[1], c[0][1], temp_c[1]); + c[0][0] = vdotq_s32_h(weight[1], src[1], c[0][0], temp_c[0]); + c[0][1] = vdotq_s32_h(weight[1], src[2], c[0][1], temp_c[1]); + c[0][0] = vdotq_s32_h(weight[2], src[2], c[0][0], temp_c[0]); + c[0][1] = vdotq_s32_h(weight[2], src[3], c[0][1], temp_c[1]); + c[0][0] = vdotq_s32_h(weight[3], src[3], c[0][0], temp_c[0]); + c[0][1] = vdotq_s32_h(weight[3], src[4], c[0][1], temp_c[1]); + c[0][0] = vdotq_s32_h(weight[4], src[4], c[0][0], temp_c[0]); + c[0][1] = vdotq_s32_h(weight[4], src[5], c[0][1], temp_c[1]); + c[0][0] = vdotq_s32_h(weight[5], src[5], c[0][0], temp_c[0]); + c[0][1] = vdotq_s32_h(weight[5], src[6], c[0][1], temp_c[1]); + c[0][0] = vdotq_s32_h(weight[6], src[6], c[0][0], temp_c[0]); + c[0][1] = vdotq_s32_h(weight[6], src[7], c[0][1], temp_c[1]); + + c[0][2] = vdotq_s32_h(weight[0], src[2], c[0][2], temp_c[0]); + c[0][3] = vdotq_s32_h(weight[0], src[3], c[0][3], temp_c[1]); + c[0][2] = vdotq_s32_h(weight[1], src[3], c[0][2], temp_c[0]); + c[0][3] = vdotq_s32_h(weight[1], src[4], c[0][3], temp_c[1]); + c[0][2] = vdotq_s32_h(weight[2], src[4], c[0][2], temp_c[0]); + c[0][3] = vdotq_s32_h(weight[2], src[5], c[0][3], temp_c[1]); + c[0][2] = vdotq_s32_h(weight[3], src[5], c[0][2], temp_c[0]); + c[0][3] = vdotq_s32_h(weight[3], src[6], c[0][3], temp_c[1]); + c[0][2] = vdotq_s32_h(weight[4], src[6], c[0][2], temp_c[0]); + c[0][3] = vdotq_s32_h(weight[4], src[7], c[0][3], temp_c[1]); + c[0][2] = vdotq_s32_h(weight[5], src[7], c[0][2], temp_c[0]); + c[0][3] = vdotq_s32_h(weight[5], src[8], c[0][3], temp_c[1]); + c[0][2] = vdotq_s32_h(weight[6], src[8], c[0][2], temp_c[0]); + c[0][3] = vdotq_s32_h(weight[6], src[9], c[0][3], temp_c[1]); + + src[0] = vld1q_s8(src_ic_0_3 + 10 * 16); + src[1] = vld1q_s8((src_ic_0_3 + 11 * 16)); + + c[0][4] = vdotq_s32_h(weight[0], src[4], c[0][4], temp_c[0]); + c[0][5] = vdotq_s32_h(weight[0], src[5], c[0][5], temp_c[1]); + c[0][4] = vdotq_s32_h(weight[1], src[5], c[0][4], temp_c[0]); + c[0][5] = vdotq_s32_h(weight[1], src[6], c[0][5], temp_c[1]); + c[0][4] = vdotq_s32_h(weight[2], src[6], c[0][4], temp_c[0]); + c[0][5] = vdotq_s32_h(weight[2], src[7], c[0][5], temp_c[1]); + c[0][4] = vdotq_s32_h(weight[3], src[7], c[0][4], temp_c[0]); + c[0][5] = vdotq_s32_h(weight[3], src[8], c[0][5], temp_c[1]); + c[0][4] = vdotq_s32_h(weight[4], src[8], c[0][4], temp_c[0]); + c[0][5] = vdotq_s32_h(weight[4], src[9], c[0][5], temp_c[1]); + c[0][4] = vdotq_s32_h(weight[5], src[9], c[0][4], temp_c[0]); + c[0][5] = vdotq_s32_h(weight[5], src[0], c[0][5], temp_c[1]); + c[0][4] = vdotq_s32_h(weight[6], src[0], c[0][4], temp_c[0]); + c[0][5] = vdotq_s32_h(weight[6], src[1], c[0][5], temp_c[1]); + + src[2] = vld1q_s8(src_ic_0_3 + 12 * 16); + src[3] = vld1q_s8((src_ic_0_3 + 13 * 16)); + + c[0][6] = vdotq_s32_h(weight[0], src[6], c[0][6], temp_c[0]); + c[0][7] = vdotq_s32_h(weight[0], src[7], c[0][7], temp_c[1]); + c[0][6] = vdotq_s32_h(weight[1], src[7], c[0][6], temp_c[0]); + c[0][7] = vdotq_s32_h(weight[1], src[8], c[0][7], temp_c[1]); + c[0][6] = vdotq_s32_h(weight[2], src[8], c[0][6], temp_c[0]); + c[0][7] = vdotq_s32_h(weight[2], src[9], c[0][7], temp_c[1]); + c[0][6] = vdotq_s32_h(weight[3], src[9], c[0][6], temp_c[0]); + c[0][7] = vdotq_s32_h(weight[3], src[0], c[0][7], temp_c[1]); + c[0][6] = vdotq_s32_h(weight[4], src[0], c[0][6], temp_c[0]); + c[0][7] = vdotq_s32_h(weight[4], src[1], c[0][7], temp_c[1]); + c[0][6] = vdotq_s32_h(weight[5], src[1], c[0][6], temp_c[0]); + c[0][7] = vdotq_s32_h(weight[5], src[2], c[0][7], temp_c[1]); + c[0][6] = vdotq_s32_h(weight[6], src[2], c[0][6], temp_c[0]); + c[0][7] = vdotq_s32_h(weight[6], src[3], c[0][7], temp_c[1]); + } + weight_ptr += fh * fw * ld_weight_ic4; + } + store_ocx_ow8_remain_static_dt<1, remain_w, Op, DstType*>( + c, op, dst_ptr, ld_dst_oc); + } +}; + +template +void conv_direct_stride1_2x2_int8_nchw44(const int8_t* src, + const int8_t* filter, + const int32_t* bias, int32_t* temp, + DstType* dst, const size_t oc, + const size_t ic, const size_t ih, + const size_t iw, const size_t oh, + const size_t ow, const Op& op) { + MEGDNN_MARK_USED_VAR(temp); + constexpr size_t filter_size = 2; + constexpr size_t fh = filter_size; + constexpr size_t fw = filter_size; + constexpr size_t ic_step = 4; + constexpr size_t oc_step = 4; + constexpr size_t big_oc_step = 8; + constexpr size_t oh_step = 1; + constexpr size_t ow_step = 8; + constexpr int pack_iw_len = 4; + + const size_t img_stride = oh * ow; + const size_t ow_end = ow / ow_step * ow_step; + const size_t ow_remain = ow - ow_end; + const size_t oc_end = oc / big_oc_step * big_oc_step; + const size_t oc_remain = oc - oc_end; + const int ld_oc = oh * ow * oc_step; + + using remain_fun = std::function; + remain_fun kern_big_oc_remain = nullptr; + remain_fun kern_small_oc_remain = nullptr; + + switch (ow_remain) { +#define cb(step) \ + case step: \ + kern_big_oc_remain = \ + ker_neon_dirctconv_2x2s1_oc8_ow8; \ + kern_small_oc_remain = \ + ker_neon_dirctconv_2x2s1_oc4_ow8; \ + break; + + UNROLL_CALL_RAW(8, cb); + default: + megdnn_assert(0, "no remain %zu for kern", ow_remain); + } +#undef cb + for (size_t oc_idx = 0; oc_idx < oc_end; oc_idx += big_oc_step) { + const size_t weight_offset = oc_idx * ic * fh * fw; + for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { + for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { + const size_t src_offset = + (oh_idx * iw + ow_idx) * ic_step * pack_iw_len; + const size_t dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; + ker_neon_dirctconv_2x2s1_oc8_ow8( + src + src_offset, filter + weight_offset, bias + oc_idx, + dst + dst_offset, ic, ih, iw, ld_oc, op); + } + if (ow_remain > 0) { + const size_t src_offset = + (oh_idx * iw + ow_end) * ic_step * pack_iw_len; + const size_t dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; + kern_big_oc_remain(src + src_offset, filter + weight_offset, + bias + oc_idx, dst + dst_offset, ic, ih, iw, + ld_oc, op); + } + } + } + if (oc_remain > 0) { + const size_t oc_idx = oc_end; + const size_t weight_offset = oc_idx * ic * fh * fw; + for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { + for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { + const size_t src_offset = + (oh_idx * iw + ow_idx) * ic_step * pack_iw_len; + const size_t dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; + ker_neon_dirctconv_2x2s1_oc4_ow8( + src + src_offset, filter + weight_offset, bias + oc_idx, + dst + dst_offset, ic, ih, iw, ld_oc, op); + } + if (ow_remain > 0) { + const size_t src_offset = + (oh_idx * iw + ow_end) * ic_step * pack_iw_len; + const size_t dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; + kern_small_oc_remain(src + src_offset, filter + weight_offset, + bias + oc_idx, dst + dst_offset, ic, ih, + iw, ld_oc, op); + } + } + } +} +template +void conv_direct_stride1_int8_nchw44_kern(const int8_t* src, + const int8_t* filter, + const int32_t* bias, int32_t* temp, + DstType* dst, const size_t oc, + const size_t ic, const size_t ih, + const size_t iw, const size_t oh, + const size_t ow, const Op& op) { + MEGDNN_MARK_USED_VAR(temp); + constexpr size_t fh = filter_size; + constexpr size_t fw = filter_size; + constexpr size_t ic_step = 4; + constexpr size_t oc_step = 4; + constexpr size_t oh_step = 1; + constexpr size_t ow_step = 8; + constexpr int pack_iw_len = 4; + + const size_t img_stride = oh * ow; + const int ld_dst_oc = oh * ow * oc_step; + const size_t ow_end = ow / ow_step * ow_step; + const size_t ow_remain = ow - ow_end; + + using remain_fun = std::function; + + remain_fun kern_small_oc_remain = nullptr; + switch (ow_remain) { +#define cb(step) \ + case step: \ + kern_small_oc_remain = \ + KerNeonDirectStride1Int8::impl; \ + break; + + UNROLL_CALL_RAW(8, cb); + default: + megdnn_assert(0, "no remain %zu for kern", ow_remain); + } +#undef cb + + for (size_t oc_idx = 0; oc_idx < oc; oc_idx += oc_step) { + const size_t weight_offset = oc_idx * ic * fh * fw; + for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { + for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { + const size_t src_offset = + (oh_idx * iw + ow_idx) * ic_step * pack_iw_len; + const size_t dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; + KerNeonDirectStride1Int8::impl(src + src_offset, + filter + weight_offset, + bias + oc_idx, + dst + dst_offset, ic, + ih, iw, op, ld_dst_oc); + } + if (ow_remain > 0) { + const size_t src_offset = + (oh_idx * iw + ow_end) * ic_step * pack_iw_len; + const size_t dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; + kern_small_oc_remain(src + src_offset, filter + weight_offset, + bias + oc_idx, dst + dst_offset, ic, ih, + iw, op, ld_dst_oc); + } + } + } +} +} // namespace + +namespace int8_direct_nchw44 { +template +struct ConvDirectInt8Nchw44Choose { + static void impl(const int8_t* src, const int8_t* filter, + const int32_t* bias, int32_t* temp, DstType* dst, + const size_t oc, const size_t ic, const size_t ih, + const size_t iw, const size_t oh, const size_t ow, + const Op& op) { + conv_direct_stride1_int8_nchw44_kern( + src, filter, bias, temp, dst, oc, ic, ih, iw, oh, ow, op); + } +}; + +template +struct ConvDirectInt8Nchw44Choose { + static void impl(const int8_t* src, const int8_t* filter, + const int32_t* bias, int32_t* temp, DstType* dst, + const size_t oc, const size_t ic, const size_t ih, + const size_t iw, const size_t oh, const size_t ow, + const Op& op) { + conv_direct_stride1_2x2_int8_nchw44( + src, filter, bias, temp, dst, oc, ic, ih, iw, oh, ow, op); + } +}; + +#define DO_CONV_KERN_FUN(stride, DstType, filter_size, bias_mode, Op) \ + template struct ConvDirectInt8Nchw44Choose; + +#define GET_OP_PARAM(stride, filter, bias_mode) \ + DO_CONV_KERN_FUN(stride, dt_qint8, filter, bias_mode, \ + \ + TypeCvtOp) \ + DO_CONV_KERN_FUN(stride, dt_qint8, filter, bias_mode, \ + \ + ReluOp) \ + DO_CONV_KERN_FUN(stride, dt_qint8, filter, bias_mode, \ + \ + HSwishOp) \ + DO_CONV_KERN_FUN(stride, dt_int32, filter, bias_mode, NoneOp) + +#define GET_BIAS_MODE_PARAM(stride, filter) \ + GET_OP_PARAM(stride, filter, BiasMode::NO_BIAS) \ + GET_OP_PARAM(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS) + +#define DISPATCH_CONV_KERN(stride) \ + GET_BIAS_MODE_PARAM(stride, 2) \ + GET_BIAS_MODE_PARAM(stride, 3) \ + GET_BIAS_MODE_PARAM(stride, 5) \ + GET_BIAS_MODE_PARAM(stride, 7) + +DISPATCH_CONV_KERN(1); + +} // namespace int8_direct_nchw44 +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen \ No newline at end of file diff --git a/dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw44_s2.cpp b/dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw44_s2.cpp new file mode 100644 index 00000000..c202512b --- /dev/null +++ b/dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw44_s2.cpp @@ -0,0 +1,778 @@ +/** + * \file + * dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw44_s2.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "src/arm_common/conv_bias/int8/direct.h" +#include "src/arm_common/conv_bias/int8/direct_nchw44_kern.h" +#include "src/arm_common/conv_bias/intrinsic_helper.h" +#include "src/arm_common/elemwise_op.h" +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/common/utils.h" +#include "src/fallback/conv_bias/common.h" + +namespace megdnn { +namespace arm_common { +namespace { +template +struct KerNeonDirectStride2Int8 { + static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, + const int32_t* bias_ptr, DstType* dst_ptr, int ic, int ih, + int iw, const Op& op, int ld_dst_oc); +}; + +template +static void ker_neon_dirctconv_2x2s2_oc8_ow8(const int8_t* src_ptr, + const int8_t* weight_ptr, + const int32_t* bias_ptr, + DstType* dst_ptr, int ic, int ih, + int iw, int ld_dst_oc, + const Op& op) { + constexpr int fh = filter_size; + constexpr int fw = filter_size; + constexpr int ic_step = 4; + constexpr int oc_step = 4; + constexpr int loop_ic_step = 4; + constexpr int ld_weight_ic4 = 16; + constexpr int pack_iw_len = 4; + + const int ic_stride = ih * iw * pack_iw_len; + const int ld_weight_oc4 = oc_step * fh * fw * ic; + + int32x4_t c[2][8]; + int8x16_t weight[2][2]; + int8x16_t src[8 + 1]; + int16x8_t temp_c[4]; + + init_ocx_ow8(c, bias_ptr, oc_step); + + for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { + for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { + const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + + fh_idx * iw * ic_step * pack_iw_len; + + src[0] = vld1q_s8(src_ic_0_3); + src[1] = vld1q_s8(src_ic_0_3 + 16); + src[2] = vld1q_s8(src_ic_0_3 + 2 * 16); + src[3] = vld1q_s8(src_ic_0_3 + 3 * 16); + src[4] = vld1q_s8(src_ic_0_3 + 4 * 16); + src[5] = vld1q_s8(src_ic_0_3 + 5 * 16); + src[6] = vld1q_s8(src_ic_0_3 + 6 * 16); + src[7] = vld1q_s8(src_ic_0_3 + 7 * 16); + src[8] = vld1q_s8(src_ic_0_3 + 8 * 16); + + // oc == 0 + const int8_t* read_weight_ptr = + weight_ptr + fh_idx * fw * ld_weight_ic4; + + weight[0][0] = vld1q_s8(read_weight_ptr); + weight[0][1] = vld1q_s8(read_weight_ptr + 16); + weight[1][0] = vld1q_s8(read_weight_ptr + ld_weight_oc4); + weight[1][1] = vld1q_s8(read_weight_ptr + ld_weight_oc4 + 16); + + c[0][0] = vdotq_s32_h(weight[0][0], src[0], c[0][0], temp_c[0]); + c[1][0] = vdotq_s32_h(weight[1][0], src[0], c[1][0], temp_c[1]); + c[0][1] = vdotq_s32_h(weight[0][0], src[2], c[0][1], temp_c[2]); + c[1][1] = vdotq_s32_h(weight[1][0], src[2], c[1][1], temp_c[3]); + c[0][0] = vdotq_s32_h(weight[0][1], src[1], c[0][0], temp_c[0]); + c[1][0] = vdotq_s32_h(weight[1][1], src[1], c[1][0], temp_c[1]); + c[0][1] = vdotq_s32_h(weight[0][1], src[3], c[0][1], temp_c[2]); + c[1][1] = vdotq_s32_h(weight[1][1], src[3], c[1][1], temp_c[3]); + + c[0][2] = vdotq_s32_h(weight[0][0], src[4], c[0][2], temp_c[0]); + c[1][2] = vdotq_s32_h(weight[1][0], src[4], c[1][2], temp_c[1]); + c[0][3] = vdotq_s32_h(weight[0][0], src[6], c[0][3], temp_c[2]); + c[1][3] = vdotq_s32_h(weight[1][0], src[6], c[1][3], temp_c[3]); + c[0][2] = vdotq_s32_h(weight[0][1], src[5], c[0][2], temp_c[0]); + c[1][2] = vdotq_s32_h(weight[1][1], src[5], c[1][2], temp_c[1]); + c[0][3] = vdotq_s32_h(weight[0][1], src[7], c[0][3], temp_c[2]); + c[1][3] = vdotq_s32_h(weight[1][1], src[7], c[1][3], temp_c[3]); + + src[0] = vld1q_s8(src_ic_0_3 + 9 * 16); + src[1] = vld1q_s8(src_ic_0_3 + 10 * 16); + src[2] = vld1q_s8(src_ic_0_3 + 11 * 16); + c[0][4] = vdotq_s32_h(weight[0][0], src[8], c[0][4], temp_c[0]); + c[1][4] = vdotq_s32_h(weight[1][0], src[8], c[1][4], temp_c[1]); + c[0][5] = vdotq_s32_h(weight[0][0], src[1], c[0][5], temp_c[2]); + c[1][5] = vdotq_s32_h(weight[1][0], src[1], c[1][5], temp_c[3]); + c[0][4] = vdotq_s32_h(weight[0][1], src[0], c[0][4], temp_c[0]); + c[1][4] = vdotq_s32_h(weight[1][1], src[0], c[1][4], temp_c[1]); + c[0][5] = vdotq_s32_h(weight[0][1], src[2], c[0][5], temp_c[2]); + c[1][5] = vdotq_s32_h(weight[1][1], src[2], c[1][5], temp_c[3]); + + src[3] = vld1q_s8(src_ic_0_3 + 12 * 16); + src[4] = vld1q_s8(src_ic_0_3 + 13 * 16); + src[5] = vld1q_s8(src_ic_0_3 + 14 * 16); + src[6] = vld1q_s8(src_ic_0_3 + 15 * 16); + c[0][6] = vdotq_s32_h(weight[0][0], src[3], c[0][6], temp_c[0]); + c[1][6] = vdotq_s32_h(weight[1][0], src[3], c[1][6], temp_c[1]); + c[0][7] = vdotq_s32_h(weight[0][0], src[5], c[0][7], temp_c[2]); + c[1][7] = vdotq_s32_h(weight[1][0], src[5], c[1][7], temp_c[3]); + c[0][6] = vdotq_s32_h(weight[0][1], src[4], c[0][6], temp_c[0]); + c[1][6] = vdotq_s32_h(weight[1][1], src[4], c[1][6], temp_c[1]); + c[0][7] = vdotq_s32_h(weight[0][1], src[6], c[0][7], temp_c[2]); + c[1][7] = vdotq_s32_h(weight[1][1], src[6], c[1][7], temp_c[3]); + } + weight_ptr += fh * fw * ld_weight_ic4; + } + store_ocx_ow8_remain_static_dt( + c, op, dst_ptr, ld_dst_oc); +} + +template +static void ker_neon_dirctconv_2x2s2_oc4_ow8(const int8_t* src_ptr, + const int8_t* weight_ptr, + const int32_t* bias_ptr, + DstType* dst_ptr, int ic, int ih, + int iw, int ld_dst_oc, + const Op& op) { + constexpr int fh = filter_size; + constexpr int fw = filter_size; + constexpr int oc_step = 4; + constexpr int ic_step = 4; + constexpr int loop_ic_step = 4; + constexpr int ld_weight_ic4 = 16; + constexpr int pack_iw_len = 4; + + const int ic_stride = ih * iw * pack_iw_len; + + int32x4_t c[c_dim][8]; + int8x16_t weight[2]; + int8x16_t src[8 + 1]; + int16x8_t temp_c[2]; + init_ocx_ow8(c, bias_ptr, oc_step); + + for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { + for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { + const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + + fh_idx * iw * ic_step * pack_iw_len; + + src[0] = vld1q_s8(src_ic_0_3); + src[1] = vld1q_s8((src_ic_0_3 + 16)); + src[2] = vld1q_s8((src_ic_0_3 + 2 * 16)); + src[3] = vld1q_s8((src_ic_0_3 + 3 * 16)); + src[4] = vld1q_s8((src_ic_0_3 + 4 * 16)); + src[5] = vld1q_s8((src_ic_0_3 + 5 * 16)); + src[6] = vld1q_s8((src_ic_0_3 + 6 * 16)); + src[7] = vld1q_s8((src_ic_0_3 + 7 * 16)); + src[8] = vld1q_s8((src_ic_0_3 + 8 * 16)); + + // oc == 0 + const int8_t* read_weight_ptr = + weight_ptr + fh_idx * fw * ld_weight_ic4; + + weight[0] = vld1q_s8(read_weight_ptr); + weight[1] = vld1q_s8(read_weight_ptr + 16); + + c[0][0] = vdotq_s32_h(weight[0], src[0], c[0][0], temp_c[0]); + c[0][1] = vdotq_s32_h(weight[0], src[2], c[0][1], temp_c[1]); + c[0][0] = vdotq_s32_h(weight[1], src[1], c[0][0], temp_c[0]); + c[0][1] = vdotq_s32_h(weight[1], src[3], c[0][1], temp_c[1]); + + c[0][2] = vdotq_s32_h(weight[0], src[4], c[0][2], temp_c[0]); + c[0][3] = vdotq_s32_h(weight[0], src[6], c[0][3], temp_c[1]); + c[0][2] = vdotq_s32_h(weight[1], src[5], c[0][2], temp_c[0]); + c[0][3] = vdotq_s32_h(weight[1], src[7], c[0][3], temp_c[1]); + + src[0] = vld1q_s8(src_ic_0_3 + 9 * 16); + src[1] = vld1q_s8(src_ic_0_3 + 10 * 16); + src[2] = vld1q_s8(src_ic_0_3 + 11 * 16); + c[0][4] = vdotq_s32_h(weight[0], src[8], c[0][4], temp_c[0]); + c[0][5] = vdotq_s32_h(weight[0], src[1], c[0][5], temp_c[1]); + c[0][4] = vdotq_s32_h(weight[1], src[0], c[0][4], temp_c[0]); + c[0][5] = vdotq_s32_h(weight[1], src[2], c[0][5], temp_c[1]); + + src[3] = vld1q_s8(src_ic_0_3 + 12 * 16); + src[4] = vld1q_s8(src_ic_0_3 + 13 * 16); + src[5] = vld1q_s8(src_ic_0_3 + 14 * 16); + src[6] = vld1q_s8(src_ic_0_3 + 15 * 16); + c[0][6] = vdotq_s32_h(weight[0], src[3], c[0][6], temp_c[0]); + c[0][7] = vdotq_s32_h(weight[0], src[5], c[0][7], temp_c[1]); + c[0][6] = vdotq_s32_h(weight[1], src[4], c[0][6], temp_c[0]); + c[0][7] = vdotq_s32_h(weight[1], src[6], c[0][7], temp_c[1]); + } + weight_ptr += fh * fw * ld_weight_ic4; + } + + store_ocx_ow8_remain_static_dt( + c, op, dst_ptr, ld_dst_oc); +} +/** +dot like impl. dot 4 ic to 1 oc, accumale to c +example: (format like weight) +packed weight +low 64 bit <0, 0> <0, 1> <1, 2> <1, 3> | <2, 0> <2, 1> <3, 2> <3, 3> +--------------------------------------------------------------------- +high 64 bit <0, 3> <0, 2> <1, 1> <1, 0> | <2, 3> <2, 2> <3, 1> <3, 0> +dot: (<0, 0> + <0, 3>) + (<0, 1> + <0, 2>) -> <0> +**/ +// TODO: can try oh = 2 impl, oc = 8 impl +template +struct KerNeonDirectStride2Int8 { + static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, + const int32_t* bias_ptr, DstType* dst_ptr, int ic, int ih, + int iw, const Op& op, int ld_dst_oc) { + constexpr int filter_size = 3; + constexpr int fh = filter_size; + constexpr int fw = filter_size; + constexpr int oc_step = 4; + constexpr int ic_step = 4; + constexpr int loop_ic_step = 4; + constexpr int ld_weight_ic4 = 16; + constexpr int pack_iw_len = 4; + + const int ic_stride = ih * iw * pack_iw_len; + + int32x4_t c[c_dim][8]; + int8x16_t weight[3]; + int8x16_t src[8 + 2]; + int16x8_t temp_c[4]; + init_ocx_ow8(c, bias_ptr, oc_step); + + for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { + for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { + const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + + fh_idx * iw * ic_step * pack_iw_len; + + src[0] = vld1q_s8(src_ic_0_3); + src[1] = vld1q_s8((src_ic_0_3 + 16)); + src[2] = vld1q_s8((src_ic_0_3 + 2 * 16)); + src[3] = vld1q_s8((src_ic_0_3 + 3 * 16)); + src[4] = vld1q_s8((src_ic_0_3 + 4 * 16)); + src[5] = vld1q_s8((src_ic_0_3 + 5 * 16)); + src[6] = vld1q_s8((src_ic_0_3 + 6 * 16)); + src[7] = vld1q_s8((src_ic_0_3 + 7 * 16)); + src[8] = vld1q_s8((src_ic_0_3 + 8 * 16)); + src[9] = vld1q_s8((src_ic_0_3 + 9 * 16)); + + // oc == 0 + const int8_t* read_weight_ptr = + weight_ptr + fh_idx * fw * ld_weight_ic4; + + weight[0] = vld1q_s8(read_weight_ptr); + weight[1] = vld1q_s8(read_weight_ptr + 16); + weight[2] = vld1q_s8(read_weight_ptr + 2 * 16); + + c[0][0] = vdotq_s32_h(weight[0], src[0], c[0][0], temp_c[0]); + c[0][1] = vdotq_s32_h(weight[0], src[2], c[0][1], temp_c[1]); + c[0][0] = vdotq_s32_h(weight[1], src[1], c[0][0], temp_c[2]); + c[0][1] = vdotq_s32_h(weight[1], src[3], c[0][1], temp_c[3]); + c[0][0] = vdotq_s32_h(weight[2], src[2], c[0][0], temp_c[0]); + c[0][1] = vdotq_s32_h(weight[2], src[4], c[0][1], temp_c[1]); + + c[0][2] = vdotq_s32_h(weight[0], src[4], c[0][2], temp_c[2]); + c[0][3] = vdotq_s32_h(weight[0], src[6], c[0][3], temp_c[3]); + c[0][2] = vdotq_s32_h(weight[1], src[5], c[0][2], temp_c[0]); + c[0][3] = vdotq_s32_h(weight[1], src[7], c[0][3], temp_c[1]); + c[0][2] = vdotq_s32_h(weight[2], src[6], c[0][2], temp_c[2]); + c[0][3] = vdotq_s32_h(weight[2], src[8], c[0][3], temp_c[3]); + + src[0] = vld1q_s8(src_ic_0_3 + 10 * 16); + src[1] = vld1q_s8((src_ic_0_3 + 11 * 16)); + src[2] = vld1q_s8((src_ic_0_3 + 12 * 16)); + c[0][4] = vdotq_s32_h(weight[0], src[8], c[0][4], temp_c[0]); + c[0][5] = vdotq_s32_h(weight[0], src[0], c[0][5], temp_c[1]); + c[0][4] = vdotq_s32_h(weight[1], src[9], c[0][4], temp_c[2]); + c[0][5] = vdotq_s32_h(weight[1], src[1], c[0][5], temp_c[3]); + c[0][4] = vdotq_s32_h(weight[2], src[0], c[0][4], temp_c[0]); + c[0][5] = vdotq_s32_h(weight[2], src[2], c[0][5], temp_c[1]); + + src[3] = vld1q_s8((src_ic_0_3 + 13 * 16)); + src[4] = vld1q_s8((src_ic_0_3 + 14 * 16)); + src[5] = vld1q_s8((src_ic_0_3 + 15 * 16)); + src[6] = vld1q_s8((src_ic_0_3 + 16 * 16)); + c[0][6] = vdotq_s32_h(weight[0], src[2], c[0][6], temp_c[2]); + c[0][7] = vdotq_s32_h(weight[0], src[4], c[0][7], temp_c[3]); + c[0][6] = vdotq_s32_h(weight[1], src[3], c[0][6], temp_c[0]); + c[0][7] = vdotq_s32_h(weight[1], src[5], c[0][7], temp_c[1]); + c[0][6] = vdotq_s32_h(weight[2], src[4], c[0][6], temp_c[2]); + c[0][7] = vdotq_s32_h(weight[2], src[6], c[0][7], temp_c[3]); + } + weight_ptr += fh * fw * ld_weight_ic4; + } + store_ocx_ow8_remain_static_dt( + c, op, dst_ptr, ld_dst_oc); + } +}; +template +struct KerNeonDirectStride2Int8 { + static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, + const int32_t* bias_ptr, DstType* dst_ptr, int ic, int ih, + int iw, const Op& op, int ld_dst_oc) { + constexpr int filter_size = 5; + constexpr int fh = filter_size; + constexpr int fw = filter_size; + constexpr int oc_step = 4; + constexpr int ic_step = 4; + constexpr int loop_ic_step = 4; + constexpr int ld_weight_ic4 = 16; + constexpr int pack_iw_len = 4; + + const int ic_stride = ih * iw * pack_iw_len; + + int32x4_t c[c_dim][8]; + int8x16_t weight[5]; + int8x16_t src[8 + 2]; + int16x8_t temp_c[4]; + init_ocx_ow8(c, bias_ptr, oc_step); + for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { + for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { + const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + + fh_idx * iw * ic_step * pack_iw_len; + + src[0] = vld1q_s8(src_ic_0_3); + src[1] = vld1q_s8((src_ic_0_3 + 16)); + src[2] = vld1q_s8((src_ic_0_3 + 2 * 16)); + src[3] = vld1q_s8((src_ic_0_3 + 3 * 16)); + src[4] = vld1q_s8((src_ic_0_3 + 4 * 16)); + src[5] = vld1q_s8((src_ic_0_3 + 5 * 16)); + src[6] = vld1q_s8((src_ic_0_3 + 6 * 16)); + src[7] = vld1q_s8((src_ic_0_3 + 7 * 16)); + src[8] = vld1q_s8((src_ic_0_3 + 8 * 16)); + src[9] = vld1q_s8((src_ic_0_3 + 9 * 16)); + + // oc == 0 + const int8_t* read_weight_ptr = + weight_ptr + fh_idx * fw * ld_weight_ic4; + + weight[0] = vld1q_s8(read_weight_ptr); + weight[1] = vld1q_s8(read_weight_ptr + 16); + weight[2] = vld1q_s8(read_weight_ptr + 2 * 16); + weight[3] = vld1q_s8(read_weight_ptr + 3 * 16); + weight[4] = vld1q_s8(read_weight_ptr + 4 * 16); + + c[0][0] = vdotq_s32_h(weight[0], src[0], c[0][0], temp_c[0]); + c[0][1] = vdotq_s32_h(weight[0], src[2], c[0][1], temp_c[1]); + c[0][0] = vdotq_s32_h(weight[1], src[1], c[0][0], temp_c[2]); + c[0][1] = vdotq_s32_h(weight[1], src[3], c[0][1], temp_c[3]); + c[0][0] = vdotq_s32_h(weight[2], src[2], c[0][0], temp_c[0]); + c[0][1] = vdotq_s32_h(weight[2], src[4], c[0][1], temp_c[1]); + c[0][0] = vdotq_s32_h(weight[3], src[3], c[0][0], temp_c[2]); + c[0][1] = vdotq_s32_h(weight[3], src[5], c[0][1], temp_c[3]); + c[0][0] = vdotq_s32_h(weight[4], src[4], c[0][0], temp_c[0]); + c[0][1] = vdotq_s32_h(weight[4], src[6], c[0][1], temp_c[1]); + + src[0] = vld1q_s8(src_ic_0_3 + 10 * 16); + c[0][2] = vdotq_s32_h(weight[0], src[4], c[0][2], temp_c[2]); + c[0][3] = vdotq_s32_h(weight[0], src[6], c[0][3], temp_c[3]); + c[0][2] = vdotq_s32_h(weight[1], src[5], c[0][2], temp_c[0]); + c[0][3] = vdotq_s32_h(weight[1], src[7], c[0][3], temp_c[1]); + c[0][2] = vdotq_s32_h(weight[2], src[6], c[0][2], temp_c[2]); + c[0][3] = vdotq_s32_h(weight[2], src[8], c[0][3], temp_c[3]); + c[0][2] = vdotq_s32_h(weight[3], src[7], c[0][2], temp_c[0]); + c[0][3] = vdotq_s32_h(weight[3], src[9], c[0][3], temp_c[1]); + c[0][2] = vdotq_s32_h(weight[4], src[8], c[0][2], temp_c[2]); + c[0][3] = vdotq_s32_h(weight[4], src[0], c[0][3], temp_c[3]); + + src[1] = vld1q_s8((src_ic_0_3 + 11 * 16)); + src[2] = vld1q_s8((src_ic_0_3 + 12 * 16)); + src[3] = vld1q_s8((src_ic_0_3 + 13 * 16)); + src[4] = vld1q_s8((src_ic_0_3 + 14 * 16)); + c[0][4] = vdotq_s32_h(weight[0], src[8], c[0][4], temp_c[0]); + c[0][5] = vdotq_s32_h(weight[0], src[0], c[0][5], temp_c[1]); + c[0][4] = vdotq_s32_h(weight[1], src[9], c[0][4], temp_c[2]); + c[0][5] = vdotq_s32_h(weight[1], src[1], c[0][5], temp_c[3]); + c[0][4] = vdotq_s32_h(weight[2], src[0], c[0][4], temp_c[0]); + c[0][5] = vdotq_s32_h(weight[2], src[2], c[0][5], temp_c[1]); + c[0][4] = vdotq_s32_h(weight[3], src[1], c[0][4], temp_c[2]); + c[0][5] = vdotq_s32_h(weight[3], src[3], c[0][5], temp_c[3]); + c[0][4] = vdotq_s32_h(weight[4], src[2], c[0][4], temp_c[0]); + c[0][5] = vdotq_s32_h(weight[4], src[4], c[0][5], temp_c[1]); + + src[5] = vld1q_s8((src_ic_0_3 + 15 * 16)); + src[6] = vld1q_s8((src_ic_0_3 + 16 * 16)); + src[7] = vld1q_s8((src_ic_0_3 + 17 * 16)); + src[8] = vld1q_s8((src_ic_0_3 + 18 * 16)); + c[0][6] = vdotq_s32_h(weight[0], src[2], c[0][6], temp_c[2]); + c[0][7] = vdotq_s32_h(weight[0], src[4], c[0][7], temp_c[3]); + c[0][6] = vdotq_s32_h(weight[1], src[3], c[0][6], temp_c[0]); + c[0][7] = vdotq_s32_h(weight[1], src[5], c[0][7], temp_c[1]); + c[0][6] = vdotq_s32_h(weight[2], src[4], c[0][6], temp_c[2]); + c[0][7] = vdotq_s32_h(weight[2], src[6], c[0][7], temp_c[3]); + c[0][6] = vdotq_s32_h(weight[3], src[5], c[0][6], temp_c[0]); + c[0][7] = vdotq_s32_h(weight[3], src[7], c[0][7], temp_c[1]); + c[0][6] = vdotq_s32_h(weight[4], src[6], c[0][6], temp_c[2]); + c[0][7] = vdotq_s32_h(weight[4], src[8], c[0][7], temp_c[3]); + } + weight_ptr += fh * fw * ld_weight_ic4; + } + + store_ocx_ow8_remain_static_dt( + c, op, dst_ptr, ld_dst_oc); + } +}; +template +struct KerNeonDirectStride2Int8 { + static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, + const int32_t* bias_ptr, DstType* dst_ptr, int ic, int ih, + int iw, const Op& op, int ld_dst_oc) { + constexpr int filter_size = 7; + constexpr int fh = filter_size; + constexpr int fw = filter_size; + constexpr int oc_step = 4; + constexpr int ic_step = 4; + constexpr int loop_ic_step = 4; + constexpr int ld_weight_ic4 = 16; + constexpr int pack_iw_len = 4; + + const int ic_stride = ih * iw * pack_iw_len; + + int32x4_t c[c_dim][8]; + int8x16_t weight[7]; + int8x16_t src[8 + 2]; + int16x8_t temp_c[4]; + init_ocx_ow8(c, bias_ptr, oc_step); + for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { + for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { + const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + + fh_idx * iw * ic_step * pack_iw_len; + + src[0] = vld1q_s8(src_ic_0_3); + src[1] = vld1q_s8(src_ic_0_3 + 1 * 16); + src[2] = vld1q_s8(src_ic_0_3 + 2 * 16); + src[3] = vld1q_s8(src_ic_0_3 + 3 * 16); + src[4] = vld1q_s8(src_ic_0_3 + 4 * 16); + src[5] = vld1q_s8(src_ic_0_3 + 5 * 16); + src[6] = vld1q_s8(src_ic_0_3 + 6 * 16); + src[7] = vld1q_s8(src_ic_0_3 + 7 * 16); + src[8] = vld1q_s8(src_ic_0_3 + 8 * 16); + src[9] = vld1q_s8(src_ic_0_3 + 9 * 16); + + // oc == 0 + const int8_t* read_weight_ptr = + weight_ptr + fh_idx * fw * ld_weight_ic4; + + weight[0] = vld1q_s8(read_weight_ptr); + weight[1] = vld1q_s8(read_weight_ptr + 16); + weight[2] = vld1q_s8(read_weight_ptr + 2 * 16); + weight[3] = vld1q_s8(read_weight_ptr + 3 * 16); + weight[4] = vld1q_s8(read_weight_ptr + 4 * 16); + weight[5] = vld1q_s8(read_weight_ptr + 5 * 16); + weight[6] = vld1q_s8(read_weight_ptr + 6 * 16); + + c[0][0] = vdotq_s32_h(weight[0], src[0], c[0][0], temp_c[0]); + c[0][1] = vdotq_s32_h(weight[0], src[2], c[0][1], temp_c[1]); + c[0][0] = vdotq_s32_h(weight[1], src[1], c[0][0], temp_c[2]); + c[0][1] = vdotq_s32_h(weight[1], src[3], c[0][1], temp_c[3]); + c[0][0] = vdotq_s32_h(weight[2], src[2], c[0][0], temp_c[0]); + c[0][1] = vdotq_s32_h(weight[2], src[4], c[0][1], temp_c[1]); + c[0][0] = vdotq_s32_h(weight[3], src[3], c[0][0], temp_c[2]); + c[0][1] = vdotq_s32_h(weight[3], src[5], c[0][1], temp_c[3]); + c[0][0] = vdotq_s32_h(weight[4], src[4], c[0][0], temp_c[0]); + c[0][1] = vdotq_s32_h(weight[4], src[6], c[0][1], temp_c[1]); + c[0][0] = vdotq_s32_h(weight[5], src[5], c[0][0], temp_c[2]); + c[0][1] = vdotq_s32_h(weight[5], src[7], c[0][1], temp_c[3]); + c[0][0] = vdotq_s32_h(weight[6], src[6], c[0][0], temp_c[0]); + c[0][1] = vdotq_s32_h(weight[6], src[8], c[0][1], temp_c[1]); + + src[0] = vld1q_s8(src_ic_0_3 + 10 * 16); + src[1] = vld1q_s8(src_ic_0_3 + 11 * 16); + src[2] = vld1q_s8(src_ic_0_3 + 12 * 16); + c[0][2] = vdotq_s32_h(weight[0], src[4], c[0][2], temp_c[2]); + c[0][3] = vdotq_s32_h(weight[0], src[6], c[0][3], temp_c[3]); + c[0][2] = vdotq_s32_h(weight[1], src[5], c[0][2], temp_c[0]); + c[0][3] = vdotq_s32_h(weight[1], src[7], c[0][3], temp_c[1]); + c[0][2] = vdotq_s32_h(weight[2], src[6], c[0][2], temp_c[2]); + c[0][3] = vdotq_s32_h(weight[2], src[8], c[0][3], temp_c[3]); + c[0][2] = vdotq_s32_h(weight[3], src[7], c[0][2], temp_c[0]); + c[0][3] = vdotq_s32_h(weight[3], src[9], c[0][3], temp_c[1]); + c[0][2] = vdotq_s32_h(weight[4], src[8], c[0][2], temp_c[2]); + c[0][3] = vdotq_s32_h(weight[4], src[0], c[0][3], temp_c[3]); + c[0][2] = vdotq_s32_h(weight[5], src[9], c[0][2], temp_c[0]); + c[0][3] = vdotq_s32_h(weight[5], src[1], c[0][3], temp_c[1]); + c[0][2] = vdotq_s32_h(weight[6], src[0], c[0][2], temp_c[2]); + c[0][3] = vdotq_s32_h(weight[6], src[2], c[0][3], temp_c[3]); + + src[3] = vld1q_s8(src_ic_0_3 + 13 * 16); + src[4] = vld1q_s8(src_ic_0_3 + 14 * 16); + src[5] = vld1q_s8(src_ic_0_3 + 15 * 16); + src[6] = vld1q_s8(src_ic_0_3 + 16 * 16); + c[0][4] = vdotq_s32_h(weight[0], src[8], c[0][4], temp_c[0]); + c[0][5] = vdotq_s32_h(weight[0], src[0], c[0][5], temp_c[1]); + c[0][4] = vdotq_s32_h(weight[1], src[9], c[0][4], temp_c[2]); + c[0][5] = vdotq_s32_h(weight[1], src[1], c[0][5], temp_c[3]); + c[0][4] = vdotq_s32_h(weight[2], src[0], c[0][4], temp_c[0]); + c[0][5] = vdotq_s32_h(weight[2], src[2], c[0][5], temp_c[1]); + c[0][4] = vdotq_s32_h(weight[3], src[1], c[0][4], temp_c[2]); + c[0][5] = vdotq_s32_h(weight[3], src[3], c[0][5], temp_c[3]); + c[0][4] = vdotq_s32_h(weight[4], src[2], c[0][4], temp_c[0]); + c[0][5] = vdotq_s32_h(weight[4], src[4], c[0][5], temp_c[1]); + c[0][4] = vdotq_s32_h(weight[5], src[3], c[0][4], temp_c[2]); + c[0][5] = vdotq_s32_h(weight[5], src[5], c[0][5], temp_c[3]); + c[0][4] = vdotq_s32_h(weight[6], src[4], c[0][4], temp_c[0]); + c[0][5] = vdotq_s32_h(weight[6], src[6], c[0][5], temp_c[1]); + + src[7] = vld1q_s8(src_ic_0_3 + 17 * 16); + src[8] = vld1q_s8(src_ic_0_3 + 18 * 16); + src[9] = vld1q_s8(src_ic_0_3 + 19 * 16); + src[0] = vld1q_s8(src_ic_0_3 + 20 * 16); + c[0][6] = vdotq_s32_h(weight[0], src[2], c[0][6], temp_c[2]); + c[0][7] = vdotq_s32_h(weight[0], src[4], c[0][7], temp_c[3]); + c[0][6] = vdotq_s32_h(weight[1], src[3], c[0][6], temp_c[0]); + c[0][7] = vdotq_s32_h(weight[1], src[5], c[0][7], temp_c[1]); + c[0][6] = vdotq_s32_h(weight[2], src[4], c[0][6], temp_c[2]); + c[0][7] = vdotq_s32_h(weight[2], src[6], c[0][7], temp_c[3]); + c[0][6] = vdotq_s32_h(weight[3], src[5], c[0][6], temp_c[0]); + c[0][7] = vdotq_s32_h(weight[3], src[7], c[0][7], temp_c[1]); + c[0][6] = vdotq_s32_h(weight[4], src[6], c[0][6], temp_c[2]); + c[0][7] = vdotq_s32_h(weight[4], src[8], c[0][7], temp_c[3]); + c[0][6] = vdotq_s32_h(weight[5], src[7], c[0][6], temp_c[0]); + c[0][7] = vdotq_s32_h(weight[5], src[9], c[0][7], temp_c[1]); + c[0][6] = vdotq_s32_h(weight[6], src[8], c[0][6], temp_c[2]); + c[0][7] = vdotq_s32_h(weight[6], src[0], c[0][7], temp_c[3]); + } + weight_ptr += fh * fw * ld_weight_ic4; + } + store_ocx_ow8_remain_static_dt( + c, op, dst_ptr, ld_dst_oc); + } +}; + +template +void conv_direct_stride2_2x2_int8_nchw44( + const int8_t* src, const int8_t* filter, const int32_t* bias, int32_t*, + DstType* dst, const size_t oc, const size_t ic, const size_t ih, + const size_t iw, const size_t oh, const size_t ow, const Op& op) { + constexpr size_t filter_size = 2; + constexpr size_t fh = filter_size; + constexpr size_t fw = filter_size; + constexpr size_t ic_step = 4; + constexpr size_t oc_step = 4; + constexpr size_t big_oc_step = 8; + constexpr size_t oh_step = 1; + constexpr size_t ow_step = 8; + constexpr size_t stride_h = 2; + constexpr size_t stride_w = 2; + constexpr int pack_iw_len = 4; + + const size_t out_img_stride = oh * ow; + const size_t ow_end = ow / ow_step * ow_step; + const size_t ow_remain = ow - ow_end; + const size_t oc_end = oc / big_oc_step * big_oc_step; + const size_t oc_remain = oc - oc_end; + const int ld_dst_oc = oh * ow * oc_step; + + using remain_fun = std::function; + remain_fun kern_big_oc_remain = nullptr; + remain_fun kern_small_oc_remain = nullptr; + + switch (ow_remain) { +#define cb(step) \ + case step: \ + kern_big_oc_remain = \ + ker_neon_dirctconv_2x2s2_oc8_ow8; \ + kern_small_oc_remain = \ + ker_neon_dirctconv_2x2s2_oc4_ow8; \ + break; + + UNROLL_CALL_RAW(8, cb); + default: + megdnn_assert(0, "no remain %zu for kern", ow_remain); + } +#undef cb + + for (size_t oc_idx = 0; oc_idx < oc_end; oc_idx += big_oc_step) { + const size_t weight_offset = oc_idx * ic * fh * fw; + for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { + for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { + const size_t src_offset = + (oh_idx * stride_h * iw + ow_idx * stride_w) * ic_step * + pack_iw_len; + const size_t dst_offset = oc_idx * out_img_stride + + (oh_idx * ow + ow_idx) * oc_step; + ker_neon_dirctconv_2x2s2_oc8_ow8( + src + src_offset, filter + weight_offset, bias + oc_idx, + dst + dst_offset, ic, ih, iw, ld_dst_oc, op); + } + if (ow_remain > 0) { + const size_t src_offset = + (oh_idx * stride_h * iw + ow_end * stride_w) * ic_step * + pack_iw_len; + const size_t dst_offset = oc_idx * out_img_stride + + (oh_idx * ow + ow_end) * oc_step; + kern_big_oc_remain(src + src_offset, filter + weight_offset, + bias + oc_idx, dst + dst_offset, ic, ih, iw, + ld_dst_oc, op); + } + } + } + + if (oc_remain > 0) { + const size_t oc_idx = oc_end; + const size_t weight_offset = oc_idx * ic * fh * fw; + for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { + for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { + const size_t src_offset = + (oh_idx * stride_h * iw + ow_idx * stride_w) * ic_step * + pack_iw_len; + const size_t dst_offset = oc_idx * out_img_stride + + (oh_idx * ow + ow_idx) * oc_step; + ker_neon_dirctconv_2x2s2_oc4_ow8( + src + src_offset, filter + weight_offset, bias + oc_idx, + dst + dst_offset, ic, ih, iw, ld_dst_oc, op); + } + if (ow_remain > 0) { + const size_t src_offset = + (oh_idx * stride_h * iw + ow_end * stride_w) * ic_step * + pack_iw_len; + const size_t dst_offset = oc_idx * out_img_stride + + (oh_idx * ow + ow_end) * oc_step; + kern_small_oc_remain(src + src_offset, filter + weight_offset, + bias + oc_idx, dst + dst_offset, ic, ih, + iw, ld_dst_oc, op); + } + } + } +} + +template +void conv_direct_stride2_int8_nchw44_kern( + const int8_t* src, const int8_t* filter, const int32_t* bias, int32_t*, + DstType* dst, const size_t oc, const size_t ic, const size_t ih, + const size_t iw, const size_t oh, const size_t ow, const Op& op) { + constexpr size_t fh = filter_size; + constexpr size_t fw = filter_size; + constexpr size_t ic_step = 4; + constexpr size_t oc_step = 4; + constexpr size_t oh_step = 1; + constexpr size_t ow_step = 8; + constexpr size_t stride_h = 2; + constexpr size_t stride_w = 2; + constexpr int pack_iw_len = 4; + + const size_t img_stride = oh * ow; + const size_t ow_end = ow / ow_step * ow_step; + const size_t ow_remain = ow - ow_end; + const int ld_dst_oc = oh * ow * oc_step; + + using remain_fun = std::function; + + remain_fun kern_small_oc_remain = nullptr; + switch (ow_remain) { +#define cb(step) \ + case step: \ + kern_small_oc_remain = \ + KerNeonDirectStride2Int8::impl; \ + break; + + UNROLL_CALL_RAW(8, cb); + default: + megdnn_assert(0, "no remain %zu for kern", ow_remain); + } +#undef cb + + for (size_t oc_idx = 0; oc_idx < oc; oc_idx += oc_step) { + const size_t weight_offset = oc_idx * ic * fh * fw; + for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { + for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { + const size_t src_offset = + (oh_idx * stride_h * iw + ow_idx * stride_w) * ic_step * + pack_iw_len; + const size_t dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; + KerNeonDirectStride2Int8::impl(src + src_offset, + filter + weight_offset, + bias + oc_idx, + dst + dst_offset, ic, + ih, iw, op, ld_dst_oc); + } + if (ow_remain > 0) { + const size_t src_offset = + (oh_idx * stride_h * iw + ow_end * stride_w) * ic_step * + pack_iw_len; + const size_t dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; + kern_small_oc_remain(src + src_offset, filter + weight_offset, + bias + oc_idx, dst + dst_offset, ic, ih, + iw, op, ld_dst_oc); + } + } + } +} +} // namespace + +namespace int8_direct_nchw44 { +template +struct ConvDirectInt8Nchw44Choose { + static void impl(const int8_t* src, const int8_t* filter, + const int32_t* bias, int32_t* temp, DstType* dst, + const size_t oc, const size_t ic, const size_t ih, + const size_t iw, const size_t oh, const size_t ow, + const Op& op) { + conv_direct_stride2_int8_nchw44_kern( + src, filter, bias, temp, dst, oc, ic, ih, iw, oh, ow, op); + } +}; + +template +struct ConvDirectInt8Nchw44Choose { + static void impl(const int8_t* src, const int8_t* filter, + const int32_t* bias, int32_t* temp, DstType* dst, + const size_t oc, const size_t ic, const size_t ih, + const size_t iw, const size_t oh, const size_t ow, + const Op& op) { + conv_direct_stride2_2x2_int8_nchw44( + src, filter, bias, temp, dst, oc, ic, ih, iw, oh, ow, op); + } +}; + +#define DO_CONV_KERN_FUN(stride, DstType, filter_size, bias_mode, Op) \ + template struct ConvDirectInt8Nchw44Choose; + +#define GET_OP_PARAM(stride, filter, bias_mode) \ + DO_CONV_KERN_FUN(stride, dt_qint8, filter, bias_mode, \ + \ + TypeCvtOp) \ + DO_CONV_KERN_FUN(stride, dt_qint8, filter, bias_mode, \ + \ + ReluOp) \ + DO_CONV_KERN_FUN(stride, dt_qint8, filter, bias_mode, \ + \ + HSwishOp) \ + DO_CONV_KERN_FUN(stride, dt_int32, filter, bias_mode, NoneOp) + +#define GET_BIAS_MODE_PARAM(stride, filter) \ + GET_OP_PARAM(stride, filter, BiasMode::NO_BIAS) \ + GET_OP_PARAM(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS) + +#define DISPATCH_CONV_KERN(stride) \ + GET_BIAS_MODE_PARAM(stride, 2) \ + GET_BIAS_MODE_PARAM(stride, 3) \ + GET_BIAS_MODE_PARAM(stride, 5) \ + GET_BIAS_MODE_PARAM(stride, 7) + +DISPATCH_CONV_KERN(2); + +} // namespace int8_direct_nchw44 +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen \ No newline at end of file diff --git a/dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_common.h b/dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_common.h new file mode 100644 index 00000000..5e81bd7e --- /dev/null +++ b/dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_common.h @@ -0,0 +1,47 @@ +/** + * \file + * dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "src/arm_common/conv_bias/int8/direct_nchw_nchw44_kern.h" +namespace megdnn { +namespace arm_common { +namespace { + +template +struct KerNeonXXs2NchwNchw44 { + static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, + const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, + int iw, int ld_dst_oc, const Op& op); +}; + +template +struct OCHelper { +public: + static const int val = 0; +}; +template <> +struct OCHelper<4> { +public: + static const int val = 1; +}; +template <> +struct OCHelper<8> { +public: + static const int val = 2; +}; + +} // namespace +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen \ No newline at end of file diff --git a/dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1.cpp b/dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1.cpp new file mode 100644 index 00000000..2722a94b --- /dev/null +++ b/dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1.cpp @@ -0,0 +1,561 @@ +/** + * \file + * dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_common.h" +#include "src/arm_common/conv_bias/int8/direct_nchw_nchw44_kern.h" +namespace megdnn { +namespace arm_common { +namespace { +/** + * @brief core code for calculation patten + * + * @tparam src_idx is offset of src reg + * @tparam weight_idx is offset of weight reg + * @tparam c_dim is output channel + * @tparam Func mla operation funcion + * @tparam stride + * @tparam T outpur regs type + * @tparam T2 src regs type + * @tparam T3 weight regs type + * @tparam T4 temp regs type + */ + +template +struct ShiftCalHelper { + static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight, T4& temp); + static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight); +}; +template +MEGDNN_ALWAYS_INLINE void cal_helper(T& c, T2& src, T3& weight, T4& temp) { + ShiftCalHelper::impl( + c, src, weight, temp); +} +template +MEGDNN_ALWAYS_INLINE void cal_helper(T& c, T2& src, T3& weight) { + ShiftCalHelper::impl( + c, src, weight); +}; +template +struct ShiftCalHelper { + static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight, T4& temp) { + c[0][0] = vdotq_s32_h(src[(0 + src_idx) % 8], weight[0][weight_idx], + c[0][0], temp[0]); + c[1][0] = vdotq_s32_h(src[(0 + src_idx) % 8], weight[1][weight_idx], + c[1][0], temp[1]); + c[0][1] = vdotq_s32_h(src[(1 + src_idx) % 8], weight[0][weight_idx], + c[0][1], temp[2]); + c[1][1] = vdotq_s32_h(src[(1 + src_idx) % 8], weight[1][weight_idx], + c[1][1], temp[3]); + c[0][2] = vdotq_s32_h(src[(2 + src_idx) % 8], weight[0][weight_idx], + c[0][2], temp[0]); + c[1][2] = vdotq_s32_h(src[(2 + src_idx) % 8], weight[1][weight_idx], + c[1][2], temp[1]); + c[0][3] = vdotq_s32_h(src[(3 + src_idx) % 8], weight[0][weight_idx], + c[0][3], temp[2]); + c[1][3] = vdotq_s32_h(src[(3 + src_idx) % 8], weight[1][weight_idx], + c[1][3], temp[3]); + + c[0][4] = vdotq_s32_h(src[(4 + src_idx) % 8], weight[0][weight_idx], + c[0][4], temp[0]); + c[1][4] = vdotq_s32_h(src[(4 + src_idx) % 8], weight[1][weight_idx], + c[1][4], temp[1]); + c[0][5] = vdotq_s32_h(src[(5 + src_idx) % 8], weight[0][weight_idx], + c[0][5], temp[2]); + c[1][5] = vdotq_s32_h(src[(5 + src_idx) % 8], weight[1][weight_idx], + c[1][5], temp[3]); + c[0][6] = vdotq_s32_h(src[(6 + src_idx) % 8], weight[0][weight_idx], + c[0][6], temp[0]); + c[1][6] = vdotq_s32_h(src[(6 + src_idx) % 8], weight[1][weight_idx], + c[1][6], temp[1]); + c[0][7] = vdotq_s32_h(src[(7 + src_idx) % 8], weight[0][weight_idx], + c[0][7], temp[2]); + c[1][7] = vdotq_s32_h(src[(7 + src_idx) % 8], weight[1][weight_idx], + c[1][7], temp[3]); + } + static MEGDNN_ALWAYS_INLINE void impl(T&, T2&, T3&); +}; +template +struct ShiftCalHelper { + static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight, T4& temp) { + c[0][0] = vdotq_s32_h(src[(0 + src_idx) % 8], weight[0][weight_idx], + c[0][0], temp[0]); + c[0][1] = vdotq_s32_h(src[(1 + src_idx) % 8], weight[0][weight_idx], + c[0][1], temp[1]); + c[0][2] = vdotq_s32_h(src[(2 + src_idx) % 8], weight[0][weight_idx], + c[0][2], temp[2]); + c[0][3] = vdotq_s32_h(src[(3 + src_idx) % 8], weight[0][weight_idx], + c[0][3], temp[3]); + c[0][4] = vdotq_s32_h(src[(4 + src_idx) % 8], weight[0][weight_idx], + c[0][4], temp[0]); + c[0][5] = vdotq_s32_h(src[(5 + src_idx) % 8], weight[0][weight_idx], + c[0][5], temp[1]); + c[0][6] = vdotq_s32_h(src[(6 + src_idx) % 8], weight[0][weight_idx], + c[0][6], temp[2]); + c[0][7] = vdotq_s32_h(src[(7 + src_idx) % 8], weight[0][weight_idx], + c[0][7], temp[3]); + } + static MEGDNN_ALWAYS_INLINE void impl(T&, T2&, T3&); +}; + +template +struct KerNeonXXs2NchwNchw44 { + static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, + const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, + int iw, int ld_dst_oc, const Op& op) { + constexpr int stride = 1; + constexpr int filter_height = 2; + constexpr int filter_width = 4; + constexpr int oc_step = 4; + constexpr int loop_ic_step = 1; + constexpr int simd_len = 16; + constexpr int pack_iw_len = 16; + constexpr int src_reg = 8; + constexpr int weight_reg = 1; + + const int ic_stride = ih * iw * pack_iw_len; + const int ld_weight_oc = oc_step * filter_height * filter_width * ic; + constexpr int c_dim = OCHelper::val; + int32x4_t c[c_dim][8]; + init_ocx_ow8(c, bias_ptr, oc_step); + + for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { + const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride; + int8x16_t src[src_reg]; + int8x16_t dot4_weight[c_dim][weight_reg]; + int16x8_t temp_c[4]; + load_helper( + dot4_weight, weight_ptr, ld_weight_oc); + load_helper( + src, nchw_src_ptr + 0 * iw * pack_iw_len, 0); + cal_helper<0, 0, c_dim, stride>(c, src, dot4_weight, temp_c); + + load_helper( + dot4_weight, weight_ptr + 1 * filter_width * oc_step, + ld_weight_oc); + load_helper( + src, nchw_src_ptr + 1 * iw * pack_iw_len, 0); + cal_helper<0, 0, c_dim, stride>(c, src, dot4_weight, temp_c); + + weight_ptr += oc_step * filter_height * filter_width; + } + + store_ocx_ow8_remain_static_dt( + c, op, dst_ptr, ld_dst_oc); + } +}; + +template +struct KerNeonXXs2NchwNchw44 { + static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, + const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, + int iw, int ld_dst_oc, const Op& op) { + constexpr int stride = 1; + constexpr int filter_height = 3; + constexpr int filter_width = 4; + constexpr int oc_step = 4; + constexpr int loop_ic_step = 1; + constexpr int simd_len = 16; + constexpr int pack_iw_len = 16; + constexpr int src_reg = 8; + constexpr int weight_reg = 1; + + const int ic_stride = ih * iw * pack_iw_len; + const int ld_weight_oc = oc_step * filter_height * filter_width * ic; + constexpr int c_dim = OCHelper::val; + int32x4_t c[c_dim][8]; + init_ocx_ow8(c, bias_ptr, oc_step); + + for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { + const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride; + int8x16_t src[src_reg]; + int8x16_t dot4_weight[c_dim][weight_reg]; + int16x8_t temp_c[4]; + load_helper( + dot4_weight, weight_ptr, ld_weight_oc); + + load_helper( + src, nchw_src_ptr + 0 * iw * pack_iw_len, 0); + cal_helper<0, 0, c_dim, stride>(c, src, dot4_weight, temp_c); + load_helper( + dot4_weight, weight_ptr + 1 * filter_width * oc_step, + ld_weight_oc); + + load_helper( + src, nchw_src_ptr + 1 * iw * pack_iw_len, 0); + cal_helper<0, 0, c_dim, stride>(c, src, dot4_weight, temp_c); + + load_helper( + dot4_weight, weight_ptr + 2 * filter_width * oc_step, + ld_weight_oc); + load_helper( + src, nchw_src_ptr + 2 * iw * pack_iw_len, 0); + cal_helper<0, 0, c_dim, stride>(c, src, dot4_weight, temp_c); + + weight_ptr += oc_step * filter_height * filter_width; + } + store_ocx_ow8_remain_static_dt( + c, op, dst_ptr, ld_dst_oc); + } +}; + +template +struct KerNeonXXs2NchwNchw44 { + static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, + const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, + int iw, int ld_dst_oc, const Op& op) { + constexpr int stride = 1; + constexpr int filter_height = 5; + constexpr int filter_width = 8; + constexpr int oc_step = 4; + constexpr int loop_ic_step = 1; + constexpr int simd_len = 16; + constexpr int pack_iw_len = 16; + constexpr int src_reg = 8; + constexpr int weight_reg = 2; + + const int ic_stride = ih * iw * pack_iw_len; + const int ld_weight_oc = oc_step * filter_height * filter_width * ic; + constexpr int c_dim = OCHelper::val; + int32x4_t c[c_dim][8]; + init_ocx_ow8(c, bias_ptr, oc_step); + + for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { + const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride; + int8x16_t src[src_reg]; + int8x16_t dot4_weight[c_dim][weight_reg]; + int16x8_t temp_c[4]; +#define cb(step) \ + load_helper( \ + dot4_weight, weight_ptr + step * filter_width * oc_step, \ + ld_weight_oc); \ + load_helper( \ + src, nchw_src_ptr + step * iw * pack_iw_len, 0); \ + cal_helper<0, 0, c_dim, stride>(c, src, dot4_weight, temp_c); \ + load_helper<4, 0, simd_len, 0, Vld1q_s8>( \ + src, \ + nchw_src_ptr + step * iw * pack_iw_len + src_reg * pack_iw_len, \ + 0); \ + cal_helper<4, 1, c_dim, stride>(c, src, dot4_weight, temp_c); + UNROLL_CALL_RAW(5, cb); +#undef cb + weight_ptr += oc_step * filter_height * filter_width; + } + store_ocx_ow8_remain_static_dt( + c, op, dst_ptr, ld_dst_oc); + } +}; + +template +struct KerNeonXXs2NchwNchw44 { + static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, + const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, + int iw, int ld_dst_oc, const Op& op) { + constexpr int stride = 1; + constexpr int filter_height = 7; + constexpr int filter_width = 8; + constexpr int oc_step = 4; + constexpr int loop_ic_step = 1; + constexpr int simd_len = 16; + constexpr int pack_iw_len = 16; + constexpr int src_reg = 8; + constexpr int weight_reg = 2; + + const int ic_stride = ih * iw * pack_iw_len; + const int ld_weight_oc = oc_step * filter_height * filter_width * ic; + constexpr int c_dim = OCHelper::val; + int32x4_t c[c_dim][8]; + init_ocx_ow8(c, bias_ptr, oc_step); + + for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { + const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride; + int8x16_t src[src_reg]; + int8x16_t dot4_weight[c_dim][weight_reg]; + int16x8_t temp_c[4]; +#define cb(step) \ + load_helper( \ + dot4_weight, weight_ptr + step * filter_width * oc_step, \ + ld_weight_oc); \ + load_helper( \ + src, nchw_src_ptr + step * iw * pack_iw_len, 0); \ + cal_helper<0, 0, c_dim, stride>(c, src, dot4_weight, temp_c); \ + load_helper<4, 0, simd_len, 0, Vld1q_s8>( \ + src, \ + nchw_src_ptr + step * iw * pack_iw_len + src_reg * pack_iw_len, \ + 0); \ + cal_helper<4, 1, c_dim, stride>(c, src, dot4_weight, temp_c); + + UNROLL_CALL_RAW(7, cb); +#undef cb + weight_ptr += oc_step * filter_height * filter_width; + } + store_ocx_ow8_remain_static_dt( + c, op, dst_ptr, ld_dst_oc); + } +}; +} // namespace + +namespace int8_direct_nchw_nchw44 { +/** + * pack {oc / 4, fh, fw, ic, 4(oc)} to {oc / 4, ic, fh ,fw/4, 4(oc)*4(fw)} + * pack interleave two adjacent row in filter to one row + * */ +template <> +void pack_nchw44_weight_for_nchw_conv<1>(const int8_t* src_ptr, int8_t* dst_ptr, + const int ic, const int fh, + const int fw, const int oc) { + constexpr int oc_step = 4; + const int fw2 = round_up(fw, 4); + const int fw_remain = fw2 - fw; + const int dst_ic_stride = fh * fw2; + const int oc_step_stride = fh * fw2 * ic * oc_step; + static const uint8_t transpose_4x4_idx[16] = {0, 4, 1, 5, 2, 6, 3, 7, + 8, 12, 9, 13, 10, 14, 11, 15}; + uint8x16_t tbl_transpose_4x4 = vld1q_u8(&transpose_4x4_idx[0]); + rep_step(oc_idx, oc, oc_step) { + int32_t* dst_temp_ptr = + reinterpret_cast(dst_ptr + oc_idx * ic * fh * fw2); + const int32_t* src_temp_ptr = reinterpret_cast( + src_ptr + oc_idx * ic * fh * fw); + // transpose ic and pad + rep(fh_idx, fh) { + rep(fw_idx, fw) { + rep(ic_idx, ic) { + *(dst_temp_ptr + ic_idx * dst_ic_stride) = *src_temp_ptr; + src_temp_ptr++; + } + dst_temp_ptr++; + } + rep(ic_idx, ic) { + memset(dst_temp_ptr + ic_idx * dst_ic_stride, 0, + sizeof(int8_t) * oc_step * fw_remain); + } + dst_temp_ptr += fw_remain; + } + // transpose fw oc + int8_t* trans_dst_temp_ptr = + reinterpret_cast(dst_ptr + oc_idx * ic * fh * fw2); + + rep_step(idx, oc_step_stride, 16) { + int8x16_t temp = vld1q_s8(trans_dst_temp_ptr + idx); + vst1q_s8(trans_dst_temp_ptr + idx, + vqtbl1q_s8(temp, tbl_transpose_4x4)); + } + } +}; + +/** + * pack (ic, h, w) to (ic, h, w * 16) + * pack interleave two adjacent row in src and repeat 4 times, store to one row + * */ +template <> +void pack_nchw_src_for_nchw44_conv<1>(const int8_t* sptr_origin, + int8_t* sptr_base, const int ic, + const int pad_top, const int pad_bottom, + const int, const int, const int ih, + const int iw, const int iw2, const int pw, + int8_t* temp_ptr) { + static uint8_t reorder_idx[16] = {0, 1, 0, 1, 0, 1, 0, 1, + 2, 3, 2, 3, 2, 3, 2, 3}; + uint8x16_t tbl_idx = vld1q_u8(&reorder_idx[0]); + + constexpr int iw_step = 4; + constexpr int pack_iw_len = 16; + const int ic_stride = ih * iw; + const int iw_with_pad = iw + 2 * pw; + const int iw_with_pad_end = iw_with_pad / iw_step * iw_step; + rep(ic_idx, ic) { + const int8_t* sptr = sptr_origin + ic_idx * ic_stride; + memset(sptr_base, 0, + sizeof(int8_t) * iw2 * (ih + pad_top + pad_bottom) * + pack_iw_len); + sptr_base += iw2 * pad_top * pack_iw_len; + rep(ih_idx, ih) { + memset(temp_ptr, 0, iw_with_pad * sizeof(int8_t)); + memcpy(temp_ptr + pw, sptr, sizeof(int8_t) * iw); + for (int iw_idx = 0; iw_idx < iw_with_pad_end; iw_idx += iw_step) { + int8x16_t src[4]; + int8x16_t dst[4]; + src[0] = vld1q_s8(temp_ptr + iw_idx); + src[1] = vld1q_s8(temp_ptr + iw_idx + 1); + src[2] = vld1q_s8(temp_ptr + iw_idx + 2); + src[3] = vld1q_s8(temp_ptr + iw_idx + 3); + dst[0] = vqtbl1q_s8(src[0], tbl_idx); + dst[1] = vqtbl1q_s8(src[1], tbl_idx); + dst[2] = vqtbl1q_s8(src[2], tbl_idx); + dst[3] = vqtbl1q_s8(src[3], tbl_idx); + vst1q_s8(sptr_base + iw_idx * pack_iw_len + 0, dst[0]); + vst1q_s8(sptr_base + iw_idx * pack_iw_len + 16, dst[1]); + vst1q_s8(sptr_base + iw_idx * pack_iw_len + 32, dst[2]); + vst1q_s8(sptr_base + iw_idx * pack_iw_len + 48, dst[3]); + } + for (int iw_idx = iw_with_pad_end; iw_idx < iw_with_pad; ++iw_idx) { + int8x16_t src = vld1q_s8(temp_ptr + iw_idx); + int8x16_t dst = vqtbl1q_s8(src, tbl_idx); + vst1q_s8(sptr_base + iw_idx * pack_iw_len, dst); + } + sptr_base += iw2 * pack_iw_len; + sptr += iw; + } + sptr_base += iw2 * pad_bottom * pack_iw_len; + } +} + +template +struct ConvDiectStrideInt8NchwNchw44 { + static void impl(const int8_t* src, const int8_t* filter, + const int32_t* bias, int32_t* temp, int8_t* dst, + const size_t oc, const size_t ic, const size_t ih, + const size_t iw, const size_t oh, const size_t ow, + const Op& op) { + MEGDNN_MARK_USED_VAR(temp); + constexpr int stride = 1; + constexpr size_t fh = filter_size; + constexpr size_t fw = (filter_size + 3) / 4 * 4; + constexpr size_t ic_step = 1; + constexpr size_t big_oc_step = 8; + constexpr size_t oc_step = 4; + constexpr size_t ih_step = 1; + constexpr size_t oh_step = 1; + constexpr size_t ow_step = 8; + constexpr size_t stride_h = stride; + constexpr size_t stride_w = stride; + constexpr int pack_iw_len = 16; + + const size_t img_stride = oh * ow; + const size_t ow_end = ow / ow_step * ow_step; + const size_t ow_remain = ow - ow_end; + const size_t oc_end = oc / big_oc_step * big_oc_step; + const size_t oc_remain = oc - oc_end; + const int ld_dst_oc = oc_step * img_stride; + + using remain_fun = std::function; + remain_fun kern_big_oc_remain = nullptr; + remain_fun kern_small_oc_remain = nullptr; + switch (ow_remain) { +#define cb(step) \ + case step: \ + kern_big_oc_remain = \ + KerNeonXXs2NchwNchw44::impl; \ + kern_small_oc_remain = \ + KerNeonXXs2NchwNchw44::impl; \ + break; + + UNROLL_CALL_RAW(8, cb); + default: + megdnn_assert(0, "no remain %zu for kern", ow_remain); + } + + for (size_t oc_idx = 0; oc_idx < oc_end; oc_idx += big_oc_step) { + const size_t weight_offset = oc_idx * ic * fh * fw; + for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { + for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { + const size_t src_offset = (oh_idx * stride_h * iw + + ow_idx * stride_w * ih_step) * + ic_step * pack_iw_len; + const size_t dst_offset = oc_idx * img_stride + + (oh_idx * ow + ow_idx) * oc_step; + + KerNeonXXs2NchwNchw44::impl(src + src_offset, + filter + weight_offset, + bias + oc_idx, + dst + dst_offset, ic, + ih, iw, ld_dst_oc, op); + } + if (ow_remain > 0) { + const size_t src_offset = (oh_idx * stride_h * iw + + ow_end * stride_w * ih_step) * + ic_step * pack_iw_len; + const size_t dst_offset = oc_idx * img_stride + + (oh_idx * ow + ow_end) * oc_step; + kern_big_oc_remain(src + src_offset, filter + weight_offset, + bias + oc_idx, dst + dst_offset, ic, ih, + iw, ld_dst_oc, op); + } + } + } + + if (oc_remain > 0) { + size_t oc_idx = oc_end; + const size_t weight_offset = oc_idx * ic * fh * fw; + for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { + for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { + const size_t src_offset = (oh_idx * stride_h * iw + + ow_idx * stride_w * ih_step) * + ic_step * pack_iw_len; + const size_t dst_offset = oc_idx * img_stride + + (oh_idx * ow + ow_idx) * oc_step; + KerNeonXXs2NchwNchw44::impl(src + src_offset, + filter + weight_offset, + bias + oc_idx, + dst + dst_offset, ic, + ih, iw, ld_dst_oc, op); + } + if (ow_remain > 0) { + const size_t src_offset = (oh_idx * stride_h * iw + + ow_end * stride_w * ih_step) * + ic_step * pack_iw_len; + const size_t dst_offset = oc_idx * img_stride + + (oh_idx * ow + ow_end) * oc_step; + kern_small_oc_remain(src + src_offset, + filter + weight_offset, bias + oc_idx, + dst + dst_offset, ic, ih, iw, + ld_dst_oc, op); + } + } + } + } +}; + +#define INSTANCE_CONV_KERN_FUN(stride, filter_size, bias_mode, Op) \ + template struct ConvDiectStrideInt8NchwNchw44; + +#define INSTANCE_OP_PARAM(stride, filter, bias_mode) \ + INSTANCE_CONV_KERN_FUN(stride, filter, bias_mode, \ + TypeCvtOp) \ + INSTANCE_CONV_KERN_FUN(stride, filter, bias_mode, \ + ReluOp) \ + INSTANCE_CONV_KERN_FUN(stride, filter, bias_mode, \ + HSwishOp) + +#define INSTANCE_BIAS_MODE_PARAM(stride, filter) \ + INSTANCE_OP_PARAM(stride, filter, BiasMode::NO_BIAS) \ + INSTANCE_OP_PARAM(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS) + +#define INSTANCE_CONV_KERN(stride) \ + INSTANCE_BIAS_MODE_PARAM(stride, 2) \ + INSTANCE_BIAS_MODE_PARAM(stride, 3) \ + INSTANCE_BIAS_MODE_PARAM(stride, 5) \ + INSTANCE_BIAS_MODE_PARAM(stride, 7) + +INSTANCE_CONV_KERN(1); + +} // namespace int8_direct_nchw_nchw44 +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen \ No newline at end of file diff --git a/dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s2.cpp b/dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s2.cpp new file mode 100644 index 00000000..c008aa29 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s2.cpp @@ -0,0 +1,1412 @@ +/** + * \file + * dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s2.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_common.h" +#include "src/arm_common/conv_bias/int8/direct_nchw_nchw44_kern.h" +namespace megdnn { +namespace arm_common { +namespace { +/** + * @brief core code for calculation patten + * + * @tparam src_idx is offset of src reg + * @tparam weight_idx is offset of weight reg + * @tparam c_dim is output channel + * @tparam Func mla operation funcion + * @tparam stride + * @tparam T outpur regs type + * @tparam T2 src regs type + * @tparam T3 weight regs type + * @tparam T4 temp regs type + */ + +template +struct ShiftCalHelper { + static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight, T4& temp); + static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight); +}; +template +MEGDNN_ALWAYS_INLINE void cal_helper(T& c, T2& src, T3& weight, T4& temp) { + ShiftCalHelper::impl(c, src, weight, temp); +} +template +MEGDNN_ALWAYS_INLINE void cal_helper(T& c, T2& src, T3& weight) { + ShiftCalHelper::impl(c, src, weight); +}; +template +struct ShiftCalHelper { + static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight, T4& temp) { + c[0][0] = Func::impl(src[0 + src_idx], weight[0][weight_idx], c[0][0], + temp[0]); + c[1][0] = Func::impl(src[0 + src_idx], weight[1][weight_idx], c[1][0], + temp[1]); + c[0][1] = Func::impl(src[1 + src_idx], weight[0][weight_idx], c[0][1], + temp[2]); + c[1][1] = Func::impl(src[1 + src_idx], weight[1][weight_idx], c[1][1], + temp[3]); + c[0][2] = Func::impl(src[2 + src_idx], weight[0][weight_idx], c[0][2], + temp[0]); + c[1][2] = Func::impl(src[2 + src_idx], weight[1][weight_idx], c[1][2], + temp[1]); + c[0][3] = Func::impl(src[3 + src_idx], weight[0][weight_idx], c[0][3], + temp[2]); + c[1][3] = Func::impl(src[3 + src_idx], weight[1][weight_idx], c[1][3], + temp[3]); + } + static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) { + c[0][0] = Func::impl(src[0 + src_idx], weight[0][weight_idx], c[0][0]); + c[1][0] = Func::impl(src[0 + src_idx], weight[1][weight_idx], c[1][0]); + c[0][1] = Func::impl(src[1 + src_idx], weight[0][weight_idx], c[0][1]); + c[1][1] = Func::impl(src[1 + src_idx], weight[1][weight_idx], c[1][1]); + c[0][2] = Func::impl(src[2 + src_idx], weight[0][weight_idx], c[0][2]); + c[1][2] = Func::impl(src[2 + src_idx], weight[1][weight_idx], c[1][2]); + c[0][3] = Func::impl(src[3 + src_idx], weight[0][weight_idx], c[0][3]); + c[1][3] = Func::impl(src[3 + src_idx], weight[1][weight_idx], c[1][3]); + } +}; +template +struct ShiftCalHelper { + static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight, T4& temp) { + c[0][0] = Func::impl(src[0 + src_idx], weight[0][weight_idx], c[0][0], + temp[0]); + c[0][1] = Func::impl(src[1 + src_idx], weight[0][weight_idx], c[0][1], + temp[2]); + c[0][2] = Func::impl(src[2 + src_idx], weight[0][weight_idx], c[0][2], + temp[0]); + c[0][3] = Func::impl(src[3 + src_idx], weight[0][weight_idx], c[0][3], + temp[2]); + } + static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) { + c[0][0] = Func::impl(src[0 + src_idx], weight[0][weight_idx], c[0][0]); + c[0][1] = Func::impl(src[1 + src_idx], weight[0][weight_idx], c[0][1]); + c[0][2] = Func::impl(src[2 + src_idx], weight[0][weight_idx], c[0][2]); + c[0][3] = Func::impl(src[3 + src_idx], weight[0][weight_idx], c[0][3]); + } +}; + +/** + * filter shape = (oc/4, ic, 7, 7, 4), first 4 oc is f0 = filter[0, 0, :, :, :] + * calculate sequence \ + * f0[0:1, 0:1, 4] dot4, \ + * f0[0:1, 2:3, 4] dot4, \ + * f0[0:1, 4:5, 4] dot4, \ + * f0[0:1, 6, 4] dot2, \ + * ... + * f0[6, 0:1, 4] dot2, \ + * f0[6, 2:3, 4] dot2, \ + * f0[6, 4:5, 4] dot2, \ + * f0[6, 6, 4] dot1, \ + * look like: + * |---|---|---|-| + * |x x|x x|x x|x| + * |x x|x x|x x|x| + * |---|---|---|-| + * |x x|x x|x x|x| + * |x x|x x|x x|x| + * |---|---|---|-| + * |x x|x x|x x|x| + * |x x|x x|x x|x| + * |---|---|---|-| + * |x x|x x|x x|x| + * |---|---|---|-| + **/ +template +struct KerNeonXXs2NchwNchw44 { + static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, + const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, + int iw, int ld_dst_oc, const Op& op) { + static const uint8_t src_idx_buffer[16] = {0, 8, 0, 8, 0, 8, 0, 8, + 0, 8, 0, 8, 0, 8, 0, 8}; + constexpr int stride = 2; + constexpr int filter_size = 7; + constexpr int ic_step = 1; + constexpr int oc_step = 4; + constexpr int pack_iw_len = 4; + constexpr int fh_step = 2; + constexpr int fh_end = filter_size / fh_step * fh_step; + constexpr int c_dim = OCHelper::val; + + const int ic_stride = ih * iw * pack_iw_len; + const int ld_dot4_weight_oc = oc_step * filter_size * filter_size * ic; + + int32x4_t c[c_dim][4]; + + init_ocx_ow4(c, bias_ptr, oc_step); + + for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { + for (int fh_idx = 0; fh_idx < fh_end; fh_idx += fh_step) { + const int8_t* nchw_src_ptr = + src_ptr + ic_idx * ic_stride + + fh_idx * iw * ic_step * pack_iw_len; + int8x16_t src[6]; + int8x16_t dot4_weight[c_dim][3]; + int16x8_t temp_c[4]; + load_helper<3, 0, 16, c_dim, Vld1q_s8>(dot4_weight, weight_ptr, + ld_dot4_weight_oc); + load_helper<6, 0, 16, 0, Vld1q_s8>(src, nchw_src_ptr, 0); + cal_helper<0, 0, c_dim, Vdotq_s32_h, stride>( + c, src, dot4_weight, temp_c); + cal_helper<1, 1, c_dim, Vdotq_s32_h, stride>( + c, src, dot4_weight, temp_c); + cal_helper<2, 2, c_dim, Vdotq_s32_h, stride>( + c, src, dot4_weight, temp_c); + + int8x8_t src_dot2[4]; + int8x8_t dot2_weight[c_dim][1]; + load_helper<1, 3 * 16, 8, c_dim, Vld1_s8>( + dot2_weight, weight_ptr, ld_dot4_weight_oc); + load_helper<4, 3 * 16, 16, 0, Vld1_s8>(src_dot2, nchw_src_ptr, + 0); + cal_helper<0, 0, c_dim, Vdot2_s32_h, stride>( + c, src_dot2, dot2_weight, temp_c); + weight_ptr += filter_size * pack_iw_len * fh_step; + } + const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride + + 6 * iw * ic_step * pack_iw_len; + + int8x8_t dot2_weight[c_dim][3]; + int16x8_t temp_c[4]; + int8x8_t src_dot2[6]; + uint8x16_t tbl = vld1q_u8(src_idx_buffer); + load_helper<3, 0, 8, c_dim, Vld1_s8>(dot2_weight, weight_ptr, + ld_dot4_weight_oc); + load_helper_x<6, 0, 16, 0, Vldq_tbl_low_s8>(src_dot2, nchw_src_ptr, + 0, tbl); + cal_helper<0, 0, c_dim, Vdot2_s32_h, stride>(c, src_dot2, + dot2_weight, temp_c); + cal_helper<1, 1, c_dim, Vdot2_s32_h, stride>(c, src_dot2, + dot2_weight, temp_c); + cal_helper<2, 2, c_dim, Vdot2_s32_h, stride>(c, src_dot2, + dot2_weight, temp_c); + + int16x8_t dot1_weight[c_dim][1]; + int16x8_t src_dot1[4]; + load_helper<1, 3 * 8, 8, c_dim, Vldq_dup_4s8_8s16>( + dot1_weight, weight_ptr, ld_dot4_weight_oc); + load_helper<4, 3 * 16, 16, 0, Vld1_dup_s8_s16>(src_dot1, + nchw_src_ptr, 0); + cal_helper<0, 0, c_dim, Vmlal_s16, stride>(c, src_dot1, + dot1_weight); + weight_ptr += filter_size * pack_iw_len; + } + store_ocx_ow4_remain_static(c, op, dst_ptr, ld_dst_oc); + } +}; +#if MEGDNN_AARCH64 +template +struct KerNeonXXs2NchwNchw44 { + static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, + const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, + int iw, int ld_dst_oc, const Op& op) { + static const uint8_t src_idx_buffer[16] = {0, 8, 0, 8, 0, 8, 0, 8, + 0, 8, 0, 8, 0, 8, 0, 8}; + uint8x16_t vtbl = vld1q_u8(src_idx_buffer); + + // constexpr int stride = 2; + constexpr int oc_block = 8; + constexpr int remain_w = 0; + constexpr int filter_size = 7; + constexpr int ic_step = 1; + constexpr int oc_step = 4; + constexpr int pack_iw_len = 4; + constexpr int fh_step = 2; + constexpr int c_dim = OCHelper::val; + + const int ic_stride = ih * iw * pack_iw_len; + const int ld_dot4_weight_oc = oc_step * filter_size * filter_size * ic; + const size_t src_step = fh_step * iw * ic_step * pack_iw_len; + const size_t weight_step = filter_size * pack_iw_len * fh_step; + const size_t weight_step_small = filter_size * pack_iw_len; + int32x4_t c[c_dim][4]; + + init_ocx_ow4(c, bias_ptr, oc_step); + + for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { + const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride; + + const int8_t* weight_ptr_oc = weight_ptr + ld_dot4_weight_oc; + + const int8_t* nchw_src_ptr_last_line = + src_ptr + ic_idx * ic_stride + + 6 * iw * ic_step * pack_iw_len; + /** + * r0-r7 c + * r24-r31 temp + * r8-r15 src + * r16-r22 weight + * r23 vtbl + */ + asm volatile( + + "ldp q8, q9, [%[nchw_src_ptr]]\n" + "ldp q16, q17, [%[weight_ptr]]\n" + "ldp q10, q11, [%[nchw_src_ptr], #32]\n" + "smull v24.8h, v8.8b, v16.8b\n" + "ldp q19, q20, [%[weight_ptr_oc]]\n" + "smull v25.8h, v9.8b, v16.8b\n" + "ldp q12, q13, [%[nchw_src_ptr], #64]\n" + "smull v26.8h, v10.8b, v16.8b\n" + "ldr q18, [%[weight_ptr],#32]\n" + "smull v27.8h, v11.8b, v16.8b\n" + "ldr q21, [%[weight_ptr_oc],#32]\n" + "smull v28.8h, v8.8b, v19.8b\n" + "smlal2 v24.8h, v8.16b, v16.16b\n" + "smlal2 v25.8h, v9.16b, v16.16b\n" + "smlal2 v26.8h, v10.16b, v16.16b\n" + "smlal2 v27.8h, v11.16b, v16.16b\n" + "sadalp %[c00].4s, v24.8h\n" + "smull v29.8h, v9.8b, v19.8b\n" + "sadalp %[c01].4s, v25.8h\n" + "smull v30.8h, v10.8b, v19.8b\n" + "sadalp %[c02].4s, v26.8h\n" + "smull v31.8h, v11.8b, v19.8b\n" + "sadalp %[c03].4s, v27.8h\n" + "smlal2 v28.8h, v8.16b, v19.16b\n" + "ldr d8, [%[nchw_src_ptr],#48]\n" + "smlal2 v29.8h, v9.16b, v19.16b\n" + "smlal2 v30.8h, v10.16b, v19.16b\n" + "smlal2 v31.8h, v11.16b, v19.16b\n" + "smull v24.8h, v9.8b, v17.8b\n" + "sadalp %[c10].4s, v28.8h\n" + "smull v25.8h, v10.8b, v17.8b\n" + "sadalp %[c11].4s, v29.8h\n" + "smull v26.8h, v11.8b, v17.8b\n" + "sadalp %[c12].4s, v30.8h\n" + "smull v27.8h, v12.8b, v17.8b\n" + "sadalp %[c13].4s, v31.8h\n" + "smlal2 v24.8h, v9.16b, v17.16b\n" + "smlal2 v25.8h, v10.16b, v17.16b\n" + "smlal2 v26.8h, v11.16b, v17.16b\n" + "smlal2 v27.8h, v12.16b, v17.16b\n" + "smull v28.8h, v9.8b, v20.8b\n" + "sadalp %[c00].4s, v24.8h\n" + "smull v29.8h, v10.8b, v20.8b\n" + "sadalp %[c01].4s, v25.8h\n" + "smull v30.8h, v11.8b, v20.8b\n" + "sadalp %[c02].4s, v26.8h\n" + "smull v31.8h, v12.8b, v20.8b\n" + "sadalp %[c03].4s, v27.8h\n" + "smlal2 v28.8h, v9.16b, v20.16b\n" + "ldr d9, [%[nchw_src_ptr],#64]\n" + "smlal2 v29.8h, v10.16b, v20.16b\n" + "ldr d14, [%[nchw_src_ptr],#80]\n" + "smlal2 v30.8h, v11.16b, v20.16b\n" + "smlal2 v31.8h, v12.16b, v20.16b\n" + "smull v24.8h, v10.8b, v18.8b\n" + "sadalp %[c10].4s, v28.8h\n" + "smull v25.8h, v11.8b, v18.8b\n" + "sadalp %[c11].4s, v29.8h\n" + "smull v26.8h, v12.8b, v18.8b\n" + "sadalp %[c12].4s, v30.8h\n" + "smull v27.8h, v13.8b, v18.8b\n" + "sadalp %[c13].4s, v31.8h\n" + "smlal2 v24.8h, v10.16b, v18.16b\n" + "ldr d19, [%[weight_ptr_oc],#48]\n" + "smlal2 v25.8h, v11.16b, v18.16b\n" + "ldr d15, [%[nchw_src_ptr],#96]\n" + "smlal2 v26.8h, v12.16b, v18.16b\n" + "smlal2 v27.8h, v13.16b, v18.16b\n" + "ldr d18, [%[weight_ptr],#48]\n" + "smull v28.8h, v10.8b, v21.8b\n" + "sadalp %[c00].4s, v24.8h\n" + "smull v29.8h, v11.8b, v21.8b\n" + "sadalp %[c01].4s, v25.8h\n" + "smull v30.8h, v12.8b, v21.8b\n" + "sadalp %[c02].4s, v26.8h\n" + "smull v31.8h, v13.8b, v21.8b\n" + "sadalp %[c03].4s, v27.8h\n" + "smlal2 v28.8h, v10.16b, v21.16b\n" + "add %[nchw_src_ptr], %[nchw_src_ptr], %[src_step]\n" + "smlal2 v29.8h, v11.16b, v21.16b\n" + "ldp q10, q11, [%[nchw_src_ptr], #32]\n" + "add %[weight_ptr], %[weight_ptr], %[weight_step]\n" + "smlal2 v30.8h, v12.16b, v21.16b\n" + "add %[weight_ptr_oc], %[weight_ptr_oc], " + "%[weight_step]\n" + "smlal2 v31.8h, v13.16b, v21.16b\n" + "ldp q16, q17, [%[weight_ptr]]\n" + "smull v24.8h, v8.8b, v18.8b\n" + "sadalp %[c10].4s, v28.8h\n" + "smull v25.8h, v9.8b, v18.8b\n" + "sadalp %[c11].4s, v29.8h\n" + "smull v26.8h, v14.8b, v18.8b\n" + "sadalp %[c12].4s, v30.8h\n" + "smull v27.8h, v15.8b, v18.8b\n" + "sadalp %[c13].4s, v31.8h\n" + "smull v28.8h, v8.8b, v19.8b\n" + "sadalp %[c00].4s, v24.8h\n" + "smull v29.8h, v9.8b, v19.8b\n" + "ldp q8, q9, [%[nchw_src_ptr]]\n" + "sadalp %[c01].4s, v25.8h\n" + "smull v30.8h, v14.8b, v19.8b\n" + "sadalp %[c02].4s, v26.8h\n" + "smull v31.8h, v15.8b, v19.8b\n" + "ldp q19, q20, [%[weight_ptr_oc]]\n" + "sadalp %[c03].4s, v27.8h\n" + "smull v24.8h, v8.8b, v16.8b\n" + "sadalp %[c10].4s, v28.8h\n" + "smull v25.8h, v9.8b, v16.8b\n" + "ldp q12, q13, [%[nchw_src_ptr], #64]\n" + "sadalp %[c11].4s, v29.8h\n" + "smull v26.8h, v10.8b, v16.8b\n" + "ldr q18, [%[weight_ptr],#32]\n" + "sadalp %[c12].4s, v30.8h\n" + "smull v27.8h, v11.8b, v16.8b\n" + "ldr q21, [%[weight_ptr_oc],#32]\n" + "sadalp %[c13].4s, v31.8h\n" + //! fh = 2 + "smull v28.8h, v8.8b, v19.8b\n" + "smlal2 v24.8h, v8.16b, v16.16b\n" + "smlal2 v25.8h, v9.16b, v16.16b\n" + "smlal2 v26.8h, v10.16b, v16.16b\n" + "smlal2 v27.8h, v11.16b, v16.16b\n" + "sadalp %[c00].4s, v24.8h\n" + "smull v29.8h, v9.8b, v19.8b\n" + "sadalp %[c01].4s, v25.8h\n" + "smull v30.8h, v10.8b, v19.8b\n" + "sadalp %[c02].4s, v26.8h\n" + "smull v31.8h, v11.8b, v19.8b\n" + "sadalp %[c03].4s, v27.8h\n" + "smlal2 v28.8h, v8.16b, v19.16b\n" + "ldr d8, [%[nchw_src_ptr],#48]\n" + "smlal2 v29.8h, v9.16b, v19.16b\n" + "smlal2 v30.8h, v10.16b, v19.16b\n" + "smlal2 v31.8h, v11.16b, v19.16b\n" + "smull v24.8h, v9.8b, v17.8b\n" + "sadalp %[c10].4s, v28.8h\n" + "smull v25.8h, v10.8b, v17.8b\n" + "sadalp %[c11].4s, v29.8h\n" + "smull v26.8h, v11.8b, v17.8b\n" + "sadalp %[c12].4s, v30.8h\n" + "smull v27.8h, v12.8b, v17.8b\n" + "sadalp %[c13].4s, v31.8h\n" + "smlal2 v24.8h, v9.16b, v17.16b\n" + "smlal2 v25.8h, v10.16b, v17.16b\n" + "smlal2 v26.8h, v11.16b, v17.16b\n" + "smlal2 v27.8h, v12.16b, v17.16b\n" + "smull v28.8h, v9.8b, v20.8b\n" + "sadalp %[c00].4s, v24.8h\n" + "smull v29.8h, v10.8b, v20.8b\n" + "sadalp %[c01].4s, v25.8h\n" + "smull v30.8h, v11.8b, v20.8b\n" + "sadalp %[c02].4s, v26.8h\n" + "smull v31.8h, v12.8b, v20.8b\n" + "sadalp %[c03].4s, v27.8h\n" + "smlal2 v28.8h, v9.16b, v20.16b\n" + "ldr d9, [%[nchw_src_ptr],#64]\n" + "smlal2 v29.8h, v10.16b, v20.16b\n" + "ldr d14, [%[nchw_src_ptr],#80]\n" + "smlal2 v30.8h, v11.16b, v20.16b\n" + "smlal2 v31.8h, v12.16b, v20.16b\n" + "smull v24.8h, v10.8b, v18.8b\n" + "sadalp %[c10].4s, v28.8h\n" + "smull v25.8h, v11.8b, v18.8b\n" + "sadalp %[c11].4s, v29.8h\n" + "smull v26.8h, v12.8b, v18.8b\n" + "sadalp %[c12].4s, v30.8h\n" + "smull v27.8h, v13.8b, v18.8b\n" + "sadalp %[c13].4s, v31.8h\n" + "smlal2 v24.8h, v10.16b, v18.16b\n" + "ldr d19, [%[weight_ptr_oc],#48]\n" + "smlal2 v25.8h, v11.16b, v18.16b\n" + "ldr d15, [%[nchw_src_ptr],#96]\n" + "smlal2 v26.8h, v12.16b, v18.16b\n" + "smlal2 v27.8h, v13.16b, v18.16b\n" + "ldr d18, [%[weight_ptr],#48]\n" + "smull v28.8h, v10.8b, v21.8b\n" + "sadalp %[c00].4s, v24.8h\n" + "smull v29.8h, v11.8b, v21.8b\n" + "sadalp %[c01].4s, v25.8h\n" + "smull v30.8h, v12.8b, v21.8b\n" + "sadalp %[c02].4s, v26.8h\n" + "smull v31.8h, v13.8b, v21.8b\n" + "sadalp %[c03].4s, v27.8h\n" + "smlal2 v28.8h, v10.16b, v21.16b\n" + "add %[nchw_src_ptr], %[nchw_src_ptr], %[src_step]\n" + "smlal2 v29.8h, v11.16b, v21.16b\n" + "add %[weight_ptr], %[weight_ptr], %[weight_step]\n" + "smlal2 v30.8h, v12.16b, v21.16b\n" + "add %[weight_ptr_oc], %[weight_ptr_oc], " + "%[weight_step]\n" + "smlal2 v31.8h, v13.16b, v21.16b\n" + "ldp q16, q17, [%[weight_ptr]]\n" + "smull v24.8h, v8.8b, v18.8b\n" + "ldp q10, q11, [%[nchw_src_ptr], #32]\n" + "sadalp %[c10].4s, v28.8h\n" + "smull v25.8h, v9.8b, v18.8b\n" + "sadalp %[c11].4s, v29.8h\n" + "smull v26.8h, v14.8b, v18.8b\n" + "sadalp %[c12].4s, v30.8h\n" + "smull v27.8h, v15.8b, v18.8b\n" + "sadalp %[c13].4s, v31.8h\n" + "smull v28.8h, v8.8b, v19.8b\n" + "sadalp %[c00].4s, v24.8h\n" + "smull v29.8h, v9.8b, v19.8b\n" + "ldp q8, q9, [%[nchw_src_ptr]]\n" + "sadalp %[c01].4s, v25.8h\n" + "smull v30.8h, v14.8b, v19.8b\n" + "sadalp %[c02].4s, v26.8h\n" + "smull v31.8h, v15.8b, v19.8b\n" + "ldp q19, q20, [%[weight_ptr_oc]]\n" + "sadalp %[c03].4s, v27.8h\n" + "smull v24.8h, v8.8b, v16.8b\n" + "sadalp %[c10].4s, v28.8h\n" + "smull v25.8h, v9.8b, v16.8b\n" + "ldp q12, q13, [%[nchw_src_ptr], #64]\n" + "sadalp %[c11].4s, v29.8h\n" + "smull v26.8h, v10.8b, v16.8b\n" + "ldr q18, [%[weight_ptr],#32]\n" + "sadalp %[c12].4s, v30.8h\n" + "smull v27.8h, v11.8b, v16.8b\n" + "ldr q21, [%[weight_ptr_oc],#32]\n" + "sadalp %[c13].4s, v31.8h\n" + //! fh = 4 + "smull v28.8h, v8.8b, v19.8b\n" + "smlal2 v24.8h, v8.16b, v16.16b\n" + "smlal2 v25.8h, v9.16b, v16.16b\n" + "smlal2 v26.8h, v10.16b, v16.16b\n" + "smlal2 v27.8h, v11.16b, v16.16b\n" + "sadalp %[c00].4s, v24.8h\n" + "smull v29.8h, v9.8b, v19.8b\n" + "sadalp %[c01].4s, v25.8h\n" + "smull v30.8h, v10.8b, v19.8b\n" + "sadalp %[c02].4s, v26.8h\n" + "smull v31.8h, v11.8b, v19.8b\n" + "sadalp %[c03].4s, v27.8h\n" + "smlal2 v28.8h, v8.16b, v19.16b\n" + "ldr d8, [%[nchw_src_ptr],#48]\n" + "smlal2 v29.8h, v9.16b, v19.16b\n" + "smlal2 v30.8h, v10.16b, v19.16b\n" + "smlal2 v31.8h, v11.16b, v19.16b\n" + "smull v24.8h, v9.8b, v17.8b\n" + "sadalp %[c10].4s, v28.8h\n" + "smull v25.8h, v10.8b, v17.8b\n" + "sadalp %[c11].4s, v29.8h\n" + "smull v26.8h, v11.8b, v17.8b\n" + "sadalp %[c12].4s, v30.8h\n" + "smull v27.8h, v12.8b, v17.8b\n" + "sadalp %[c13].4s, v31.8h\n" + "smlal2 v24.8h, v9.16b, v17.16b\n" + "smlal2 v25.8h, v10.16b, v17.16b\n" + "smlal2 v26.8h, v11.16b, v17.16b\n" + "smlal2 v27.8h, v12.16b, v17.16b\n" + "smull v28.8h, v9.8b, v20.8b\n" + "sadalp %[c00].4s, v24.8h\n" + "smull v29.8h, v10.8b, v20.8b\n" + "sadalp %[c01].4s, v25.8h\n" + "smull v30.8h, v11.8b, v20.8b\n" + "sadalp %[c02].4s, v26.8h\n" + "smull v31.8h, v12.8b, v20.8b\n" + "sadalp %[c03].4s, v27.8h\n" + "smlal2 v28.8h, v9.16b, v20.16b\n" + "ldr d9, [%[nchw_src_ptr],#64]\n" + "smlal2 v29.8h, v10.16b, v20.16b\n" + "ldr d14, [%[nchw_src_ptr],#80]\n" + "smlal2 v30.8h, v11.16b, v20.16b\n" + "smlal2 v31.8h, v12.16b, v20.16b\n" + "smull v24.8h, v10.8b, v18.8b\n" + "sadalp %[c10].4s, v28.8h\n" + "smull v25.8h, v11.8b, v18.8b\n" + "sadalp %[c11].4s, v29.8h\n" + "smull v26.8h, v12.8b, v18.8b\n" + "sadalp %[c12].4s, v30.8h\n" + "smull v27.8h, v13.8b, v18.8b\n" + "sadalp %[c13].4s, v31.8h\n" + "smlal2 v24.8h, v10.16b, v18.16b\n" + "ldr d19, [%[weight_ptr_oc],#48]\n" + "smlal2 v25.8h, v11.16b, v18.16b\n" + "ldr d15, [%[nchw_src_ptr],#96]\n" + "smlal2 v26.8h, v12.16b, v18.16b\n" + "smlal2 v27.8h, v13.16b, v18.16b\n" + "ldr d18, [%[weight_ptr],#48]\n" + "smull v28.8h, v10.8b, v21.8b\n" + "add %[weight_ptr], %[weight_ptr], %[weight_step]\n" + "sadalp %[c00].4s, v24.8h\n" + "smull v29.8h, v11.8b, v21.8b\n" + "add %[weight_ptr_oc], %[weight_ptr_oc], %[weight_step]\n" + "sadalp %[c01].4s, v25.8h\n" + "smull v30.8h, v12.8b, v21.8b\n" + "ldr q16, [%[weight_ptr]]\n" + "sadalp %[c02].4s, v26.8h\n" + "smull v31.8h, v13.8b, v21.8b\n" + "sadalp %[c03].4s, v27.8h\n" + "smlal2 v28.8h, v10.16b, v21.16b\n" + "smlal2 v29.8h, v11.16b, v21.16b\n" + "ldp q10, q11, [%[nchw_src_ptr_last_line], #32]\n" + "smlal2 v30.8h, v12.16b, v21.16b\n" + "smlal2 v31.8h, v13.16b, v21.16b\n" + "ldp q12, q13, [%[nchw_src_ptr_last_line], #64]\n" + "smull v24.8h, v8.8b, v18.8b\n" + "ldr d21, [%[weight_ptr_oc],#16]\n" + "sadalp %[c10].4s, v28.8h\n" + "smull v25.8h, v9.8b, v18.8b\n" + "sadalp %[c11].4s, v29.8h\n" + "smull v26.8h, v14.8b, v18.8b\n" + "sadalp %[c12].4s, v30.8h\n" + "smull v27.8h, v15.8b, v18.8b\n" + "ldr d18, [%[weight_ptr],#16]\n" + "sadalp %[c13].4s, v31.8h\n" + "smull v28.8h, v8.8b, v19.8b\n" + "sadalp %[c00].4s, v24.8h\n" + "smull v29.8h, v9.8b, v19.8b\n" + "ldp q8, q9, [%[nchw_src_ptr_last_line]]\n" + "sadalp %[c01].4s, v25.8h\n" + "smull v30.8h, v14.8b, v19.8b\n" + "sadalp %[c02].4s, v26.8h\n" + "smull v31.8h, v15.8b, v19.8b\n" + "ldr q19, [%[weight_ptr_oc]]\n" + "tbl v8.16b, {v8.16b}, %[vtbl].16b\n" + "tbl v9.16b, {v9.16b}, %[vtbl].16b\n" + "sadalp %[c03].4s, v27.8h\n" + "tbl v10.16b, {v10.16b}, %[vtbl].16b\n" + "tbl v11.16b, {v11.16b}, %[vtbl].16b\n" + "sadalp %[c10].4s, v28.8h\n" + "tbl v12.16b, {v12.16b}, %[vtbl].16b\n" + "tbl v13.16b, {v13.16b}, %[vtbl].16b\n" + "sadalp %[c11].4s, v29.8h\n" + /// last line//// + "smull v24.8h, v8.8b, v16.8b\n" + "sadalp %[c12].4s, v30.8h\n" + "smull v25.8h, v9.8b, v16.8b\n" + "sadalp %[c13].4s, v31.8h\n" + "smull v26.8h, v10.8b, v16.8b\n" + "smull v27.8h, v11.8b, v16.8b\n" + "smlal2 v24.8h, v9.16b, v16.16b\n" + "smlal2 v25.8h, v10.16b, v16.16b\n" + "smlal2 v26.8h, v11.16b, v16.16b\n" + "smlal2 v27.8h, v12.16b, v16.16b\n" + "sadalp %[c00].4s, v24.8h\n" + "smull v28.8h, v8.8b, v19.8b\n" + "sadalp %[c01].4s, v25.8h\n" + "smull v29.8h, v9.8b, v19.8b\n" + "sadalp %[c02].4s, v26.8h\n" + "smull v30.8h, v10.8b, v19.8b\n" + "sadalp %[c03].4s, v27.8h\n" + "smull v31.8h, v11.8b, v19.8b\n" + "smlal2 v28.8h, v9.16b, v19.16b\n" + "dup v9.8b, v11.b[0]\n" + "smlal2 v29.8h, v10.16b, v19.16b\n" + "smlal2 v30.8h, v11.16b, v19.16b\n" + "smlal2 v31.8h, v12.16b, v19.16b\n" + "sadalp %[c10].4s, v28.8h\n" + "smull v24.8h, v10.8b, v18.8b\n" + "sadalp %[c11].4s, v29.8h\n" + "smull v25.8h, v11.8b, v18.8b\n" + "sadalp %[c12].4s, v30.8h\n" + "smull v26.8h, v12.8b, v18.8b\n" + "sadalp %[c13].4s, v31.8h\n" + "smull v27.8h, v13.8b, v18.8b\n" + "add x10, %[nchw_src_ptr_last_line], #96\n" + "sadalp %[c00].4s, v24.8h\n" + "smull v28.8h, v10.8b, v21.8b\n" + + "sadalp %[c01].4s, v25.8h\n" + "add x5, %[weight_ptr], #24\n" + "smull v29.8h, v11.8b, v21.8b\n" + "add x6, %[weight_ptr_oc], #24\n" + "sadalp %[c02].4s, v26.8h\n" + "smull v30.8h, v12.8b, v21.8b\n" + "sadalp %[c03].4s, v27.8h\n" + "smull v31.8h, v13.8b, v21.8b\n" + "dup v10.8b, v12.b[0]\n" + "sadalp %[c10].4s, v28.8h\n" + "ld1r {v12.8b}, [x10]\n" + "sadalp %[c11].4s, v29.8h\n" + "dup v11.8b, v13.b[0]\n" + "sadalp %[c12].4s, v30.8h\n" + "ld1r {v16.2s}, [x5]\n" + "sadalp %[c13].4s, v31.8h\n" + "sxtl v16.8h, v16.8b\n" + ///////////////last element///////// + "add %[weight_ptr], %[weight_ptr], %[weight_step_small]\n" + "sxtl v9.8h, v9.8b\n" + "ld1r {v19.2s}, [x6]\n" + "sxtl v10.8h, v10.8b\n" + "sxtl v11.8h, v11.8b\n" + "smlal %[c00].4s, v9.4h, v16.4h\n" + "sxtl v12.8h, v12.8b\n" + "smlal %[c01].4s, v10.4h, v16.4h\n" + "sxtl v19.8h, v19.8b\n" + "smlal %[c02].4s, v11.4h, v16.4h\n" + "smlal %[c03].4s, v12.4h, v16.4h\n" + "smlal %[c10].4s, v9.4h, v19.4h\n" + "smlal %[c11].4s, v10.4h, v19.4h\n" + "smlal %[c12].4s, v11.4h, v19.4h\n" + "smlal %[c13].4s, v12.4h, v19.4h\n" + : + + [c00] "+w"(c[0][0]), [c10] "+w"(c[1][0]), + [c01] "+w"(c[0][1]), [c11] "+w"(c[1][1]), + [c02] "+w"(c[0][2]), [c12] "+w"(c[1][2]), + [c03] "+w"(c[0][3]), [c13] "+w"(c[1][3]), + [nchw_src_ptr] "+r"(nchw_src_ptr), + [weight_ptr] "+r"(weight_ptr), + [weight_ptr_oc] "+r"(weight_ptr_oc) + + : [vtbl] "w"(vtbl), + [nchw_src_ptr_last_line] "r"(nchw_src_ptr_last_line), + [src_step] "r"(src_step), [weight_step] "r"(weight_step), + [weight_step_small] "r"(weight_step_small) + : "x5", "x6", "x7", "x8", "x9", "x10", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", + "v19", "v20", "v21", "v24", "v25", "v26", "v27", "v28", + "v29", "v30", "v31", "cc", "memory"); + } + store_ocx_ow4_remain_static(c, op, dst_ptr, ld_dst_oc); + } +}; +#endif +template +struct KerNeonXXs2NchwNchw44 { + static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, + const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, + int iw, int ld_dst_oc, const Op& op) { + constexpr int stride = 2; + constexpr int filter_size = 5; + static const uint8_t src_idx_buffer[16] = {0, 8, 0, 8, 0, 8, 0, 8, + 0, 8, 0, 8, 0, 8, 0, 8}; + constexpr int ih_step = 2; + constexpr int ic_step = 1; + constexpr int oc_step = 4; + constexpr int pack_iw_len = 4; + constexpr int fh_end = filter_size / ih_step * ih_step; + + const int ic_stride = ih * iw * pack_iw_len; + const int ld_dot4_weight_oc = oc_step * filter_size * filter_size * ic; + constexpr int c_dim = OCHelper::val; + int32x4_t c[c_dim][4]; + + init_ocx_ow4(c, bias_ptr, oc_step); + + for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { + for (int fh_idx = 0; fh_idx < fh_end; fh_idx += ih_step) { + const int8_t* nchw_src_ptr = + src_ptr + ic_idx * ic_stride + + fh_idx * iw * ic_step * pack_iw_len; + int8x16_t src[5]; + int8x16_t dot4_weight[c_dim][2]; + int16x8_t temp_c[4]; + load_helper<2, 0, 16, c_dim, Vld1q_s8>(dot4_weight, weight_ptr, + ld_dot4_weight_oc); + load_helper<5, 0, 16, 0, Vld1q_s8>(src, nchw_src_ptr, 0); + cal_helper<0, 0, c_dim, Vdotq_s32_h, stride>( + c, src, dot4_weight, temp_c); + cal_helper<1, 1, c_dim, Vdotq_s32_h, stride>( + c, src, dot4_weight, temp_c); + + int8x8_t src_dot2[4]; + int8x8_t dot2_weight[c_dim][1]; + load_helper<1, 2 * 16, 8, c_dim, Vld1_s8>( + dot2_weight, weight_ptr, ld_dot4_weight_oc); + load_helper<4, 2 * 16, 16, 0, Vld1_s8>(src_dot2, nchw_src_ptr, + 0); + cal_helper<0, 0, c_dim, Vdot2_s32_h, stride>( + c, src_dot2, dot2_weight, temp_c); + weight_ptr += filter_size * pack_iw_len * ih_step; + } + const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride + + fh_end * iw * ic_step * pack_iw_len; + + int8x8_t dot2_weight[c_dim][2]; + int16x8_t temp_c[4]; + int8x8_t src_dot2[5]; + uint8x16_t tbl = vld1q_u8(src_idx_buffer); + load_helper<2, 0, 8, c_dim, Vld1_s8>(dot2_weight, weight_ptr, + ld_dot4_weight_oc); + load_helper_x<5, 0, 16, 0, Vldq_tbl_low_s8>(src_dot2, nchw_src_ptr, + 0, tbl); + + cal_helper<0, 0, c_dim, Vdot2_s32_h, stride>(c, src_dot2, + dot2_weight, temp_c); + cal_helper<1, 1, c_dim, Vdot2_s32_h, stride>(c, src_dot2, + dot2_weight, temp_c); + + int16x8_t dot1_weight[c_dim][1]; + int16x8_t src_dot1[4]; + load_helper<1, 2 * 8, 8, c_dim, Vldq_dup_4s8_8s16>( + dot1_weight, weight_ptr, ld_dot4_weight_oc); + load_helper<4, 2 * 16, 16, 0, Vld1_dup_s8_s16>(src_dot1, + nchw_src_ptr, 0); + + cal_helper<0, 0, c_dim, Vmlal_s16, stride>(c, src_dot1, + dot1_weight); + weight_ptr += filter_size * pack_iw_len; + } + store_ocx_ow4_remain_static(c, op, dst_ptr, ld_dst_oc); + } +}; +/** + * filter shape = (oc/4, ic, 3, 3, 4), first 4 oc is f0 = filter[0, 0, :, :, :] + * calculate sequence \ + * f0[0:1, 0:1, 4] dot4, \ + * f0[0:1, 2, 4] dot2, \ + * f0[2, 0:1, 4] dot2, \ + * f0[2, 2, 4] dot1 \ + * look like: + * |---|-| + * |x x|x| + * |x x|x| + * |-----| + * |x x|x| + * |-----| + **/ +template +struct KerNeonXXs2NchwNchw44 { + static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, + const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, + int iw, int ld_dst_oc, const Op& op) { + constexpr int stride = 2; + constexpr int filter_size = 3; + static const uint8_t src_idx_buffer[16] = {0, 8, 0, 8, 0, 8, 0, 8, + 0, 8, 0, 8, 0, 8, 0, 8}; + constexpr int oc_step = 4; + constexpr int ic_step = 1; + constexpr int loop_ic_step = 1; + constexpr int pack_iw_len = 4; + + const int ic_stride = ih * iw * pack_iw_len; + const int ld_weight_oc = oc_step * filter_size * filter_size * ic; + constexpr int c_dim = OCHelper::val; + + int32x4_t c[c_dim][4]; + init_ocx_ow4(c, bias_ptr, oc_step); + + for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { + // first 2 line + { + const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride; + int8x16_t src[4]; + int8x16_t dot4_weight[c_dim][1]; + int16x8_t temp_c[4]; + load_helper<1, 0, 16, c_dim, Vld1q_s8>(dot4_weight, weight_ptr, + ld_weight_oc); + load_helper<4, 0, 16, 0, Vld1q_s8>(src, nchw_src_ptr, 0); + cal_helper<0, 0, c_dim, Vdotq_s32_h, stride>( + c, src, dot4_weight, temp_c); + + int8x8_t src_dot2[4]; + int8x8_t dot2_weight[c_dim][1]; + load_helper<1, 1 * 16, 8, c_dim, Vld1_s8>( + dot2_weight, weight_ptr, ld_weight_oc); + load_helper<4, 1 * 16, 16, 0, Vld1_s8>(src_dot2, nchw_src_ptr, + 0); + cal_helper<0, 0, c_dim, Vdot2_s32_h, stride>( + c, src_dot2, dot2_weight, temp_c); + } + // last line + { + const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride + + 2 * iw * ic_step * pack_iw_len; + int16x8_t temp_c[4]; + int8x8_t src_dot2[4]; + int8x8_t dot2_weight[c_dim][1]; + uint8x16_t tbl = vld1q_u8(src_idx_buffer); + load_helper<1, 24, 8, c_dim, Vld1_s8>(dot2_weight, weight_ptr, + ld_weight_oc); + load_helper_x<4, 0, 16, 0, Vldq_tbl_low_s8>( + src_dot2, nchw_src_ptr, 0, tbl); + cal_helper<0, 0, c_dim, Vdot2_s32_h, stride>( + c, src_dot2, dot2_weight, temp_c); + int16x8_t dot1_weight[c_dim][1]; + int16x8_t src_dot1[4]; + load_helper<1, 32, 8, c_dim, Vldq_dup_4s8_8s16>( + dot1_weight, weight_ptr, ld_weight_oc); + load_helper<4, 1 * 16, 16, 0, Vld1_dup_s8_s16>(src_dot1, + nchw_src_ptr, 0); + cal_helper<0, 0, c_dim, Vmlal_s16, stride>(c, src_dot1, + dot1_weight); + weight_ptr += filter_size * filter_size * pack_iw_len; + } + } + store_ocx_ow4_remain_static(c, op, dst_ptr, ld_dst_oc); + } +}; + +#if MEGDNN_AARCH64 +template +struct KerNeonXXs2NchwNchw44 { + static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, + const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, + int iw, int ld_dst_oc, const Op& op) { + constexpr int filter_size = 3; + static const uint8_t src_idx_buffer[16] = {0, 8, 0, 8, 0, 8, 0, 8, + 0, 8, 0, 8, 0, 8, 0, 8}; + constexpr int oc_block = 8; + constexpr int remain_w = 0; + + constexpr int oc_step = 4; + constexpr int ic_step = 1; + constexpr int loop_ic_step = 1; + constexpr int pack_iw_len = 4; + + const int ic_stride = ih * iw * pack_iw_len; + const int ld_weight_oc = oc_step * filter_size * filter_size * ic; + const size_t weight_step = filter_size * filter_size * pack_iw_len; + constexpr int c_dim = OCHelper::val; + + int32x4_t c[c_dim][4]; + init_ocx_ow4(c, bias_ptr, oc_step); + uint8x16_t vtbl = vld1q_u8(src_idx_buffer); + for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { + const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride; + const int8_t* nchw_src_ptr_last_line = + src_ptr + ic_idx * ic_stride + + 2 * iw * ic_step * pack_iw_len; + const int8_t* weight_ptr_oc = weight_ptr + ld_weight_oc; + /** + * r0-r7 c + * r24-r31 temp + * r8-r15 src + * r16-r19 weight + * r20-vtbl + */ + asm volatile( + //! load src 0,1 + "ldp q8,q9, [%[nchw_src_ptr]]\n" + "ldr q16, [%[weight_ptr]]\n" + "ldp q10,q11, [%[nchw_src_ptr], #32]\n" + "add x5, %[weight_ptr], #32\n" + "smull v24.8h, v8.8b, v16.8b\n" + "ldr q17, [%[weight_ptr_oc]]\n" + "smull v25.8h, v9.8b, v16.8b\n" + "add x6, %[weight_ptr_oc], #32\n" + "smull v26.8h, v10.8b, v16.8b\n" + "smull v27.8h, v11.8b, v16.8b\n" + "smlal2 v24.8h, v8.16b, v16.16b\n" + "add x7, %[nchw_src_ptr_last_line], #64\n" + "smlal2 v25.8h, v9.16b, v16.16b\n" + "smlal2 v26.8h, v10.16b, v16.16b\n" + "smlal2 v27.8h, v11.16b, v16.16b\n" + "sadalp %[c00].4s, v24.8h\n" + "smull v28.8h, v8.8b, v17.8b\n" + "ldr d12, [%[nchw_src_ptr],#16]\n" + "sadalp %[c01].4s, v25.8h\n" + "smull v29.8h, v9.8b, v17.8b\n" + "ldr d13, [%[nchw_src_ptr],#32]\n" + "sadalp %[c02].4s, v26.8h\n" + "smull v30.8h, v10.8b, v17.8b\n" + "ldr d14, [%[nchw_src_ptr],#48]\n" + "sadalp %[c03].4s, v27.8h\n" + "smull v31.8h, v11.8b, v17.8b\n" + "ldr d18, [%[weight_ptr],#16]\n" + "smlal2 v28.8h, v8.16b, v17.16b\n" + "ldr d19, [%[weight_ptr_oc],#16]\n" + "smlal2 v29.8h, v9.16b, v17.16b\n" + "ldr d15, [%[nchw_src_ptr],#64]\n" + "smlal2 v30.8h, v10.16b, v17.16b\n" + "ldp q8,q9, [%[nchw_src_ptr_last_line]]\n" + "smull v24.8h, v12.8b, v18.8b\n" + "sadalp %[c10].4s, v28.8h\n" + "smlal2 v31.8h, v11.16b, v17.16b\n" + "ldp q10,q11, [%[nchw_src_ptr_last_line], #32]\n" + "sadalp %[c11].4s, v29.8h\n" + "smull v25.8h, v13.8b, v18.8b\n" + "tbl v8.16b, {v8.16b}, %[vtbl].16b\n" + "sadalp %[c12].4s, v30.8h\n" + "smull v26.8h, v14.8b, v18.8b\n" + "ldr d16, [%[weight_ptr],#24]\n" + "sadalp %[c13].4s, v31.8h\n" + "ldr d17, [%[weight_ptr_oc],#24]\n" + "smull v27.8h, v15.8b, v18.8b\n" + "tbl v9.16b, {v9.16b}, %[vtbl].16b\n" + "sadalp %[c00].4s, v24.8h\n" + "smull v28.8h, v12.8b, v19.8b\n" + "tbl v10.16b, {v10.16b}, %[vtbl].16b\n" + "sadalp %[c01].4s, v25.8h\n" + "smull v29.8h, v13.8b, v19.8b\n" + "tbl v11.16b, {v11.16b}, %[vtbl].16b\n" + "sadalp %[c02].4s, v26.8h\n" + "smull v30.8h, v14.8b, v19.8b\n" + "ld1r {v18.2s}, [x5]\n" + "sadalp %[c03].4s, v27.8h\n" + "smull v31.8h, v15.8b, v19.8b\n" + "ld1r {v19.2s}, [x6]\n" + "sadalp %[c10].4s, v28.8h\n" + "smull v24.8h, v8.8b, v16.8b\n" + "sadalp %[c11].4s, v29.8h\n" + "smull v25.8h, v9.8b, v16.8b\n" + "dup v12.8b, v9.b[0]\n" + "sadalp %[c12].4s, v30.8h\n" + "smull v26.8h, v10.8b, v16.8b\n" + "dup v12.8b, v9.b[0]\n" + "sadalp %[c13].4s, v31.8h\n" + "smull v27.8h, v11.8b, v16.8b\n" + "dup v13.8b, v10.b[0]\n" + "smull v28.8h, v8.8b, v17.8b\n" + "dup v14.8b, v11.b[0]\n" + "sadalp %[c00].4s, v24.8h\n" + "smull v29.8h, v9.8b, v17.8b\n" + "ld1r {v15.8b}, [x7]\n" + "sadalp %[c01].4s, v25.8h\n" + "smull v30.8h, v10.8b, v17.8b\n" + "sxtl v12.8h, v12.8b\n" + "sxtl v18.8h, v18.8b\n" + "sadalp %[c02].4s, v26.8h\n" + "smull v31.8h, v11.8b, v17.8b\n" + "sxtl v13.8h, v13.8b\n" + "sadalp %[c03].4s, v27.8h\n" + "smlal %[c00].4s, v12.4h, v18.4h\n" + "sxtl v14.8h, v14.8b\n" + "sadalp %[c10].4s, v28.8h\n" + "smlal %[c01].4s, v13.4h, v18.4h\n" + "sxtl v15.8h, v15.8b\n" + "sadalp %[c11].4s, v29.8h\n" + "smlal %[c02].4s, v14.4h, v18.4h\n" + "sxtl v19.8h, v19.8b\n" + "sadalp %[c12].4s, v30.8h\n" + "add %[weight_ptr], %[weight_ptr], %[weight_step]\n" + "smlal %[c03].4s, v15.4h, v18.4h\n" + "sadalp %[c13].4s, v31.8h\n" + "smlal %[c10].4s, v12.4h, v19.4h\n" + "smlal %[c11].4s, v13.4h, v19.4h\n" + "smlal %[c12].4s, v14.4h, v19.4h\n" + "smlal %[c13].4s, v15.4h, v19.4h\n" + : + + [c00] "+w"(c[0][0]), [c10] "+w"(c[1][0]), + [c01] "+w"(c[0][1]), [c11] "+w"(c[1][1]), + [c02] "+w"(c[0][2]), [c12] "+w"(c[1][2]), + [c03] "+w"(c[0][3]), [c13] "+w"(c[1][3]), + + [weight_ptr] "+r"(weight_ptr), + [weight_ptr_oc] "+r"(weight_ptr_oc) + : [vtbl] "w"(vtbl), [nchw_src_ptr] "r"(nchw_src_ptr), + [nchw_src_ptr_last_line] "r"(nchw_src_ptr_last_line), + [weight_step] "r"(weight_step) + : "x5", "x6", "x7", "v8", "v9", "v10", "v11", "v12", "v13", + "v14", "v15", "v16", "v17", "v18", "v19", "v24", "v25", + "v26", "v27", "v28", "v29", "v30", "v31", "cc", "memory"); + } + store_ocx_ow4_remain_static(c, op, dst_ptr, ld_dst_oc); + } +}; +#endif + +template +struct KerNeonXXs2NchwNchw44 { + static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, + const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, + int iw, int ld_dst_oc, const Op& op) { + constexpr int filter_size = 2; + constexpr int oc_step = 4; + constexpr int loop_ic_step = 1; + constexpr int pack_iw_len = 4; + + const int ic_stride = ih * iw * pack_iw_len; + const int ld_weight_oc = oc_step * filter_size * filter_size * ic; + constexpr int c_dim = OCHelper::val; + + int32x4_t c[c_dim][4]; + init_ocx_ow4(c, bias_ptr, oc_step); + + for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { + const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride; + int8x16_t src[4]; + int8x16_t dot4_weight[c_dim][1]; + int16x8_t temp_c[4]; + load_helper<1, 0, 16, c_dim, Vld1q_s8>(dot4_weight, weight_ptr, + ld_weight_oc); + load_helper<4, 0, 16, 0, Vld1q_s8>(src, nchw_src_ptr, 0); + cal_helper<0, 0, c_dim, Vdotq_s32_h, stride>(c, src, dot4_weight, + temp_c); + weight_ptr += oc_step * filter_size * filter_size; + } + store_ocx_ow4_remain_static(c, op, dst_ptr, ld_dst_oc); + } +}; + +enum PACK_MODE { NO_PAD = 0, FIRST_PAD = 1, LAST_PAD = 2 }; +template +MEGDNN_ALWAYS_INLINE void pack_src_one_line(const int8_t* inptr, int8_t* outptr, + int left_pad, int right_pad, + const int iw) { + const int8_t* src_row_0 = inptr; + const int8_t* src_row_1 = inptr + iw; + constexpr int combine_row = 2; + constexpr int iw_step = 16; + constexpr int src_expand = 4; + constexpr int out_gap = iw_step * src_expand; + const int iw_end = iw / iw_step * iw_step; + + memset(outptr, 0, combine_row * left_pad * src_expand * sizeof(int8_t)); + outptr += combine_row * left_pad * src_expand; + + for (int iw_idx = 0; iw_idx < iw_end; iw_idx += iw_step) { + int8x16_t row0 = vld1q_s8(src_row_0 + iw_idx); + int8x16_t row1 = vdupq_n_s8(0); + if (mode == PACK_MODE::NO_PAD) { + row1 = vld1q_s8(src_row_1 + iw_idx); + } else if (mode == PACK_MODE::FIRST_PAD) { + row1 = row0; + row0 = vdupq_n_s8(0); + } + int8x16x2_t pack_rows = vzipq_s8(row0, row1); +#define STORE_8S8(step) \ + vst1_s8(outptr + step * 8, \ + vreinterpret_s8_s16(vdup_laneq_s16( \ + vreinterpretq_s16_s8(pack_rows.val[0]), step))); + + UNROLL_CALL_RAW(8, STORE_8S8); +#undef STORE_8S8 +#define STORE_8S8(step) \ + vst1_s8(outptr + out_gap + step * 8, \ + vreinterpret_s8_s16(vdup_laneq_s16( \ + vreinterpretq_s16_s8(pack_rows.val[1]), step))); + + UNROLL_CALL_RAW(8, STORE_8S8); +#undef STORE_8S8 + outptr += out_gap * combine_row; + } + for (int iw_idx = iw_end; iw_idx < iw; iw_idx++) { + int8x8_t row0 = vld1_dup_s8(src_row_0 + iw_idx); + int8x8_t row1 = vdup_n_s8(0); + if (mode == PACK_MODE::NO_PAD) { + row1 = vld1_dup_s8(src_row_1 + iw_idx); + } else if (mode == PACK_MODE::FIRST_PAD) { + row1 = row0; + row0 = vdup_n_s8(0); + } + int8x8x2_t pack_rows = vzip_s8(row0, row1); + vst1_s8(outptr, pack_rows.val[0]); + outptr += src_expand * combine_row; + } + memset(outptr, 0, combine_row * right_pad * src_expand * sizeof(int8_t)); + outptr += combine_row * right_pad * src_expand; +} + +} // namespace + +namespace int8_direct_nchw_nchw44 { +/** + * pack (ic, h, w) to (ic, h / 2, 2 * w) + * pack interleave two adjacent row in src and repeat 4 times, store to one row + * */ +template <> +void pack_nchw_src_for_nchw44_conv<2>(const int8_t* inptr, int8_t* outptr, + const int ic, const int top_pad, + const int bottom_pad, const int left_pad, + const int right_pad, const int ih, + const int iw, const int, const int, + int8_t*) { + constexpr int src_expand = 4; + constexpr int oh_step = 2; + const int oh = ih + top_pad + bottom_pad; + const int oh_end = div_floor(ih + top_pad, oh_step) * oh_step; + const int ow = (iw + left_pad + right_pad) * src_expand; + + for (int ic_idx = 0; ic_idx < ic; ++ic_idx) { + int oh_idx = 0; + for (; oh_idx < top_pad; oh_idx += oh_step) { + if (top_pad - oh_idx >= oh_step) { + memset(outptr, 0, oh_step * ow * sizeof(int8_t)); + } else { + pack_src_one_line(inptr, outptr, left_pad, + right_pad, iw); + inptr += iw; + } + outptr += oh_step * ow; + } + + for (; oh_idx < oh_end; oh_idx += oh_step) { + pack_src_one_line(inptr, outptr, left_pad, + right_pad, iw); + inptr += oh_step * iw; + outptr += oh_step * ow; + } + + for (; oh_idx < oh; oh_idx += oh_step) { + const int last_pad = oh_idx - ih - top_pad; + if (last_pad >= 0) { + memset(outptr, 0, oh_step * ow * sizeof(int8_t)); + } else { + pack_src_one_line(inptr, outptr, left_pad, + right_pad, iw); + inptr += iw; + } + outptr += oh_step * ow; + } + } +} + +/** + * pack {oc / 4, fh, fw, ic, 4(oc)} to {oc / 4, ic, fh * fw, 4(oc)} + * pack interleave two adjacent row in filter to one row + * */ +template <> +void pack_nchw44_weight_for_nchw_conv<2>(const int8_t* inptr, int8_t* outptr, + const int ic, const int fh, + const int fw, const int oc) { + constexpr int oc_step = 4; + constexpr int ic_step = 2; + constexpr int fh_step = 2; + constexpr int fw_step = 2; + const int ic_end = ic / ic_step * ic_step; + const int ic_remain = ic - ic_end; + const int fh_end = fh / fh_step * fh_step; + const int fh_remain = fh - fh_end; + const int fw_end = fw / fw_step * fw_step; + const int fw_remain = fw - fw_end; + const int filter_stride = ic * oc_step; + static const uint8_t ic2_idx_h_buffer[16] = {0, 8, 1, 9, 2, 10, 3, 11, + 4, 12, 5, 13, 6, 14, 7, 15}; + uint8x16_t ic2_idx_h = vld1q_u8(ic2_idx_h_buffer); + for (int oc_idx = 0; oc_idx < oc; oc_idx += oc_step) { + for (int ic_idx = 0; ic_idx < ic_end; ic_idx += ic_step) { + const int ic_offset = ic_idx * oc_step; + int8_t* output_ic0 = outptr + ic_idx * fh * fw * oc_step; + int8_t* output_ic1 = output_ic0 + fh * fw * oc_step; + for (int fh_idx = 0; fh_idx < fh_end; fh_idx += fh_step) { + const int fh_offset = fh_idx * fw * filter_stride; + for (int fw_idx = 0; fw_idx < fw; ++fw_idx) { + const int8_t* filter_ptr = inptr + fh_offset + + fw_idx * filter_stride + + ic_offset; + int8x8_t row_0 = vld1_s8(filter_ptr); + int8x8_t row_1 = vld1_s8(filter_ptr + fw * filter_stride); + int8x16_t combine_row = vcombine_s8(row_0, row_1); + combine_row = vqtbl1q_s8(combine_row, ic2_idx_h); + vst1_s8(output_ic0, vget_low_s8(combine_row)); + vst1_s8(output_ic1, vget_high_s8(combine_row)); + output_ic0 += 8; + output_ic1 += 8; + } + } + if (fh_remain > 0) { + const int fh_offset = fh_end * fw * filter_stride; + for (int fw_idx = 0; fw_idx < fw_end; fw_idx += fw_step) { + const int8_t* filter_ptr = inptr + fh_offset + + fw_idx * filter_stride + + ic_offset; + int8x8_t row_0 = vld1_s8(filter_ptr); + int8x8_t row_1 = vld1_s8(filter_ptr + filter_stride); + int8x16_t combine_row = vcombine_s8(row_0, row_1); + combine_row = vqtbl1q_s8(combine_row, ic2_idx_h); + vst1_s8(output_ic0, vget_low_s8(combine_row)); + vst1_s8(output_ic1, vget_high_s8(combine_row)); + output_ic0 += 8; + output_ic1 += 8; + } + if (fw_remain > 0) { + const int8_t* filter_ptr = inptr + fh_offset + + fw_end * filter_stride + + ic_offset; + int8x8_t row_0 = vld1_s8(filter_ptr); + vst1_lane_s32((int32_t*)output_ic0, + vreinterpret_s32_s8(row_0), 0); + vst1_lane_s32((int32_t*)output_ic1, + vreinterpret_s32_s8(row_0), 1); + output_ic0 += 4; + output_ic1 += 4; + } + } + } + if (ic_remain > 0) { + const int ic_offset = ic_end * oc_step; + int8_t* output_ic0 = outptr + ic_end * fh * fw * oc_step; + for (int fh_idx = 0; fh_idx < fh_end; fh_idx += fh_step) { + const int fh_offset = fh_idx * fw * filter_stride; + for (int fw_idx = 0; fw_idx < fw; ++fw_idx) { + const int8_t* filter_ptr = inptr + fh_offset + + fw_idx * filter_stride + + ic_offset; + int8x8_t row_0 = vreinterpret_s8_s32( + vld1_dup_s32((const int32_t*)(filter_ptr))); + int8x8_t row_1 = vreinterpret_s8_s32(vld1_dup_s32( + (const int32_t*)(filter_ptr + fw * filter_stride))); + int8x16_t combine_row = vcombine_s8(row_0, row_1); + combine_row = vqtbl1q_s8(combine_row, ic2_idx_h); + vst1_s8(output_ic0, vget_low_s8(combine_row)); + output_ic0 += 8; + } + } + if (fh_remain > 0) { + const int fh_offset = fh_end * fw * filter_stride; + for (int fw_idx = 0; fw_idx < fw_end; fw_idx += fw_step) { + const int8_t* filter_ptr = inptr + fh_offset + + fw_idx * filter_stride + + ic_offset; + int8x8_t row_0 = vreinterpret_s8_s32( + vld1_dup_s32((const int32_t*)(filter_ptr))); + int8x8_t row_1 = vreinterpret_s8_s32(vld1_dup_s32( + (const int32_t*)(filter_ptr + filter_stride))); + int8x16_t combine_row = vcombine_s8(row_0, row_1); + combine_row = vqtbl1q_s8(combine_row, ic2_idx_h); + vst1_s8(output_ic0, vget_low_s8(combine_row)); + output_ic0 += 8; + } + if (fw_remain > 0) { + const int8_t* filter_ptr = inptr + fh_offset + + fw_end * filter_stride + + ic_offset; + *(int32_t*)(output_ic0) = *(const int32_t*)(filter_ptr); + output_ic0 += 4; + } + } + } + inptr += oc_step * fh * fw * ic; + outptr += oc_step * fh * fw * ic; + } +} + +template +struct ConvDiectStrideInt8NchwNchw44 { + static void impl(const int8_t* src, const int8_t* filter, + const int32_t* bias, int32_t* temp, int8_t* dst, + const size_t oc, const size_t ic, const size_t ih, + const size_t iw, const size_t oh, const size_t ow, + const Op& op) { + MEGDNN_MARK_USED_VAR(temp); + constexpr size_t stride = 2; + constexpr size_t fh = filter_size; + constexpr size_t fw = + stride == 2 ? filter_size : (filter_size + 3) / 4 * 4; + constexpr size_t ic_step = 1; + constexpr size_t big_oc_step = 8; + constexpr size_t oc_step = 4; + constexpr size_t ih_step = stride == 2 ? 2 : 1; + constexpr size_t oh_step = 1; + constexpr size_t ow_step = stride == 2 ? 4 : 8; + constexpr size_t stride_h = stride; + constexpr size_t stride_w = stride; + constexpr int pack_iw_len = 4; + + const size_t img_stride = oh * ow; + const size_t ow_end = ow / ow_step * ow_step; + const size_t ow_remain = ow - ow_end; + const size_t oc_end = oc / big_oc_step * big_oc_step; + const size_t oc_remain = oc - oc_end; + const int ld_dst_oc = oc_step * img_stride; + + using remain_fun = std::function; + remain_fun kern_big_oc_remain = nullptr; + remain_fun kern_small_oc_remain = nullptr; + switch (ow_remain) { +#define cb(step) \ + case step: \ + kern_big_oc_remain = \ + KerNeonXXs2NchwNchw44::impl; \ + kern_small_oc_remain = \ + KerNeonXXs2NchwNchw44::impl; \ + break; + + UNROLL_CALL_RAW(4, cb); + default: + megdnn_assert(0, "no remain %zu for kern", ow_remain); + } +#undef cb + + for (size_t oc_idx = 0; oc_idx < oc_end; oc_idx += big_oc_step) { + const size_t weight_offset = oc_idx * ic * fh * fw; + for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { + for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { + const size_t src_offset = (oh_idx * stride_h * iw + + ow_idx * stride_w * ih_step) * + ic_step * pack_iw_len; + const size_t dst_offset = oc_idx * img_stride + + (oh_idx * ow + ow_idx) * oc_step; + KerNeonXXs2NchwNchw44::impl(src + src_offset, + filter + weight_offset, + bias + oc_idx, + dst + dst_offset, ic, + ih, iw, ld_dst_oc, op); + } + if (ow_remain > 0) { + const size_t src_offset = (oh_idx * stride_h * iw + + ow_end * stride_w * ih_step) * + ic_step * pack_iw_len; + const size_t dst_offset = oc_idx * img_stride + + (oh_idx * ow + ow_end) * oc_step; + kern_big_oc_remain(src + src_offset, filter + weight_offset, + bias + oc_idx, dst + dst_offset, ic, ih, + iw, ld_dst_oc, op); + } + } + } + if (oc_remain > 0) { + size_t oc_idx = oc_end; + const size_t weight_offset = oc_idx * ic * fh * fw; + for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { + for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { + const size_t src_offset = (oh_idx * stride_h * iw + + ow_idx * stride_w * ih_step) * + ic_step * pack_iw_len; + const size_t dst_offset = oc_idx * img_stride + + (oh_idx * ow + ow_idx) * oc_step; + KerNeonXXs2NchwNchw44::impl(src + src_offset, + filter + weight_offset, + bias + oc_idx, + dst + dst_offset, ic, + ih, iw, ld_dst_oc, op); + } + if (ow_remain > 0) { + const size_t src_offset = (oh_idx * stride_h * iw + + ow_end * stride_w * ih_step) * + ic_step * pack_iw_len; + const size_t dst_offset = oc_idx * img_stride + + (oh_idx * ow + ow_end) * oc_step; + kern_small_oc_remain(src + src_offset, + filter + weight_offset, bias + oc_idx, + dst + dst_offset, ic, ih, iw, + ld_dst_oc, op); + } + } + } + } +}; + +#define INSTANCE_CONV_KERN_FUN(stride, filter_size, bias_mode, Op) \ + template struct ConvDiectStrideInt8NchwNchw44; + +#define INSTANCE_OP_PARAM(stride, filter, bias_mode) \ + INSTANCE_CONV_KERN_FUN(stride, filter, bias_mode, \ + TypeCvtOp) \ + INSTANCE_CONV_KERN_FUN(stride, filter, bias_mode, \ + ReluOp) \ + INSTANCE_CONV_KERN_FUN(stride, filter, bias_mode, \ + HSwishOp) + +#define INSTANCE_BIAS_MODE_PARAM(stride, filter) \ + INSTANCE_OP_PARAM(stride, filter, BiasMode::NO_BIAS) \ + INSTANCE_OP_PARAM(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS) + +#define INSTANCE_CONV_KERN(stride) \ + INSTANCE_BIAS_MODE_PARAM(stride, 2) \ + INSTANCE_BIAS_MODE_PARAM(stride, 3) \ + INSTANCE_BIAS_MODE_PARAM(stride, 5) \ + INSTANCE_BIAS_MODE_PARAM(stride, 7) + +INSTANCE_CONV_KERN(2); + +} // namespace int8_direct_nchw_nchw44 +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen \ No newline at end of file diff --git a/dnn/src/arm_common/conv_bias/int8/direct_nchw44_algo.cpp b/dnn/src/arm_common/conv_bias/int8/direct_nchw44_algo.cpp index 0626b23e..62c63fd2 100644 --- a/dnn/src/arm_common/conv_bias/int8/direct_nchw44_algo.cpp +++ b/dnn/src/arm_common/conv_bias/int8/direct_nchw44_algo.cpp @@ -114,7 +114,7 @@ static void copy_padding_kern(const WorkspaceBundle& bundle, rep(ih_idx, IH) { std::memset(sptr_base, 0, nr_pad_w * sizeof(int8_t)); sptr_base += nr_pad_w; - nchw44_pack_src(sptr, sptr_base, IW); + int8_direct_nchw44::nchw44_pack_src(sptr, sptr_base, IW); sptr_base += IW * pack_ic * expend_element; sptr += IW * pack_ic; std::memset(sptr_base, 0, row_last_pad * sizeof(int8_t)); @@ -125,8 +125,8 @@ static void copy_padding_kern(const WorkspaceBundle& bundle, } } -template +template static void do_conv_kern(const WorkspaceBundle& bundle, const ConvBiasImpl::NCBKernParam& kern_param, const ConvBiasImpl::NCBKernIndex& ncb_index, @@ -182,8 +182,10 @@ static void do_conv_kern(const WorkspaceBundle& bundle, kern_param.bias(batch_id, group_id) + oc_idx; auto packed_weight = reinterpret_cast(bundle.get(1)) + group_id * OC * IC * FH * FW + oc_idx * IC * FH * FW; - nchw44_pack_filter(fptr, packed_weight, oc_block / 4 * IC / 4 * FH * FW); - conv_direct_int8_nchw44( + int8_direct_nchw44::nchw44_pack_filter(fptr, packed_weight, + oc_block / 4 * IC / 4 * FH * FW); + int8_direct_nchw44::conv_direct_int8_nchw44( sptr, packed_weight, bptr, nullptr, static_cast(dst), oc_block, IC, IH2, IW2, OH, OW, op); } @@ -233,40 +235,38 @@ ConvBiasImpl::AlgoS8DirectNCHW44::dispatch_kerns( size_t N = param.n; size_t IC = fm.icpg; size_t OC = fm.ocpg; - size_t OW = param.osz[1]; size_t group = fm.group; size_t fh = fm.spatial[0]; size_t fw = fm.spatial[1]; WorkspaceBundle wbundle = get_bundle(param); conv_fun do_conv_fun = nullptr; - int ow_remain = OW % 8; + bool need_post_process = param.dst_type.enumv() == DTypeEnum::QuantizedS8; // NOTE: remain_w is not used to gen hash of midout for compatible with changing // shape runtime -#define DO_CONV_KERN_FUN(stride, dst_type, filter, bias_mode, remain_w, op) \ +#define DO_CONV_KERN_FUN(stride, dst_type, filter, bias_mode, op) \ MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8_nchw44, \ midout_iv(#stride #dst_type #filter #bias_mode #op##_hash)) { \ - do_conv_fun = do_conv_kern; \ + do_conv_fun = do_conv_kern; \ } \ MIDOUT_END(); -#define GET_OP_PARAM(stride, filter, bias_mode, remain_w) \ +#define GET_OP_PARAM(stride, filter, bias_mode) \ if (need_post_process) { \ switch (param.nonlineMode) { \ case param::ConvBias::NonlineMode::IDENTITY: \ DO_CONV_KERN_FUN(stride, dt_qint8, filter, bias_mode, \ - remain_w, \ + \ TypeCvtOp) \ break; \ case param::ConvBias::NonlineMode::RELU: \ DO_CONV_KERN_FUN(stride, dt_qint8, filter, bias_mode, \ - remain_w, \ + \ ReluOp) \ break; \ case param::ConvBias::NonlineMode::H_SWISH: \ DO_CONV_KERN_FUN(stride, dt_qint8, filter, bias_mode, \ - remain_w, \ + \ HSwishOp) \ break; \ default: \ @@ -277,7 +277,7 @@ ConvBiasImpl::AlgoS8DirectNCHW44::dispatch_kerns( switch (param.nonlineMode) { \ case param::ConvBias::NonlineMode::IDENTITY: \ DO_CONV_KERN_FUN(stride, dt_int32, filter, bias_mode, \ - remain_w, NoneOp) \ + NoneOp) \ break; \ default: \ megdnn_assert( \ @@ -287,48 +287,17 @@ ConvBiasImpl::AlgoS8DirectNCHW44::dispatch_kerns( } \ } -#define GET_REMAIN_W_PARAM(stride, filter, bias_mode) \ - switch (ow_remain) { \ - case 0: \ - GET_OP_PARAM(stride, filter, bias_mode, 0); \ - break; \ - case 1: \ - GET_OP_PARAM(stride, filter, bias_mode, 1); \ - break; \ - case 2: \ - GET_OP_PARAM(stride, filter, bias_mode, 2); \ - break; \ - case 3: \ - GET_OP_PARAM(stride, filter, bias_mode, 3); \ - break; \ - case 4: \ - GET_OP_PARAM(stride, filter, bias_mode, 4); \ - break; \ - case 5: \ - GET_OP_PARAM(stride, filter, bias_mode, 5); \ - break; \ - case 6: \ - GET_OP_PARAM(stride, filter, bias_mode, 6); \ - break; \ - case 7: \ - GET_OP_PARAM(stride, filter, bias_mode, 7); \ - break; \ - default: \ - megdnn_assert(0); \ - } - -#define GET_BIAS_MODE_PARAM(stride, filter) \ - switch (param.bias_mode) { \ - case BiasMode::NO_BIAS: \ - GET_REMAIN_W_PARAM(stride, filter, BiasMode::NO_BIAS) \ - break; \ - case BiasMode::BROADCAST_CHANNEL_BIAS: \ - GET_REMAIN_W_PARAM(stride, filter, \ - BiasMode::BROADCAST_CHANNEL_BIAS) \ - break; \ - default: \ - megdnn_assert(0); \ - break; \ +#define GET_BIAS_MODE_PARAM(stride, filter) \ + switch (param.bias_mode) { \ + case BiasMode::NO_BIAS: \ + GET_OP_PARAM(stride, filter, BiasMode::NO_BIAS) \ + break; \ + case BiasMode::BROADCAST_CHANNEL_BIAS: \ + GET_OP_PARAM(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS) \ + break; \ + default: \ + megdnn_assert(0); \ + break; \ } #define DISPATCH_CONV_KERN(stride) \ diff --git a/dnn/src/arm_common/conv_bias/int8/direct_nchw44_kern.h b/dnn/src/arm_common/conv_bias/int8/direct_nchw44_kern.h index e66a50ed..9bf1f50d 100644 --- a/dnn/src/arm_common/conv_bias/int8/direct_nchw44_kern.h +++ b/dnn/src/arm_common/conv_bias/int8/direct_nchw44_kern.h @@ -19,506 +19,7 @@ namespace megdnn { namespace arm_common { -namespace { - -template -static void ker_neon_dirctconv_2x2s1_oc8_ow8(const int8_t* src_ptr, - const int8_t* weight_ptr, - const int32_t* bias_ptr, - DstType* dst_ptr, int ic, int ih, - int iw, int ld_dst_oc, - const Op& op) { - constexpr int fh = filter_size; - constexpr int fw = filter_size; - constexpr int ic_step = 4; - constexpr int oc_step = 4; - constexpr int loop_ic_step = 4; - constexpr int ld_weight_ic4 = 16; - constexpr int pack_iw_len = 4; - - const int ic_stride = ih * iw * pack_iw_len; - const int ld_weight_oc4 = oc_step * fh * fw * ic; - - int32x4_t c[2][8]; - int8x16_t weight[2][2]; - int8x16_t src[8 + 1]; - int16x8_t temp_c[4]; - - init_ocx_ow8(c, bias_ptr, oc_step); - - for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { - for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { - const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + - fh_idx * iw * ic_step * pack_iw_len; - - src[0] = vld1q_s8(src_ic_0_3); - src[1] = vld1q_s8((src_ic_0_3 + 16)); - src[2] = vld1q_s8((src_ic_0_3 + 2 * 16)); - src[3] = vld1q_s8((src_ic_0_3 + 3 * 16)); - src[4] = vld1q_s8((src_ic_0_3 + 4 * 16)); - src[5] = vld1q_s8((src_ic_0_3 + 5 * 16)); - src[6] = vld1q_s8((src_ic_0_3 + 6 * 16)); - src[7] = vld1q_s8((src_ic_0_3 + 7 * 16)); - src[8] = vld1q_s8((src_ic_0_3 + 8 * 16)); - - // oc == 0 - const int8_t* read_weight_ptr = - weight_ptr + fh_idx * fw * ld_weight_ic4; - - weight[0][0] = vld1q_s8(read_weight_ptr); - weight[0][1] = vld1q_s8(read_weight_ptr + 16); - weight[1][0] = vld1q_s8(read_weight_ptr + ld_weight_oc4); - weight[1][1] = vld1q_s8(read_weight_ptr + ld_weight_oc4 + 16); - - c[0][0] = vdotq_s32_h(weight[0][0], src[0], c[0][0], temp_c[0]); - c[1][0] = vdotq_s32_h(weight[1][0], src[0], c[1][0], temp_c[1]); - c[0][1] = vdotq_s32_h(weight[0][0], src[1], c[0][1], temp_c[2]); - c[1][1] = vdotq_s32_h(weight[1][0], src[1], c[1][1], temp_c[3]); - c[0][0] = vdotq_s32_h(weight[0][1], src[1], c[0][0], temp_c[0]); - c[1][0] = vdotq_s32_h(weight[1][1], src[1], c[1][0], temp_c[1]); - c[0][1] = vdotq_s32_h(weight[0][1], src[2], c[0][1], temp_c[2]); - c[1][1] = vdotq_s32_h(weight[1][1], src[2], c[1][1], temp_c[3]); - - c[0][2] = vdotq_s32_h(weight[0][0], src[2], c[0][2], temp_c[0]); - c[1][2] = vdotq_s32_h(weight[1][0], src[2], c[1][2], temp_c[1]); - c[0][3] = vdotq_s32_h(weight[0][0], src[3], c[0][3], temp_c[2]); - c[1][3] = vdotq_s32_h(weight[1][0], src[3], c[1][3], temp_c[3]); - c[0][2] = vdotq_s32_h(weight[0][1], src[3], c[0][2], temp_c[0]); - c[1][2] = vdotq_s32_h(weight[1][1], src[3], c[1][2], temp_c[1]); - c[0][3] = vdotq_s32_h(weight[0][1], src[4], c[0][3], temp_c[2]); - c[1][3] = vdotq_s32_h(weight[1][1], src[4], c[1][3], temp_c[3]); - - c[0][4] = vdotq_s32_h(weight[0][0], src[4], c[0][4], temp_c[0]); - c[1][4] = vdotq_s32_h(weight[1][0], src[4], c[1][4], temp_c[1]); - c[0][5] = vdotq_s32_h(weight[0][0], src[5], c[0][5], temp_c[2]); - c[1][5] = vdotq_s32_h(weight[1][0], src[5], c[1][5], temp_c[3]); - c[0][4] = vdotq_s32_h(weight[0][1], src[5], c[0][4], temp_c[0]); - c[1][4] = vdotq_s32_h(weight[1][1], src[5], c[1][4], temp_c[1]); - c[0][5] = vdotq_s32_h(weight[0][1], src[6], c[0][5], temp_c[2]); - c[1][5] = vdotq_s32_h(weight[1][1], src[6], c[1][5], temp_c[3]); - - c[0][6] = vdotq_s32_h(weight[0][0], src[6], c[0][6], temp_c[0]); - c[1][6] = vdotq_s32_h(weight[1][0], src[6], c[1][6], temp_c[1]); - c[0][7] = vdotq_s32_h(weight[0][0], src[7], c[0][7], temp_c[2]); - c[1][7] = vdotq_s32_h(weight[1][0], src[7], c[1][7], temp_c[3]); - c[0][6] = vdotq_s32_h(weight[0][1], src[7], c[0][6], temp_c[0]); - c[1][6] = vdotq_s32_h(weight[1][1], src[7], c[1][6], temp_c[1]); - c[0][7] = vdotq_s32_h(weight[0][1], src[8], c[0][7], temp_c[2]); - c[1][7] = vdotq_s32_h(weight[1][1], src[8], c[1][7], temp_c[3]); - } - weight_ptr += fh * fw * ld_weight_ic4; - } - store_ocx_ow8_remain_static_dt( - c, op, dst_ptr, ld_dst_oc); -} - -template -static void ker_neon_dirctconv_2x2s1_oc4_ow8(const int8_t* src_ptr, - const int8_t* weight_ptr, - const int32_t* bias_ptr, - DstType* dst_ptr, int ic, int ih, - int iw, int ld_dst_oc, - const Op& op) { - constexpr int fh = filter_size; - constexpr int fw = filter_size; - constexpr int oc_step = 4; - constexpr int ic_step = 4; - constexpr int loop_ic_step = 4; - constexpr int ld_weight_ic4 = 16; - constexpr int pack_iw_len = 4; - - const int ic_stride = ih * iw * pack_iw_len; - - int32x4_t c[1][8]; - int8x16_t weight[1][2]; - int8x16_t src[8 + 1]; - int16x8_t temp_c[2]; - init_ocx_ow8(c, bias_ptr, oc_step); - - for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { - for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { - const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + - fh_idx * iw * ic_step * pack_iw_len; - - src[0] = vld1q_s8(src_ic_0_3); - src[1] = vld1q_s8((src_ic_0_3 + 16)); - src[2] = vld1q_s8((src_ic_0_3 + 2 * 16)); - src[3] = vld1q_s8((src_ic_0_3 + 3 * 16)); - src[4] = vld1q_s8((src_ic_0_3 + 4 * 16)); - src[5] = vld1q_s8((src_ic_0_3 + 5 * 16)); - src[6] = vld1q_s8((src_ic_0_3 + 6 * 16)); - src[7] = vld1q_s8((src_ic_0_3 + 7 * 16)); - src[8] = vld1q_s8((src_ic_0_3 + 8 * 16)); - - // oc == 0 - const int8_t* read_weight_ptr = - weight_ptr + fh_idx * fw * ld_weight_ic4; - - weight[0][0] = vld1q_s8(read_weight_ptr); - weight[0][1] = vld1q_s8(read_weight_ptr + 16); - - c[0][0] = vdotq_s32_h(weight[0][0], src[0], c[0][0], temp_c[0]); - c[0][1] = vdotq_s32_h(weight[0][0], src[1], c[0][1], temp_c[1]); - c[0][0] = vdotq_s32_h(weight[0][1], src[1], c[0][0], temp_c[0]); - c[0][1] = vdotq_s32_h(weight[0][1], src[2], c[0][1], temp_c[1]); - - c[0][2] = vdotq_s32_h(weight[0][0], src[2], c[0][2], temp_c[0]); - c[0][3] = vdotq_s32_h(weight[0][0], src[3], c[0][3], temp_c[1]); - c[0][2] = vdotq_s32_h(weight[0][1], src[3], c[0][2], temp_c[0]); - c[0][3] = vdotq_s32_h(weight[0][1], src[4], c[0][3], temp_c[1]); - - c[0][4] = vdotq_s32_h(weight[0][0], src[4], c[0][4], temp_c[0]); - c[0][5] = vdotq_s32_h(weight[0][0], src[5], c[0][5], temp_c[1]); - c[0][4] = vdotq_s32_h(weight[0][1], src[5], c[0][4], temp_c[0]); - c[0][5] = vdotq_s32_h(weight[0][1], src[6], c[0][5], temp_c[1]); - - c[0][6] = vdotq_s32_h(weight[0][0], src[6], c[0][6], temp_c[0]); - c[0][7] = vdotq_s32_h(weight[0][0], src[7], c[0][7], temp_c[1]); - c[0][6] = vdotq_s32_h(weight[0][1], src[7], c[0][6], temp_c[0]); - c[0][7] = vdotq_s32_h(weight[0][1], src[8], c[0][7], temp_c[1]); - } - weight_ptr += fh * fw * ld_weight_ic4; - } - - store_ocx_ow8_remain_static_dt( - c, op, dst_ptr, ld_dst_oc); -} - -template -struct KerNeonDirectStride1Int8 { - static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, - const int32_t* bias_ptr, DstType* dst_ptr, int ic, int ih, - int iw, const Op& op, int ld_dst_oc); -}; -template -struct KerNeonDirectStride1Int8 { - static void impl(const int8_t*, const int8_t*, const int32_t*, DstType*, - int, int, int, const Op&, int) { - megdnn_throw("no impl"); - } -}; -/** -dot like impl. dot 4 ic to 1 oc, accumale to c -example: (format like weight) -packed weight -low 64 bit <0, 0> <0, 1> <1, 2> <1, 3> | <2, 0> <2, 1> <3, 2> <3, 3> ---------------------------------------------------------------------- -high 64 bit <0, 3> <0, 2> <1, 1> <1, 0> | <2, 3> <2, 2> <3, 1> <3, 0> -dot: (<0, 0> + <0, 3>) + (<0, 1> + <0, 2>) -> <0> -**/ -//! TODO: can try oh = 2 impl, oc = 8 impl -template -struct KerNeonDirectStride1Int8 { - static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, - const int32_t* bias_ptr, DstType* dst_ptr, int ic, int ih, - int iw, const Op& op, int ld_dst_oc) { - constexpr int filter_size = 3; - constexpr int fh = filter_size; - constexpr int fw = filter_size; - constexpr int oc_step = 4; - constexpr int ic_step = 4; - constexpr int loop_ic_step = 4; - constexpr int ld_weight_ic4 = 16; - constexpr int pack_iw_len = 4; - - const int ic_stride = ih * iw * pack_iw_len; - - int32x4_t c[c_dim][8]; - int8x16_t weight[3]; - int8x16_t src[8 + 2]; - int16x8_t temp_c[2]; - init_ocx_ow8(c, bias_ptr, oc_step); - - for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { - for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { - const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + - fh_idx * iw * ic_step * pack_iw_len; - - src[0] = vld1q_s8(src_ic_0_3); - src[1] = vld1q_s8((src_ic_0_3 + 16)); - src[2] = vld1q_s8((src_ic_0_3 + 2 * 16)); - src[3] = vld1q_s8((src_ic_0_3 + 3 * 16)); - src[4] = vld1q_s8((src_ic_0_3 + 4 * 16)); - src[5] = vld1q_s8((src_ic_0_3 + 5 * 16)); - src[6] = vld1q_s8((src_ic_0_3 + 6 * 16)); - src[7] = vld1q_s8((src_ic_0_3 + 7 * 16)); - src[8] = vld1q_s8((src_ic_0_3 + 8 * 16)); - src[9] = vld1q_s8((src_ic_0_3 + 9 * 16)); - - // oc == 0 - const int8_t* read_weight_ptr = - weight_ptr + fh_idx * fw * ld_weight_ic4; - - weight[0] = vld1q_s8(read_weight_ptr); - weight[1] = vld1q_s8(read_weight_ptr + 16); - weight[2] = vld1q_s8(read_weight_ptr + 2 * 16); - - c[0][0] = vdotq_s32_h(weight[0], src[0], c[0][0], temp_c[0]); - c[0][1] = vdotq_s32_h(weight[0], src[1], c[0][1], temp_c[1]); - c[0][0] = vdotq_s32_h(weight[1], src[1], c[0][0], temp_c[0]); - c[0][1] = vdotq_s32_h(weight[1], src[2], c[0][1], temp_c[1]); - c[0][0] = vdotq_s32_h(weight[2], src[2], c[0][0], temp_c[0]); - c[0][1] = vdotq_s32_h(weight[2], src[3], c[0][1], temp_c[1]); - - c[0][2] = vdotq_s32_h(weight[0], src[2], c[0][2], temp_c[0]); - c[0][3] = vdotq_s32_h(weight[0], src[3], c[0][3], temp_c[1]); - c[0][2] = vdotq_s32_h(weight[1], src[3], c[0][2], temp_c[0]); - c[0][3] = vdotq_s32_h(weight[1], src[4], c[0][3], temp_c[1]); - c[0][2] = vdotq_s32_h(weight[2], src[4], c[0][2], temp_c[0]); - c[0][3] = vdotq_s32_h(weight[2], src[5], c[0][3], temp_c[1]); - - c[0][4] = vdotq_s32_h(weight[0], src[4], c[0][4], temp_c[0]); - c[0][5] = vdotq_s32_h(weight[0], src[5], c[0][5], temp_c[1]); - c[0][4] = vdotq_s32_h(weight[1], src[5], c[0][4], temp_c[0]); - c[0][5] = vdotq_s32_h(weight[1], src[6], c[0][5], temp_c[1]); - c[0][4] = vdotq_s32_h(weight[2], src[6], c[0][4], temp_c[0]); - c[0][5] = vdotq_s32_h(weight[2], src[7], c[0][5], temp_c[1]); - - c[0][6] = vdotq_s32_h(weight[0], src[6], c[0][6], temp_c[0]); - c[0][7] = vdotq_s32_h(weight[0], src[7], c[0][7], temp_c[1]); - c[0][6] = vdotq_s32_h(weight[1], src[7], c[0][6], temp_c[0]); - c[0][7] = vdotq_s32_h(weight[1], src[8], c[0][7], temp_c[1]); - c[0][6] = vdotq_s32_h(weight[2], src[8], c[0][6], temp_c[0]); - c[0][7] = vdotq_s32_h(weight[2], src[9], c[0][7], temp_c[1]); - } - weight_ptr += fh * fw * ld_weight_ic4; - } - store_ocx_ow8_remain_static_dt<1, remain_w, Op, DstType*>( - c, op, dst_ptr, ld_dst_oc); - } -}; - -template -struct KerNeonDirectStride1Int8 { - static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, - const int32_t* bias_ptr, DstType* dst_ptr, int ic, int ih, - int iw, const Op& op, int ld_dst_oc) { - constexpr int filter_size = 5; - constexpr int fh = filter_size; - constexpr int fw = filter_size; - constexpr int oc_step = 4; - constexpr int ic_step = 4; - constexpr int loop_ic_step = 4; - constexpr int ld_weight_ic4 = 16; - constexpr int pack_iw_len = 4; - - const int ic_stride = ih * iw * pack_iw_len; - - int32x4_t c[c_dim][8]; - int8x16_t weight[5]; - int8x16_t src[8 + 2]; - int16x8_t temp_c[2]; - init_ocx_ow8(c, bias_ptr, oc_step); - - for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { - for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { - const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + - fh_idx * iw * ic_step * pack_iw_len; - - src[0] = vld1q_s8(src_ic_0_3); - src[1] = vld1q_s8((src_ic_0_3 + 16)); - src[2] = vld1q_s8((src_ic_0_3 + 2 * 16)); - src[3] = vld1q_s8((src_ic_0_3 + 3 * 16)); - src[4] = vld1q_s8((src_ic_0_3 + 4 * 16)); - src[5] = vld1q_s8((src_ic_0_3 + 5 * 16)); - src[6] = vld1q_s8((src_ic_0_3 + 6 * 16)); - src[7] = vld1q_s8((src_ic_0_3 + 7 * 16)); - src[8] = vld1q_s8((src_ic_0_3 + 8 * 16)); - src[9] = vld1q_s8((src_ic_0_3 + 9 * 16)); - - // oc == 0 - const int8_t* read_weight_ptr = - weight_ptr + fh_idx * fw * ld_weight_ic4; - - weight[0] = vld1q_s8(read_weight_ptr); - weight[1] = vld1q_s8(read_weight_ptr + 16); - weight[2] = vld1q_s8(read_weight_ptr + 2 * 16); - weight[3] = vld1q_s8(read_weight_ptr + 3 * 16); - weight[4] = vld1q_s8(read_weight_ptr + 4 * 16); - - c[0][0] = vdotq_s32_h(weight[0], src[0], c[0][0], temp_c[0]); - c[0][1] = vdotq_s32_h(weight[0], src[1], c[0][1], temp_c[1]); - c[0][0] = vdotq_s32_h(weight[1], src[1], c[0][0], temp_c[0]); - c[0][1] = vdotq_s32_h(weight[1], src[2], c[0][1], temp_c[1]); - c[0][0] = vdotq_s32_h(weight[2], src[2], c[0][0], temp_c[0]); - c[0][1] = vdotq_s32_h(weight[2], src[3], c[0][1], temp_c[1]); - c[0][0] = vdotq_s32_h(weight[3], src[3], c[0][0], temp_c[0]); - c[0][1] = vdotq_s32_h(weight[3], src[4], c[0][1], temp_c[1]); - c[0][0] = vdotq_s32_h(weight[4], src[4], c[0][0], temp_c[0]); - c[0][1] = vdotq_s32_h(weight[4], src[5], c[0][1], temp_c[1]); - - c[0][2] = vdotq_s32_h(weight[0], src[2], c[0][2], temp_c[0]); - c[0][3] = vdotq_s32_h(weight[0], src[3], c[0][3], temp_c[1]); - c[0][2] = vdotq_s32_h(weight[1], src[3], c[0][2], temp_c[0]); - c[0][3] = vdotq_s32_h(weight[1], src[4], c[0][3], temp_c[1]); - c[0][2] = vdotq_s32_h(weight[2], src[4], c[0][2], temp_c[0]); - c[0][3] = vdotq_s32_h(weight[2], src[5], c[0][3], temp_c[1]); - c[0][2] = vdotq_s32_h(weight[3], src[5], c[0][2], temp_c[0]); - c[0][3] = vdotq_s32_h(weight[3], src[6], c[0][3], temp_c[1]); - c[0][2] = vdotq_s32_h(weight[4], src[6], c[0][2], temp_c[0]); - c[0][3] = vdotq_s32_h(weight[4], src[7], c[0][3], temp_c[1]); - - c[0][4] = vdotq_s32_h(weight[0], src[4], c[0][4], temp_c[0]); - c[0][5] = vdotq_s32_h(weight[0], src[5], c[0][5], temp_c[1]); - c[0][4] = vdotq_s32_h(weight[1], src[5], c[0][4], temp_c[0]); - c[0][5] = vdotq_s32_h(weight[1], src[6], c[0][5], temp_c[1]); - c[0][4] = vdotq_s32_h(weight[2], src[6], c[0][4], temp_c[0]); - c[0][5] = vdotq_s32_h(weight[2], src[7], c[0][5], temp_c[1]); - c[0][4] = vdotq_s32_h(weight[3], src[7], c[0][4], temp_c[0]); - c[0][5] = vdotq_s32_h(weight[3], src[8], c[0][5], temp_c[1]); - c[0][4] = vdotq_s32_h(weight[4], src[8], c[0][4], temp_c[0]); - c[0][5] = vdotq_s32_h(weight[4], src[9], c[0][5], temp_c[1]); - - src[0] = vld1q_s8(src_ic_0_3 + 10 * 16); - src[1] = vld1q_s8((src_ic_0_3 + 11 * 16)); - - c[0][6] = vdotq_s32_h(weight[0], src[6], c[0][6], temp_c[0]); - c[0][7] = vdotq_s32_h(weight[0], src[7], c[0][7], temp_c[1]); - c[0][6] = vdotq_s32_h(weight[1], src[7], c[0][6], temp_c[0]); - c[0][7] = vdotq_s32_h(weight[1], src[8], c[0][7], temp_c[1]); - c[0][6] = vdotq_s32_h(weight[2], src[8], c[0][6], temp_c[0]); - c[0][7] = vdotq_s32_h(weight[2], src[9], c[0][7], temp_c[1]); - c[0][6] = vdotq_s32_h(weight[3], src[9], c[0][6], temp_c[0]); - c[0][7] = vdotq_s32_h(weight[3], src[0], c[0][7], temp_c[1]); - c[0][6] = vdotq_s32_h(weight[4], src[0], c[0][6], temp_c[0]); - c[0][7] = vdotq_s32_h(weight[4], src[1], c[0][7], temp_c[1]); - } - weight_ptr += fh * fw * ld_weight_ic4; - } - store_ocx_ow8_remain_static_dt<1, remain_w, Op, DstType*>( - c, op, dst_ptr, ld_dst_oc); - } -}; - -template -struct KerNeonDirectStride1Int8 { - static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, - const int32_t* bias_ptr, DstType* dst_ptr, int ic, int ih, - int iw, const Op& op, int ld_dst_oc) { - constexpr int filter_size = 7; - constexpr int fh = filter_size; - constexpr int fw = filter_size; - constexpr int oc_step = 4; - constexpr int ic_step = 4; - constexpr int loop_ic_step = 4; - constexpr int ld_weight_ic4 = 16; - constexpr int pack_iw_len = 4; - - const int ic_stride = ih * iw * pack_iw_len; - - int32x4_t c[c_dim][8]; - int8x16_t weight[7]; - int8x16_t src[8 + 2]; - int16x8_t temp_c[2]; - init_ocx_ow8(c, bias_ptr, oc_step); - - for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { - for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { - const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + - fh_idx * iw * ic_step * pack_iw_len; - - src[0] = vld1q_s8(src_ic_0_3); - src[1] = vld1q_s8((src_ic_0_3 + 16)); - src[2] = vld1q_s8((src_ic_0_3 + 2 * 16)); - src[3] = vld1q_s8((src_ic_0_3 + 3 * 16)); - src[4] = vld1q_s8((src_ic_0_3 + 4 * 16)); - src[5] = vld1q_s8((src_ic_0_3 + 5 * 16)); - src[6] = vld1q_s8((src_ic_0_3 + 6 * 16)); - src[7] = vld1q_s8((src_ic_0_3 + 7 * 16)); - src[8] = vld1q_s8((src_ic_0_3 + 8 * 16)); - src[9] = vld1q_s8((src_ic_0_3 + 9 * 16)); - - // oc == 0 - const int8_t* read_weight_ptr = - weight_ptr + fh_idx * fw * ld_weight_ic4; - - weight[0] = vld1q_s8(read_weight_ptr); - weight[1] = vld1q_s8(read_weight_ptr + 16); - weight[2] = vld1q_s8(read_weight_ptr + 2 * 16); - weight[3] = vld1q_s8(read_weight_ptr + 3 * 16); - weight[4] = vld1q_s8(read_weight_ptr + 4 * 16); - weight[5] = vld1q_s8(read_weight_ptr + 5 * 16); - weight[6] = vld1q_s8(read_weight_ptr + 6 * 16); - - c[0][0] = vdotq_s32_h(weight[0], src[0], c[0][0], temp_c[0]); - c[0][1] = vdotq_s32_h(weight[0], src[1], c[0][1], temp_c[1]); - c[0][0] = vdotq_s32_h(weight[1], src[1], c[0][0], temp_c[0]); - c[0][1] = vdotq_s32_h(weight[1], src[2], c[0][1], temp_c[1]); - c[0][0] = vdotq_s32_h(weight[2], src[2], c[0][0], temp_c[0]); - c[0][1] = vdotq_s32_h(weight[2], src[3], c[0][1], temp_c[1]); - c[0][0] = vdotq_s32_h(weight[3], src[3], c[0][0], temp_c[0]); - c[0][1] = vdotq_s32_h(weight[3], src[4], c[0][1], temp_c[1]); - c[0][0] = vdotq_s32_h(weight[4], src[4], c[0][0], temp_c[0]); - c[0][1] = vdotq_s32_h(weight[4], src[5], c[0][1], temp_c[1]); - c[0][0] = vdotq_s32_h(weight[5], src[5], c[0][0], temp_c[0]); - c[0][1] = vdotq_s32_h(weight[5], src[6], c[0][1], temp_c[1]); - c[0][0] = vdotq_s32_h(weight[6], src[6], c[0][0], temp_c[0]); - c[0][1] = vdotq_s32_h(weight[6], src[7], c[0][1], temp_c[1]); - - c[0][2] = vdotq_s32_h(weight[0], src[2], c[0][2], temp_c[0]); - c[0][3] = vdotq_s32_h(weight[0], src[3], c[0][3], temp_c[1]); - c[0][2] = vdotq_s32_h(weight[1], src[3], c[0][2], temp_c[0]); - c[0][3] = vdotq_s32_h(weight[1], src[4], c[0][3], temp_c[1]); - c[0][2] = vdotq_s32_h(weight[2], src[4], c[0][2], temp_c[0]); - c[0][3] = vdotq_s32_h(weight[2], src[5], c[0][3], temp_c[1]); - c[0][2] = vdotq_s32_h(weight[3], src[5], c[0][2], temp_c[0]); - c[0][3] = vdotq_s32_h(weight[3], src[6], c[0][3], temp_c[1]); - c[0][2] = vdotq_s32_h(weight[4], src[6], c[0][2], temp_c[0]); - c[0][3] = vdotq_s32_h(weight[4], src[7], c[0][3], temp_c[1]); - c[0][2] = vdotq_s32_h(weight[5], src[7], c[0][2], temp_c[0]); - c[0][3] = vdotq_s32_h(weight[5], src[8], c[0][3], temp_c[1]); - c[0][2] = vdotq_s32_h(weight[6], src[8], c[0][2], temp_c[0]); - c[0][3] = vdotq_s32_h(weight[6], src[9], c[0][3], temp_c[1]); - - src[0] = vld1q_s8(src_ic_0_3 + 10 * 16); - src[1] = vld1q_s8((src_ic_0_3 + 11 * 16)); - - c[0][4] = vdotq_s32_h(weight[0], src[4], c[0][4], temp_c[0]); - c[0][5] = vdotq_s32_h(weight[0], src[5], c[0][5], temp_c[1]); - c[0][4] = vdotq_s32_h(weight[1], src[5], c[0][4], temp_c[0]); - c[0][5] = vdotq_s32_h(weight[1], src[6], c[0][5], temp_c[1]); - c[0][4] = vdotq_s32_h(weight[2], src[6], c[0][4], temp_c[0]); - c[0][5] = vdotq_s32_h(weight[2], src[7], c[0][5], temp_c[1]); - c[0][4] = vdotq_s32_h(weight[3], src[7], c[0][4], temp_c[0]); - c[0][5] = vdotq_s32_h(weight[3], src[8], c[0][5], temp_c[1]); - c[0][4] = vdotq_s32_h(weight[4], src[8], c[0][4], temp_c[0]); - c[0][5] = vdotq_s32_h(weight[4], src[9], c[0][5], temp_c[1]); - c[0][4] = vdotq_s32_h(weight[5], src[9], c[0][4], temp_c[0]); - c[0][5] = vdotq_s32_h(weight[5], src[0], c[0][5], temp_c[1]); - c[0][4] = vdotq_s32_h(weight[6], src[0], c[0][4], temp_c[0]); - c[0][5] = vdotq_s32_h(weight[6], src[1], c[0][5], temp_c[1]); - - src[2] = vld1q_s8(src_ic_0_3 + 12 * 16); - src[3] = vld1q_s8((src_ic_0_3 + 13 * 16)); - - c[0][6] = vdotq_s32_h(weight[0], src[6], c[0][6], temp_c[0]); - c[0][7] = vdotq_s32_h(weight[0], src[7], c[0][7], temp_c[1]); - c[0][6] = vdotq_s32_h(weight[1], src[7], c[0][6], temp_c[0]); - c[0][7] = vdotq_s32_h(weight[1], src[8], c[0][7], temp_c[1]); - c[0][6] = vdotq_s32_h(weight[2], src[8], c[0][6], temp_c[0]); - c[0][7] = vdotq_s32_h(weight[2], src[9], c[0][7], temp_c[1]); - c[0][6] = vdotq_s32_h(weight[3], src[9], c[0][6], temp_c[0]); - c[0][7] = vdotq_s32_h(weight[3], src[0], c[0][7], temp_c[1]); - c[0][6] = vdotq_s32_h(weight[4], src[0], c[0][6], temp_c[0]); - c[0][7] = vdotq_s32_h(weight[4], src[1], c[0][7], temp_c[1]); - c[0][6] = vdotq_s32_h(weight[5], src[1], c[0][6], temp_c[0]); - c[0][7] = vdotq_s32_h(weight[5], src[2], c[0][7], temp_c[1]); - c[0][6] = vdotq_s32_h(weight[6], src[2], c[0][6], temp_c[0]); - c[0][7] = vdotq_s32_h(weight[6], src[3], c[0][7], temp_c[1]); - } - weight_ptr += fh * fw * ld_weight_ic4; - } - store_ocx_ow8_remain_static_dt<1, remain_w, Op, DstType*>( - c, op, dst_ptr, ld_dst_oc); - } -}; +namespace int8_direct_nchw44 { /** origin weight shape @@ -568,799 +69,8 @@ static inline void nchw44_pack_src(const int8_t* src, int8_t* dst, int length) { } } -template -void conv_direct_stride1_2x2_int8_nchw44(const int8_t* src, - const int8_t* filter, - const int32_t* bias, int32_t* temp, - DstType* dst, const size_t oc, - const size_t ic, const size_t ih, - const size_t iw, const size_t oh, - const size_t ow, const Op& op) { - MEGDNN_MARK_USED_VAR(temp); - constexpr size_t filter_size = 2; - constexpr size_t fh = filter_size; - constexpr size_t fw = filter_size; - constexpr size_t ic_step = 4; - constexpr size_t oc_step = 4; - constexpr size_t big_oc_step = 8; - constexpr size_t oh_step = 1; - constexpr size_t ow_step = 8; - constexpr int pack_iw_len = 4; - - const size_t img_stride = oh * ow; - const size_t ow_end = ow / ow_step * ow_step; - const size_t ow_remain = ow - ow_end; - const size_t oc_end = oc / big_oc_step * big_oc_step; - const size_t oc_remain = oc - oc_end; - const int ld_oc = oh * ow * oc_step; - for (size_t oc_idx = 0; oc_idx < oc_end; oc_idx += big_oc_step) { - const size_t weight_offset = oc_idx * ic * fh * fw; - for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { - for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { - const size_t src_offset = - (oh_idx * iw + ow_idx) * ic_step * pack_iw_len; - const size_t dst_offset = - oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; - ker_neon_dirctconv_2x2s1_oc8_ow8( - src + src_offset, filter + weight_offset, bias + oc_idx, - dst + dst_offset, ic, ih, iw, ld_oc, op); - } - if (ow_remain > 0) { - const size_t src_offset = - (oh_idx * iw + ow_end) * ic_step * pack_iw_len; - const size_t dst_offset = - oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; - ker_neon_dirctconv_2x2s1_oc8_ow8( - src + src_offset, filter + weight_offset, bias + oc_idx, - dst + dst_offset, ic, ih, iw, ld_oc, op); - } - } - } - if (oc_remain > 0) { - const size_t oc_idx = oc_end; - const size_t weight_offset = oc_idx * ic * fh * fw; - for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { - for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { - const size_t src_offset = - (oh_idx * iw + ow_idx) * ic_step * pack_iw_len; - const size_t dst_offset = - oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; - ker_neon_dirctconv_2x2s1_oc4_ow8( - src + src_offset, filter + weight_offset, bias + oc_idx, - dst + dst_offset, ic, ih, iw, ld_oc, op); - } - if (ow_remain > 0) { - const size_t src_offset = - (oh_idx * iw + ow_end) * ic_step * pack_iw_len; - const size_t dst_offset = - oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; - ker_neon_dirctconv_2x2s1_oc4_ow8( - src + src_offset, filter + weight_offset, bias + oc_idx, - dst + dst_offset, ic, ih, iw, ld_oc, op); - } - } - } -} -template -void conv_direct_stride1_int8_nchw44_kern(const int8_t* src, - const int8_t* filter, - const int32_t* bias, int32_t* temp, - DstType* dst, const size_t oc, - const size_t ic, const size_t ih, - const size_t iw, const size_t oh, - const size_t ow, const Op& op) { - MEGDNN_MARK_USED_VAR(temp); - constexpr size_t fh = filter_size; - constexpr size_t fw = filter_size; - constexpr size_t ic_step = 4; - constexpr size_t oc_step = 4; - constexpr size_t oh_step = 1; - constexpr size_t ow_step = 8; - constexpr int pack_iw_len = 4; - - const size_t img_stride = oh * ow; - const int ld_dst_oc = oh * ow * oc_step; - const size_t ow_end = ow / ow_step * ow_step; - const size_t ow_remain = ow - ow_end; - for (size_t oc_idx = 0; oc_idx < oc; oc_idx += oc_step) { - const size_t weight_offset = oc_idx * ic * fh * fw; - for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { - for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { - const size_t src_offset = - (oh_idx * iw + ow_idx) * ic_step * pack_iw_len; - const size_t dst_offset = - oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; - KerNeonDirectStride1Int8::impl(src + src_offset, - filter + weight_offset, - bias + oc_idx, - dst + dst_offset, ic, - ih, iw, op, ld_dst_oc); - } - if (ow_remain > 0) { - const size_t src_offset = - (oh_idx * iw + ow_end) * ic_step * pack_iw_len; - const size_t dst_offset = - oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; - KerNeonDirectStride1Int8::impl(src + src_offset, - filter + weight_offset, - bias + oc_idx, - dst + dst_offset, ic, - ih, iw, op, ld_dst_oc); - } - } - } -} -/////////////////////stride 2///////////////// -template -struct KerNeonDirectStride2Int8 { - static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, - const int32_t* bias_ptr, DstType* dst_ptr, int ic, int ih, - int iw, const Op& op, int ld_dst_oc); -}; -template -struct KerNeonDirectStride2Int8 { - static void impl(const int8_t*, const int8_t*, const int32_t*, DstType*, - int, int, int, const Op&, int) { - megdnn_throw("no impl"); - } -}; - -template -static void ker_neon_dirctconv_2x2s2_oc8_ow8(const int8_t* src_ptr, - const int8_t* weight_ptr, - const int32_t* bias_ptr, - DstType* dst_ptr, int ic, int ih, - int iw, int ld_dst_oc, - const Op& op) { - constexpr int fh = filter_size; - constexpr int fw = filter_size; - constexpr int ic_step = 4; - constexpr int oc_step = 4; - constexpr int loop_ic_step = 4; - constexpr int ld_weight_ic4 = 16; - constexpr int pack_iw_len = 4; - - const int ic_stride = ih * iw * pack_iw_len; - const int ld_weight_oc4 = oc_step * fh * fw * ic; - - int32x4_t c[2][8]; - int8x16_t weight[2][2]; - int8x16_t src[8 + 1]; - int16x8_t temp_c[4]; - - init_ocx_ow8(c, bias_ptr, oc_step); - - for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { - for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { - const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + - fh_idx * iw * ic_step * pack_iw_len; - - src[0] = vld1q_s8(src_ic_0_3); - src[1] = vld1q_s8(src_ic_0_3 + 16); - src[2] = vld1q_s8(src_ic_0_3 + 2 * 16); - src[3] = vld1q_s8(src_ic_0_3 + 3 * 16); - src[4] = vld1q_s8(src_ic_0_3 + 4 * 16); - src[5] = vld1q_s8(src_ic_0_3 + 5 * 16); - src[6] = vld1q_s8(src_ic_0_3 + 6 * 16); - src[7] = vld1q_s8(src_ic_0_3 + 7 * 16); - src[8] = vld1q_s8(src_ic_0_3 + 8 * 16); - - // oc == 0 - const int8_t* read_weight_ptr = - weight_ptr + fh_idx * fw * ld_weight_ic4; - - weight[0][0] = vld1q_s8(read_weight_ptr); - weight[0][1] = vld1q_s8(read_weight_ptr + 16); - weight[1][0] = vld1q_s8(read_weight_ptr + ld_weight_oc4); - weight[1][1] = vld1q_s8(read_weight_ptr + ld_weight_oc4 + 16); - - c[0][0] = vdotq_s32_h(weight[0][0], src[0], c[0][0], temp_c[0]); - c[1][0] = vdotq_s32_h(weight[1][0], src[0], c[1][0], temp_c[1]); - c[0][1] = vdotq_s32_h(weight[0][0], src[2], c[0][1], temp_c[2]); - c[1][1] = vdotq_s32_h(weight[1][0], src[2], c[1][1], temp_c[3]); - c[0][0] = vdotq_s32_h(weight[0][1], src[1], c[0][0], temp_c[0]); - c[1][0] = vdotq_s32_h(weight[1][1], src[1], c[1][0], temp_c[1]); - c[0][1] = vdotq_s32_h(weight[0][1], src[3], c[0][1], temp_c[2]); - c[1][1] = vdotq_s32_h(weight[1][1], src[3], c[1][1], temp_c[3]); - - c[0][2] = vdotq_s32_h(weight[0][0], src[4], c[0][2], temp_c[0]); - c[1][2] = vdotq_s32_h(weight[1][0], src[4], c[1][2], temp_c[1]); - c[0][3] = vdotq_s32_h(weight[0][0], src[6], c[0][3], temp_c[2]); - c[1][3] = vdotq_s32_h(weight[1][0], src[6], c[1][3], temp_c[3]); - c[0][2] = vdotq_s32_h(weight[0][1], src[5], c[0][2], temp_c[0]); - c[1][2] = vdotq_s32_h(weight[1][1], src[5], c[1][2], temp_c[1]); - c[0][3] = vdotq_s32_h(weight[0][1], src[7], c[0][3], temp_c[2]); - c[1][3] = vdotq_s32_h(weight[1][1], src[7], c[1][3], temp_c[3]); - - src[0] = vld1q_s8(src_ic_0_3 + 9 * 16); - src[1] = vld1q_s8(src_ic_0_3 + 10 * 16); - src[2] = vld1q_s8(src_ic_0_3 + 11 * 16); - c[0][4] = vdotq_s32_h(weight[0][0], src[8], c[0][4], temp_c[0]); - c[1][4] = vdotq_s32_h(weight[1][0], src[8], c[1][4], temp_c[1]); - c[0][5] = vdotq_s32_h(weight[0][0], src[1], c[0][5], temp_c[2]); - c[1][5] = vdotq_s32_h(weight[1][0], src[1], c[1][5], temp_c[3]); - c[0][4] = vdotq_s32_h(weight[0][1], src[0], c[0][4], temp_c[0]); - c[1][4] = vdotq_s32_h(weight[1][1], src[0], c[1][4], temp_c[1]); - c[0][5] = vdotq_s32_h(weight[0][1], src[2], c[0][5], temp_c[2]); - c[1][5] = vdotq_s32_h(weight[1][1], src[2], c[1][5], temp_c[3]); - - src[3] = vld1q_s8(src_ic_0_3 + 12 * 16); - src[4] = vld1q_s8(src_ic_0_3 + 13 * 16); - src[5] = vld1q_s8(src_ic_0_3 + 14 * 16); - src[6] = vld1q_s8(src_ic_0_3 + 15 * 16); - c[0][6] = vdotq_s32_h(weight[0][0], src[3], c[0][6], temp_c[0]); - c[1][6] = vdotq_s32_h(weight[1][0], src[3], c[1][6], temp_c[1]); - c[0][7] = vdotq_s32_h(weight[0][0], src[5], c[0][7], temp_c[2]); - c[1][7] = vdotq_s32_h(weight[1][0], src[5], c[1][7], temp_c[3]); - c[0][6] = vdotq_s32_h(weight[0][1], src[4], c[0][6], temp_c[0]); - c[1][6] = vdotq_s32_h(weight[1][1], src[4], c[1][6], temp_c[1]); - c[0][7] = vdotq_s32_h(weight[0][1], src[6], c[0][7], temp_c[2]); - c[1][7] = vdotq_s32_h(weight[1][1], src[6], c[1][7], temp_c[3]); - } - weight_ptr += fh * fw * ld_weight_ic4; - } - store_ocx_ow8_remain_static_dt( - c, op, dst_ptr, ld_dst_oc); -} - -template -static void ker_neon_dirctconv_2x2s2_oc4_ow8(const int8_t* src_ptr, - const int8_t* weight_ptr, - const int32_t* bias_ptr, - DstType* dst_ptr, int ic, int ih, - int iw, int ld_dst_oc, - const Op& op) { - constexpr int fh = filter_size; - constexpr int fw = filter_size; - constexpr int oc_step = 4; - constexpr int ic_step = 4; - constexpr int loop_ic_step = 4; - constexpr int ld_weight_ic4 = 16; - constexpr int pack_iw_len = 4; - - const int ic_stride = ih * iw * pack_iw_len; - - int32x4_t c[c_dim][8]; - int8x16_t weight[2]; - int8x16_t src[8 + 1]; - int16x8_t temp_c[2]; - init_ocx_ow8(c, bias_ptr, oc_step); - - for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { - for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { - const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + - fh_idx * iw * ic_step * pack_iw_len; - - src[0] = vld1q_s8(src_ic_0_3); - src[1] = vld1q_s8((src_ic_0_3 + 16)); - src[2] = vld1q_s8((src_ic_0_3 + 2 * 16)); - src[3] = vld1q_s8((src_ic_0_3 + 3 * 16)); - src[4] = vld1q_s8((src_ic_0_3 + 4 * 16)); - src[5] = vld1q_s8((src_ic_0_3 + 5 * 16)); - src[6] = vld1q_s8((src_ic_0_3 + 6 * 16)); - src[7] = vld1q_s8((src_ic_0_3 + 7 * 16)); - src[8] = vld1q_s8((src_ic_0_3 + 8 * 16)); - - // oc == 0 - const int8_t* read_weight_ptr = - weight_ptr + fh_idx * fw * ld_weight_ic4; - - weight[0] = vld1q_s8(read_weight_ptr); - weight[1] = vld1q_s8(read_weight_ptr + 16); - - c[0][0] = vdotq_s32_h(weight[0], src[0], c[0][0], temp_c[0]); - c[0][1] = vdotq_s32_h(weight[0], src[2], c[0][1], temp_c[1]); - c[0][0] = vdotq_s32_h(weight[1], src[1], c[0][0], temp_c[0]); - c[0][1] = vdotq_s32_h(weight[1], src[3], c[0][1], temp_c[1]); - - c[0][2] = vdotq_s32_h(weight[0], src[4], c[0][2], temp_c[0]); - c[0][3] = vdotq_s32_h(weight[0], src[6], c[0][3], temp_c[1]); - c[0][2] = vdotq_s32_h(weight[1], src[5], c[0][2], temp_c[0]); - c[0][3] = vdotq_s32_h(weight[1], src[7], c[0][3], temp_c[1]); - - src[0] = vld1q_s8(src_ic_0_3 + 9 * 16); - src[1] = vld1q_s8(src_ic_0_3 + 10 * 16); - src[2] = vld1q_s8(src_ic_0_3 + 11 * 16); - c[0][4] = vdotq_s32_h(weight[0], src[8], c[0][4], temp_c[0]); - c[0][5] = vdotq_s32_h(weight[0], src[1], c[0][5], temp_c[1]); - c[0][4] = vdotq_s32_h(weight[1], src[0], c[0][4], temp_c[0]); - c[0][5] = vdotq_s32_h(weight[1], src[2], c[0][5], temp_c[1]); - - src[3] = vld1q_s8(src_ic_0_3 + 12 * 16); - src[4] = vld1q_s8(src_ic_0_3 + 13 * 16); - src[5] = vld1q_s8(src_ic_0_3 + 14 * 16); - src[6] = vld1q_s8(src_ic_0_3 + 15 * 16); - c[0][6] = vdotq_s32_h(weight[0], src[3], c[0][6], temp_c[0]); - c[0][7] = vdotq_s32_h(weight[0], src[5], c[0][7], temp_c[1]); - c[0][6] = vdotq_s32_h(weight[1], src[4], c[0][6], temp_c[0]); - c[0][7] = vdotq_s32_h(weight[1], src[6], c[0][7], temp_c[1]); - } - weight_ptr += fh * fw * ld_weight_ic4; - } - - store_ocx_ow8_remain_static_dt( - c, op, dst_ptr, ld_dst_oc); -} -/** -dot like impl. dot 4 ic to 1 oc, accumale to c -example: (format like weight) -packed weight -low 64 bit <0, 0> <0, 1> <1, 2> <1, 3> | <2, 0> <2, 1> <3, 2> <3, 3> ---------------------------------------------------------------------- -high 64 bit <0, 3> <0, 2> <1, 1> <1, 0> | <2, 3> <2, 2> <3, 1> <3, 0> -dot: (<0, 0> + <0, 3>) + (<0, 1> + <0, 2>) -> <0> -**/ -// TODO: can try oh = 2 impl, oc = 8 impl -template -struct KerNeonDirectStride2Int8 { - static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, - const int32_t* bias_ptr, DstType* dst_ptr, int ic, int ih, - int iw, const Op& op, int ld_dst_oc) { - constexpr int filter_size = 3; - constexpr int fh = filter_size; - constexpr int fw = filter_size; - constexpr int oc_step = 4; - constexpr int ic_step = 4; - constexpr int loop_ic_step = 4; - constexpr int ld_weight_ic4 = 16; - constexpr int pack_iw_len = 4; - - const int ic_stride = ih * iw * pack_iw_len; - - int32x4_t c[c_dim][8]; - int8x16_t weight[3]; - int8x16_t src[8 + 2]; - int16x8_t temp_c[4]; - init_ocx_ow8(c, bias_ptr, oc_step); - - for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { - for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { - const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + - fh_idx * iw * ic_step * pack_iw_len; - - src[0] = vld1q_s8(src_ic_0_3); - src[1] = vld1q_s8((src_ic_0_3 + 16)); - src[2] = vld1q_s8((src_ic_0_3 + 2 * 16)); - src[3] = vld1q_s8((src_ic_0_3 + 3 * 16)); - src[4] = vld1q_s8((src_ic_0_3 + 4 * 16)); - src[5] = vld1q_s8((src_ic_0_3 + 5 * 16)); - src[6] = vld1q_s8((src_ic_0_3 + 6 * 16)); - src[7] = vld1q_s8((src_ic_0_3 + 7 * 16)); - src[8] = vld1q_s8((src_ic_0_3 + 8 * 16)); - src[9] = vld1q_s8((src_ic_0_3 + 9 * 16)); - - // oc == 0 - const int8_t* read_weight_ptr = - weight_ptr + fh_idx * fw * ld_weight_ic4; - - weight[0] = vld1q_s8(read_weight_ptr); - weight[1] = vld1q_s8(read_weight_ptr + 16); - weight[2] = vld1q_s8(read_weight_ptr + 2 * 16); - - c[0][0] = vdotq_s32_h(weight[0], src[0], c[0][0], temp_c[0]); - c[0][1] = vdotq_s32_h(weight[0], src[2], c[0][1], temp_c[1]); - c[0][0] = vdotq_s32_h(weight[1], src[1], c[0][0], temp_c[2]); - c[0][1] = vdotq_s32_h(weight[1], src[3], c[0][1], temp_c[3]); - c[0][0] = vdotq_s32_h(weight[2], src[2], c[0][0], temp_c[0]); - c[0][1] = vdotq_s32_h(weight[2], src[4], c[0][1], temp_c[1]); - - c[0][2] = vdotq_s32_h(weight[0], src[4], c[0][2], temp_c[2]); - c[0][3] = vdotq_s32_h(weight[0], src[6], c[0][3], temp_c[3]); - c[0][2] = vdotq_s32_h(weight[1], src[5], c[0][2], temp_c[0]); - c[0][3] = vdotq_s32_h(weight[1], src[7], c[0][3], temp_c[1]); - c[0][2] = vdotq_s32_h(weight[2], src[6], c[0][2], temp_c[2]); - c[0][3] = vdotq_s32_h(weight[2], src[8], c[0][3], temp_c[3]); - - src[0] = vld1q_s8(src_ic_0_3 + 10 * 16); - src[1] = vld1q_s8((src_ic_0_3 + 11 * 16)); - src[2] = vld1q_s8((src_ic_0_3 + 12 * 16)); - c[0][4] = vdotq_s32_h(weight[0], src[8], c[0][4], temp_c[0]); - c[0][5] = vdotq_s32_h(weight[0], src[0], c[0][5], temp_c[1]); - c[0][4] = vdotq_s32_h(weight[1], src[9], c[0][4], temp_c[2]); - c[0][5] = vdotq_s32_h(weight[1], src[1], c[0][5], temp_c[3]); - c[0][4] = vdotq_s32_h(weight[2], src[0], c[0][4], temp_c[0]); - c[0][5] = vdotq_s32_h(weight[2], src[2], c[0][5], temp_c[1]); - - src[3] = vld1q_s8((src_ic_0_3 + 13 * 16)); - src[4] = vld1q_s8((src_ic_0_3 + 14 * 16)); - src[5] = vld1q_s8((src_ic_0_3 + 15 * 16)); - src[6] = vld1q_s8((src_ic_0_3 + 16 * 16)); - c[0][6] = vdotq_s32_h(weight[0], src[2], c[0][6], temp_c[2]); - c[0][7] = vdotq_s32_h(weight[0], src[4], c[0][7], temp_c[3]); - c[0][6] = vdotq_s32_h(weight[1], src[3], c[0][6], temp_c[0]); - c[0][7] = vdotq_s32_h(weight[1], src[5], c[0][7], temp_c[1]); - c[0][6] = vdotq_s32_h(weight[2], src[4], c[0][6], temp_c[2]); - c[0][7] = vdotq_s32_h(weight[2], src[6], c[0][7], temp_c[3]); - } - weight_ptr += fh * fw * ld_weight_ic4; - } - store_ocx_ow8_remain_static_dt( - c, op, dst_ptr, ld_dst_oc); - } -}; -template -struct KerNeonDirectStride2Int8 { - static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, - const int32_t* bias_ptr, DstType* dst_ptr, int ic, int ih, - int iw, const Op& op, int ld_dst_oc) { - constexpr int filter_size = 5; - constexpr int fh = filter_size; - constexpr int fw = filter_size; - constexpr int oc_step = 4; - constexpr int ic_step = 4; - constexpr int loop_ic_step = 4; - constexpr int ld_weight_ic4 = 16; - constexpr int pack_iw_len = 4; - - const int ic_stride = ih * iw * pack_iw_len; - - int32x4_t c[c_dim][8]; - int8x16_t weight[5]; - int8x16_t src[8 + 2]; - int16x8_t temp_c[4]; - init_ocx_ow8(c, bias_ptr, oc_step); - for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { - for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { - const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + - fh_idx * iw * ic_step * pack_iw_len; - - src[0] = vld1q_s8(src_ic_0_3); - src[1] = vld1q_s8((src_ic_0_3 + 16)); - src[2] = vld1q_s8((src_ic_0_3 + 2 * 16)); - src[3] = vld1q_s8((src_ic_0_3 + 3 * 16)); - src[4] = vld1q_s8((src_ic_0_3 + 4 * 16)); - src[5] = vld1q_s8((src_ic_0_3 + 5 * 16)); - src[6] = vld1q_s8((src_ic_0_3 + 6 * 16)); - src[7] = vld1q_s8((src_ic_0_3 + 7 * 16)); - src[8] = vld1q_s8((src_ic_0_3 + 8 * 16)); - src[9] = vld1q_s8((src_ic_0_3 + 9 * 16)); - - // oc == 0 - const int8_t* read_weight_ptr = - weight_ptr + fh_idx * fw * ld_weight_ic4; - - weight[0] = vld1q_s8(read_weight_ptr); - weight[1] = vld1q_s8(read_weight_ptr + 16); - weight[2] = vld1q_s8(read_weight_ptr + 2 * 16); - weight[3] = vld1q_s8(read_weight_ptr + 3 * 16); - weight[4] = vld1q_s8(read_weight_ptr + 4 * 16); - - c[0][0] = vdotq_s32_h(weight[0], src[0], c[0][0], temp_c[0]); - c[0][1] = vdotq_s32_h(weight[0], src[2], c[0][1], temp_c[1]); - c[0][0] = vdotq_s32_h(weight[1], src[1], c[0][0], temp_c[2]); - c[0][1] = vdotq_s32_h(weight[1], src[3], c[0][1], temp_c[3]); - c[0][0] = vdotq_s32_h(weight[2], src[2], c[0][0], temp_c[0]); - c[0][1] = vdotq_s32_h(weight[2], src[4], c[0][1], temp_c[1]); - c[0][0] = vdotq_s32_h(weight[3], src[3], c[0][0], temp_c[2]); - c[0][1] = vdotq_s32_h(weight[3], src[5], c[0][1], temp_c[3]); - c[0][0] = vdotq_s32_h(weight[4], src[4], c[0][0], temp_c[0]); - c[0][1] = vdotq_s32_h(weight[4], src[6], c[0][1], temp_c[1]); - - src[0] = vld1q_s8(src_ic_0_3 + 10 * 16); - c[0][2] = vdotq_s32_h(weight[0], src[4], c[0][2], temp_c[2]); - c[0][3] = vdotq_s32_h(weight[0], src[6], c[0][3], temp_c[3]); - c[0][2] = vdotq_s32_h(weight[1], src[5], c[0][2], temp_c[0]); - c[0][3] = vdotq_s32_h(weight[1], src[7], c[0][3], temp_c[1]); - c[0][2] = vdotq_s32_h(weight[2], src[6], c[0][2], temp_c[2]); - c[0][3] = vdotq_s32_h(weight[2], src[8], c[0][3], temp_c[3]); - c[0][2] = vdotq_s32_h(weight[3], src[7], c[0][2], temp_c[0]); - c[0][3] = vdotq_s32_h(weight[3], src[9], c[0][3], temp_c[1]); - c[0][2] = vdotq_s32_h(weight[4], src[8], c[0][2], temp_c[2]); - c[0][3] = vdotq_s32_h(weight[4], src[0], c[0][3], temp_c[3]); - - src[1] = vld1q_s8((src_ic_0_3 + 11 * 16)); - src[2] = vld1q_s8((src_ic_0_3 + 12 * 16)); - src[3] = vld1q_s8((src_ic_0_3 + 13 * 16)); - src[4] = vld1q_s8((src_ic_0_3 + 14 * 16)); - c[0][4] = vdotq_s32_h(weight[0], src[8], c[0][4], temp_c[0]); - c[0][5] = vdotq_s32_h(weight[0], src[0], c[0][5], temp_c[1]); - c[0][4] = vdotq_s32_h(weight[1], src[9], c[0][4], temp_c[2]); - c[0][5] = vdotq_s32_h(weight[1], src[1], c[0][5], temp_c[3]); - c[0][4] = vdotq_s32_h(weight[2], src[0], c[0][4], temp_c[0]); - c[0][5] = vdotq_s32_h(weight[2], src[2], c[0][5], temp_c[1]); - c[0][4] = vdotq_s32_h(weight[3], src[1], c[0][4], temp_c[2]); - c[0][5] = vdotq_s32_h(weight[3], src[3], c[0][5], temp_c[3]); - c[0][4] = vdotq_s32_h(weight[4], src[2], c[0][4], temp_c[0]); - c[0][5] = vdotq_s32_h(weight[4], src[4], c[0][5], temp_c[1]); - - src[5] = vld1q_s8((src_ic_0_3 + 15 * 16)); - src[6] = vld1q_s8((src_ic_0_3 + 16 * 16)); - src[7] = vld1q_s8((src_ic_0_3 + 17 * 16)); - src[8] = vld1q_s8((src_ic_0_3 + 18 * 16)); - c[0][6] = vdotq_s32_h(weight[0], src[2], c[0][6], temp_c[2]); - c[0][7] = vdotq_s32_h(weight[0], src[4], c[0][7], temp_c[3]); - c[0][6] = vdotq_s32_h(weight[1], src[3], c[0][6], temp_c[0]); - c[0][7] = vdotq_s32_h(weight[1], src[5], c[0][7], temp_c[1]); - c[0][6] = vdotq_s32_h(weight[2], src[4], c[0][6], temp_c[2]); - c[0][7] = vdotq_s32_h(weight[2], src[6], c[0][7], temp_c[3]); - c[0][6] = vdotq_s32_h(weight[3], src[5], c[0][6], temp_c[0]); - c[0][7] = vdotq_s32_h(weight[3], src[7], c[0][7], temp_c[1]); - c[0][6] = vdotq_s32_h(weight[4], src[6], c[0][6], temp_c[2]); - c[0][7] = vdotq_s32_h(weight[4], src[8], c[0][7], temp_c[3]); - } - weight_ptr += fh * fw * ld_weight_ic4; - } - - store_ocx_ow8_remain_static_dt( - c, op, dst_ptr, ld_dst_oc); - } -}; -template -struct KerNeonDirectStride2Int8 { - static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, - const int32_t* bias_ptr, DstType* dst_ptr, int ic, int ih, - int iw, const Op& op, int ld_dst_oc) { - constexpr int filter_size = 7; - constexpr int fh = filter_size; - constexpr int fw = filter_size; - constexpr int oc_step = 4; - constexpr int ic_step = 4; - constexpr int loop_ic_step = 4; - constexpr int ld_weight_ic4 = 16; - constexpr int pack_iw_len = 4; - - const int ic_stride = ih * iw * pack_iw_len; - - int32x4_t c[c_dim][8]; - int8x16_t weight[7]; - int8x16_t src[8 + 2]; - int16x8_t temp_c[4]; - init_ocx_ow8(c, bias_ptr, oc_step); - for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { - for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { - const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + - fh_idx * iw * ic_step * pack_iw_len; - - src[0] = vld1q_s8(src_ic_0_3); - src[1] = vld1q_s8(src_ic_0_3 + 1 * 16); - src[2] = vld1q_s8(src_ic_0_3 + 2 * 16); - src[3] = vld1q_s8(src_ic_0_3 + 3 * 16); - src[4] = vld1q_s8(src_ic_0_3 + 4 * 16); - src[5] = vld1q_s8(src_ic_0_3 + 5 * 16); - src[6] = vld1q_s8(src_ic_0_3 + 6 * 16); - src[7] = vld1q_s8(src_ic_0_3 + 7 * 16); - src[8] = vld1q_s8(src_ic_0_3 + 8 * 16); - src[9] = vld1q_s8(src_ic_0_3 + 9 * 16); - - // oc == 0 - const int8_t* read_weight_ptr = - weight_ptr + fh_idx * fw * ld_weight_ic4; - - weight[0] = vld1q_s8(read_weight_ptr); - weight[1] = vld1q_s8(read_weight_ptr + 16); - weight[2] = vld1q_s8(read_weight_ptr + 2 * 16); - weight[3] = vld1q_s8(read_weight_ptr + 3 * 16); - weight[4] = vld1q_s8(read_weight_ptr + 4 * 16); - weight[5] = vld1q_s8(read_weight_ptr + 5 * 16); - weight[6] = vld1q_s8(read_weight_ptr + 6 * 16); - - c[0][0] = vdotq_s32_h(weight[0], src[0], c[0][0], temp_c[0]); - c[0][1] = vdotq_s32_h(weight[0], src[2], c[0][1], temp_c[1]); - c[0][0] = vdotq_s32_h(weight[1], src[1], c[0][0], temp_c[2]); - c[0][1] = vdotq_s32_h(weight[1], src[3], c[0][1], temp_c[3]); - c[0][0] = vdotq_s32_h(weight[2], src[2], c[0][0], temp_c[0]); - c[0][1] = vdotq_s32_h(weight[2], src[4], c[0][1], temp_c[1]); - c[0][0] = vdotq_s32_h(weight[3], src[3], c[0][0], temp_c[2]); - c[0][1] = vdotq_s32_h(weight[3], src[5], c[0][1], temp_c[3]); - c[0][0] = vdotq_s32_h(weight[4], src[4], c[0][0], temp_c[0]); - c[0][1] = vdotq_s32_h(weight[4], src[6], c[0][1], temp_c[1]); - c[0][0] = vdotq_s32_h(weight[5], src[5], c[0][0], temp_c[2]); - c[0][1] = vdotq_s32_h(weight[5], src[7], c[0][1], temp_c[3]); - c[0][0] = vdotq_s32_h(weight[6], src[6], c[0][0], temp_c[0]); - c[0][1] = vdotq_s32_h(weight[6], src[8], c[0][1], temp_c[1]); - - src[0] = vld1q_s8(src_ic_0_3 + 10 * 16); - src[1] = vld1q_s8(src_ic_0_3 + 11 * 16); - src[2] = vld1q_s8(src_ic_0_3 + 12 * 16); - c[0][2] = vdotq_s32_h(weight[0], src[4], c[0][2], temp_c[2]); - c[0][3] = vdotq_s32_h(weight[0], src[6], c[0][3], temp_c[3]); - c[0][2] = vdotq_s32_h(weight[1], src[5], c[0][2], temp_c[0]); - c[0][3] = vdotq_s32_h(weight[1], src[7], c[0][3], temp_c[1]); - c[0][2] = vdotq_s32_h(weight[2], src[6], c[0][2], temp_c[2]); - c[0][3] = vdotq_s32_h(weight[2], src[8], c[0][3], temp_c[3]); - c[0][2] = vdotq_s32_h(weight[3], src[7], c[0][2], temp_c[0]); - c[0][3] = vdotq_s32_h(weight[3], src[9], c[0][3], temp_c[1]); - c[0][2] = vdotq_s32_h(weight[4], src[8], c[0][2], temp_c[2]); - c[0][3] = vdotq_s32_h(weight[4], src[0], c[0][3], temp_c[3]); - c[0][2] = vdotq_s32_h(weight[5], src[9], c[0][2], temp_c[0]); - c[0][3] = vdotq_s32_h(weight[5], src[1], c[0][3], temp_c[1]); - c[0][2] = vdotq_s32_h(weight[6], src[0], c[0][2], temp_c[2]); - c[0][3] = vdotq_s32_h(weight[6], src[2], c[0][3], temp_c[3]); - - src[3] = vld1q_s8(src_ic_0_3 + 13 * 16); - src[4] = vld1q_s8(src_ic_0_3 + 14 * 16); - src[5] = vld1q_s8(src_ic_0_3 + 15 * 16); - src[6] = vld1q_s8(src_ic_0_3 + 16 * 16); - c[0][4] = vdotq_s32_h(weight[0], src[8], c[0][4], temp_c[0]); - c[0][5] = vdotq_s32_h(weight[0], src[0], c[0][5], temp_c[1]); - c[0][4] = vdotq_s32_h(weight[1], src[9], c[0][4], temp_c[2]); - c[0][5] = vdotq_s32_h(weight[1], src[1], c[0][5], temp_c[3]); - c[0][4] = vdotq_s32_h(weight[2], src[0], c[0][4], temp_c[0]); - c[0][5] = vdotq_s32_h(weight[2], src[2], c[0][5], temp_c[1]); - c[0][4] = vdotq_s32_h(weight[3], src[1], c[0][4], temp_c[2]); - c[0][5] = vdotq_s32_h(weight[3], src[3], c[0][5], temp_c[3]); - c[0][4] = vdotq_s32_h(weight[4], src[2], c[0][4], temp_c[0]); - c[0][5] = vdotq_s32_h(weight[4], src[4], c[0][5], temp_c[1]); - c[0][4] = vdotq_s32_h(weight[5], src[3], c[0][4], temp_c[2]); - c[0][5] = vdotq_s32_h(weight[5], src[5], c[0][5], temp_c[3]); - c[0][4] = vdotq_s32_h(weight[6], src[4], c[0][4], temp_c[0]); - c[0][5] = vdotq_s32_h(weight[6], src[6], c[0][5], temp_c[1]); - - src[7] = vld1q_s8(src_ic_0_3 + 17 * 16); - src[8] = vld1q_s8(src_ic_0_3 + 18 * 16); - src[9] = vld1q_s8(src_ic_0_3 + 19 * 16); - src[0] = vld1q_s8(src_ic_0_3 + 20 * 16); - c[0][6] = vdotq_s32_h(weight[0], src[2], c[0][6], temp_c[2]); - c[0][7] = vdotq_s32_h(weight[0], src[4], c[0][7], temp_c[3]); - c[0][6] = vdotq_s32_h(weight[1], src[3], c[0][6], temp_c[0]); - c[0][7] = vdotq_s32_h(weight[1], src[5], c[0][7], temp_c[1]); - c[0][6] = vdotq_s32_h(weight[2], src[4], c[0][6], temp_c[2]); - c[0][7] = vdotq_s32_h(weight[2], src[6], c[0][7], temp_c[3]); - c[0][6] = vdotq_s32_h(weight[3], src[5], c[0][6], temp_c[0]); - c[0][7] = vdotq_s32_h(weight[3], src[7], c[0][7], temp_c[1]); - c[0][6] = vdotq_s32_h(weight[4], src[6], c[0][6], temp_c[2]); - c[0][7] = vdotq_s32_h(weight[4], src[8], c[0][7], temp_c[3]); - c[0][6] = vdotq_s32_h(weight[5], src[7], c[0][6], temp_c[0]); - c[0][7] = vdotq_s32_h(weight[5], src[9], c[0][7], temp_c[1]); - c[0][6] = vdotq_s32_h(weight[6], src[8], c[0][6], temp_c[2]); - c[0][7] = vdotq_s32_h(weight[6], src[0], c[0][7], temp_c[3]); - } - weight_ptr += fh * fw * ld_weight_ic4; - } - store_ocx_ow8_remain_static_dt( - c, op, dst_ptr, ld_dst_oc); - } -}; - -template -void conv_direct_stride2_2x2_int8_nchw44( - const int8_t* src, const int8_t* filter, const int32_t* bias, int32_t*, - DstType* dst, const size_t oc, const size_t ic, const size_t ih, - const size_t iw, const size_t oh, const size_t ow, const Op& op) { - constexpr size_t filter_size = 2; - constexpr size_t fh = filter_size; - constexpr size_t fw = filter_size; - constexpr size_t ic_step = 4; - constexpr size_t oc_step = 4; - constexpr size_t big_oc_step = 8; - constexpr size_t oh_step = 1; - constexpr size_t ow_step = 8; - constexpr size_t stride_h = 2; - constexpr size_t stride_w = 2; - constexpr int pack_iw_len = 4; - - const size_t out_img_stride = oh * ow; - const size_t ow_end = ow / ow_step * ow_step; - const size_t ow_remain = ow - ow_end; - const size_t oc_end = oc / big_oc_step * big_oc_step; - const size_t oc_remain = oc - oc_end; - const int ld_dst_oc = oh * ow * oc_step; - for (size_t oc_idx = 0; oc_idx < oc_end; oc_idx += big_oc_step) { - const size_t weight_offset = oc_idx * ic * fh * fw; - for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { - for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { - const size_t src_offset = - (oh_idx * stride_h * iw + ow_idx * stride_w) * ic_step * - pack_iw_len; - const size_t dst_offset = oc_idx * out_img_stride + - (oh_idx * ow + ow_idx) * oc_step; - ker_neon_dirctconv_2x2s2_oc8_ow8( - src + src_offset, filter + weight_offset, bias + oc_idx, - dst + dst_offset, ic, ih, iw, ld_dst_oc, op); - } - if (ow_remain > 0) { - const size_t src_offset = - (oh_idx * stride_h * iw + ow_end * stride_w) * ic_step * - pack_iw_len; - const size_t dst_offset = oc_idx * out_img_stride + - (oh_idx * ow + ow_end) * oc_step; - ker_neon_dirctconv_2x2s2_oc8_ow8( - src + src_offset, filter + weight_offset, bias + oc_idx, - dst + dst_offset, ic, ih, iw, ld_dst_oc, op); - } - } - } - if (oc_remain > 0) { - const size_t oc_idx = oc_end; - const size_t weight_offset = oc_idx * ic * fh * fw; - for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { - for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { - const size_t src_offset = - (oh_idx * stride_h * iw + ow_idx * stride_w) * ic_step * - pack_iw_len; - const size_t dst_offset = oc_idx * out_img_stride + - (oh_idx * ow + ow_idx) * oc_step; - ker_neon_dirctconv_2x2s2_oc4_ow8( - src + src_offset, filter + weight_offset, bias + oc_idx, - dst + dst_offset, ic, ih, iw, ld_dst_oc, op); - } - if (ow_remain > 0) { - const size_t src_offset = - (oh_idx * stride_h * iw + ow_end * stride_w) * ic_step * - pack_iw_len; - const size_t dst_offset = oc_idx * out_img_stride + - (oh_idx * ow + ow_end) * oc_step; - ker_neon_dirctconv_2x2s2_oc4_ow8( - src + src_offset, filter + weight_offset, bias + oc_idx, - dst + dst_offset, ic, ih, iw, ld_dst_oc, op); - } - } - } -} -template -void conv_direct_stride2_int8_nchw44_kern( - const int8_t* src, const int8_t* filter, const int32_t* bias, int32_t*, - DstType* dst, const size_t oc, const size_t ic, const size_t ih, - const size_t iw, const size_t oh, const size_t ow, const Op& op) { - constexpr size_t fh = filter_size; - constexpr size_t fw = filter_size; - constexpr size_t ic_step = 4; - constexpr size_t oc_step = 4; - constexpr size_t oh_step = 1; - constexpr size_t ow_step = 8; - constexpr size_t stride_h = 2; - constexpr size_t stride_w = 2; - constexpr int pack_iw_len = 4; - - const size_t img_stride = oh * ow; - const size_t ow_end = ow / ow_step * ow_step; - const size_t ow_remain = ow - ow_end; - const int ld_dst_oc = oh * ow * oc_step; - for (size_t oc_idx = 0; oc_idx < oc; oc_idx += oc_step) { - const size_t weight_offset = oc_idx * ic * fh * fw; - for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { - for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { - const size_t src_offset = - (oh_idx * stride_h * iw + ow_idx * stride_w) * ic_step * - pack_iw_len; - const size_t dst_offset = - oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; - KerNeonDirectStride2Int8::impl(src + src_offset, - filter + weight_offset, - bias + oc_idx, - dst + dst_offset, ic, - ih, iw, op, ld_dst_oc); - } - if (ow_remain > 0) { - const size_t src_offset = - (oh_idx * stride_h * iw + ow_end * stride_w) * ic_step * - pack_iw_len; - const size_t dst_offset = - oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; - KerNeonDirectStride2Int8::impl(src + src_offset, - filter + weight_offset, - bias + oc_idx, - dst + dst_offset, ic, - ih, iw, op, ld_dst_oc); - } - } - } -} -template +template struct ConvDirectInt8Nchw44Choose { static void impl(const int8_t* src, const int8_t* filter, const int32_t* bias, int32_t* temp, DstType* dst, @@ -1369,59 +79,19 @@ struct ConvDirectInt8Nchw44Choose { const Op& op); }; -template -struct ConvDirectInt8Nchw44Choose { - static void impl(const int8_t* src, const int8_t* filter, - const int32_t* bias, int32_t* temp, DstType* dst, - const size_t oc, const size_t ic, const size_t ih, - const size_t iw, const size_t oh, const size_t ow, - const Op& op) { - if (filter_size == 2) { - conv_direct_stride1_2x2_int8_nchw44( - src, filter, bias, temp, dst, oc, ic, ih, iw, oh, ow, op); - } else { - conv_direct_stride1_int8_nchw44_kern( - src, filter, bias, temp, dst, oc, ic, ih, iw, oh, ow, op); - } - } -}; -template -struct ConvDirectInt8Nchw44Choose { - static void impl(const int8_t* src, const int8_t* filter, - const int32_t* bias, int32_t* temp, DstType* dst, - const size_t oc, const size_t ic, const size_t ih, - const size_t iw, const size_t oh, const size_t ow, - const Op& op) { - if (filter_size == 2) { - conv_direct_stride2_2x2_int8_nchw44( - src, filter, bias, temp, dst, oc, ic, ih, iw, oh, ow, op); - } else { - conv_direct_stride2_int8_nchw44_kern( - src, filter, bias, temp, dst, oc, ic, ih, iw, oh, ow, op); - } - } -}; -template +template void conv_direct_int8_nchw44(const int8_t* src, const int8_t* filter, const int32_t* bias, int32_t* temp, DstType* dst, const size_t oc, const size_t ic, const size_t ih, const size_t iw, const size_t oh, const size_t ow, const Op& op) { - ConvDirectInt8Nchw44Choose::impl(src, filter, bias, temp, dst, oc, ic, ih, iw, oh, ow, op); } -} // namespace +} // namespace int8_direct_nchw44 } // namespace arm_common } // namespace megdnn diff --git a/dnn/src/arm_common/conv_bias/int8/direct_nchw_nchw44_algo.cpp b/dnn/src/arm_common/conv_bias/int8/direct_nchw_nchw44_algo.cpp index 93999a61..3da2d94d 100644 --- a/dnn/src/arm_common/conv_bias/int8/direct_nchw_nchw44_algo.cpp +++ b/dnn/src/arm_common/conv_bias/int8/direct_nchw_nchw44_algo.cpp @@ -117,11 +117,11 @@ static void copy_padding_kern(const WorkspaceBundle& bundle, const size_t tmp_size = get_temp_bytes(iw, pw); int8_t* tmp_ptr = reinterpret_cast(bundle.get(2)) + ncb_index.thread_id * tmp_size; - pack_nchw_src_for_nchw44_conv<1>(sptr, sptr_base, 1, ph, ph, pw, pw, ih, - iw, iw2, pw, tmp_ptr); + int8_direct_nchw_nchw44::pack_nchw_src_for_nchw44_conv<1>( + sptr, sptr_base, 1, ph, ph, pw, pw, ih, iw, iw2, pw, tmp_ptr); } else { - pack_nchw_src_for_nchw44_conv<2>(sptr, sptr_base, 1, ph, ph, pw, pw, ih, - iw, iw2, pw, nullptr); + int8_direct_nchw_nchw44::pack_nchw_src_for_nchw44_conv<2>( + sptr, sptr_base, 1, ph, ph, pw, pw, ih, iw, iw2, pw, nullptr); } } static void pack_weight(const WorkspaceBundle& bundle, @@ -142,11 +142,11 @@ static void pack_weight(const WorkspaceBundle& bundle, group_id * oc * ic * fh * fw2 + oc_idx * ic * fh * fw2; if (stride_h == 1) { - pack_nchw44_weight_for_nchw_conv<1>(fptr, packed_weight, ic, fh, fw, - oc_block); + int8_direct_nchw_nchw44::pack_nchw44_weight_for_nchw_conv<1>( + fptr, packed_weight, ic, fh, fw, oc_block); } else { - pack_nchw44_weight_for_nchw_conv<2>(fptr, packed_weight, ic, fh, fw, - oc_block); + int8_direct_nchw_nchw44::pack_nchw44_weight_for_nchw_conv<2>( + fptr, packed_weight, ic, fh, fw, oc_block); } } template @@ -208,7 +208,8 @@ static void do_conv_kern(const WorkspaceBundle& bundle, int8_t* packed_weight = reinterpret_cast(bundle.get(1)) + group_id * oc * ic * fh * fw2 + oc_idx * ic * fh * fw2; - conv_direct_int8_nchw_nchw44( + int8_direct_nchw_nchw44::conv_direct_int8_nchw_nchw44( sptr, packed_weight, bptr, nullptr, dst, oc_block, ic, ih2, iw2, oh, ow, op); } diff --git a/dnn/src/arm_common/conv_bias/int8/direct_nchw_nchw44_kern.h b/dnn/src/arm_common/conv_bias/int8/direct_nchw_nchw44_kern.h index 4715d809..90126638 100644 --- a/dnn/src/arm_common/conv_bias/int8/direct_nchw_nchw44_kern.h +++ b/dnn/src/arm_common/conv_bias/int8/direct_nchw_nchw44_kern.h @@ -21,1369 +21,8 @@ namespace megdnn { namespace arm_common { -namespace { -/** - * @brief core code for calculation patten - * - * @tparam src_idx is offset of src reg - * @tparam weight_idx is offset of weight reg - * @tparam c_dim is output channel - * @tparam Func mla operation funcion - * @tparam stride - * @tparam T outpur regs type - * @tparam T2 src regs type - * @tparam T3 weight regs type - * @tparam T4 temp regs type - */ -template -struct ShiftCalHelper { - static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight, T4& temp); - static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight); -}; -template -struct ShiftCalHelper { - static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight, T4& temp) { - c[0][0] = Func::impl(src[0 + src_idx], weight[0][weight_idx], c[0][0], - temp[0]); - c[1][0] = Func::impl(src[0 + src_idx], weight[1][weight_idx], c[1][0], - temp[1]); - c[0][1] = Func::impl(src[1 + src_idx], weight[0][weight_idx], c[0][1], - temp[2]); - c[1][1] = Func::impl(src[1 + src_idx], weight[1][weight_idx], c[1][1], - temp[3]); - c[0][2] = Func::impl(src[2 + src_idx], weight[0][weight_idx], c[0][2], - temp[0]); - c[1][2] = Func::impl(src[2 + src_idx], weight[1][weight_idx], c[1][2], - temp[1]); - c[0][3] = Func::impl(src[3 + src_idx], weight[0][weight_idx], c[0][3], - temp[2]); - c[1][3] = Func::impl(src[3 + src_idx], weight[1][weight_idx], c[1][3], - temp[3]); - } - static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) { - c[0][0] = Func::impl(src[0 + src_idx], weight[0][weight_idx], c[0][0]); - c[1][0] = Func::impl(src[0 + src_idx], weight[1][weight_idx], c[1][0]); - c[0][1] = Func::impl(src[1 + src_idx], weight[0][weight_idx], c[0][1]); - c[1][1] = Func::impl(src[1 + src_idx], weight[1][weight_idx], c[1][1]); - c[0][2] = Func::impl(src[2 + src_idx], weight[0][weight_idx], c[0][2]); - c[1][2] = Func::impl(src[2 + src_idx], weight[1][weight_idx], c[1][2]); - c[0][3] = Func::impl(src[3 + src_idx], weight[0][weight_idx], c[0][3]); - c[1][3] = Func::impl(src[3 + src_idx], weight[1][weight_idx], c[1][3]); - } -}; -template -struct ShiftCalHelper { - static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight, T4& temp) { - c[0][0] = Func::impl(src[0 + src_idx], weight[0][weight_idx], c[0][0], - temp[0]); - c[0][1] = Func::impl(src[1 + src_idx], weight[0][weight_idx], c[0][1], - temp[2]); - c[0][2] = Func::impl(src[2 + src_idx], weight[0][weight_idx], c[0][2], - temp[0]); - c[0][3] = Func::impl(src[3 + src_idx], weight[0][weight_idx], c[0][3], - temp[2]); - } - static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) { - c[0][0] = Func::impl(src[0 + src_idx], weight[0][weight_idx], c[0][0]); - c[0][1] = Func::impl(src[1 + src_idx], weight[0][weight_idx], c[0][1]); - c[0][2] = Func::impl(src[2 + src_idx], weight[0][weight_idx], c[0][2]); - c[0][3] = Func::impl(src[3 + src_idx], weight[0][weight_idx], c[0][3]); - } -}; - -template -struct ShiftCalHelper { - static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight, T4& temp) { - c[0][0] = Func::impl(src[(0 + src_idx) % 8], weight[0][weight_idx], - c[0][0], temp[0]); - c[1][0] = Func::impl(src[(0 + src_idx) % 8], weight[1][weight_idx], - c[1][0], temp[1]); - c[0][1] = Func::impl(src[(1 + src_idx) % 8], weight[0][weight_idx], - c[0][1], temp[2]); - c[1][1] = Func::impl(src[(1 + src_idx) % 8], weight[1][weight_idx], - c[1][1], temp[3]); - c[0][2] = Func::impl(src[(2 + src_idx) % 8], weight[0][weight_idx], - c[0][2], temp[0]); - c[1][2] = Func::impl(src[(2 + src_idx) % 8], weight[1][weight_idx], - c[1][2], temp[1]); - c[0][3] = Func::impl(src[(3 + src_idx) % 8], weight[0][weight_idx], - c[0][3], temp[2]); - c[1][3] = Func::impl(src[(3 + src_idx) % 8], weight[1][weight_idx], - c[1][3], temp[3]); - - c[0][4] = Func::impl(src[(4 + src_idx) % 8], weight[0][weight_idx], - c[0][4], temp[0]); - c[1][4] = Func::impl(src[(4 + src_idx) % 8], weight[1][weight_idx], - c[1][4], temp[1]); - c[0][5] = Func::impl(src[(5 + src_idx) % 8], weight[0][weight_idx], - c[0][5], temp[2]); - c[1][5] = Func::impl(src[(5 + src_idx) % 8], weight[1][weight_idx], - c[1][5], temp[3]); - c[0][6] = Func::impl(src[(6 + src_idx) % 8], weight[0][weight_idx], - c[0][6], temp[0]); - c[1][6] = Func::impl(src[(6 + src_idx) % 8], weight[1][weight_idx], - c[1][6], temp[1]); - c[0][7] = Func::impl(src[(7 + src_idx) % 8], weight[0][weight_idx], - c[0][7], temp[2]); - c[1][7] = Func::impl(src[(7 + src_idx) % 8], weight[1][weight_idx], - c[1][7], temp[3]); - } - static MEGDNN_ALWAYS_INLINE void impl(T&, T2&, T3&); -}; -template -struct ShiftCalHelper { - static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight, T4& temp) { - c[0][0] = Func::impl(src[(0 + src_idx) % 8], weight[0][weight_idx], - c[0][0], temp[0]); - c[0][1] = Func::impl(src[(1 + src_idx) % 8], weight[0][weight_idx], - c[0][1], temp[1]); - c[0][2] = Func::impl(src[(2 + src_idx) % 8], weight[0][weight_idx], - c[0][2], temp[2]); - c[0][3] = Func::impl(src[(3 + src_idx) % 8], weight[0][weight_idx], - c[0][3], temp[3]); - c[0][4] = Func::impl(src[(4 + src_idx) % 8], weight[0][weight_idx], - c[0][4], temp[0]); - c[0][5] = Func::impl(src[(5 + src_idx) % 8], weight[0][weight_idx], - c[0][5], temp[1]); - c[0][6] = Func::impl(src[(6 + src_idx) % 8], weight[0][weight_idx], - c[0][6], temp[2]); - c[0][7] = Func::impl(src[(7 + src_idx) % 8], weight[0][weight_idx], - c[0][7], temp[3]); - } - static MEGDNN_ALWAYS_INLINE void impl(T&, T2&, T3&); -}; - -template -MEGDNN_ALWAYS_INLINE void cal_helper(T& c, T2& src, T3& weight, T4& temp) { - ShiftCalHelper::impl(c, src, weight, temp); -} -template -MEGDNN_ALWAYS_INLINE void cal_helper(T& c, T2& src, T3& weight) { - ShiftCalHelper::impl(c, src, weight); -}; - -template -struct OCHelper { -public: - static const int val = 0; -}; -template <> -struct OCHelper<4> { -public: - static const int val = 1; -}; -template <> -struct OCHelper<8> { -public: - static const int val = 2; -}; - -template -struct KerNeonXXs2NchwNchw44 { - static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, - const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, - int iw, int ld_dst_oc, const Op& op); -}; -/** - * filter shape = (oc/4, ic, 7, 7, 4), first 4 oc is f0 = filter[0, 0, :, :, :] - * calculate sequence \ - * f0[0:1, 0:1, 4] dot4, \ - * f0[0:1, 2:3, 4] dot4, \ - * f0[0:1, 4:5, 4] dot4, \ - * f0[0:1, 6, 4] dot2, \ - * ... - * f0[6, 0:1, 4] dot2, \ - * f0[6, 2:3, 4] dot2, \ - * f0[6, 4:5, 4] dot2, \ - * f0[6, 6, 4] dot1, \ - * look like: - * |---|---|---|-| - * |x x|x x|x x|x| - * |x x|x x|x x|x| - * |---|---|---|-| - * |x x|x x|x x|x| - * |x x|x x|x x|x| - * |---|---|---|-| - * |x x|x x|x x|x| - * |x x|x x|x x|x| - * |---|---|---|-| - * |x x|x x|x x|x| - * |---|---|---|-| - **/ -template -struct KerNeonXXs2NchwNchw44 { - static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, - const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, - int iw, int ld_dst_oc, const Op& op) { - static const uint8_t src_idx_buffer[16] = {0, 8, 0, 8, 0, 8, 0, 8, - 0, 8, 0, 8, 0, 8, 0, 8}; - constexpr int filter_size = 7; - constexpr int ic_step = 1; - constexpr int oc_step = 4; - constexpr int pack_iw_len = 4; - constexpr int fh_step = 2; - constexpr int fh_end = filter_size / fh_step * fh_step; - constexpr int c_dim = OCHelper::val; - - const int ic_stride = ih * iw * pack_iw_len; - const int ld_dot4_weight_oc = oc_step * filter_size * filter_size * ic; - - int32x4_t c[c_dim][4]; - - init_ocx_ow4(c, bias_ptr, oc_step); - - for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { - for (int fh_idx = 0; fh_idx < fh_end; fh_idx += fh_step) { - const int8_t* nchw_src_ptr = - src_ptr + ic_idx * ic_stride + - fh_idx * iw * ic_step * pack_iw_len; - int8x16_t src[6]; - int8x16_t dot4_weight[c_dim][3]; - int16x8_t temp_c[4]; - load_helper<3, 0, 16, c_dim, Vld1q_s8>(dot4_weight, weight_ptr, - ld_dot4_weight_oc); - load_helper<6, 0, 16, 0, Vld1q_s8>(src, nchw_src_ptr, 0); - cal_helper<0, 0, c_dim, Vdotq_s32_h, stride>( - c, src, dot4_weight, temp_c); - cal_helper<1, 1, c_dim, Vdotq_s32_h, stride>( - c, src, dot4_weight, temp_c); - cal_helper<2, 2, c_dim, Vdotq_s32_h, stride>( - c, src, dot4_weight, temp_c); - - int8x8_t src_dot2[4]; - int8x8_t dot2_weight[c_dim][1]; - load_helper<1, 3 * 16, 8, c_dim, Vld1_s8>( - dot2_weight, weight_ptr, ld_dot4_weight_oc); - load_helper<4, 3 * 16, 16, 0, Vld1_s8>(src_dot2, nchw_src_ptr, - 0); - cal_helper<0, 0, c_dim, Vdot2_s32_h, stride>( - c, src_dot2, dot2_weight, temp_c); - weight_ptr += filter_size * pack_iw_len * fh_step; - } - const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride + - 6 * iw * ic_step * pack_iw_len; - - int8x8_t dot2_weight[c_dim][3]; - int16x8_t temp_c[4]; - int8x8_t src_dot2[6]; - uint8x16_t tbl = vld1q_u8(src_idx_buffer); - load_helper<3, 0, 8, c_dim, Vld1_s8>(dot2_weight, weight_ptr, - ld_dot4_weight_oc); - load_helper_x<6, 0, 16, 0, Vldq_tbl_low_s8>(src_dot2, nchw_src_ptr, - 0, tbl); - cal_helper<0, 0, c_dim, Vdot2_s32_h, stride>(c, src_dot2, - dot2_weight, temp_c); - cal_helper<1, 1, c_dim, Vdot2_s32_h, stride>(c, src_dot2, - dot2_weight, temp_c); - cal_helper<2, 2, c_dim, Vdot2_s32_h, stride>(c, src_dot2, - dot2_weight, temp_c); - - int16x8_t dot1_weight[c_dim][1]; - int16x8_t src_dot1[4]; - load_helper<1, 3 * 8, 8, c_dim, Vldq_dup_4s8_8s16>( - dot1_weight, weight_ptr, ld_dot4_weight_oc); - load_helper<4, 3 * 16, 16, 0, Vld1_dup_s8_s16>(src_dot1, - nchw_src_ptr, 0); - cal_helper<0, 0, c_dim, Vmlal_s16, stride>(c, src_dot1, - dot1_weight); - weight_ptr += filter_size * pack_iw_len; - } - store_ocx_ow4_remain_static(c, op, dst_ptr, ld_dst_oc); - } -}; -#if MEGDNN_AARCH64 -template -struct KerNeonXXs2NchwNchw44 { - static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, - const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, - int iw, int ld_dst_oc, const Op& op) { - static const uint8_t src_idx_buffer[16] = {0, 8, 0, 8, 0, 8, 0, 8, - 0, 8, 0, 8, 0, 8, 0, 8}; - uint8x16_t vtbl = vld1q_u8(src_idx_buffer); - - // constexpr int stride = 2; - constexpr int oc_block = 8; - constexpr int remain_w = 0; - constexpr int filter_size = 7; - constexpr int ic_step = 1; - constexpr int oc_step = 4; - constexpr int pack_iw_len = 4; - constexpr int fh_step = 2; - constexpr int c_dim = OCHelper::val; - - const int ic_stride = ih * iw * pack_iw_len; - const int ld_dot4_weight_oc = oc_step * filter_size * filter_size * ic; - const size_t src_step = fh_step * iw * ic_step * pack_iw_len; - const size_t weight_step = filter_size * pack_iw_len * fh_step; - const size_t weight_step_small = filter_size * pack_iw_len; - int32x4_t c[c_dim][4]; - - init_ocx_ow4(c, bias_ptr, oc_step); - - for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { - const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride; - - const int8_t* weight_ptr_oc = weight_ptr + ld_dot4_weight_oc; - - const int8_t* nchw_src_ptr_last_line = - src_ptr + ic_idx * ic_stride + - 6 * iw * ic_step * pack_iw_len; - /** - * r0-r7 c - * r24-r31 temp - * r8-r15 src - * r16-r22 weight - * r23 vtbl - */ - asm volatile( - - "ldp q8, q9, [%[nchw_src_ptr]]\n" - "ldp q16, q17, [%[weight_ptr]]\n" - "ldp q10, q11, [%[nchw_src_ptr], #32]\n" - "smull v24.8h, v8.8b, v16.8b\n" - "ldp q19, q20, [%[weight_ptr_oc]]\n" - "smull v25.8h, v9.8b, v16.8b\n" - "ldp q12, q13, [%[nchw_src_ptr], #64]\n" - "smull v26.8h, v10.8b, v16.8b\n" - "ldr q18, [%[weight_ptr],#32]\n" - "smull v27.8h, v11.8b, v16.8b\n" - "ldr q21, [%[weight_ptr_oc],#32]\n" - "smull v28.8h, v8.8b, v19.8b\n" - "smlal2 v24.8h, v8.16b, v16.16b\n" - "smlal2 v25.8h, v9.16b, v16.16b\n" - "smlal2 v26.8h, v10.16b, v16.16b\n" - "smlal2 v27.8h, v11.16b, v16.16b\n" - "sadalp %[c00].4s, v24.8h\n" - "smull v29.8h, v9.8b, v19.8b\n" - "sadalp %[c01].4s, v25.8h\n" - "smull v30.8h, v10.8b, v19.8b\n" - "sadalp %[c02].4s, v26.8h\n" - "smull v31.8h, v11.8b, v19.8b\n" - "sadalp %[c03].4s, v27.8h\n" - "smlal2 v28.8h, v8.16b, v19.16b\n" - "ldr d8, [%[nchw_src_ptr],#48]\n" - "smlal2 v29.8h, v9.16b, v19.16b\n" - "smlal2 v30.8h, v10.16b, v19.16b\n" - "smlal2 v31.8h, v11.16b, v19.16b\n" - "smull v24.8h, v9.8b, v17.8b\n" - "sadalp %[c10].4s, v28.8h\n" - "smull v25.8h, v10.8b, v17.8b\n" - "sadalp %[c11].4s, v29.8h\n" - "smull v26.8h, v11.8b, v17.8b\n" - "sadalp %[c12].4s, v30.8h\n" - "smull v27.8h, v12.8b, v17.8b\n" - "sadalp %[c13].4s, v31.8h\n" - "smlal2 v24.8h, v9.16b, v17.16b\n" - "smlal2 v25.8h, v10.16b, v17.16b\n" - "smlal2 v26.8h, v11.16b, v17.16b\n" - "smlal2 v27.8h, v12.16b, v17.16b\n" - "smull v28.8h, v9.8b, v20.8b\n" - "sadalp %[c00].4s, v24.8h\n" - "smull v29.8h, v10.8b, v20.8b\n" - "sadalp %[c01].4s, v25.8h\n" - "smull v30.8h, v11.8b, v20.8b\n" - "sadalp %[c02].4s, v26.8h\n" - "smull v31.8h, v12.8b, v20.8b\n" - "sadalp %[c03].4s, v27.8h\n" - "smlal2 v28.8h, v9.16b, v20.16b\n" - "ldr d9, [%[nchw_src_ptr],#64]\n" - "smlal2 v29.8h, v10.16b, v20.16b\n" - "ldr d14, [%[nchw_src_ptr],#80]\n" - "smlal2 v30.8h, v11.16b, v20.16b\n" - "smlal2 v31.8h, v12.16b, v20.16b\n" - "smull v24.8h, v10.8b, v18.8b\n" - "sadalp %[c10].4s, v28.8h\n" - "smull v25.8h, v11.8b, v18.8b\n" - "sadalp %[c11].4s, v29.8h\n" - "smull v26.8h, v12.8b, v18.8b\n" - "sadalp %[c12].4s, v30.8h\n" - "smull v27.8h, v13.8b, v18.8b\n" - "sadalp %[c13].4s, v31.8h\n" - "smlal2 v24.8h, v10.16b, v18.16b\n" - "ldr d19, [%[weight_ptr_oc],#48]\n" - "smlal2 v25.8h, v11.16b, v18.16b\n" - "ldr d15, [%[nchw_src_ptr],#96]\n" - "smlal2 v26.8h, v12.16b, v18.16b\n" - "smlal2 v27.8h, v13.16b, v18.16b\n" - "ldr d18, [%[weight_ptr],#48]\n" - "smull v28.8h, v10.8b, v21.8b\n" - "sadalp %[c00].4s, v24.8h\n" - "smull v29.8h, v11.8b, v21.8b\n" - "sadalp %[c01].4s, v25.8h\n" - "smull v30.8h, v12.8b, v21.8b\n" - "sadalp %[c02].4s, v26.8h\n" - "smull v31.8h, v13.8b, v21.8b\n" - "sadalp %[c03].4s, v27.8h\n" - "smlal2 v28.8h, v10.16b, v21.16b\n" - "add %[nchw_src_ptr], %[nchw_src_ptr], %[src_step]\n" - "smlal2 v29.8h, v11.16b, v21.16b\n" - "ldp q10, q11, [%[nchw_src_ptr], #32]\n" - "add %[weight_ptr], %[weight_ptr], %[weight_step]\n" - "smlal2 v30.8h, v12.16b, v21.16b\n" - "add %[weight_ptr_oc], %[weight_ptr_oc], " - "%[weight_step]\n" - "smlal2 v31.8h, v13.16b, v21.16b\n" - "ldp q16, q17, [%[weight_ptr]]\n" - "smull v24.8h, v8.8b, v18.8b\n" - "sadalp %[c10].4s, v28.8h\n" - "smull v25.8h, v9.8b, v18.8b\n" - "sadalp %[c11].4s, v29.8h\n" - "smull v26.8h, v14.8b, v18.8b\n" - "sadalp %[c12].4s, v30.8h\n" - "smull v27.8h, v15.8b, v18.8b\n" - "sadalp %[c13].4s, v31.8h\n" - "smull v28.8h, v8.8b, v19.8b\n" - "sadalp %[c00].4s, v24.8h\n" - "smull v29.8h, v9.8b, v19.8b\n" - "ldp q8, q9, [%[nchw_src_ptr]]\n" - "sadalp %[c01].4s, v25.8h\n" - "smull v30.8h, v14.8b, v19.8b\n" - "sadalp %[c02].4s, v26.8h\n" - "smull v31.8h, v15.8b, v19.8b\n" - "ldp q19, q20, [%[weight_ptr_oc]]\n" - "sadalp %[c03].4s, v27.8h\n" - "smull v24.8h, v8.8b, v16.8b\n" - "sadalp %[c10].4s, v28.8h\n" - "smull v25.8h, v9.8b, v16.8b\n" - "ldp q12, q13, [%[nchw_src_ptr], #64]\n" - "sadalp %[c11].4s, v29.8h\n" - "smull v26.8h, v10.8b, v16.8b\n" - "ldr q18, [%[weight_ptr],#32]\n" - "sadalp %[c12].4s, v30.8h\n" - "smull v27.8h, v11.8b, v16.8b\n" - "ldr q21, [%[weight_ptr_oc],#32]\n" - "sadalp %[c13].4s, v31.8h\n" - //! fh = 2 - "smull v28.8h, v8.8b, v19.8b\n" - "smlal2 v24.8h, v8.16b, v16.16b\n" - "smlal2 v25.8h, v9.16b, v16.16b\n" - "smlal2 v26.8h, v10.16b, v16.16b\n" - "smlal2 v27.8h, v11.16b, v16.16b\n" - "sadalp %[c00].4s, v24.8h\n" - "smull v29.8h, v9.8b, v19.8b\n" - "sadalp %[c01].4s, v25.8h\n" - "smull v30.8h, v10.8b, v19.8b\n" - "sadalp %[c02].4s, v26.8h\n" - "smull v31.8h, v11.8b, v19.8b\n" - "sadalp %[c03].4s, v27.8h\n" - "smlal2 v28.8h, v8.16b, v19.16b\n" - "ldr d8, [%[nchw_src_ptr],#48]\n" - "smlal2 v29.8h, v9.16b, v19.16b\n" - "smlal2 v30.8h, v10.16b, v19.16b\n" - "smlal2 v31.8h, v11.16b, v19.16b\n" - "smull v24.8h, v9.8b, v17.8b\n" - "sadalp %[c10].4s, v28.8h\n" - "smull v25.8h, v10.8b, v17.8b\n" - "sadalp %[c11].4s, v29.8h\n" - "smull v26.8h, v11.8b, v17.8b\n" - "sadalp %[c12].4s, v30.8h\n" - "smull v27.8h, v12.8b, v17.8b\n" - "sadalp %[c13].4s, v31.8h\n" - "smlal2 v24.8h, v9.16b, v17.16b\n" - "smlal2 v25.8h, v10.16b, v17.16b\n" - "smlal2 v26.8h, v11.16b, v17.16b\n" - "smlal2 v27.8h, v12.16b, v17.16b\n" - "smull v28.8h, v9.8b, v20.8b\n" - "sadalp %[c00].4s, v24.8h\n" - "smull v29.8h, v10.8b, v20.8b\n" - "sadalp %[c01].4s, v25.8h\n" - "smull v30.8h, v11.8b, v20.8b\n" - "sadalp %[c02].4s, v26.8h\n" - "smull v31.8h, v12.8b, v20.8b\n" - "sadalp %[c03].4s, v27.8h\n" - "smlal2 v28.8h, v9.16b, v20.16b\n" - "ldr d9, [%[nchw_src_ptr],#64]\n" - "smlal2 v29.8h, v10.16b, v20.16b\n" - "ldr d14, [%[nchw_src_ptr],#80]\n" - "smlal2 v30.8h, v11.16b, v20.16b\n" - "smlal2 v31.8h, v12.16b, v20.16b\n" - "smull v24.8h, v10.8b, v18.8b\n" - "sadalp %[c10].4s, v28.8h\n" - "smull v25.8h, v11.8b, v18.8b\n" - "sadalp %[c11].4s, v29.8h\n" - "smull v26.8h, v12.8b, v18.8b\n" - "sadalp %[c12].4s, v30.8h\n" - "smull v27.8h, v13.8b, v18.8b\n" - "sadalp %[c13].4s, v31.8h\n" - "smlal2 v24.8h, v10.16b, v18.16b\n" - "ldr d19, [%[weight_ptr_oc],#48]\n" - "smlal2 v25.8h, v11.16b, v18.16b\n" - "ldr d15, [%[nchw_src_ptr],#96]\n" - "smlal2 v26.8h, v12.16b, v18.16b\n" - "smlal2 v27.8h, v13.16b, v18.16b\n" - "ldr d18, [%[weight_ptr],#48]\n" - "smull v28.8h, v10.8b, v21.8b\n" - "sadalp %[c00].4s, v24.8h\n" - "smull v29.8h, v11.8b, v21.8b\n" - "sadalp %[c01].4s, v25.8h\n" - "smull v30.8h, v12.8b, v21.8b\n" - "sadalp %[c02].4s, v26.8h\n" - "smull v31.8h, v13.8b, v21.8b\n" - "sadalp %[c03].4s, v27.8h\n" - "smlal2 v28.8h, v10.16b, v21.16b\n" - "add %[nchw_src_ptr], %[nchw_src_ptr], %[src_step]\n" - "smlal2 v29.8h, v11.16b, v21.16b\n" - "add %[weight_ptr], %[weight_ptr], %[weight_step]\n" - "smlal2 v30.8h, v12.16b, v21.16b\n" - "add %[weight_ptr_oc], %[weight_ptr_oc], " - "%[weight_step]\n" - "smlal2 v31.8h, v13.16b, v21.16b\n" - "ldp q16, q17, [%[weight_ptr]]\n" - "smull v24.8h, v8.8b, v18.8b\n" - "ldp q10, q11, [%[nchw_src_ptr], #32]\n" - "sadalp %[c10].4s, v28.8h\n" - "smull v25.8h, v9.8b, v18.8b\n" - "sadalp %[c11].4s, v29.8h\n" - "smull v26.8h, v14.8b, v18.8b\n" - "sadalp %[c12].4s, v30.8h\n" - "smull v27.8h, v15.8b, v18.8b\n" - "sadalp %[c13].4s, v31.8h\n" - "smull v28.8h, v8.8b, v19.8b\n" - "sadalp %[c00].4s, v24.8h\n" - "smull v29.8h, v9.8b, v19.8b\n" - "ldp q8, q9, [%[nchw_src_ptr]]\n" - "sadalp %[c01].4s, v25.8h\n" - "smull v30.8h, v14.8b, v19.8b\n" - "sadalp %[c02].4s, v26.8h\n" - "smull v31.8h, v15.8b, v19.8b\n" - "ldp q19, q20, [%[weight_ptr_oc]]\n" - "sadalp %[c03].4s, v27.8h\n" - "smull v24.8h, v8.8b, v16.8b\n" - "sadalp %[c10].4s, v28.8h\n" - "smull v25.8h, v9.8b, v16.8b\n" - "ldp q12, q13, [%[nchw_src_ptr], #64]\n" - "sadalp %[c11].4s, v29.8h\n" - "smull v26.8h, v10.8b, v16.8b\n" - "ldr q18, [%[weight_ptr],#32]\n" - "sadalp %[c12].4s, v30.8h\n" - "smull v27.8h, v11.8b, v16.8b\n" - "ldr q21, [%[weight_ptr_oc],#32]\n" - "sadalp %[c13].4s, v31.8h\n" - //! fh = 4 - "smull v28.8h, v8.8b, v19.8b\n" - "smlal2 v24.8h, v8.16b, v16.16b\n" - "smlal2 v25.8h, v9.16b, v16.16b\n" - "smlal2 v26.8h, v10.16b, v16.16b\n" - "smlal2 v27.8h, v11.16b, v16.16b\n" - "sadalp %[c00].4s, v24.8h\n" - "smull v29.8h, v9.8b, v19.8b\n" - "sadalp %[c01].4s, v25.8h\n" - "smull v30.8h, v10.8b, v19.8b\n" - "sadalp %[c02].4s, v26.8h\n" - "smull v31.8h, v11.8b, v19.8b\n" - "sadalp %[c03].4s, v27.8h\n" - "smlal2 v28.8h, v8.16b, v19.16b\n" - "ldr d8, [%[nchw_src_ptr],#48]\n" - "smlal2 v29.8h, v9.16b, v19.16b\n" - "smlal2 v30.8h, v10.16b, v19.16b\n" - "smlal2 v31.8h, v11.16b, v19.16b\n" - "smull v24.8h, v9.8b, v17.8b\n" - "sadalp %[c10].4s, v28.8h\n" - "smull v25.8h, v10.8b, v17.8b\n" - "sadalp %[c11].4s, v29.8h\n" - "smull v26.8h, v11.8b, v17.8b\n" - "sadalp %[c12].4s, v30.8h\n" - "smull v27.8h, v12.8b, v17.8b\n" - "sadalp %[c13].4s, v31.8h\n" - "smlal2 v24.8h, v9.16b, v17.16b\n" - "smlal2 v25.8h, v10.16b, v17.16b\n" - "smlal2 v26.8h, v11.16b, v17.16b\n" - "smlal2 v27.8h, v12.16b, v17.16b\n" - "smull v28.8h, v9.8b, v20.8b\n" - "sadalp %[c00].4s, v24.8h\n" - "smull v29.8h, v10.8b, v20.8b\n" - "sadalp %[c01].4s, v25.8h\n" - "smull v30.8h, v11.8b, v20.8b\n" - "sadalp %[c02].4s, v26.8h\n" - "smull v31.8h, v12.8b, v20.8b\n" - "sadalp %[c03].4s, v27.8h\n" - "smlal2 v28.8h, v9.16b, v20.16b\n" - "ldr d9, [%[nchw_src_ptr],#64]\n" - "smlal2 v29.8h, v10.16b, v20.16b\n" - "ldr d14, [%[nchw_src_ptr],#80]\n" - "smlal2 v30.8h, v11.16b, v20.16b\n" - "smlal2 v31.8h, v12.16b, v20.16b\n" - "smull v24.8h, v10.8b, v18.8b\n" - "sadalp %[c10].4s, v28.8h\n" - "smull v25.8h, v11.8b, v18.8b\n" - "sadalp %[c11].4s, v29.8h\n" - "smull v26.8h, v12.8b, v18.8b\n" - "sadalp %[c12].4s, v30.8h\n" - "smull v27.8h, v13.8b, v18.8b\n" - "sadalp %[c13].4s, v31.8h\n" - "smlal2 v24.8h, v10.16b, v18.16b\n" - "ldr d19, [%[weight_ptr_oc],#48]\n" - "smlal2 v25.8h, v11.16b, v18.16b\n" - "ldr d15, [%[nchw_src_ptr],#96]\n" - "smlal2 v26.8h, v12.16b, v18.16b\n" - "smlal2 v27.8h, v13.16b, v18.16b\n" - "ldr d18, [%[weight_ptr],#48]\n" - "smull v28.8h, v10.8b, v21.8b\n" - "add %[weight_ptr], %[weight_ptr], %[weight_step]\n" - "sadalp %[c00].4s, v24.8h\n" - "smull v29.8h, v11.8b, v21.8b\n" - "add %[weight_ptr_oc], %[weight_ptr_oc], %[weight_step]\n" - "sadalp %[c01].4s, v25.8h\n" - "smull v30.8h, v12.8b, v21.8b\n" - "ldr q16, [%[weight_ptr]]\n" - "sadalp %[c02].4s, v26.8h\n" - "smull v31.8h, v13.8b, v21.8b\n" - "sadalp %[c03].4s, v27.8h\n" - "smlal2 v28.8h, v10.16b, v21.16b\n" - "smlal2 v29.8h, v11.16b, v21.16b\n" - "ldp q10, q11, [%[nchw_src_ptr_last_line], #32]\n" - "smlal2 v30.8h, v12.16b, v21.16b\n" - "smlal2 v31.8h, v13.16b, v21.16b\n" - "ldp q12, q13, [%[nchw_src_ptr_last_line], #64]\n" - "smull v24.8h, v8.8b, v18.8b\n" - "ldr d21, [%[weight_ptr_oc],#16]\n" - "sadalp %[c10].4s, v28.8h\n" - "smull v25.8h, v9.8b, v18.8b\n" - "sadalp %[c11].4s, v29.8h\n" - "smull v26.8h, v14.8b, v18.8b\n" - "sadalp %[c12].4s, v30.8h\n" - "smull v27.8h, v15.8b, v18.8b\n" - "ldr d18, [%[weight_ptr],#16]\n" - "sadalp %[c13].4s, v31.8h\n" - "smull v28.8h, v8.8b, v19.8b\n" - "sadalp %[c00].4s, v24.8h\n" - "smull v29.8h, v9.8b, v19.8b\n" - "ldp q8, q9, [%[nchw_src_ptr_last_line]]\n" - "sadalp %[c01].4s, v25.8h\n" - "smull v30.8h, v14.8b, v19.8b\n" - "sadalp %[c02].4s, v26.8h\n" - "smull v31.8h, v15.8b, v19.8b\n" - "ldr q19, [%[weight_ptr_oc]]\n" - "tbl v8.16b, {v8.16b}, %[vtbl].16b\n" - "tbl v9.16b, {v9.16b}, %[vtbl].16b\n" - "sadalp %[c03].4s, v27.8h\n" - "tbl v10.16b, {v10.16b}, %[vtbl].16b\n" - "tbl v11.16b, {v11.16b}, %[vtbl].16b\n" - "sadalp %[c10].4s, v28.8h\n" - "tbl v12.16b, {v12.16b}, %[vtbl].16b\n" - "tbl v13.16b, {v13.16b}, %[vtbl].16b\n" - "sadalp %[c11].4s, v29.8h\n" - /// last line//// - "smull v24.8h, v8.8b, v16.8b\n" - "sadalp %[c12].4s, v30.8h\n" - "smull v25.8h, v9.8b, v16.8b\n" - "sadalp %[c13].4s, v31.8h\n" - "smull v26.8h, v10.8b, v16.8b\n" - "smull v27.8h, v11.8b, v16.8b\n" - "smlal2 v24.8h, v9.16b, v16.16b\n" - "smlal2 v25.8h, v10.16b, v16.16b\n" - "smlal2 v26.8h, v11.16b, v16.16b\n" - "smlal2 v27.8h, v12.16b, v16.16b\n" - "sadalp %[c00].4s, v24.8h\n" - "smull v28.8h, v8.8b, v19.8b\n" - "sadalp %[c01].4s, v25.8h\n" - "smull v29.8h, v9.8b, v19.8b\n" - "sadalp %[c02].4s, v26.8h\n" - "smull v30.8h, v10.8b, v19.8b\n" - "sadalp %[c03].4s, v27.8h\n" - "smull v31.8h, v11.8b, v19.8b\n" - "smlal2 v28.8h, v9.16b, v19.16b\n" - "dup v9.8b, v11.b[0]\n" - "smlal2 v29.8h, v10.16b, v19.16b\n" - "smlal2 v30.8h, v11.16b, v19.16b\n" - "smlal2 v31.8h, v12.16b, v19.16b\n" - "sadalp %[c10].4s, v28.8h\n" - "smull v24.8h, v10.8b, v18.8b\n" - "sadalp %[c11].4s, v29.8h\n" - "smull v25.8h, v11.8b, v18.8b\n" - "sadalp %[c12].4s, v30.8h\n" - "smull v26.8h, v12.8b, v18.8b\n" - "sadalp %[c13].4s, v31.8h\n" - "smull v27.8h, v13.8b, v18.8b\n" - "add x10, %[nchw_src_ptr_last_line], #96\n" - "sadalp %[c00].4s, v24.8h\n" - "smull v28.8h, v10.8b, v21.8b\n" - - "sadalp %[c01].4s, v25.8h\n" - "add x5, %[weight_ptr], #24\n" - "smull v29.8h, v11.8b, v21.8b\n" - "add x6, %[weight_ptr_oc], #24\n" - "sadalp %[c02].4s, v26.8h\n" - "smull v30.8h, v12.8b, v21.8b\n" - "sadalp %[c03].4s, v27.8h\n" - "smull v31.8h, v13.8b, v21.8b\n" - "dup v10.8b, v12.b[0]\n" - "sadalp %[c10].4s, v28.8h\n" - "ld1r {v12.8b}, [x10]\n" - "sadalp %[c11].4s, v29.8h\n" - "dup v11.8b, v13.b[0]\n" - "sadalp %[c12].4s, v30.8h\n" - "ld1r {v16.2s}, [x5]\n" - "sadalp %[c13].4s, v31.8h\n" - "sxtl v16.8h, v16.8b\n" - ///////////////last element///////// - "add %[weight_ptr], %[weight_ptr], %[weight_step_small]\n" - "sxtl v9.8h, v9.8b\n" - "ld1r {v19.2s}, [x6]\n" - "sxtl v10.8h, v10.8b\n" - "sxtl v11.8h, v11.8b\n" - "smlal %[c00].4s, v9.4h, v16.4h\n" - "sxtl v12.8h, v12.8b\n" - "smlal %[c01].4s, v10.4h, v16.4h\n" - "sxtl v19.8h, v19.8b\n" - "smlal %[c02].4s, v11.4h, v16.4h\n" - "smlal %[c03].4s, v12.4h, v16.4h\n" - "smlal %[c10].4s, v9.4h, v19.4h\n" - "smlal %[c11].4s, v10.4h, v19.4h\n" - "smlal %[c12].4s, v11.4h, v19.4h\n" - "smlal %[c13].4s, v12.4h, v19.4h\n" - : - - [c00] "+w"(c[0][0]), [c10] "+w"(c[1][0]), - [c01] "+w"(c[0][1]), [c11] "+w"(c[1][1]), - [c02] "+w"(c[0][2]), [c12] "+w"(c[1][2]), - [c03] "+w"(c[0][3]), [c13] "+w"(c[1][3]), - [nchw_src_ptr] "+r"(nchw_src_ptr), - [weight_ptr] "+r"(weight_ptr), - [weight_ptr_oc] "+r"(weight_ptr_oc) - - : [vtbl] "w"(vtbl), - [nchw_src_ptr_last_line] "r"(nchw_src_ptr_last_line), - [src_step] "r"(src_step), [weight_step] "r"(weight_step), - [weight_step_small] "r"(weight_step_small) - : "x5", "x6", "x7", "x8", "x9", "x10", "v8", "v9", "v10", - "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", - "v19", "v20", "v21", "v24", "v25", "v26", "v27", "v28", - "v29", "v30", "v31", "cc", "memory"); - } - store_ocx_ow4_remain_static(c, op, dst_ptr, ld_dst_oc); - } -}; -#endif -template -struct KerNeonXXs2NchwNchw44 { - static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, - const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, - int iw, int ld_dst_oc, const Op& op) { - constexpr int filter_size = 5; - static const uint8_t src_idx_buffer[16] = {0, 8, 0, 8, 0, 8, 0, 8, - 0, 8, 0, 8, 0, 8, 0, 8}; - constexpr int ih_step = 2; - constexpr int ic_step = 1; - constexpr int oc_step = 4; - constexpr int pack_iw_len = 4; - constexpr int fh_end = filter_size / ih_step * ih_step; - - const int ic_stride = ih * iw * pack_iw_len; - const int ld_dot4_weight_oc = oc_step * filter_size * filter_size * ic; - constexpr int c_dim = OCHelper::val; - int32x4_t c[c_dim][4]; - - init_ocx_ow4(c, bias_ptr, oc_step); - - for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { - for (int fh_idx = 0; fh_idx < fh_end; fh_idx += ih_step) { - const int8_t* nchw_src_ptr = - src_ptr + ic_idx * ic_stride + - fh_idx * iw * ic_step * pack_iw_len; - int8x16_t src[5]; - int8x16_t dot4_weight[c_dim][2]; - int16x8_t temp_c[4]; - load_helper<2, 0, 16, c_dim, Vld1q_s8>(dot4_weight, weight_ptr, - ld_dot4_weight_oc); - load_helper<5, 0, 16, 0, Vld1q_s8>(src, nchw_src_ptr, 0); - cal_helper<0, 0, c_dim, Vdotq_s32_h, stride>( - c, src, dot4_weight, temp_c); - cal_helper<1, 1, c_dim, Vdotq_s32_h, stride>( - c, src, dot4_weight, temp_c); - - int8x8_t src_dot2[4]; - int8x8_t dot2_weight[c_dim][1]; - load_helper<1, 2 * 16, 8, c_dim, Vld1_s8>( - dot2_weight, weight_ptr, ld_dot4_weight_oc); - load_helper<4, 2 * 16, 16, 0, Vld1_s8>(src_dot2, nchw_src_ptr, - 0); - cal_helper<0, 0, c_dim, Vdot2_s32_h, stride>( - c, src_dot2, dot2_weight, temp_c); - weight_ptr += filter_size * pack_iw_len * ih_step; - } - const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride + - fh_end * iw * ic_step * pack_iw_len; - - int8x8_t dot2_weight[c_dim][2]; - int16x8_t temp_c[4]; - int8x8_t src_dot2[5]; - uint8x16_t tbl = vld1q_u8(src_idx_buffer); - load_helper<2, 0, 8, c_dim, Vld1_s8>(dot2_weight, weight_ptr, - ld_dot4_weight_oc); - load_helper_x<5, 0, 16, 0, Vldq_tbl_low_s8>(src_dot2, nchw_src_ptr, - 0, tbl); - - cal_helper<0, 0, c_dim, Vdot2_s32_h, stride>(c, src_dot2, - dot2_weight, temp_c); - cal_helper<1, 1, c_dim, Vdot2_s32_h, stride>(c, src_dot2, - dot2_weight, temp_c); - - int16x8_t dot1_weight[c_dim][1]; - int16x8_t src_dot1[4]; - load_helper<1, 2 * 8, 8, c_dim, Vldq_dup_4s8_8s16>( - dot1_weight, weight_ptr, ld_dot4_weight_oc); - load_helper<4, 2 * 16, 16, 0, Vld1_dup_s8_s16>(src_dot1, - nchw_src_ptr, 0); - - cal_helper<0, 0, c_dim, Vmlal_s16, stride>(c, src_dot1, - dot1_weight); - weight_ptr += filter_size * pack_iw_len; - } - store_ocx_ow4_remain_static(c, op, dst_ptr, ld_dst_oc); - } -}; -/** - * filter shape = (oc/4, ic, 3, 3, 4), first 4 oc is f0 = filter[0, 0, :, :, :] - * calculate sequence \ - * f0[0:1, 0:1, 4] dot4, \ - * f0[0:1, 2, 4] dot2, \ - * f0[2, 0:1, 4] dot2, \ - * f0[2, 2, 4] dot1 \ - * look like: - * |---|-| - * |x x|x| - * |x x|x| - * |-----| - * |x x|x| - * |-----| - **/ -template -struct KerNeonXXs2NchwNchw44 { - static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, - const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, - int iw, int ld_dst_oc, const Op& op) { - constexpr int filter_size = 3; - static const uint8_t src_idx_buffer[16] = {0, 8, 0, 8, 0, 8, 0, 8, - 0, 8, 0, 8, 0, 8, 0, 8}; - constexpr int oc_step = 4; - constexpr int ic_step = 1; - constexpr int loop_ic_step = 1; - constexpr int pack_iw_len = 4; - - const int ic_stride = ih * iw * pack_iw_len; - const int ld_weight_oc = oc_step * filter_size * filter_size * ic; - constexpr int c_dim = OCHelper::val; - - int32x4_t c[c_dim][4]; - init_ocx_ow4(c, bias_ptr, oc_step); - - for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { - // first 2 line - { - const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride; - int8x16_t src[4]; - int8x16_t dot4_weight[c_dim][1]; - int16x8_t temp_c[4]; - load_helper<1, 0, 16, c_dim, Vld1q_s8>(dot4_weight, weight_ptr, - ld_weight_oc); - load_helper<4, 0, 16, 0, Vld1q_s8>(src, nchw_src_ptr, 0); - cal_helper<0, 0, c_dim, Vdotq_s32_h, stride>( - c, src, dot4_weight, temp_c); - - int8x8_t src_dot2[4]; - int8x8_t dot2_weight[c_dim][1]; - load_helper<1, 1 * 16, 8, c_dim, Vld1_s8>( - dot2_weight, weight_ptr, ld_weight_oc); - load_helper<4, 1 * 16, 16, 0, Vld1_s8>(src_dot2, nchw_src_ptr, - 0); - cal_helper<0, 0, c_dim, Vdot2_s32_h, stride>( - c, src_dot2, dot2_weight, temp_c); - } - // last line - { - const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride + - 2 * iw * ic_step * pack_iw_len; - int16x8_t temp_c[4]; - int8x8_t src_dot2[4]; - int8x8_t dot2_weight[c_dim][1]; - uint8x16_t tbl = vld1q_u8(src_idx_buffer); - load_helper<1, 24, 8, c_dim, Vld1_s8>(dot2_weight, weight_ptr, - ld_weight_oc); - load_helper_x<4, 0, 16, 0, Vldq_tbl_low_s8>( - src_dot2, nchw_src_ptr, 0, tbl); - cal_helper<0, 0, c_dim, Vdot2_s32_h, stride>( - c, src_dot2, dot2_weight, temp_c); - int16x8_t dot1_weight[c_dim][1]; - int16x8_t src_dot1[4]; - load_helper<1, 32, 8, c_dim, Vldq_dup_4s8_8s16>( - dot1_weight, weight_ptr, ld_weight_oc); - load_helper<4, 1 * 16, 16, 0, Vld1_dup_s8_s16>(src_dot1, - nchw_src_ptr, 0); - cal_helper<0, 0, c_dim, Vmlal_s16, stride>(c, src_dot1, - dot1_weight); - weight_ptr += filter_size * filter_size * pack_iw_len; - } - } - store_ocx_ow4_remain_static(c, op, dst_ptr, ld_dst_oc); - } -}; -#if MEGDNN_AARCH64 -template -struct KerNeonXXs2NchwNchw44 { - static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, - const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, - int iw, int ld_dst_oc, const Op& op) { - constexpr int filter_size = 3; - static const uint8_t src_idx_buffer[16] = {0, 8, 0, 8, 0, 8, 0, 8, - 0, 8, 0, 8, 0, 8, 0, 8}; - constexpr int oc_block = 8; - constexpr int remain_w = 0; - - constexpr int oc_step = 4; - constexpr int ic_step = 1; - constexpr int loop_ic_step = 1; - constexpr int pack_iw_len = 4; - - const int ic_stride = ih * iw * pack_iw_len; - const int ld_weight_oc = oc_step * filter_size * filter_size * ic; - const size_t weight_step = filter_size * filter_size * pack_iw_len; - constexpr int c_dim = OCHelper::val; - - int32x4_t c[c_dim][4]; - init_ocx_ow4(c, bias_ptr, oc_step); - uint8x16_t vtbl = vld1q_u8(src_idx_buffer); - for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { - const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride; - const int8_t* nchw_src_ptr_last_line = - src_ptr + ic_idx * ic_stride + - 2 * iw * ic_step * pack_iw_len; - const int8_t* weight_ptr_oc = weight_ptr + ld_weight_oc; - /** - * r0-r7 c - * r24-r31 temp - * r8-r15 src - * r16-r19 weight - * r20-vtbl - */ - asm volatile( - //! load src 0,1 - "ldp q8,q9, [%[nchw_src_ptr]]\n" - "ldr q16, [%[weight_ptr]]\n" - "ldp q10,q11, [%[nchw_src_ptr], #32]\n" - "add x5, %[weight_ptr], #32\n" - "smull v24.8h, v8.8b, v16.8b\n" - "ldr q17, [%[weight_ptr_oc]]\n" - "smull v25.8h, v9.8b, v16.8b\n" - "add x6, %[weight_ptr_oc], #32\n" - "smull v26.8h, v10.8b, v16.8b\n" - "smull v27.8h, v11.8b, v16.8b\n" - "smlal2 v24.8h, v8.16b, v16.16b\n" - "add x7, %[nchw_src_ptr_last_line], #64\n" - "smlal2 v25.8h, v9.16b, v16.16b\n" - "smlal2 v26.8h, v10.16b, v16.16b\n" - "smlal2 v27.8h, v11.16b, v16.16b\n" - "sadalp %[c00].4s, v24.8h\n" - "smull v28.8h, v8.8b, v17.8b\n" - "ldr d12, [%[nchw_src_ptr],#16]\n" - "sadalp %[c01].4s, v25.8h\n" - "smull v29.8h, v9.8b, v17.8b\n" - "ldr d13, [%[nchw_src_ptr],#32]\n" - "sadalp %[c02].4s, v26.8h\n" - "smull v30.8h, v10.8b, v17.8b\n" - "ldr d14, [%[nchw_src_ptr],#48]\n" - "sadalp %[c03].4s, v27.8h\n" - "smull v31.8h, v11.8b, v17.8b\n" - "ldr d18, [%[weight_ptr],#16]\n" - "smlal2 v28.8h, v8.16b, v17.16b\n" - "ldr d19, [%[weight_ptr_oc],#16]\n" - "smlal2 v29.8h, v9.16b, v17.16b\n" - "ldr d15, [%[nchw_src_ptr],#64]\n" - "smlal2 v30.8h, v10.16b, v17.16b\n" - "ldp q8,q9, [%[nchw_src_ptr_last_line]]\n" - "smull v24.8h, v12.8b, v18.8b\n" - "sadalp %[c10].4s, v28.8h\n" - "smlal2 v31.8h, v11.16b, v17.16b\n" - "ldp q10,q11, [%[nchw_src_ptr_last_line], #32]\n" - "sadalp %[c11].4s, v29.8h\n" - "smull v25.8h, v13.8b, v18.8b\n" - "tbl v8.16b, {v8.16b}, %[vtbl].16b\n" - "sadalp %[c12].4s, v30.8h\n" - "smull v26.8h, v14.8b, v18.8b\n" - "ldr d16, [%[weight_ptr],#24]\n" - "sadalp %[c13].4s, v31.8h\n" - "ldr d17, [%[weight_ptr_oc],#24]\n" - "smull v27.8h, v15.8b, v18.8b\n" - "tbl v9.16b, {v9.16b}, %[vtbl].16b\n" - "sadalp %[c00].4s, v24.8h\n" - "smull v28.8h, v12.8b, v19.8b\n" - "tbl v10.16b, {v10.16b}, %[vtbl].16b\n" - "sadalp %[c01].4s, v25.8h\n" - "smull v29.8h, v13.8b, v19.8b\n" - "tbl v11.16b, {v11.16b}, %[vtbl].16b\n" - "sadalp %[c02].4s, v26.8h\n" - "smull v30.8h, v14.8b, v19.8b\n" - "ld1r {v18.2s}, [x5]\n" - "sadalp %[c03].4s, v27.8h\n" - "smull v31.8h, v15.8b, v19.8b\n" - "ld1r {v19.2s}, [x6]\n" - "sadalp %[c10].4s, v28.8h\n" - "smull v24.8h, v8.8b, v16.8b\n" - "sadalp %[c11].4s, v29.8h\n" - "smull v25.8h, v9.8b, v16.8b\n" - "dup v12.8b, v9.b[0]\n" - "sadalp %[c12].4s, v30.8h\n" - "smull v26.8h, v10.8b, v16.8b\n" - "dup v12.8b, v9.b[0]\n" - "sadalp %[c13].4s, v31.8h\n" - "smull v27.8h, v11.8b, v16.8b\n" - "dup v13.8b, v10.b[0]\n" - "smull v28.8h, v8.8b, v17.8b\n" - "dup v14.8b, v11.b[0]\n" - "sadalp %[c00].4s, v24.8h\n" - "smull v29.8h, v9.8b, v17.8b\n" - "ld1r {v15.8b}, [x7]\n" - "sadalp %[c01].4s, v25.8h\n" - "smull v30.8h, v10.8b, v17.8b\n" - "sxtl v12.8h, v12.8b\n" - "sxtl v18.8h, v18.8b\n" - "sadalp %[c02].4s, v26.8h\n" - "smull v31.8h, v11.8b, v17.8b\n" - "sxtl v13.8h, v13.8b\n" - "sadalp %[c03].4s, v27.8h\n" - "smlal %[c00].4s, v12.4h, v18.4h\n" - "sxtl v14.8h, v14.8b\n" - "sadalp %[c10].4s, v28.8h\n" - "smlal %[c01].4s, v13.4h, v18.4h\n" - "sxtl v15.8h, v15.8b\n" - "sadalp %[c11].4s, v29.8h\n" - "smlal %[c02].4s, v14.4h, v18.4h\n" - "sxtl v19.8h, v19.8b\n" - "sadalp %[c12].4s, v30.8h\n" - "add %[weight_ptr], %[weight_ptr], %[weight_step]\n" - "smlal %[c03].4s, v15.4h, v18.4h\n" - "sadalp %[c13].4s, v31.8h\n" - "smlal %[c10].4s, v12.4h, v19.4h\n" - "smlal %[c11].4s, v13.4h, v19.4h\n" - "smlal %[c12].4s, v14.4h, v19.4h\n" - "smlal %[c13].4s, v15.4h, v19.4h\n" - : - - [c00] "+w"(c[0][0]), [c10] "+w"(c[1][0]), - [c01] "+w"(c[0][1]), [c11] "+w"(c[1][1]), - [c02] "+w"(c[0][2]), [c12] "+w"(c[1][2]), - [c03] "+w"(c[0][3]), [c13] "+w"(c[1][3]), - - [weight_ptr] "+r"(weight_ptr), - [weight_ptr_oc] "+r"(weight_ptr_oc) - : [vtbl] "w"(vtbl), [nchw_src_ptr] "r"(nchw_src_ptr), - [nchw_src_ptr_last_line] "r"(nchw_src_ptr_last_line), - [weight_step] "r"(weight_step) - : "x5", "x6", "x7", "v8", "v9", "v10", "v11", "v12", "v13", - "v14", "v15", "v16", "v17", "v18", "v19", "v24", "v25", - "v26", "v27", "v28", "v29", "v30", "v31", "cc", "memory"); - } - store_ocx_ow4_remain_static(c, op, dst_ptr, ld_dst_oc); - } -}; -#endif - -template -struct KerNeonXXs2NchwNchw44 { - static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, - const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, - int iw, int ld_dst_oc, const Op& op) { - constexpr int filter_size = 2; - constexpr int oc_step = 4; - constexpr int loop_ic_step = 1; - constexpr int pack_iw_len = 4; +namespace int8_direct_nchw_nchw44 { - const int ic_stride = ih * iw * pack_iw_len; - const int ld_weight_oc = oc_step * filter_size * filter_size * ic; - constexpr int c_dim = OCHelper::val; - - int32x4_t c[c_dim][4]; - init_ocx_ow4(c, bias_ptr, oc_step); - - for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { - const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride; - int8x16_t src[4]; - int8x16_t dot4_weight[c_dim][1]; - int16x8_t temp_c[4]; - load_helper<1, 0, 16, c_dim, Vld1q_s8>(dot4_weight, weight_ptr, - ld_weight_oc); - load_helper<4, 0, 16, 0, Vld1q_s8>(src, nchw_src_ptr, 0); - cal_helper<0, 0, c_dim, Vdotq_s32_h, stride>(c, src, dot4_weight, - temp_c); - weight_ptr += oc_step * filter_size * filter_size; - } - store_ocx_ow4_remain_static(c, op, dst_ptr, ld_dst_oc); - } -}; - -template -struct KerNeonXXs2NchwNchw44 { - static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, - const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, - int iw, int ld_dst_oc, const Op& op) { - constexpr int stride = 1; - constexpr int filter_height = 2; - constexpr int filter_width = 4; - constexpr int oc_step = 4; - constexpr int loop_ic_step = 1; - constexpr int simd_len = 16; - constexpr int pack_iw_len = 16; - constexpr int src_reg = 8; - constexpr int weight_reg = 1; - - const int ic_stride = ih * iw * pack_iw_len; - const int ld_weight_oc = oc_step * filter_height * filter_width * ic; - constexpr int c_dim = OCHelper::val; - int32x4_t c[c_dim][8]; - init_ocx_ow8(c, bias_ptr, oc_step); - - for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { - const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride; - int8x16_t src[src_reg]; - int8x16_t dot4_weight[c_dim][weight_reg]; - int16x8_t temp_c[4]; - load_helper( - dot4_weight, weight_ptr, ld_weight_oc); - load_helper( - src, nchw_src_ptr + 0 * iw * pack_iw_len, 0); - cal_helper<0, 0, c_dim, Vdotq_s32_h, stride>(c, src, dot4_weight, - temp_c); - - load_helper( - dot4_weight, weight_ptr + 1 * filter_width * oc_step, - ld_weight_oc); - load_helper( - src, nchw_src_ptr + 1 * iw * pack_iw_len, 0); - cal_helper<0, 0, c_dim, Vdotq_s32_h, stride>(c, src, dot4_weight, - temp_c); - - weight_ptr += oc_step * filter_height * filter_width; - } - - store_ocx_ow8_remain_static_dt( - c, op, dst_ptr, ld_dst_oc); - } -}; - -template -struct KerNeonXXs2NchwNchw44 { - static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, - const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, - int iw, int ld_dst_oc, const Op& op) { - constexpr int stride = 1; - constexpr int filter_height = 3; - constexpr int filter_width = 4; - constexpr int oc_step = 4; - constexpr int loop_ic_step = 1; - constexpr int simd_len = 16; - constexpr int pack_iw_len = 16; - constexpr int src_reg = 8; - constexpr int weight_reg = 1; - - const int ic_stride = ih * iw * pack_iw_len; - const int ld_weight_oc = oc_step * filter_height * filter_width * ic; - constexpr int c_dim = OCHelper::val; - int32x4_t c[c_dim][8]; - init_ocx_ow8(c, bias_ptr, oc_step); - - for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { - const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride; - int8x16_t src[src_reg]; - int8x16_t dot4_weight[c_dim][weight_reg]; - int16x8_t temp_c[4]; - load_helper( - dot4_weight, weight_ptr, ld_weight_oc); - - load_helper( - src, nchw_src_ptr + 0 * iw * pack_iw_len, 0); - cal_helper<0, 0, c_dim, Vdotq_s32_h, stride>(c, src, dot4_weight, - temp_c); - load_helper( - dot4_weight, weight_ptr + 1 * filter_width * oc_step, - ld_weight_oc); - - load_helper( - src, nchw_src_ptr + 1 * iw * pack_iw_len, 0); - cal_helper<0, 0, c_dim, Vdotq_s32_h, stride>(c, src, dot4_weight, - temp_c); - - load_helper( - dot4_weight, weight_ptr + 2 * filter_width * oc_step, - ld_weight_oc); - load_helper( - src, nchw_src_ptr + 2 * iw * pack_iw_len, 0); - cal_helper<0, 0, c_dim, Vdotq_s32_h, stride>(c, src, dot4_weight, - temp_c); - - weight_ptr += oc_step * filter_height * filter_width; - } - store_ocx_ow8_remain_static_dt( - c, op, dst_ptr, ld_dst_oc); - } -}; - -template -struct KerNeonXXs2NchwNchw44 { - static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, - const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, - int iw, int ld_dst_oc, const Op& op) { - constexpr int stride = 1; - constexpr int filter_height = 5; - constexpr int filter_width = 8; - constexpr int oc_step = 4; - constexpr int loop_ic_step = 1; - constexpr int simd_len = 16; - constexpr int pack_iw_len = 16; - constexpr int src_reg = 8; - constexpr int weight_reg = 2; - - const int ic_stride = ih * iw * pack_iw_len; - const int ld_weight_oc = oc_step * filter_height * filter_width * ic; - constexpr int c_dim = OCHelper::val; - int32x4_t c[c_dim][8]; - init_ocx_ow8(c, bias_ptr, oc_step); - - for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { - const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride; - int8x16_t src[src_reg]; - int8x16_t dot4_weight[c_dim][weight_reg]; - int16x8_t temp_c[4]; -#define cb(step) \ - load_helper( \ - dot4_weight, weight_ptr + step * filter_width * oc_step, \ - ld_weight_oc); \ - load_helper( \ - src, nchw_src_ptr + step * iw * pack_iw_len, 0); \ - cal_helper<0, 0, c_dim, Vdotq_s32_h, stride>(c, src, dot4_weight, temp_c); \ - load_helper<4, 0, simd_len, 0, Vld1q_s8>( \ - src, \ - nchw_src_ptr + step * iw * pack_iw_len + src_reg * pack_iw_len, \ - 0); \ - cal_helper<4, 1, c_dim, Vdotq_s32_h, stride>(c, src, dot4_weight, temp_c); - UNROLL_CALL_RAW(5, cb); -#undef cb - weight_ptr += oc_step * filter_height * filter_width; - } - store_ocx_ow8_remain_static_dt( - c, op, dst_ptr, ld_dst_oc); - } -}; - -template -struct KerNeonXXs2NchwNchw44 { - static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, - const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, - int iw, int ld_dst_oc, const Op& op) { - constexpr int stride = 1; - constexpr int filter_height = 7; - constexpr int filter_width = 8; - constexpr int oc_step = 4; - constexpr int loop_ic_step = 1; - constexpr int simd_len = 16; - constexpr int pack_iw_len = 16; - constexpr int src_reg = 8; - constexpr int weight_reg = 2; - - const int ic_stride = ih * iw * pack_iw_len; - const int ld_weight_oc = oc_step * filter_height * filter_width * ic; - constexpr int c_dim = OCHelper::val; - int32x4_t c[c_dim][8]; - init_ocx_ow8(c, bias_ptr, oc_step); - - for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { - const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride; - int8x16_t src[src_reg]; - int8x16_t dot4_weight[c_dim][weight_reg]; - int16x8_t temp_c[4]; -#define cb(step) \ - load_helper( \ - dot4_weight, weight_ptr + step * filter_width * oc_step, \ - ld_weight_oc); \ - load_helper( \ - src, nchw_src_ptr + step * iw * pack_iw_len, 0); \ - cal_helper<0, 0, c_dim, Vdotq_s32_h, stride>(c, src, dot4_weight, temp_c); \ - load_helper<4, 0, simd_len, 0, Vld1q_s8>( \ - src, \ - nchw_src_ptr + step * iw * pack_iw_len + src_reg * pack_iw_len, \ - 0); \ - cal_helper<4, 1, c_dim, Vdotq_s32_h, stride>(c, src, dot4_weight, temp_c); - - UNROLL_CALL_RAW(7, cb); -#undef cb - weight_ptr += oc_step * filter_height * filter_width; - } - store_ocx_ow8_remain_static_dt( - c, op, dst_ptr, ld_dst_oc); - } -}; - -enum PACK_MODE { NO_PAD = 0, FIRST_PAD = 1, LAST_PAD = 2 }; -template -MEGDNN_ALWAYS_INLINE void pack_src_one_line(const int8_t* inptr, int8_t* outptr, - int left_pad, int right_pad, - const int iw) { - const int8_t* src_row_0 = inptr; - const int8_t* src_row_1 = inptr + iw; - constexpr int combine_row = 2; - constexpr int iw_step = 16; - constexpr int src_expand = 4; - constexpr int out_gap = iw_step * src_expand; - const int iw_end = iw / iw_step * iw_step; - - memset(outptr, 0, combine_row * left_pad * src_expand * sizeof(int8_t)); - outptr += combine_row * left_pad * src_expand; - - for (int iw_idx = 0; iw_idx < iw_end; iw_idx += iw_step) { - int8x16_t row0 = vld1q_s8(src_row_0 + iw_idx); - int8x16_t row1 = vdupq_n_s8(0); - if (mode == PACK_MODE::NO_PAD) { - row1 = vld1q_s8(src_row_1 + iw_idx); - } else if (mode == PACK_MODE::FIRST_PAD) { - row1 = row0; - row0 = vdupq_n_s8(0); - } - int8x16x2_t pack_rows = vzipq_s8(row0, row1); -#define STORE_8S8(step) \ - vst1_s8(outptr + step * 8, \ - vreinterpret_s8_s16(vdup_laneq_s16( \ - vreinterpretq_s16_s8(pack_rows.val[0]), step))); - - UNROLL_CALL_RAW(8, STORE_8S8); -#undef STORE_8S8 -#define STORE_8S8(step) \ - vst1_s8(outptr + out_gap + step * 8, \ - vreinterpret_s8_s16(vdup_laneq_s16( \ - vreinterpretq_s16_s8(pack_rows.val[1]), step))); - - UNROLL_CALL_RAW(8, STORE_8S8); -#undef STORE_8S8 - outptr += out_gap * combine_row; - } - for (int iw_idx = iw_end; iw_idx < iw; iw_idx++) { - int8x8_t row0 = vld1_dup_s8(src_row_0 + iw_idx); - int8x8_t row1 = vdup_n_s8(0); - if (mode == PACK_MODE::NO_PAD) { - row1 = vld1_dup_s8(src_row_1 + iw_idx); - } else if (mode == PACK_MODE::FIRST_PAD) { - row1 = row0; - row0 = vdup_n_s8(0); - } - int8x8x2_t pack_rows = vzip_s8(row0, row1); - vst1_s8(outptr, pack_rows.val[0]); - outptr += src_expand * combine_row; - } - memset(outptr, 0, combine_row * right_pad * src_expand * sizeof(int8_t)); - outptr += combine_row * right_pad * src_expand; -} template void pack_nchw_src_for_nchw44_conv(const int8_t* inptr, int8_t* outptr, const int ic, const int top_pad, @@ -1391,509 +30,19 @@ void pack_nchw_src_for_nchw44_conv(const int8_t* inptr, int8_t* outptr, const int right_pad, const int ih, const int iw, const int iw2, const int pw, int8_t* temp_ptr); -/** - * pack (ic, h, w) to (ic, h / 2, 2 * w) - * pack interleave two adjacent row in src and repeat 4 times, store to one row - * */ -template <> -void pack_nchw_src_for_nchw44_conv<2>(const int8_t* inptr, int8_t* outptr, - const int ic, const int top_pad, - const int bottom_pad, const int left_pad, - const int right_pad, const int ih, - const int iw, const int, const int, - int8_t*) { - constexpr int src_expand = 4; - constexpr int oh_step = 2; - const int oh = ih + top_pad + bottom_pad; - const int oh_end = div_floor(ih + top_pad, oh_step) * oh_step; - const int ow = (iw + left_pad + right_pad) * src_expand; - - for (int ic_idx = 0; ic_idx < ic; ++ic_idx) { - int oh_idx = 0; - for (; oh_idx < top_pad; oh_idx += oh_step) { - if (top_pad - oh_idx >= oh_step) { - memset(outptr, 0, oh_step * ow * sizeof(int8_t)); - } else { - pack_src_one_line(inptr, outptr, left_pad, - right_pad, iw); - inptr += iw; - } - outptr += oh_step * ow; - } - - for (; oh_idx < oh_end; oh_idx += oh_step) { - pack_src_one_line(inptr, outptr, left_pad, - right_pad, iw); - inptr += oh_step * iw; - outptr += oh_step * ow; - } - - for (; oh_idx < oh; oh_idx += oh_step) { - const int last_pad = oh_idx - ih - top_pad; - if (last_pad >= 0) { - memset(outptr, 0, oh_step * ow * sizeof(int8_t)); - } else { - pack_src_one_line(inptr, outptr, left_pad, - right_pad, iw); - inptr += iw; - } - outptr += oh_step * ow; - } - } -} -/** - * pack (ic, h, w) to (ic, h, w * 16) - * pack interleave two adjacent row in src and repeat 4 times, store to one row - * */ -template <> -void pack_nchw_src_for_nchw44_conv<1>(const int8_t* sptr_origin, - int8_t* sptr_base, const int ic, - const int pad_top, const int pad_bottom, - const int, const int, const int ih, - const int iw, const int iw2, const int pw, - int8_t* temp_ptr) { - static uint8_t reorder_idx[16] = {0, 1, 0, 1, 0, 1, 0, 1, - 2, 3, 2, 3, 2, 3, 2, 3}; - uint8x16_t tbl_idx = vld1q_u8(&reorder_idx[0]); - - constexpr int iw_step = 4; - constexpr int pack_iw_len = 16; - const int ic_stride = ih * iw; - const int iw_with_pad = iw + 2 * pw; - const int iw_with_pad_end = iw_with_pad / iw_step * iw_step; - rep(ic_idx, ic) { - const int8_t* sptr = sptr_origin + ic_idx * ic_stride; - memset(sptr_base, 0, - sizeof(int8_t) * iw2 * (ih + pad_top + pad_bottom) * - pack_iw_len); - sptr_base += iw2 * pad_top * pack_iw_len; - rep(ih_idx, ih) { - memset(temp_ptr, 0, iw_with_pad * sizeof(int8_t)); - memcpy(temp_ptr + pw, sptr, sizeof(int8_t) * iw); - for (int iw_idx = 0; iw_idx < iw_with_pad_end; iw_idx += iw_step) { - int8x16_t src[4]; - int8x16_t dst[4]; - src[0] = vld1q_s8(temp_ptr + iw_idx); - src[1] = vld1q_s8(temp_ptr + iw_idx + 1); - src[2] = vld1q_s8(temp_ptr + iw_idx + 2); - src[3] = vld1q_s8(temp_ptr + iw_idx + 3); - dst[0] = vqtbl1q_s8(src[0], tbl_idx); - dst[1] = vqtbl1q_s8(src[1], tbl_idx); - dst[2] = vqtbl1q_s8(src[2], tbl_idx); - dst[3] = vqtbl1q_s8(src[3], tbl_idx); - vst1q_s8(sptr_base + iw_idx * pack_iw_len + 0, dst[0]); - vst1q_s8(sptr_base + iw_idx * pack_iw_len + 16, dst[1]); - vst1q_s8(sptr_base + iw_idx * pack_iw_len + 32, dst[2]); - vst1q_s8(sptr_base + iw_idx * pack_iw_len + 48, dst[3]); - } - for (int iw_idx = iw_with_pad_end; iw_idx < iw_with_pad; ++iw_idx) { - int8x16_t src = vld1q_s8(temp_ptr + iw_idx); - int8x16_t dst = vqtbl1q_s8(src, tbl_idx); - vst1q_s8(sptr_base + iw_idx * pack_iw_len, dst); - } - sptr_base += iw2 * pack_iw_len; - sptr += iw; - } - sptr_base += iw2 * pad_bottom * pack_iw_len; - } -} template void pack_nchw44_weight_for_nchw_conv(const int8_t* inptr, int8_t* outptr, const int ic, const int fh, const int fw, const int oc); -/** - * pack {oc / 4, fh, fw, ic, 4(oc)} to {oc / 4, ic, fh * fw, 4(oc)} - * pack interleave two adjacent row in filter to one row - * */ -template <> -void pack_nchw44_weight_for_nchw_conv<2>(const int8_t* inptr, int8_t* outptr, - const int ic, const int fh, - const int fw, const int oc) { - constexpr int oc_step = 4; - constexpr int ic_step = 2; - constexpr int fh_step = 2; - constexpr int fw_step = 2; - const int ic_end = ic / ic_step * ic_step; - const int ic_remain = ic - ic_end; - const int fh_end = fh / fh_step * fh_step; - const int fh_remain = fh - fh_end; - const int fw_end = fw / fw_step * fw_step; - const int fw_remain = fw - fw_end; - const int filter_stride = ic * oc_step; - static const uint8_t ic2_idx_h_buffer[16] = {0, 8, 1, 9, 2, 10, 3, 11, - 4, 12, 5, 13, 6, 14, 7, 15}; - uint8x16_t ic2_idx_h = vld1q_u8(ic2_idx_h_buffer); - for (int oc_idx = 0; oc_idx < oc; oc_idx += oc_step) { - for (int ic_idx = 0; ic_idx < ic_end; ic_idx += ic_step) { - const int ic_offset = ic_idx * oc_step; - int8_t* output_ic0 = outptr + ic_idx * fh * fw * oc_step; - int8_t* output_ic1 = output_ic0 + fh * fw * oc_step; - for (int fh_idx = 0; fh_idx < fh_end; fh_idx += fh_step) { - const int fh_offset = fh_idx * fw * filter_stride; - for (int fw_idx = 0; fw_idx < fw; ++fw_idx) { - const int8_t* filter_ptr = inptr + fh_offset + - fw_idx * filter_stride + - ic_offset; - int8x8_t row_0 = vld1_s8(filter_ptr); - int8x8_t row_1 = vld1_s8(filter_ptr + fw * filter_stride); - int8x16_t combine_row = vcombine_s8(row_0, row_1); - combine_row = vqtbl1q_s8(combine_row, ic2_idx_h); - vst1_s8(output_ic0, vget_low_s8(combine_row)); - vst1_s8(output_ic1, vget_high_s8(combine_row)); - output_ic0 += 8; - output_ic1 += 8; - } - } - if (fh_remain > 0) { - const int fh_offset = fh_end * fw * filter_stride; - for (int fw_idx = 0; fw_idx < fw_end; fw_idx += fw_step) { - const int8_t* filter_ptr = inptr + fh_offset + - fw_idx * filter_stride + - ic_offset; - int8x8_t row_0 = vld1_s8(filter_ptr); - int8x8_t row_1 = vld1_s8(filter_ptr + filter_stride); - int8x16_t combine_row = vcombine_s8(row_0, row_1); - combine_row = vqtbl1q_s8(combine_row, ic2_idx_h); - vst1_s8(output_ic0, vget_low_s8(combine_row)); - vst1_s8(output_ic1, vget_high_s8(combine_row)); - output_ic0 += 8; - output_ic1 += 8; - } - if (fw_remain > 0) { - const int8_t* filter_ptr = inptr + fh_offset + - fw_end * filter_stride + - ic_offset; - int8x8_t row_0 = vld1_s8(filter_ptr); - vst1_lane_s32((int32_t*)output_ic0, - vreinterpret_s32_s8(row_0), 0); - vst1_lane_s32((int32_t*)output_ic1, - vreinterpret_s32_s8(row_0), 1); - output_ic0 += 4; - output_ic1 += 4; - } - } - } - if (ic_remain > 0) { - const int ic_offset = ic_end * oc_step; - int8_t* output_ic0 = outptr + ic_end * fh * fw * oc_step; - for (int fh_idx = 0; fh_idx < fh_end; fh_idx += fh_step) { - const int fh_offset = fh_idx * fw * filter_stride; - for (int fw_idx = 0; fw_idx < fw; ++fw_idx) { - const int8_t* filter_ptr = inptr + fh_offset + - fw_idx * filter_stride + - ic_offset; - int8x8_t row_0 = vreinterpret_s8_s32( - vld1_dup_s32((const int32_t*)(filter_ptr))); - int8x8_t row_1 = vreinterpret_s8_s32(vld1_dup_s32( - (const int32_t*)(filter_ptr + fw * filter_stride))); - int8x16_t combine_row = vcombine_s8(row_0, row_1); - combine_row = vqtbl1q_s8(combine_row, ic2_idx_h); - vst1_s8(output_ic0, vget_low_s8(combine_row)); - output_ic0 += 8; - } - } - if (fh_remain > 0) { - const int fh_offset = fh_end * fw * filter_stride; - for (int fw_idx = 0; fw_idx < fw_end; fw_idx += fw_step) { - const int8_t* filter_ptr = inptr + fh_offset + - fw_idx * filter_stride + - ic_offset; - int8x8_t row_0 = vreinterpret_s8_s32( - vld1_dup_s32((const int32_t*)(filter_ptr))); - int8x8_t row_1 = vreinterpret_s8_s32(vld1_dup_s32( - (const int32_t*)(filter_ptr + filter_stride))); - int8x16_t combine_row = vcombine_s8(row_0, row_1); - combine_row = vqtbl1q_s8(combine_row, ic2_idx_h); - vst1_s8(output_ic0, vget_low_s8(combine_row)); - output_ic0 += 8; - } - if (fw_remain > 0) { - const int8_t* filter_ptr = inptr + fh_offset + - fw_end * filter_stride + - ic_offset; - *(int32_t*)(output_ic0) = *(const int32_t*)(filter_ptr); - output_ic0 += 4; - } - } - } - inptr += oc_step * fh * fw * ic; - outptr += oc_step * fh * fw * ic; - } -} -/** - * pack {oc / 4, fh, fw, ic, 4(oc)} to {oc / 4, ic, fh ,fw/4, 4(oc)*4(fw)} - * pack interleave two adjacent row in filter to one row - * */ -template <> -void pack_nchw44_weight_for_nchw_conv<1>(const int8_t* src_ptr, int8_t* dst_ptr, - const int ic, const int fh, - const int fw, const int oc) { - constexpr int oc_step = 4; - const int fw2 = round_up(fw, 4); - const int fw_remain = fw2 - fw; - const int dst_ic_stride = fh * fw2; - const int oc_step_stride = fh * fw2 * ic * oc_step; - static const uint8_t transpose_4x4_idx[16] = {0, 4, 1, 5, 2, 6, 3, 7, - 8, 12, 9, 13, 10, 14, 11, 15}; - uint8x16_t tbl_transpose_4x4 = vld1q_u8(&transpose_4x4_idx[0]); - rep_step(oc_idx, oc, oc_step) { - int32_t* dst_temp_ptr = - reinterpret_cast(dst_ptr + oc_idx * ic * fh * fw2); - const int32_t* src_temp_ptr = reinterpret_cast( - src_ptr + oc_idx * ic * fh * fw); - // transpose ic and pad - rep(fh_idx, fh) { - rep(fw_idx, fw) { - rep(ic_idx, ic) { - *(dst_temp_ptr + ic_idx * dst_ic_stride) = *src_temp_ptr; - src_temp_ptr++; - } - dst_temp_ptr++; - } - rep(ic_idx, ic) { - memset(dst_temp_ptr + ic_idx * dst_ic_stride, 0, - sizeof(int8_t) * oc_step * fw_remain); - } - dst_temp_ptr += fw_remain; - } - // transpose fw oc - int8_t* trans_dst_temp_ptr = - reinterpret_cast(dst_ptr + oc_idx * ic * fh * fw2); - rep_step(idx, oc_step_stride, 16) { - int8x16_t temp = vld1q_s8(trans_dst_temp_ptr + idx); - vst1q_s8(trans_dst_temp_ptr + idx, - vqtbl1q_s8(temp, tbl_transpose_4x4)); - } - } -}; template struct ConvDiectStrideInt8NchwNchw44 { static void impl(const int8_t* src, const int8_t* filter, const int32_t* bias, int32_t* temp, int8_t* dst, const size_t oc, const size_t ic, const size_t ih, const size_t iw, const size_t oh, const size_t ow, - const Op& op) { - MEGDNN_MARK_USED_VAR(temp); - constexpr size_t fh = filter_size; - constexpr size_t fw = - stride == 2 ? filter_size : (filter_size + 3) / 4 * 4; - constexpr size_t ic_step = 1; - constexpr size_t big_oc_step = 8; - constexpr size_t oc_step = 4; - constexpr size_t ih_step = stride == 2 ? 2 : 1; - constexpr size_t oh_step = 1; - constexpr size_t ow_step = stride == 2 ? 4 : 8; - constexpr size_t stride_h = stride; - constexpr size_t stride_w = stride; - constexpr int pack_iw_len = 4; - - const size_t img_stride = oh * ow; - const size_t ow_end = ow / ow_step * ow_step; - const size_t ow_remain = ow - ow_end; - const size_t oc_end = oc / big_oc_step * big_oc_step; - const size_t oc_remain = oc - oc_end; - const int ld_dst_oc = oc_step * img_stride; - - using remain_fun = std::function; - remain_fun kern_big_oc_remain = nullptr; - remain_fun kern_small_oc_remain = nullptr; - switch (ow_remain) { -#define cb(step) \ - case step: \ - kern_big_oc_remain = \ - KerNeonXXs2NchwNchw44::impl; \ - kern_small_oc_remain = \ - KerNeonXXs2NchwNchw44::impl; \ - break; - - UNROLL_CALL_RAW(4, cb); - default: - megdnn_assert(0, "no remain %zu for kern", ow_remain); - } - - for (size_t oc_idx = 0; oc_idx < oc_end; oc_idx += big_oc_step) { - const size_t weight_offset = oc_idx * ic * fh * fw; - for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { - for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { - const size_t src_offset = (oh_idx * stride_h * iw + - ow_idx * stride_w * ih_step) * - ic_step * pack_iw_len; - const size_t dst_offset = oc_idx * img_stride + - (oh_idx * ow + ow_idx) * oc_step; - KerNeonXXs2NchwNchw44::impl(src + src_offset, - filter + weight_offset, - bias + oc_idx, - dst + dst_offset, ic, - ih, iw, ld_dst_oc, op); - } - if (ow_remain > 0) { - const size_t src_offset = (oh_idx * stride_h * iw + - ow_end * stride_w * ih_step) * - ic_step * pack_iw_len; - const size_t dst_offset = oc_idx * img_stride + - (oh_idx * ow + ow_end) * oc_step; - kern_big_oc_remain(src + src_offset, filter + weight_offset, - bias + oc_idx, dst + dst_offset, ic, ih, - iw, ld_dst_oc, op); - } - } - } - if (oc_remain > 0) { - size_t oc_idx = oc_end; - const size_t weight_offset = oc_idx * ic * fh * fw; - for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { - for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { - const size_t src_offset = (oh_idx * stride_h * iw + - ow_idx * stride_w * ih_step) * - ic_step * pack_iw_len; - const size_t dst_offset = oc_idx * img_stride + - (oh_idx * ow + ow_idx) * oc_step; - KerNeonXXs2NchwNchw44::impl(src + src_offset, - filter + weight_offset, - bias + oc_idx, - dst + dst_offset, ic, - ih, iw, ld_dst_oc, op); - } - if (ow_remain > 0) { - const size_t src_offset = (oh_idx * stride_h * iw + - ow_end * stride_w * ih_step) * - ic_step * pack_iw_len; - const size_t dst_offset = oc_idx * img_stride + - (oh_idx * ow + ow_end) * oc_step; - kern_small_oc_remain(src + src_offset, - filter + weight_offset, bias + oc_idx, - dst + dst_offset, ic, ih, iw, - ld_dst_oc, op); - } - } - } - } -}; - -template -struct ConvDiectStrideInt8NchwNchw44 { - static void impl(const int8_t* src, const int8_t* filter, - const int32_t* bias, int32_t* temp, int8_t* dst, - const size_t oc, const size_t ic, const size_t ih, - const size_t iw, const size_t oh, const size_t ow, - const Op& op) { - MEGDNN_MARK_USED_VAR(temp); - constexpr int stride = 1; - constexpr size_t fh = filter_size; - constexpr size_t fw = (filter_size + 3) / 4 * 4; - constexpr size_t ic_step = 1; - constexpr size_t big_oc_step = 8; - constexpr size_t oc_step = 4; - constexpr size_t ih_step = 1; - constexpr size_t oh_step = 1; - constexpr size_t ow_step = 8; - constexpr size_t stride_h = stride; - constexpr size_t stride_w = stride; - constexpr int pack_iw_len = 16; - - const size_t img_stride = oh * ow; - const size_t ow_end = ow / ow_step * ow_step; - const size_t ow_remain = ow - ow_end; - const size_t oc_end = oc / big_oc_step * big_oc_step; - const size_t oc_remain = oc - oc_end; - const int ld_dst_oc = oc_step * img_stride; - - using remain_fun = std::function; - remain_fun kern_big_oc_remain = nullptr; - remain_fun kern_small_oc_remain = nullptr; - switch (ow_remain) { -#define cb(step) \ - case step: \ - kern_big_oc_remain = \ - KerNeonXXs2NchwNchw44::impl; \ - kern_small_oc_remain = \ - KerNeonXXs2NchwNchw44::impl; \ - break; - - UNROLL_CALL_RAW(8, cb); - default: - megdnn_assert(0, "no remain %zu for kern", ow_remain); - } - - for (size_t oc_idx = 0; oc_idx < oc_end; oc_idx += big_oc_step) { - const size_t weight_offset = oc_idx * ic * fh * fw; - for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { - for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { - const size_t src_offset = (oh_idx * stride_h * iw + - ow_idx * stride_w * ih_step) * - ic_step * pack_iw_len; - const size_t dst_offset = oc_idx * img_stride + - (oh_idx * ow + ow_idx) * oc_step; - - KerNeonXXs2NchwNchw44::impl(src + src_offset, - filter + weight_offset, - bias + oc_idx, - dst + dst_offset, ic, - ih, iw, ld_dst_oc, op); - } - if (ow_remain > 0) { - const size_t src_offset = (oh_idx * stride_h * iw + - ow_end * stride_w * ih_step) * - ic_step * pack_iw_len; - const size_t dst_offset = oc_idx * img_stride + - (oh_idx * ow + ow_end) * oc_step; - kern_big_oc_remain(src + src_offset, filter + weight_offset, - bias + oc_idx, dst + dst_offset, ic, ih, - iw, ld_dst_oc, op); - } - } - } - - if (oc_remain > 0) { - size_t oc_idx = oc_end; - const size_t weight_offset = oc_idx * ic * fh * fw; - for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { - for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { - const size_t src_offset = (oh_idx * stride_h * iw + - ow_idx * stride_w * ih_step) * - ic_step * pack_iw_len; - const size_t dst_offset = oc_idx * img_stride + - (oh_idx * ow + ow_idx) * oc_step; - KerNeonXXs2NchwNchw44::impl(src + src_offset, - filter + weight_offset, - bias + oc_idx, - dst + dst_offset, ic, - ih, iw, ld_dst_oc, op); - } - if (ow_remain > 0) { - const size_t src_offset = (oh_idx * stride_h * iw + - ow_end * stride_w * ih_step) * - ic_step * pack_iw_len; - const size_t dst_offset = oc_idx * img_stride + - (oh_idx * ow + ow_end) * oc_step; - kern_small_oc_remain(src + src_offset, - filter + weight_offset, bias + oc_idx, - dst + dst_offset, ic, ih, iw, - ld_dst_oc, op); - } - } - } - } + const Op& op); }; template @@ -1908,7 +57,7 @@ static void conv_direct_int8_nchw_nchw44(const int8_t* src, src, filter, bias, temp, dst, oc, ic, ih, iw, oh, ow, op); } -} // namespace +} // namespace int8_direct_nchw_nchw44 } // namespace arm_common } // namespace megdnn diff --git a/dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_algo.cpp b/dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_algo.cpp index 798dc967..25b6f955 100644 --- a/dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_algo.cpp +++ b/dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_algo.cpp @@ -93,8 +93,8 @@ void do_weight_trans(const WorkspaceBundle& bundle, const int fw2 = round_up(fw, 4); auto packed_weight = reinterpret_cast(bundle.get(1)); auto origin_weight = kern_param.filter(); - pack_weight_int8_nchw_nchw44_dot(packed_weight, origin_weight, oc, ic, fh, - fw, fw2); + dot_direct_nchw_nchw44::pack_weight_int8_nchw_nchw44_dot( + packed_weight, origin_weight, oc, ic, fh, fw, fw2); } template @@ -147,7 +147,7 @@ static void do_conv_kern(const WorkspaceBundle& bundle, tmp_ptr = reinterpret_cast(bundle.get(2)) + ncb_index.thread_id * tmp_size; } - pack_src_int8_nchw_nchw44_dot( + dot_direct_nchw_nchw44::pack_src_int8_nchw_nchw44_dot( sptr, origin_sptr, ph, pw, remain_right_pad, ih_real - src_top_pad - src_bottom_pad, iw, iw2, src_top_pad, src_bottom_pad, ic, ih * iw, tmp_ptr); @@ -164,7 +164,8 @@ static void do_conv_kern(const WorkspaceBundle& bundle, float scale_bias = kern_param.bias_type.param().scale; float scale_dst = kern_param.dst_type.param().scale; Op op(scale_bias, scale_dst); - conv_direct_int8_nchw_nchw44_dot( + dot_direct_nchw_nchw44::conv_direct_int8_nchw_nchw44_dot( sptr, fptr, bptr, nullptr, dst, oc_block, ic, ih_real, iw2, oh, oh_block_real, ow, op); } diff --git a/dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_kern.h b/dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_kern.h index 6bc57ebb..3529d907 100644 --- a/dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_kern.h +++ b/dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_kern.h @@ -20,83 +20,15 @@ #include "src/common/utils.h" #include "src/fallback/conv_bias/common.h" -using namespace megdnn; -using namespace arm_common; -namespace { +namespace megdnn { +namespace arm_common { +namespace dot_direct_nchw_nchw44 { template struct ShiftCalHelper { static void impl(T& c, T2& src, T3& weight); }; -template -struct ShiftCalHelper { - static void impl(T& c, T2& src, T3& weight) { -#define cb(step) \ - c[0][step * 2] = Func::template impl<(src_idx + step) % 4>( \ - c[0][step * 2], weight[0][weight_idx], \ - src[0][(src_idx + step) / 4]); \ - c[1][step * 2] = Func::template impl<(src_idx + step) % 4>( \ - c[1][step * 2], weight[1][weight_idx], \ - src[0][(src_idx + step) / 4]); \ - c[0][step * 2 + 1] = Func::template impl<(src_idx + step) % 4>( \ - c[0][step * 2 + 1], weight[0][weight_idx], \ - src[1][(src_idx + step) / 4]); \ - c[1][step * 2 + 1] = Func::template impl<(src_idx + step) % 4>( \ - c[1][step * 2 + 1], weight[1][weight_idx], \ - src[1][(src_idx + step) / 4]); - - UNROLL_CALL_RAW(4, cb); -#undef cb - } -}; - -template -struct ShiftCalHelper { - static void impl(T& c, T2& src, T3& weight) { -#define cb(step) \ - c[0][step * 2] = Func::template impl<(src_idx + step) % 4>( \ - c[0][step * 2], weight[0][weight_idx], \ - src[0][(src_idx + step) / 4]); \ - c[0][step * 2 + 1] = Func::template impl<(src_idx + step) % 4>( \ - c[0][step * 2 + 1], weight[0][weight_idx], \ - src[1][(src_idx + step) / 4]); - - UNROLL_CALL_RAW(4, cb); -#undef cb - } -}; - -template -struct ShiftCalHelper { - static void impl(T& c, T2& src, T3& weight) { -#define cb(step) \ - c[0][step] = Func::template impl<(src_idx + step) % 4>( \ - c[0][step], weight[0][weight_idx], src[(src_idx + step) / 4]); \ - c[1][step] = Func::template impl<(src_idx + step) % 4>( \ - c[1][step], weight[1][weight_idx], src[(src_idx + step) / 4]); - - UNROLL_CALL_RAW(8, cb); -#undef cb - } -}; - -template -struct ShiftCalHelper { - static void impl(T& c, T2& src, T3& weight) { -#define cb(step) \ - c[0][step] = Func::template impl<(src_idx + step) % 4>( \ - c[0][step], weight[0][weight_idx], src[(src_idx + step) / 4]); - - UNROLL_CALL_RAW(8, cb); -#undef cb - } -}; - template inline void cal_helper(T& c, T2& src, T3& weight) { @@ -133,490 +65,12 @@ struct KerNeonDotXXs2Nchw44Int8 { int iw, int ld_dst_oc, const Op& op); }; -template -struct KerNeonDotXXs2Nchw44Int8 { - static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, - const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, - int iw, int ld_dst_oc, const Op& op) { - constexpr int filter_hight = 2; - constexpr int filter_width = 4; - constexpr int weight_reg = 1; - constexpr int src_reg = 1; - - constexpr int oc_step = 4; - constexpr int ic_step = 1; - constexpr int pack_iw_len = 1; - constexpr int simd_len = 16; - - const int ld_bias = oc_step; - const int ic_stride = ih * iw * pack_iw_len; - const int ld_weight_oc = oc_step * filter_hight * filter_width * ic; - constexpr int c_dim = OCHelper::val; - - int32x4_t c[c_dim][8]; - init_ocx_ow8(c, bias_ptr, ld_bias); - for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { - int8x16_t src[2][src_reg]; - int8x16_t weight[c_dim][weight_reg]; - // row 0 - load_helper( - src, src_ptr + 0 * iw, stride); - load_helper( - weight, weight_ptr, ld_weight_oc); - cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, - weight); - // row 1 - load_helper( - src, src_ptr + 1 * iw, stride); - load_helper( - weight, weight_ptr, ld_weight_oc); - cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, - weight); - - src_ptr += ic_stride; - weight_ptr += filter_hight * filter_width * oc_step; - } - store_ocx_ow8_remain_static_dt( - c, op, dst_ptr, ld_dst_oc); - } -}; - -template -struct KerNeonDotXXs2Nchw44Int8 { - static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, - const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, - int iw, int ld_dst_oc, const Op& op) { - constexpr int filter_hight = 3; - constexpr int filter_width = 4; - constexpr int weight_reg = 1; - constexpr int src_reg = 1; - - constexpr int oc_step = 4; - constexpr int ic_step = 1; - constexpr int pack_iw_len = 1; - constexpr int simd_len = 16; - - const int ld_bias = oc_step; - const int ic_stride = ih * iw * pack_iw_len; - const int ld_weight_oc = oc_step * filter_hight * filter_width * ic; - constexpr int c_dim = OCHelper::val; - - int32x4_t c[c_dim][8]; - init_ocx_ow8(c, bias_ptr, ld_bias); - for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { - int8x16_t src[2][src_reg]; - int8x16_t weight[c_dim][weight_reg]; - // row 0 - load_helper( - src, src_ptr + 0 * iw, stride); - load_helper( - weight, weight_ptr, ld_weight_oc); - cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, - weight); - // row 1 - load_helper( - src, src_ptr + 1 * iw, stride); - load_helper( - weight, weight_ptr, ld_weight_oc); - cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, - weight); - // row 2 - load_helper( - src, src_ptr + 2 * iw, stride); - load_helper( - weight, weight_ptr, ld_weight_oc); - cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, - weight); - - src_ptr += ic_stride; - weight_ptr += filter_hight * filter_width * oc_step; - } - store_ocx_ow8_remain_static_dt( - c, op, dst_ptr, ld_dst_oc); - } -}; - -template -struct KerNeonDotXXs2Nchw44Int8 { - static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, - const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, - int iw, int ld_dst_oc, const Op& op) { - constexpr int filter_hight = 5; - constexpr int filter_width = 8; - constexpr int src_reg = 2; - constexpr int weight_reg = 2; - - constexpr int oc_step = 4; - constexpr int ic_step = 1; - constexpr int pack_iw_len = 1; - constexpr int simd_len = 16; - - const int ld_bias = oc_step; - const int ic_stride = ih * iw * pack_iw_len; - const int ld_weight_oc = oc_step * filter_hight * filter_width * ic; - constexpr int c_dim = OCHelper::val; - - int32x4_t c[c_dim][8]; - init_ocx_ow8(c, bias_ptr, ld_bias); - - for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { - int8x16_t src[2][src_reg]; - int8x16_t weight[c_dim][weight_reg]; -#define cb(step) \ - load_helper(src, src_ptr + step * iw, \ - stride); \ - load_helper( \ - weight, weight_ptr + step * 2 * simd_len, ld_weight_oc); \ - cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, \ - weight); \ - cal_helper<1, 1, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, weight); - UNROLL_CALL_RAW(5, cb); -#undef cb - src_ptr += ic_stride; - weight_ptr += 5 * 32; - } - store_ocx_ow8_remain_static_dt( - c, op, dst_ptr, ld_dst_oc); - } -}; - -/** - * oc = 8, ow = 8 - * dot 4 element, pad last filter and do twice dot every row filter, filter like - * below - * -------------------------- - * |x, x, x, x,| x, x, x, 0 | - * -------------------------- - **/ -template -struct KerNeonDotXXs2Nchw44Int8 { - static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, - const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, - int iw, int ld_dst_oc, const Op& op) { - constexpr int filter_hight = 7; - constexpr int filter_width = 8; - constexpr int src_reg = 2; - constexpr int weight_reg = 2; - - constexpr int oc_step = 4; - constexpr int ic_step = 1; - constexpr int pack_iw_len = 1; - constexpr int simd_len = 16; - - const int ld_bias = oc_step; - const int ic_stride = ih * iw * pack_iw_len; - const int ld_weight_oc = oc_step * filter_hight * filter_width * ic; - constexpr int c_dim = OCHelper::val; - - int32x4_t c[c_dim][8]; - init_ocx_ow8(c, bias_ptr, ld_bias); - - for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { - int8x16_t src[2][src_reg]; - int8x16_t weight[c_dim][weight_reg]; -#define cb(step) \ - load_helper(src, src_ptr + step * iw, \ - stride); \ - load_helper( \ - weight, weight_ptr + step * 2 * simd_len, ld_weight_oc); \ - cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, \ - weight); \ - cal_helper<1, 1, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, weight); - UNROLL_CALL_RAW(7, cb); -#undef cb - src_ptr += ic_stride; - weight_ptr += 7 * 32; - } - store_ocx_ow8_remain_static_dt( - c, op, dst_ptr, ld_dst_oc); - } -}; -////////////////////stride 1/////////////////// -template -struct KerNeonDotXXs2Nchw44Int8 { - static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, - const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, - int iw, int ld_dst_oc, const Op& op) { - constexpr int stride = 1; - constexpr int filter_hight = 2; - constexpr int filter_width = 4; - constexpr int weight_reg = 2; - constexpr int src_reg = 2; - - constexpr int oc_step = 4; - constexpr int ic_step = 1; - constexpr int pack_iw_len = 4; - constexpr int simd_len = 16; - - const int ld_bias = oc_step; - const int ic_stride = ih * iw * pack_iw_len; - const int ld_weight_oc = oc_step * filter_hight * filter_width * ic; - constexpr int c_dim = OCHelper::val; - - int32x4_t c[c_dim][8]; - init_ocx_ow8(c, bias_ptr, ld_bias); - for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { - int8x16_t src[src_reg]; - int8x16_t weight[c_dim][weight_reg]; - // row 0 - load_helper( - src, src_ptr + 0 * iw * pack_iw_len, 0); - load_helper( - weight, weight_ptr, ld_weight_oc); - cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, - weight); - // row 1 - load_helper( - src, src_ptr + 1 * iw * pack_iw_len, 0); - cal_helper<0, 1, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, - weight); - - src_ptr += ic_stride; - weight_ptr += filter_hight * filter_width * oc_step; - } - store_ocx_ow8_remain_static_dt( - c, op, dst_ptr, ld_dst_oc); - } -}; -template -struct KerNeonDotXXs2Nchw44Int8 { - static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, - const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, - int iw, int ld_dst_oc, const Op& op) { - constexpr int stride = 1; - constexpr int filter_hight = 3; - constexpr int filter_width = 4; - constexpr int weight_reg = 3; - constexpr int src_reg = 2; - - constexpr int oc_step = 4; - constexpr int ic_step = 1; - constexpr int pack_iw_len = 4; - constexpr int simd_len = 16; - - const int ld_bias = oc_step; - const int ic_stride = ih * iw * pack_iw_len; - const int ld_weight_oc = oc_step * filter_hight * filter_width * ic; - constexpr int c_dim = OCHelper::val; - - int32x4_t c[c_dim][8]; - init_ocx_ow8(c, bias_ptr, ld_bias); - for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { - int8x16_t src[src_reg]; - int8x16_t weight[c_dim][weight_reg]; - // row 0 - load_helper( - src, src_ptr + 0 * iw * pack_iw_len, 0); - load_helper( - weight, weight_ptr, ld_weight_oc); - cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, - weight); - // row 1 - load_helper( - src, src_ptr + 1 * iw * pack_iw_len, 0); - cal_helper<0, 1, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, - weight); - // row 2 - load_helper( - src, src_ptr + 2 * iw * pack_iw_len, 0); - cal_helper<0, 2, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, - weight); - - src_ptr += ic_stride; - weight_ptr += filter_hight * filter_width * oc_step; - } - store_ocx_ow8_remain_static_dt( - c, op, dst_ptr, ld_dst_oc); - } -}; - -template -struct KerNeonDotXXs2Nchw44Int8 { - static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, - const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, - int iw, int ld_dst_oc, const Op& op) { - constexpr int stride = 1; - constexpr int filter_hight = 5; - constexpr int filter_width = 8; - constexpr int src_reg = 3; - constexpr int weight_reg = 2; - - constexpr int oc_step = 4; - constexpr int ic_step = 1; - constexpr int pack_iw_len = 4; - constexpr int simd_len = 16; - - const int ld_bias = oc_step; - const int ic_stride = ih * iw * pack_iw_len; - const int ld_weight_oc = oc_step * filter_hight * filter_width * ic; - constexpr int c_dim = OCHelper::val; - - int32x4_t c[c_dim][8]; - init_ocx_ow8(c, bias_ptr, ld_bias); - - for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { - int8x16_t src[src_reg]; - int8x16_t weight[c_dim][weight_reg]; -#define cb(step) \ - load_helper( \ - src, src_ptr + step * iw * pack_iw_len, 0); \ - load_helper( \ - weight, weight_ptr + step * 2 * simd_len, ld_weight_oc); \ - cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, \ - weight); \ - cal_helper<4, 1, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, weight); - - UNROLL_CALL_RAW(5, cb); -#undef cb - src_ptr += ic_stride; - weight_ptr += filter_hight * filter_width * oc_step; - } - store_ocx_ow8_remain_static_dt( - c, op, dst_ptr, ld_dst_oc); - } -}; - -template -struct KerNeonDotXXs2Nchw44Int8 { - static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, - const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, - int iw, int ld_dst_oc, const Op& op) { - constexpr int stride = 1; - constexpr int filter_hight = 7; - constexpr int filter_width = 8; - constexpr int src_reg = 3; - constexpr int weight_reg = 2; - - constexpr int oc_step = 4; - constexpr int ic_step = 1; - constexpr int pack_iw_len = 4; - constexpr int simd_len = 16; - - const int ld_bias = oc_step; - const int ic_stride = ih * iw * pack_iw_len; - const int ld_weight_oc = oc_step * filter_hight * filter_width * ic; - constexpr int c_dim = OCHelper::val; - - int32x4_t c[c_dim][8]; - init_ocx_ow8(c, bias_ptr, ld_bias); - - for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { - int8x16_t src[src_reg]; - int8x16_t weight[c_dim][weight_reg]; -#define cb(step) \ - load_helper( \ - src, src_ptr + step * iw * pack_iw_len, 0); \ - load_helper( \ - weight, weight_ptr + step * 2 * simd_len, ld_weight_oc); \ - cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, \ - weight); \ - cal_helper<4, 1, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, weight); - - UNROLL_CALL_RAW(7, cb); -#undef cb - src_ptr += ic_stride; - weight_ptr += filter_hight * filter_width * oc_step; - } - store_ocx_ow8_remain_static_dt( - c, op, dst_ptr, ld_dst_oc); - } -}; - template void pack_src_int8_nchw_nchw44_dot(int8_t* sptr_base, const int8_t* sptr_origin, const int, const int pw, const int, const int ih, const int iw, const int iw2, const int pad_top, const int pad_bottom, - const int ic, const int ic_stride, int8_t*) { - constexpr int ic_step = 1; - rep_step(ic_idx, ic, ic_step) { - const int8_t* sptr = sptr_origin + ic_idx * ic_stride; - memset(sptr_base, 0, - sizeof(int8_t) * ic_step * iw2 * (ih + pad_top + pad_bottom)); - sptr_base += iw2 * pad_top * ic_step; - rep(ih_idx, ih) { - memcpy(sptr_base + pw * ic_step, sptr, - sizeof(int8_t) * iw * ic_step); - sptr_base += iw2 * ic_step; - sptr += iw * ic_step; - } - sptr_base += iw2 * pad_bottom * ic_step; - } -} - -template <> -void pack_src_int8_nchw_nchw44_dot<1>(int8_t* sptr_base, - const int8_t* sptr_origin, const int, - const int pw, const int, const int ih, - const int iw, const int iw2, - const int pad_top, const int pad_bottom, - const int ic, const int ic_stride, - int8_t* temp_ptr) { - static uint8_t reorder_idx[16] = {0, 1, 2, 3, 1, 2, 3, 4, - 2, 3, 4, 5, 3, 4, 5, 6}; - uint8x16_t tbl_idx = vld1q_u8(&reorder_idx[0]); - - constexpr int iw_step = 16; - constexpr int pack_iw_len = 4; - const int iw_with_pad = iw + 2 * pw; - const int iw_with_pad_end = iw_with_pad / iw_step * iw_step; - rep(ic_idx, ic) { - const int8_t* sptr = sptr_origin + ic_idx * ic_stride; - memset(sptr_base, 0, - sizeof(int8_t) * iw2 * (ih + pad_top + pad_bottom) * - pack_iw_len); - sptr_base += iw2 * pad_top * pack_iw_len; - rep(ih_idx, ih) { - memset(temp_ptr, 0, iw_with_pad * sizeof(int8_t)); - memcpy(temp_ptr + pw, sptr, sizeof(int8_t) * iw); - for (int iw_idx = 0; iw_idx < iw_with_pad_end; iw_idx += iw_step) { - int8x16_t src[4]; - int8x16_t dst[4]; - src[0] = vld1q_s8(temp_ptr + iw_idx); - src[1] = vld1q_s8(temp_ptr + iw_idx + 4); - src[2] = vld1q_s8(temp_ptr + iw_idx + 8); - src[3] = vld1q_s8(temp_ptr + iw_idx + 12); - dst[0] = vqtbl1q_s8(src[0], tbl_idx); - dst[1] = vqtbl1q_s8(src[1], tbl_idx); - dst[2] = vqtbl1q_s8(src[2], tbl_idx); - dst[3] = vqtbl1q_s8(src[3], tbl_idx); - vst1q_s8(sptr_base + iw_idx * pack_iw_len + 0, dst[0]); - vst1q_s8(sptr_base + iw_idx * pack_iw_len + 16, dst[1]); - vst1q_s8(sptr_base + iw_idx * pack_iw_len + 32, dst[2]); - vst1q_s8(sptr_base + iw_idx * pack_iw_len + 48, dst[3]); - } - for (int iw_idx = iw_with_pad_end; iw_idx < iw_with_pad; ++iw_idx) { - *(sptr_base + iw_idx * pack_iw_len + 0) = - *(temp_ptr + iw_idx + 0); - *(sptr_base + iw_idx * pack_iw_len + 1) = - *(temp_ptr + iw_idx + 1); - *(sptr_base + iw_idx * pack_iw_len + 2) = - *(temp_ptr + iw_idx + 2); - *(sptr_base + iw_idx * pack_iw_len + 3) = - *(temp_ptr + iw_idx + 3); - } - sptr_base += iw2 * pack_iw_len; - sptr += iw; - } - sptr_base += iw2 * pad_bottom * pack_iw_len; - } -} + const int ic, const int ic_stride, int8_t*); static inline void pack_weight_int8_nchw_nchw44_dot(int8_t* dst_ptr, const int8_t* src_ptr, @@ -663,117 +117,15 @@ static inline void pack_weight_int8_nchw_nchw44_dot(int8_t* dst_ptr, } template -static void conv_direct_int8_nchw_nchw44_dot( - const int8_t* src, const int8_t* filter, const int32_t* bias, - int32_t* temp, int8_t* dst, const int oc, const int ic, const int ih, - const int iw, const int oh, const int oh_block, const int ow, - const Op& op) { - MEGDNN_MARK_USED_VAR(temp); - constexpr int fh = filter_size; - constexpr int fw = (filter_size + 3) / 4 * 4; -#if MEGDNN_AARCH64 - constexpr int big_oc_step = 8; -#else - constexpr int big_oc_step = 4; -#endif - constexpr int oc_step = 4; - constexpr int ih_step = 1; - constexpr int oh_step = 1; - constexpr int ow_step = 8; - constexpr int stride_h = stride; - constexpr int stride_w = stride; - constexpr int pack_iw_len = stride == 2 ? 1 : 4; - - const int img_stride = oh * ow; - const int ow_end = ow / ow_step * ow_step; - const int ow_remain = ow - ow_end; - const int oc_end = oc / big_oc_step * big_oc_step; - const int oc_remain = oc - oc_end; - const int ld_dst_oc = oc_step * img_stride; - - using remain_fun = - std::function; - remain_fun kern_big_oc_remain = nullptr; - remain_fun kern_small_oc_remain = nullptr; - switch (ow_remain) { -#define cb(step) \ - case step: \ - kern_big_oc_remain = \ - KerNeonDotXXs2Nchw44Int8::impl; \ - kern_small_oc_remain = \ - KerNeonDotXXs2Nchw44Int8::impl; \ - break; - - UNROLL_CALL_RAW(8, cb); - default: - megdnn_assert(0, "no remain %d for kern", ow_remain); - } - - for (int oc_idx = 0; oc_idx < oc_end; oc_idx += big_oc_step) { - const int weight_offset = oc_idx * ic * fh * fw; - for (int oh_idx = 0; oh_idx < oh_block; oh_idx += oh_step) { - for (int ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { - const int src_offset = - (oh_idx * stride_h * iw + ow_idx * stride_w * ih_step) * - pack_iw_len; - const int dst_offset = - oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; - KerNeonDotXXs2Nchw44Int8::impl(src + src_offset, - filter + weight_offset, - bias + oc_idx, - dst + dst_offset, ic, ih, - iw, ld_dst_oc, op); - } - if (ow_remain > 0) { - const int src_offset = - (oh_idx * stride_h * iw + ow_end * stride_w * ih_step) * - pack_iw_len; - const int dst_offset = - oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; - kern_big_oc_remain(src + src_offset, filter + weight_offset, - bias + oc_idx, dst + dst_offset, ic, ih, iw, - ld_dst_oc, op); - } - } - } - if (oc_remain > 0) { - int oc_idx = oc_end; - const int weight_offset = oc_idx * ic * fh * fw; - for (int oh_idx = 0; oh_idx < oh_block; oh_idx += oh_step) { - for (int ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { - const int src_offset = - (oh_idx * stride_h * iw + ow_idx * stride_w * ih_step) * - pack_iw_len; - const int dst_offset = - oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; - KerNeonDotXXs2Nchw44Int8::impl(src + src_offset, - filter + weight_offset, - bias + oc_idx, - dst + dst_offset, ic, ih, - iw, ld_dst_oc, op); - } - if (ow_remain > 0) { - const int src_offset = - (oh_idx * stride_h * iw + ow_end * stride_w * ih_step) * - pack_iw_len; - const int dst_offset = - oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; - kern_small_oc_remain(src + src_offset, filter + weight_offset, - bias + oc_idx, dst + dst_offset, ic, ih, - iw, ld_dst_oc, op); - } - } - } -} - -} // namespace +void conv_direct_int8_nchw_nchw44_dot(const int8_t* src, const int8_t* filter, + const int32_t* bias, int32_t* temp, + int8_t* dst, const int oc, const int ic, + const int ih, const int iw, const int oh, + const int oh_block, const int ow, + const Op& op); + +} // namespace dot_direct_nchw_nchw44 +} // namespace arm_common +} // namespace megdnn #endif // vim: syntax=cpp.doxygen diff --git a/dnn/test/arm_common/conv_bias_multi_thread.cpp b/dnn/test/arm_common/conv_bias_multi_thread.cpp index 7f5f9907..1d1476df 100644 --- a/dnn/test/arm_common/conv_bias_multi_thread.cpp +++ b/dnn/test/arm_common/conv_bias_multi_thread.cpp @@ -2344,7 +2344,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_F32) { #endif std::vector gemv_args; for (auto&& arg : args) - if(arg.src.shape[2] == 1 && arg.src.shape[3] == 1) { + if (arg.src.shape[2] == 1 && arg.src.shape[3] == 1) { gemv_args.emplace_back(arg); } check_conv_bias(gemv_args, handle(), "CONV1x1_GEMV"); @@ -2361,7 +2361,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_MK4_PACK_F32) { #endif std::vector gemv_args; for (auto&& arg : args) - if(arg.src.shape[2] == 1 && arg.src.shape[3] == 1) { + if (arg.src.shape[2] == 1 && arg.src.shape[3] == 1) { gemv_args.emplace_back(arg); } check_conv_bias(gemv_args, handle(), "CONV1x1_GEMV"); diff --git a/dnn/test/arm_common/matrix_mul.cpp b/dnn/test/arm_common/matrix_mul.cpp index faf4d816..73db7dcd 100644 --- a/dnn/test/arm_common/matrix_mul.cpp +++ b/dnn/test/arm_common/matrix_mul.cpp @@ -6,7 +6,8 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #include "test/arm_common/fixture.h" @@ -30,8 +31,7 @@ TEST_F(ARM_COMMON, MATRIX_MUL_INT8x8x16) { TEST_F(ARM_COMMON, MATRIX_MUL_QUINT8) { matrix_mul::check_matrix_mul(dtype::Quantized8Asymm(1.2f, (uint8_t)127), - dtype::Quantized8Asymm(1.3f, (uint8_t)129), - {}, + dtype::Quantized8Asymm(1.3f, (uint8_t)129), {}, handle()); } @@ -232,8 +232,7 @@ TEST_F(ARM_COMMON, QINT8x8x32_GEVM) { Checker checker(handle()); using Param = MatrixMul::Param; - checker.set_before_exec_callback( - AlgoChecker("ARM_COMMON_GEVM")); + checker.set_before_exec_callback(AlgoChecker("ARM_COMMON_GEVM")); std::unique_ptr rng = std::make_unique(-127, 127); checker.set_rng(0, rng.get()).set_rng(1, rng.get()); @@ -251,7 +250,7 @@ TEST_F(ARM_COMMON, QINT8x8x32_GEVM) { .set_dtype(2, dtype::QuantizedS32(6.25f)) .execs({A, B, {}}); }; - + // M = 1 for (size_t N : {1, 10, 16, 33, 64}) for (size_t K : {7, 512, 1024}) @@ -263,8 +262,7 @@ TEST_F(ARM_COMMON, FP32_GEVM) { Checker checker(handle()); using Param = MatrixMul::Param; - checker.set_before_exec_callback( - AlgoChecker("ARM_COMMON_GEVM")); + checker.set_before_exec_callback(AlgoChecker("ARM_COMMON_GEVM")); checker.set_epsilon(1e-2); auto run = [&](size_t M, size_t K, size_t N) { @@ -276,7 +274,7 @@ TEST_F(ARM_COMMON, FP32_GEVM) { B = TensorShape{N, K}; checker.set_param(param).execs({A, B, {}}); }; - + // M = 1 for (size_t M : {1}) for (size_t K : {1000, 4096, 25088}) @@ -298,15 +296,15 @@ TEST_F(ARM_COMMON, FP32_GEMV_MK4) { param.transposeA = false; param.transposeB = false; TensorShape A, B; - A = TensorShape{M/4, K/4, 4, 4}; - B = TensorShape{K/4, 1, 4}; + A = TensorShape{M / 4, K / 4, 4, 4}; + B = TensorShape{K / 4, 1, 4}; checker.set_param(param).execs({A, B, {}}); }; - + // N = 1 for (size_t M : {4, 16, 128, 1024}) for (size_t K : {4, 8, 12, 128, 256, 4096}) - run(M, K); + run(M, K); } #if MEGDNN_WITH_BENCHMARK @@ -343,7 +341,7 @@ TEST_F(ARM_COMMON, BENCHMARK_SGEMV) { for (size_t M : {4, 64, 1024, 4096}) for (size_t K : {128, 256, 1024, 4096}) - run(M, K, 1); + run(M, K, 1); } TEST_F(ARM_COMMON, BENCHMARK_SGEMV_FP32) { @@ -372,7 +370,7 @@ TEST_F(ARM_COMMON, BENCHMARK_SGEMV_FP32) { .exec({{2, 1024}, {1024, 512}, {}}); benchmarker.set_display(true); } - + // run gemv run(12, 48, 1); run(48, 12, 1); @@ -396,14 +394,14 @@ TEST_F(ARM_COMMON, BENCHMARK_SGEMV_MK4) { Benchmarker benchmarker(handle()); benchmarker.set_times(exec_times); benchmarker.set_dtype(0, dtype::Float32()) - .set_dtype(1, dtype::Float32()) - .set_param(param); + .set_dtype(1, dtype::Float32()) + .set_param(param); auto run = [&](size_t M, size_t K) { - printf("SGEMV_MK4: (%zu, %zu, %zu)\n", M, K, N); + printf("SGEMV_MK4: (%zu, %zu)\n", M, K); TensorShape A, B; - A = TensorShape{M/4, K/4, 4, 4}; - B = TensorShape{K/4, 1, 4}; + A = TensorShape{M / 4, K / 4, 4, 4}; + B = TensorShape{K / 4, 1, 4}; auto time = benchmarker.exec({A, B, {}}) / exec_times; auto computations = 2.f * M * K * 1e-6; auto perf = computations / time; @@ -422,7 +420,7 @@ TEST_F(ARM_COMMON, BENCHMARK_SGEMV_MK4) { // run gemv mk4 for (size_t M : {4, 64, 1024, 4096}) for (size_t K : {128, 1024, 4096}) - run(M, K); + run(M, K); } TEST_F(ARM_COMMON, BENCHMARK_SGEMV_FP16) { @@ -490,7 +488,7 @@ TEST_F(ARM_COMMON, BENCHMARK_SGEMM) { //////////////////////// gemv ////////////////////////// for (size_t M : {8, 64, 112, 256}) { for (size_t K : {8, 64, 112, 256}) { - run (M, 1, K); + run(M, 1, K); } } @@ -502,10 +500,8 @@ TEST_F(ARM_COMMON, BENCHMARK_SGEMM) { } } } - } - TEST_F(ARM_COMMON, BENCHMARK_MATRIX_MUL_INT8x8x32) { constexpr size_t RUNS = 50; param::MatrixMul param; @@ -514,7 +510,8 @@ TEST_F(ARM_COMMON, BENCHMARK_MATRIX_MUL_INT8x8x32) { .set_dtype(0, dtype::Int8{}) .set_dtype(1, dtype::Int8{}) .set_dtype(2, dtype::Int32{}) - .set_param(param).set_display(false); + .set_param(param) + .set_display(false); Benchmarker benchmarker_float(handle()); benchmarker_float.set_display(false).set_times(RUNS); @@ -533,7 +530,7 @@ TEST_F(ARM_COMMON, BENCHMARK_MATRIX_MUL_INT8x8x32) { //////////////////////// gemv ////////////////////////// for (size_t M : {8, 64, 112, 256}) { for (size_t K : {8, 64, 112, 256}) { - run (M, 1, K); + run(M, 1, K); } } @@ -618,5 +615,4 @@ TEST_F(ARM_COMMON, BENCHMARK_TRANSPOSED_MATRIX_MUL_QUINT8) { #endif - // vim: syntax=cpp.doxygen