From 32717b0ca404b2385de83d8537ea99e8cb9bc4ff Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 7 Jan 2022 17:32:44 +0800 Subject: [PATCH] fix(build): split some cpp, which consume two many mem when build make build possible at 8G ddr env, when -j8 GitOrigin-RevId: d0c442b41d6633ef9c6304ecad0b3bddae9908e3 --- ...1.cpp => f32_direct_nchw44_kern_2x2s1_bias.cpp} | 5 +- ...ct_nchw44_kern_2x2s1_broadcast_channel_bias.cpp | 15 + .../f32_direct_nchw44_kern_2x2s1_no_bias.cpp | 15 + ...2.cpp => f32_direct_nchw44_kern_2x2s2_bias.cpp} | 5 +- ...ct_nchw44_kern_2x2s2_broadcast_channel_bias.cpp | 15 + .../f32_direct_nchw44_kern_2x2s2_no_bias.cpp | 15 + ...1.cpp => f32_direct_nchw44_kern_3x3s1_bias.cpp} | 5 +- ...ct_nchw44_kern_3x3s1_broadcast_channel_bias.cpp | 15 + .../f32_direct_nchw44_kern_3x3s1_no_bias.cpp | 15 + ...2.cpp => f32_direct_nchw44_kern_3x3s2_bias.cpp} | 5 +- ...ct_nchw44_kern_3x3s2_broadcast_channel_bias.cpp | 15 + .../f32_direct_nchw44_kern_3x3s2_no_bias.cpp | 15 + ...1.cpp => f32_direct_nchw44_kern_5x5s1_bias.cpp} | 5 +- ...ct_nchw44_kern_5x5s1_broadcast_channel_bias.cpp | 15 + .../f32_direct_nchw44_kern_5x5s1_no_bias.cpp | 15 + ...2.cpp => f32_direct_nchw44_kern_5x5s2_bias.cpp} | 5 +- ...ct_nchw44_kern_5x5s2_broadcast_channel_bias.cpp | 15 + .../f32_direct_nchw44_kern_5x5s2_no_bias.cpp | 15 + ...1.cpp => f32_direct_nchw44_kern_7x7s1_bias.cpp} | 5 +- ...ct_nchw44_kern_7x7s1_broadcast_channel_bias.cpp | 15 + .../f32_direct_nchw44_kern_7x7s1_no_bias.cpp | 15 + ...2.cpp => f32_direct_nchw44_kern_7x7s2_bias.cpp} | 5 +- ...ct_nchw44_kern_7x7s2_broadcast_channel_bias.cpp | 15 + .../f32_direct_nchw44_kern_7x7s2_no_bias.cpp | 15 + .../f32_direct_nchw44_kern_common_s1.h | 13 +- .../f32_direct_nchw44_kern_common_s2.h | 13 +- ... => f32_direct_nchw_nchw44_kern_2x2s1_bias.cpp} | 5 +- ...hw_nchw44_kern_2x2s1_broadcast_channel_bias.cpp | 15 + .../f32_direct_nchw_nchw44_kern_2x2s1_no_bias.cpp | 15 + ... => f32_direct_nchw_nchw44_kern_2x2s2_bias.cpp} | 5 +- ...hw_nchw44_kern_2x2s2_broadcast_channel_bias.cpp | 15 + .../f32_direct_nchw_nchw44_kern_2x2s2_no_bias.cpp | 15 + ... => f32_direct_nchw_nchw44_kern_3x3s1_bias.cpp} | 5 +- ...hw_nchw44_kern_3x3s1_broadcast_channel_bias.cpp | 15 + .../f32_direct_nchw_nchw44_kern_3x3s1_no_bias.cpp | 15 + ... => f32_direct_nchw_nchw44_kern_3x3s2_bias.cpp} | 5 +- ...hw_nchw44_kern_3x3s2_broadcast_channel_bias.cpp | 15 + .../f32_direct_nchw_nchw44_kern_3x3s2_no_bias.cpp | 15 + .../f32_direct_nchw_nchw44_kern_5x5s1_bias.cpp | 15 + ...hw_nchw44_kern_5x5s1_broadcast_channel_bias.cpp | 15 + .../f32_direct_nchw_nchw44_kern_5x5s1_no_bias.cpp | 15 + .../f32_direct_nchw_nchw44_kern_5x5s2_bias.cpp | 15 + ...hw_nchw44_kern_5x5s2_broadcast_channel_bias.cpp | 15 + .../f32_direct_nchw_nchw44_kern_5x5s2_no_bias.cpp | 15 + .../f32_direct_nchw_nchw44_kern_7x7s1_bias.cpp | 15 + ...hw_nchw44_kern_7x7s1_broadcast_channel_bias.cpp | 15 + .../f32_direct_nchw_nchw44_kern_7x7s1_no_bias.cpp | 15 + .../f32_direct_nchw_nchw44_kern_7x7s2_bias.cpp | 15 + ...hw_nchw44_kern_7x7s2_broadcast_channel_bias.cpp | 15 + .../f32_direct_nchw_nchw44_kern_7x7s2_no_bias.cpp | 15 + .../f32_direct_nchw_nchw44_kern_common.h | 10 +- ...direct_nchw44_s1.cpp => dot_direct_nchw44_s1.h} | 21 +- .../direct_kernels/dot_direct_nchw44_s1_2x2.cpp | 21 + .../direct_kernels/dot_direct_nchw44_s1_3x3.cpp | 21 + .../direct_kernels/dot_direct_nchw44_s1_5x5.cpp | 21 + .../direct_kernels/dot_direct_nchw44_s1_7x7.cpp | 21 + ...direct_nchw44_s2.cpp => dot_direct_nchw44_s2.h} | 21 +- .../direct_kernels/dot_direct_nchw44_s2_2x2.cpp | 21 + .../direct_kernels/dot_direct_nchw44_s2_3x3.cpp | 21 + .../direct_kernels/dot_direct_nchw44_s2_5x5.cpp | 21 + .../direct_kernels/dot_direct_nchw44_s2_7x7.cpp | 21 + .../int8_direct_nchw_nchw44_common.h | 4 +- .../direct_kernels/int8_direct_nchw_nchw44_s1.cpp | 450 +------------------ .../direct_kernels/int8_direct_nchw_nchw44_s1.h | 481 +++++++++++++++++++++ .../int8_direct_nchw_nchw44_s1_1x1.cpp | 19 + .../int8_direct_nchw_nchw44_s1_2x2.cpp | 19 + .../int8_direct_nchw_nchw44_s1_3x3.cpp | 19 + .../int8_direct_nchw_nchw44_s1_5x5.cpp | 19 + .../int8_direct_nchw_nchw44_s1_7x7.cpp | 19 + ....cpp => direct_nchw_nchw44_kern_impl_2x2s1.cpp} | 7 +- ....cpp => direct_nchw_nchw44_kern_impl_2x2s2.cpp} | 7 +- .../kernel/direct_nchw_nchw44_kern_impl_3x3s1.cpp | 16 + .../kernel/direct_nchw_nchw44_kern_impl_3x3s2.cpp | 16 + .../kernel/direct_nchw_nchw44_kern_impl_5x5s1.cpp} | 6 +- .../kernel/direct_nchw_nchw44_kern_impl_5x5s2.cpp} | 6 +- .../kernel/direct_nchw_nchw44_kern_impl_7x7s1.cpp} | 6 +- .../kernel/direct_nchw_nchw44_kern_impl_7x7s2.cpp} | 6 +- dnn/src/fallback/elemwise/opr_binary_impl.cpp | 275 ++++++++++++ dnn/src/fallback/elemwise/opr_impl.cpp | 252 ----------- dnn/src/fallback/elemwise/opr_unary_impl.cpp | 122 ++++++ dnn/src/naive/elemwise_multi_type/opr_impl_1.cpp | 138 ++++++ dnn/src/naive/elemwise_multi_type/opr_impl_2.cpp | 115 +++++ .../{opr_impl.cpp => opr_impl_3.cpp} | 214 +-------- 83 files changed, 2058 insertions(+), 1004 deletions(-) rename dnn/src/arm_common/conv_bias/fp32/direct_kernels/{f32_direct_nchw44_kern_2x2s1.cpp => f32_direct_nchw44_kern_2x2s1_bias.cpp} (86%) create mode 100644 dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s1_broadcast_channel_bias.cpp create mode 100644 dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s1_no_bias.cpp rename dnn/src/arm_common/conv_bias/fp32/direct_kernels/{f32_direct_nchw44_kern_5x5s2.cpp => f32_direct_nchw44_kern_2x2s2_bias.cpp} (86%) create mode 100644 dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s2_broadcast_channel_bias.cpp create mode 100644 dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s2_no_bias.cpp rename dnn/src/arm_common/conv_bias/fp32/direct_kernels/{f32_direct_nchw44_kern_5x5s1.cpp => f32_direct_nchw44_kern_3x3s1_bias.cpp} (86%) create mode 100644 dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s1_broadcast_channel_bias.cpp create mode 100644 dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s1_no_bias.cpp rename dnn/src/arm_common/conv_bias/fp32/direct_kernels/{f32_direct_nchw44_kern_2x2s2.cpp => f32_direct_nchw44_kern_3x3s2_bias.cpp} (86%) create mode 100644 dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s2_broadcast_channel_bias.cpp create mode 100644 dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s2_no_bias.cpp rename dnn/src/arm_common/conv_bias/fp32/direct_kernels/{f32_direct_nchw44_kern_3x3s1.cpp => f32_direct_nchw44_kern_5x5s1_bias.cpp} (86%) create mode 100644 dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s1_broadcast_channel_bias.cpp create mode 100644 dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s1_no_bias.cpp rename dnn/src/arm_common/conv_bias/fp32/direct_kernels/{f32_direct_nchw44_kern_7x7s2.cpp => f32_direct_nchw44_kern_5x5s2_bias.cpp} (86%) create mode 100644 dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s2_broadcast_channel_bias.cpp create mode 100644 dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s2_no_bias.cpp rename dnn/src/arm_common/conv_bias/fp32/direct_kernels/{f32_direct_nchw44_kern_7x7s1.cpp => f32_direct_nchw44_kern_7x7s1_bias.cpp} (86%) create mode 100644 dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s1_broadcast_channel_bias.cpp create mode 100644 dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s1_no_bias.cpp rename dnn/src/arm_common/conv_bias/fp32/direct_kernels/{f32_direct_nchw44_kern_3x3s2.cpp => f32_direct_nchw44_kern_7x7s2_bias.cpp} (86%) create mode 100644 dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s2_broadcast_channel_bias.cpp create mode 100644 dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s2_no_bias.cpp rename dnn/src/arm_common/conv_bias/fp32/direct_kernels/{f32_direct_nchw_nchw44_kern_2x2s1.cpp => f32_direct_nchw_nchw44_kern_2x2s1_bias.cpp} (86%) create mode 100644 dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s1_broadcast_channel_bias.cpp create mode 100644 dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s1_no_bias.cpp rename dnn/src/arm_common/conv_bias/fp32/direct_kernels/{f32_direct_nchw_nchw44_kern_2x2s2.cpp => f32_direct_nchw_nchw44_kern_2x2s2_bias.cpp} (86%) create mode 100644 dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s2_broadcast_channel_bias.cpp create mode 100644 dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s2_no_bias.cpp rename dnn/src/arm_common/conv_bias/fp32/direct_kernels/{f32_direct_nchw_nchw44_kern_3x3s1.cpp => f32_direct_nchw_nchw44_kern_3x3s1_bias.cpp} (86%) create mode 100644 dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s1_broadcast_channel_bias.cpp create mode 100644 dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s1_no_bias.cpp rename dnn/src/arm_common/conv_bias/fp32/direct_kernels/{f32_direct_nchw_nchw44_kern_3x3s2.cpp => f32_direct_nchw_nchw44_kern_3x3s2_bias.cpp} (86%) create mode 100644 dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s2_broadcast_channel_bias.cpp create mode 100644 dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s2_no_bias.cpp create mode 100644 dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s1_bias.cpp create mode 100644 dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s1_broadcast_channel_bias.cpp create mode 100644 dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s1_no_bias.cpp create mode 100644 dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s2_bias.cpp create mode 100644 dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s2_broadcast_channel_bias.cpp create mode 100644 dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s2_no_bias.cpp create mode 100644 dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s1_bias.cpp create mode 100644 dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s1_broadcast_channel_bias.cpp create mode 100644 dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s1_no_bias.cpp create mode 100644 dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s2_bias.cpp create mode 100644 dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s2_broadcast_channel_bias.cpp create mode 100644 dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s2_no_bias.cpp rename dnn/src/arm_common/conv_bias/int8/direct_kernels/{dot_direct_nchw44_s1.cpp => dot_direct_nchw44_s1.h} (97%) create mode 100644 dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s1_2x2.cpp create mode 100644 dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s1_3x3.cpp create mode 100644 dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s1_5x5.cpp create mode 100644 dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s1_7x7.cpp rename dnn/src/arm_common/conv_bias/int8/direct_kernels/{dot_direct_nchw44_s2.cpp => dot_direct_nchw44_s2.h} (97%) create mode 100644 dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s2_2x2.cpp create mode 100644 dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s2_3x3.cpp create mode 100644 dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s2_5x5.cpp create mode 100644 dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s2_7x7.cpp create mode 100644 dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1.h create mode 100644 dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1_1x1.cpp create mode 100644 dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1_2x2.cpp create mode 100644 dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1_3x3.cpp create mode 100644 dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1_5x5.cpp create mode 100644 dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1_7x7.cpp rename dnn/src/arm_common/conv_bias/int8x8x16/kernel/{direct_nchw_nchw44_kern_impl_s1.cpp => direct_nchw_nchw44_kern_impl_2x2s1.cpp} (83%) rename dnn/src/arm_common/conv_bias/int8x8x16/kernel/{direct_nchw_nchw44_kern_impl_s2.cpp => direct_nchw_nchw44_kern_impl_2x2s2.cpp} (83%) create mode 100644 dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl_3x3s1.cpp create mode 100644 dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl_3x3s2.cpp rename dnn/src/arm_common/conv_bias/{fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s1.cpp => int8x8x16/kernel/direct_nchw_nchw44_kern_impl_5x5s1.cpp} (66%) rename dnn/src/arm_common/conv_bias/{fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s2.cpp => int8x8x16/kernel/direct_nchw_nchw44_kern_impl_5x5s2.cpp} (66%) rename dnn/src/arm_common/conv_bias/{fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s1.cpp => int8x8x16/kernel/direct_nchw_nchw44_kern_impl_7x7s1.cpp} (66%) rename dnn/src/arm_common/conv_bias/{fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s2.cpp => int8x8x16/kernel/direct_nchw_nchw44_kern_impl_7x7s2.cpp} (66%) create mode 100644 dnn/src/fallback/elemwise/opr_binary_impl.cpp create mode 100644 dnn/src/fallback/elemwise/opr_unary_impl.cpp create mode 100644 dnn/src/naive/elemwise_multi_type/opr_impl_1.cpp create mode 100644 dnn/src/naive/elemwise_multi_type/opr_impl_2.cpp rename dnn/src/naive/elemwise_multi_type/{opr_impl.cpp => opr_impl_3.cpp} (56%) diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s1.cpp b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s1_bias.cpp similarity index 86% rename from dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s1.cpp rename to dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s1_bias.cpp index 30874d33..200769d6 100644 --- a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s1.cpp +++ b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s1_bias.cpp @@ -1,6 +1,6 @@ /** * \file - * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s1.cpp + * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s1_bias.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -11,4 +11,5 @@ * implied. */ #include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" -INSTANTIATION_CONV_S1(2); \ No newline at end of file +INSTANTIATION_CONV_S1_BIAS(2); +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s1_broadcast_channel_bias.cpp b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s1_broadcast_channel_bias.cpp new file mode 100644 index 00000000..c6c974a5 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s1_broadcast_channel_bias.cpp @@ -0,0 +1,15 @@ +/** + * \file + * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s1_broadcast_channel_bias.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" +INSTANTIATION_CONV_S1_BROADCAST_CHANNEL_BIAS(2); +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s1_no_bias.cpp b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s1_no_bias.cpp new file mode 100644 index 00000000..6f075a54 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s1_no_bias.cpp @@ -0,0 +1,15 @@ +/** + * \file + * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s1_no_bias.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" +INSTANTIATION_CONV_S1_NO_BIAS(2); +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s2.cpp b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s2_bias.cpp similarity index 86% rename from dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s2.cpp rename to dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s2_bias.cpp index be37fc1c..a9728847 100644 --- a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s2.cpp +++ b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s2_bias.cpp @@ -1,6 +1,6 @@ /** * \file - * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s2.cpp + * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s2_bias.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -11,4 +11,5 @@ * implied. */ #include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" -INSTANTIATION_CONV_S2(5); \ No newline at end of file +INSTANTIATION_CONV_S2_BIAS(2); +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s2_broadcast_channel_bias.cpp b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s2_broadcast_channel_bias.cpp new file mode 100644 index 00000000..ae899e2c --- /dev/null +++ b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s2_broadcast_channel_bias.cpp @@ -0,0 +1,15 @@ +/** + * \file + * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s2_broadcast_channel_bias.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" +INSTANTIATION_CONV_S2_BROADCAST_CHANNEL_BIAS(2); +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s2_no_bias.cpp b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s2_no_bias.cpp new file mode 100644 index 00000000..94c09aea --- /dev/null +++ b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s2_no_bias.cpp @@ -0,0 +1,15 @@ +/** + * \file + * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s2_no_bias.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" +INSTANTIATION_CONV_S2_NO_BIAS(2); +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s1.cpp b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s1_bias.cpp similarity index 86% rename from dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s1.cpp rename to dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s1_bias.cpp index 689df883..0047c51e 100644 --- a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s1.cpp +++ b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s1_bias.cpp @@ -1,6 +1,6 @@ /** * \file - * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s1.cpp + * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s1_bias.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -11,4 +11,5 @@ * implied. */ #include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" -INSTANTIATION_CONV_S1(5); \ No newline at end of file +INSTANTIATION_CONV_S1_BIAS(3); +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s1_broadcast_channel_bias.cpp b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s1_broadcast_channel_bias.cpp new file mode 100644 index 00000000..c273dede --- /dev/null +++ b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s1_broadcast_channel_bias.cpp @@ -0,0 +1,15 @@ +/** + * \file + * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s1_broadcast_channel_bias.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" +INSTANTIATION_CONV_S1_BROADCAST_CHANNEL_BIAS(3); +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s1_no_bias.cpp b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s1_no_bias.cpp new file mode 100644 index 00000000..719dbd1d --- /dev/null +++ b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s1_no_bias.cpp @@ -0,0 +1,15 @@ +/** + * \file + * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s1_no_bias.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" +INSTANTIATION_CONV_S1_NO_BIAS(3); +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s2.cpp b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s2_bias.cpp similarity index 86% rename from dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s2.cpp rename to dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s2_bias.cpp index 355c9ed6..01209f9c 100644 --- a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s2.cpp +++ b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s2_bias.cpp @@ -1,6 +1,6 @@ /** * \file - * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s2.cpp + * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s2_bias.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -11,4 +11,5 @@ * implied. */ #include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" -INSTANTIATION_CONV_S2(2); \ No newline at end of file +INSTANTIATION_CONV_S2_BIAS(3); +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s2_broadcast_channel_bias.cpp b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s2_broadcast_channel_bias.cpp new file mode 100644 index 00000000..7bed53e2 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s2_broadcast_channel_bias.cpp @@ -0,0 +1,15 @@ +/** + * \file + * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s2_broadcast_channel_bias.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" +INSTANTIATION_CONV_S2_BROADCAST_CHANNEL_BIAS(3); +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s2_no_bias.cpp b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s2_no_bias.cpp new file mode 100644 index 00000000..9aa190df --- /dev/null +++ b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s2_no_bias.cpp @@ -0,0 +1,15 @@ +/** + * \file + * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s2_no_bias.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" +INSTANTIATION_CONV_S2_NO_BIAS(3); +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s1.cpp b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s1_bias.cpp similarity index 86% rename from dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s1.cpp rename to dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s1_bias.cpp index 6d99b01b..5cbcb78a 100644 --- a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s1.cpp +++ b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s1_bias.cpp @@ -1,6 +1,6 @@ /** * \file - * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s1.cpp + * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s1_bias.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -11,4 +11,5 @@ * implied. */ #include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" -INSTANTIATION_CONV_S1(3); \ No newline at end of file +INSTANTIATION_CONV_S1_BIAS(5); +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s1_broadcast_channel_bias.cpp b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s1_broadcast_channel_bias.cpp new file mode 100644 index 00000000..bcf92bab --- /dev/null +++ b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s1_broadcast_channel_bias.cpp @@ -0,0 +1,15 @@ +/** + * \file + * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s1_broadcast_channel_bias.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" +INSTANTIATION_CONV_S1_BROADCAST_CHANNEL_BIAS(5); +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s1_no_bias.cpp b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s1_no_bias.cpp new file mode 100644 index 00000000..d944b02b --- /dev/null +++ b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s1_no_bias.cpp @@ -0,0 +1,15 @@ +/** + * \file + * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s1_no_bias.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" +INSTANTIATION_CONV_S1_NO_BIAS(5); +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s2.cpp b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s2_bias.cpp similarity index 86% rename from dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s2.cpp rename to dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s2_bias.cpp index a7571939..a75f159a 100644 --- a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s2.cpp +++ b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s2_bias.cpp @@ -1,6 +1,6 @@ /** * \file - * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s2.cpp + * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s2_bias.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -11,4 +11,5 @@ * implied. */ #include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" -INSTANTIATION_CONV_S2(7); +INSTANTIATION_CONV_S2_BIAS(5); +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s2_broadcast_channel_bias.cpp b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s2_broadcast_channel_bias.cpp new file mode 100644 index 00000000..ff9653ea --- /dev/null +++ b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s2_broadcast_channel_bias.cpp @@ -0,0 +1,15 @@ +/** + * \file + * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s2_broadcast_channel_bias.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" +INSTANTIATION_CONV_S2_BROADCAST_CHANNEL_BIAS(5); +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s2_no_bias.cpp b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s2_no_bias.cpp new file mode 100644 index 00000000..a2705bde --- /dev/null +++ b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s2_no_bias.cpp @@ -0,0 +1,15 @@ +/** + * \file + * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s2_no_bias.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" +INSTANTIATION_CONV_S2_NO_BIAS(5); +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s1.cpp b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s1_bias.cpp similarity index 86% rename from dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s1.cpp rename to dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s1_bias.cpp index db2be268..47cbf3d7 100644 --- a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s1.cpp +++ b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s1_bias.cpp @@ -1,6 +1,6 @@ /** * \file - * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s1.cpp + * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s1_bias.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -11,4 +11,5 @@ * implied. */ #include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" -INSTANTIATION_CONV_S1(7); \ No newline at end of file +INSTANTIATION_CONV_S1_BIAS(7); +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s1_broadcast_channel_bias.cpp b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s1_broadcast_channel_bias.cpp new file mode 100644 index 00000000..f8fa2c29 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s1_broadcast_channel_bias.cpp @@ -0,0 +1,15 @@ +/** + * \file + * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s1_broadcast_channel_bias.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" +INSTANTIATION_CONV_S1_BROADCAST_CHANNEL_BIAS(7); +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s1_no_bias.cpp b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s1_no_bias.cpp new file mode 100644 index 00000000..c7824aad --- /dev/null +++ b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s1_no_bias.cpp @@ -0,0 +1,15 @@ +/** + * \file + * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s1_no_bias.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" +INSTANTIATION_CONV_S1_NO_BIAS(7); +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s2.cpp b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s2_bias.cpp similarity index 86% rename from dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s2.cpp rename to dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s2_bias.cpp index 6075a8ca..fd603f3b 100644 --- a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s2.cpp +++ b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s2_bias.cpp @@ -1,6 +1,6 @@ /** * \file - * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s2.cpp + * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s2_bias.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -11,4 +11,5 @@ * implied. */ #include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" -INSTANTIATION_CONV_S2(3); \ No newline at end of file +INSTANTIATION_CONV_S2_BIAS(7); +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s2_broadcast_channel_bias.cpp b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s2_broadcast_channel_bias.cpp new file mode 100644 index 00000000..bd1e5c29 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s2_broadcast_channel_bias.cpp @@ -0,0 +1,15 @@ +/** + * \file + * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s2_broadcast_channel_bias.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" +INSTANTIATION_CONV_S2_BROADCAST_CHANNEL_BIAS(7); +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s2_no_bias.cpp b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s2_no_bias.cpp new file mode 100644 index 00000000..0caee33c --- /dev/null +++ b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s2_no_bias.cpp @@ -0,0 +1,15 @@ +/** + * \file + * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s2_no_bias.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" +INSTANTIATION_CONV_S2_NO_BIAS(7); +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h index bf9418b4..3915caea 100644 --- a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h +++ b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h @@ -469,9 +469,12 @@ void conv_bias::conv_direct_fp32_nchw44( INSTANTIATION(filter_size, bias, HSwishOp) \ INSTANTIATION(filter_size, bias, SigmoidOp) -#define INSTANTIATION_CONV_S1(filter_size) \ - FOR_OP(filter_size, BiasMode::NO_BIAS) \ - FOR_OP(filter_size, BiasMode::BROADCAST_CHANNEL_BIAS) \ - FOR_OP(filter_size, BiasMode::BIAS) +#define INSTANTIATION_CONV_S1_NO_BIAS(filter_size) \ + FOR_OP(filter_size, BiasMode::NO_BIAS) -// vim: syntax=cpp.doxygen \ No newline at end of file +#define INSTANTIATION_CONV_S1_BROADCAST_CHANNEL_BIAS(filter_size) \ + FOR_OP(filter_size, BiasMode::BROADCAST_CHANNEL_BIAS) + +#define INSTANTIATION_CONV_S1_BIAS(filter_size) FOR_OP(filter_size, BiasMode::BIAS) + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h index b31cd438..cbbf047a 100644 --- a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h +++ b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h @@ -550,9 +550,12 @@ void conv_bias::conv_direct_fp32_nchw44( INSTANTIATION(filter_size, bias, HSwishOp) \ INSTANTIATION(filter_size, bias, SigmoidOp) -#define INSTANTIATION_CONV_S2(filter_size) \ - FOR_OP(filter_size, BiasMode::NO_BIAS) \ - FOR_OP(filter_size, BiasMode::BROADCAST_CHANNEL_BIAS) \ - FOR_OP(filter_size, BiasMode::BIAS) +#define INSTANTIATION_CONV_S2_NO_BIAS(filter_size) \ + FOR_OP(filter_size, BiasMode::NO_BIAS) -// vim: syntax=cpp.doxygen \ No newline at end of file +#define INSTANTIATION_CONV_S2_BROADCAST_CHANNEL_BIAS(filter_size) \ + FOR_OP(filter_size, BiasMode::BROADCAST_CHANNEL_BIAS) + +#define INSTANTIATION_CONV_S2_BIAS(filter_size) FOR_OP(filter_size, BiasMode::BIAS) + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s1.cpp b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s1_bias.cpp similarity index 86% rename from dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s1.cpp rename to dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s1_bias.cpp index 291741b9..e82c464e 100644 --- a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s1.cpp +++ b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s1_bias.cpp @@ -1,6 +1,6 @@ /** * \file - * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s1.cpp + * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s1_bias.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -11,4 +11,5 @@ * implied. */ #include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" -INSTANCE_CONV(2, 1); \ No newline at end of file +INSTANCE_CONV_BIAS(2, 1); +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s1_broadcast_channel_bias.cpp b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s1_broadcast_channel_bias.cpp new file mode 100644 index 00000000..fbdb5ec7 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s1_broadcast_channel_bias.cpp @@ -0,0 +1,15 @@ +/** + * \file + * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s1_broadcast_channel_bias.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" +INSTANCE_CONV_BROADCAST_CHANNEL_BIAS(2, 1); +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s1_no_bias.cpp b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s1_no_bias.cpp new file mode 100644 index 00000000..97f2595c --- /dev/null +++ b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s1_no_bias.cpp @@ -0,0 +1,15 @@ +/** + * \file + * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s1_no_bias.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" +INSTANCE_CONV_NO_BIAS(2, 1); +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s2.cpp b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s2_bias.cpp similarity index 86% rename from dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s2.cpp rename to dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s2_bias.cpp index d3f360c5..3acba768 100644 --- a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s2.cpp +++ b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s2_bias.cpp @@ -1,6 +1,6 @@ /** * \file - * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s2.cpp + * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s2_bias.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -11,4 +11,5 @@ * implied. */ #include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" -INSTANCE_CONV(2, 2); +INSTANCE_CONV_BIAS(2, 2); +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s2_broadcast_channel_bias.cpp b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s2_broadcast_channel_bias.cpp new file mode 100644 index 00000000..1e4c7197 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s2_broadcast_channel_bias.cpp @@ -0,0 +1,15 @@ +/** + * \file + * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s2_broadcast_channel_bias.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" +INSTANCE_CONV_BROADCAST_CHANNEL_BIAS(2, 2); +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s2_no_bias.cpp b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s2_no_bias.cpp new file mode 100644 index 00000000..03bc548f --- /dev/null +++ b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s2_no_bias.cpp @@ -0,0 +1,15 @@ +/** + * \file + * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s2_no_bias.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" +INSTANCE_CONV_NO_BIAS(2, 2); +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s1.cpp b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s1_bias.cpp similarity index 86% rename from dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s1.cpp rename to dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s1_bias.cpp index 432708e7..89bc21d5 100644 --- a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s1.cpp +++ b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s1_bias.cpp @@ -1,6 +1,6 @@ /** * \file - * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s1.cpp + * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s1_bias.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -11,4 +11,5 @@ * implied. */ #include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" -INSTANCE_CONV(3, 1); +INSTANCE_CONV_BIAS(3, 1); +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s1_broadcast_channel_bias.cpp b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s1_broadcast_channel_bias.cpp new file mode 100644 index 00000000..fe811030 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s1_broadcast_channel_bias.cpp @@ -0,0 +1,15 @@ +/** + * \file + * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s1_broadcast_channel_bias.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" +INSTANCE_CONV_BROADCAST_CHANNEL_BIAS(3, 1); +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s1_no_bias.cpp b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s1_no_bias.cpp new file mode 100644 index 00000000..88cbe5f8 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s1_no_bias.cpp @@ -0,0 +1,15 @@ +/** + * \file + * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s1_no_bias + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" +INSTANCE_CONV_NO_BIAS(3, 1); +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s2.cpp b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s2_bias.cpp similarity index 86% rename from dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s2.cpp rename to dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s2_bias.cpp index 38ffe8ef..f2f02815 100644 --- a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s2.cpp +++ b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s2_bias.cpp @@ -1,6 +1,6 @@ /** * \file - * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s2.cpp + * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s2_bias.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -11,4 +11,5 @@ * implied. */ #include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" -INSTANCE_CONV(3, 2); +INSTANCE_CONV_BIAS(3, 2); +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s2_broadcast_channel_bias.cpp b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s2_broadcast_channel_bias.cpp new file mode 100644 index 00000000..416e9839 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s2_broadcast_channel_bias.cpp @@ -0,0 +1,15 @@ +/** + * \file + * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s2_broadcast_channel_bias.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" +INSTANCE_CONV_BROADCAST_CHANNEL_BIAS(3, 2); +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s2_no_bias.cpp b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s2_no_bias.cpp new file mode 100644 index 00000000..bf3f792d --- /dev/null +++ b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s2_no_bias.cpp @@ -0,0 +1,15 @@ +/** + * \file + * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s2_no_bias.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" +INSTANCE_CONV_NO_BIAS(3, 2); +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s1_bias.cpp b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s1_bias.cpp new file mode 100644 index 00000000..f4d38c4e --- /dev/null +++ b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s1_bias.cpp @@ -0,0 +1,15 @@ +/** + * \file + * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s1_bias.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" +INSTANCE_CONV_BIAS(5, 1); +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s1_broadcast_channel_bias.cpp b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s1_broadcast_channel_bias.cpp new file mode 100644 index 00000000..1dcb60b0 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s1_broadcast_channel_bias.cpp @@ -0,0 +1,15 @@ +/** + * \file + * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s1_broadcast_channel_bias.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" +INSTANCE_CONV_BROADCAST_CHANNEL_BIAS(5, 1); +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s1_no_bias.cpp b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s1_no_bias.cpp new file mode 100644 index 00000000..e32fbccb --- /dev/null +++ b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s1_no_bias.cpp @@ -0,0 +1,15 @@ +/** + * \file + * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s1_no_bias.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" +INSTANCE_CONV_NO_BIAS(5, 1); +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s2_bias.cpp b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s2_bias.cpp new file mode 100644 index 00000000..a818401f --- /dev/null +++ b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s2_bias.cpp @@ -0,0 +1,15 @@ +/** + * \file + * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s2_bias.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" +INSTANCE_CONV_BIAS(5, 2); +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s2_broadcast_channel_bias.cpp b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s2_broadcast_channel_bias.cpp new file mode 100644 index 00000000..be387827 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s2_broadcast_channel_bias.cpp @@ -0,0 +1,15 @@ +/** + * \file + * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s2_broadcast_channel_bias.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" +INSTANCE_CONV_BROADCAST_CHANNEL_BIAS(5, 2); +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s2_no_bias.cpp b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s2_no_bias.cpp new file mode 100644 index 00000000..64c9db59 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s2_no_bias.cpp @@ -0,0 +1,15 @@ +/** + * \file + * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s2_no_bias.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" +INSTANCE_CONV_NO_BIAS(5, 2); +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s1_bias.cpp b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s1_bias.cpp new file mode 100644 index 00000000..6fb2e117 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s1_bias.cpp @@ -0,0 +1,15 @@ +/** + * \file + * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s1_bias.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" +INSTANCE_CONV_BIAS(7, 1); +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s1_broadcast_channel_bias.cpp b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s1_broadcast_channel_bias.cpp new file mode 100644 index 00000000..74ad5102 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s1_broadcast_channel_bias.cpp @@ -0,0 +1,15 @@ +/** + * \file + * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s1_broadcast_channel_bias.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" +INSTANCE_CONV_BROADCAST_CHANNEL_BIAS(7, 1); +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s1_no_bias.cpp b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s1_no_bias.cpp new file mode 100644 index 00000000..94af0cbd --- /dev/null +++ b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s1_no_bias.cpp @@ -0,0 +1,15 @@ +/** + * \file + * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s1_no_bias.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" +INSTANCE_CONV_NO_BIAS(7, 1); +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s2_bias.cpp b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s2_bias.cpp new file mode 100644 index 00000000..576213bc --- /dev/null +++ b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s2_bias.cpp @@ -0,0 +1,15 @@ +/** + * \file + * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s2_bias.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" +INSTANCE_CONV_BIAS(7, 2); +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s2_broadcast_channel_bias.cpp b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s2_broadcast_channel_bias.cpp new file mode 100644 index 00000000..58890e90 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s2_broadcast_channel_bias.cpp @@ -0,0 +1,15 @@ +/** + * \file + * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s2_broadcast_channel_bias.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" +INSTANCE_CONV_BROADCAST_CHANNEL_BIAS(7, 2); +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s2_no_bias.cpp b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s2_no_bias.cpp new file mode 100644 index 00000000..4a4e0f35 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s2_no_bias.cpp @@ -0,0 +1,15 @@ +/** + * \file + * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s2_no_bias.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" +INSTANCE_CONV_NO_BIAS(7, 2); +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h index 36daabc8..1869f2ff 100644 --- a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h +++ b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h @@ -928,9 +928,11 @@ void fp32_direct_nchw_nchw44::conv_direct_fp32_nchw_nchw44( INSTANTIATION(stride, filter, bias, ReluOp) \ INSTANTIATION(stride, filter, bias, HSwishOp) -#define INSTANCE_CONV(filter, stride) \ - FOR_OP(stride, filter, BiasMode::NO_BIAS) \ - FOR_OP(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS) \ - FOR_OP(stride, filter, BiasMode::BIAS) +#define INSTANCE_CONV_NO_BIAS(filter, stride) FOR_OP(stride, filter, BiasMode::NO_BIAS) + +#define INSTANCE_CONV_BROADCAST_CHANNEL_BIAS(filter, stride) \ + FOR_OP(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS) + +#define INSTANCE_CONV_BIAS(filter, stride) FOR_OP(stride, filter, BiasMode::BIAS) // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s1.cpp b/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s1.h similarity index 97% rename from dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s1.cpp rename to dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s1.h index cc9624f3..48d7bd1d 100644 --- a/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s1.cpp +++ b/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s1.h @@ -1,6 +1,6 @@ /** * \file - * dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s1.cpp + * dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s1.h * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -265,7 +265,8 @@ void conv_direct_sdot_int8_nchw44( #define INSTANTIATION(dst_type, stride, filter_size, bias_mode, Op) \ template void \ - conv_direct_sdot_int8_nchw44( \ + megdnn::arm_common::direct_dotprod_nchw44::conv_direct_sdot_int8_nchw44< \ + dst_type, stride, bias_mode, Op, filter_size>( \ dst_type * dst, const int oh, const int ow, const int8_t* src, \ const int ih, const int iw, const int8_t* weight, const int32_t* bias, \ const int oh_size, const int oc, const int ic, const Op& op); @@ -284,22 +285,6 @@ void conv_direct_sdot_int8_nchw44( FOR_OP(stride, i, BiasMode::NO_BIAS) \ FOR_OP(stride, i, BiasMode::BROADCAST_CHANNEL_BIAS) -#define FOR_FILTER(stride) \ - FOR_BIAS(stride, 2) \ - FOR_BIAS(stride, 3) \ - FOR_BIAS(stride, 5) \ - FOR_BIAS(stride, 7) - -FOR_FILTER(1) - -#undef FOR_STRIDE -#undef FOR_FILTER -#undef FOR_IC -#undef FOR_BIAS -#undef FOR_NONLINEAR -#undef FOR_REMAIN -#undef INSTANTIATION - } // namespace direct_dotprod_nchw44 } // namespace arm_common } // namespace megdnn diff --git a/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s1_2x2.cpp b/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s1_2x2.cpp new file mode 100644 index 00000000..66ae2846 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s1_2x2.cpp @@ -0,0 +1,21 @@ +/** + * \file + * dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s1_2x2.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s1.h" +#if MGB_ENABLE_DOT +using namespace megdnn; +using namespace arm_common; + +FOR_BIAS(1, 2); + +#endif +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s1_3x3.cpp b/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s1_3x3.cpp new file mode 100644 index 00000000..faf8f46c --- /dev/null +++ b/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s1_3x3.cpp @@ -0,0 +1,21 @@ +/** + * \file + * dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s1_3x3.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s1.h" +#if MGB_ENABLE_DOT +using namespace megdnn; +using namespace arm_common; + +FOR_BIAS(1, 3); + +#endif +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s1_5x5.cpp b/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s1_5x5.cpp new file mode 100644 index 00000000..94fe0811 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s1_5x5.cpp @@ -0,0 +1,21 @@ +/** + * \file + * dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s1_5x5.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s1.h" +#if MGB_ENABLE_DOT +using namespace megdnn; +using namespace arm_common; + +FOR_BIAS(1, 5); + +#endif +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s1_7x7.cpp b/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s1_7x7.cpp new file mode 100644 index 00000000..001c58a9 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s1_7x7.cpp @@ -0,0 +1,21 @@ +/** + * \file + * dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s1_7x7.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s1.h" +#if MGB_ENABLE_DOT +using namespace megdnn; +using namespace arm_common; + +FOR_BIAS(1, 7); + +#endif +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s2.cpp b/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s2.h similarity index 97% rename from dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s2.cpp rename to dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s2.h index 841e868f..21ce451b 100644 --- a/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s2.cpp +++ b/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s2.h @@ -1,6 +1,6 @@ /** * \file - * dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s2.cpp + * dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s2.h * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -266,7 +266,8 @@ void conv_direct_sdot_int8_nchw44( #define INSTANTIATION(dst_type, stride, filter_size, bias_mode, Op) \ template void \ - conv_direct_sdot_int8_nchw44( \ + megdnn::arm_common::direct_dotprod_nchw44::conv_direct_sdot_int8_nchw44< \ + dst_type, stride, bias_mode, Op, filter_size>( \ dst_type * dst, const int oh, const int ow, const int8_t* src, \ const int ih, const int iw, const int8_t* weight, const int32_t* bias, \ const int oh_size, const int oc, const int ic, const Op& op); @@ -285,22 +286,6 @@ void conv_direct_sdot_int8_nchw44( FOR_OP(stride, i, BiasMode::NO_BIAS) \ FOR_OP(stride, i, BiasMode::BROADCAST_CHANNEL_BIAS) -#define FOR_FILTER(stride) \ - FOR_BIAS(stride, 2) \ - FOR_BIAS(stride, 3) \ - FOR_BIAS(stride, 5) \ - FOR_BIAS(stride, 7) - -FOR_FILTER(2) - -#undef FOR_STRIDE -#undef FOR_FILTER -#undef FOR_IC -#undef FOR_BIAS -#undef FOR_NONLINEAR -#undef FOR_REMAIN -#undef INSTANTIATION - } // namespace direct_dotprod_nchw44 } // namespace arm_common } // namespace megdnn diff --git a/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s2_2x2.cpp b/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s2_2x2.cpp new file mode 100644 index 00000000..521ce682 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s2_2x2.cpp @@ -0,0 +1,21 @@ +/** + * \file + * dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s2_2x2.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s2.h" +#if MGB_ENABLE_DOT +using namespace megdnn; +using namespace arm_common; + +FOR_BIAS(2, 2); + +#endif +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s2_3x3.cpp b/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s2_3x3.cpp new file mode 100644 index 00000000..d2af6eca --- /dev/null +++ b/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s2_3x3.cpp @@ -0,0 +1,21 @@ +/** + * \file + * dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s2_3x3.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s2.h" +#if MGB_ENABLE_DOT +using namespace megdnn; +using namespace arm_common; + +FOR_BIAS(2, 3); + +#endif +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s2_5x5.cpp b/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s2_5x5.cpp new file mode 100644 index 00000000..949c105e --- /dev/null +++ b/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s2_5x5.cpp @@ -0,0 +1,21 @@ +/** + * \file + * dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s2_5x5.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s2.h" +#if MGB_ENABLE_DOT +using namespace megdnn; +using namespace arm_common; + +FOR_BIAS(2, 5); + +#endif +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s2_7x7.cpp b/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s2_7x7.cpp new file mode 100644 index 00000000..4337a3e5 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s2_7x7.cpp @@ -0,0 +1,21 @@ +/** + * \file + * dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s2_7x7.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s2.h" +#if MGB_ENABLE_DOT +using namespace megdnn; +using namespace arm_common; + +FOR_BIAS(2, 7); + +#endif +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_common.h b/dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_common.h index 74a11e4f..174c48c3 100644 --- a/dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_common.h +++ b/dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_common.h @@ -1,6 +1,6 @@ /** * \file - * dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1.cpp + * dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_common.h * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -45,4 +45,4 @@ public: } // namespace arm_common } // namespace megdnn -// vim: syntax=cpp.doxygen \ No newline at end of file +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1.cpp b/dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1.cpp index f238a1fc..3373cf9a 100644 --- a/dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1.cpp +++ b/dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1.cpp @@ -13,336 +13,9 @@ #include "src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_common.h" #include "src/arm_common/conv_bias/int8/direct_nchw_nchw44_kern.h" + namespace megdnn { namespace arm_common { -namespace { -/** - * @brief core code for calculation patten - * - * @tparam src_idx is offset of src reg - * @tparam weight_idx is offset of weight reg - * @tparam c_dim is output channel - * @tparam Func mla operation funcion - * @tparam stride - * @tparam T outpur regs type - * @tparam T2 src regs type - * @tparam T3 weight regs type - * @tparam T4 temp regs type - */ - -template < - int src_idx, int weight_idx, int c_dim, int stride, typename T, typename T2, - typename T3, typename T4> -struct ShiftCalHelper { - static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight, T4& temp); - static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight); -}; -template < - int src_idx, int weight_idx, int c_dim, int stride, typename T, typename T2, - typename T3, typename T4> -MEGDNN_ALWAYS_INLINE void cal_helper(T& c, T2& src, T3& weight, T4& temp) { - ShiftCalHelper::impl( - c, src, weight, temp); -} -template < - int src_idx, int weight_idx, int c_dim, int stride, typename T, typename T2, - typename T3> -MEGDNN_ALWAYS_INLINE void cal_helper(T& c, T2& src, T3& weight) { - ShiftCalHelper::impl( - c, src, weight); -}; -template < - int src_idx, int weight_idx, typename T, typename T2, typename T3, typename T4> -struct ShiftCalHelper { - static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight, T4& temp) { - c[0][0] = vdotq_s32_h( - src[(0 + src_idx) % 8], weight[0][weight_idx], c[0][0], temp[0]); - c[1][0] = vdotq_s32_h( - src[(0 + src_idx) % 8], weight[1][weight_idx], c[1][0], temp[1]); - c[0][1] = vdotq_s32_h( - src[(1 + src_idx) % 8], weight[0][weight_idx], c[0][1], temp[2]); - c[1][1] = vdotq_s32_h( - src[(1 + src_idx) % 8], weight[1][weight_idx], c[1][1], temp[3]); - c[0][2] = vdotq_s32_h( - src[(2 + src_idx) % 8], weight[0][weight_idx], c[0][2], temp[0]); - c[1][2] = vdotq_s32_h( - src[(2 + src_idx) % 8], weight[1][weight_idx], c[1][2], temp[1]); - c[0][3] = vdotq_s32_h( - src[(3 + src_idx) % 8], weight[0][weight_idx], c[0][3], temp[2]); - c[1][3] = vdotq_s32_h( - src[(3 + src_idx) % 8], weight[1][weight_idx], c[1][3], temp[3]); - - c[0][4] = vdotq_s32_h( - src[(4 + src_idx) % 8], weight[0][weight_idx], c[0][4], temp[0]); - c[1][4] = vdotq_s32_h( - src[(4 + src_idx) % 8], weight[1][weight_idx], c[1][4], temp[1]); - c[0][5] = vdotq_s32_h( - src[(5 + src_idx) % 8], weight[0][weight_idx], c[0][5], temp[2]); - c[1][5] = vdotq_s32_h( - src[(5 + src_idx) % 8], weight[1][weight_idx], c[1][5], temp[3]); - c[0][6] = vdotq_s32_h( - src[(6 + src_idx) % 8], weight[0][weight_idx], c[0][6], temp[0]); - c[1][6] = vdotq_s32_h( - src[(6 + src_idx) % 8], weight[1][weight_idx], c[1][6], temp[1]); - c[0][7] = vdotq_s32_h( - src[(7 + src_idx) % 8], weight[0][weight_idx], c[0][7], temp[2]); - c[1][7] = vdotq_s32_h( - src[(7 + src_idx) % 8], weight[1][weight_idx], c[1][7], temp[3]); - } - static MEGDNN_ALWAYS_INLINE void impl(T&, T2&, T3&); -}; -template < - int src_idx, int weight_idx, typename T, typename T2, typename T3, typename T4> -struct ShiftCalHelper { - static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight, T4& temp) { - c[0][0] = vdotq_s32_h( - src[(0 + src_idx) % 8], weight[0][weight_idx], c[0][0], temp[0]); - c[0][1] = vdotq_s32_h( - src[(1 + src_idx) % 8], weight[0][weight_idx], c[0][1], temp[1]); - c[0][2] = vdotq_s32_h( - src[(2 + src_idx) % 8], weight[0][weight_idx], c[0][2], temp[2]); - c[0][3] = vdotq_s32_h( - src[(3 + src_idx) % 8], weight[0][weight_idx], c[0][3], temp[3]); - c[0][4] = vdotq_s32_h( - src[(4 + src_idx) % 8], weight[0][weight_idx], c[0][4], temp[0]); - c[0][5] = vdotq_s32_h( - src[(5 + src_idx) % 8], weight[0][weight_idx], c[0][5], temp[1]); - c[0][6] = vdotq_s32_h( - src[(6 + src_idx) % 8], weight[0][weight_idx], c[0][6], temp[2]); - c[0][7] = vdotq_s32_h( - src[(7 + src_idx) % 8], weight[0][weight_idx], c[0][7], temp[3]); - } - static MEGDNN_ALWAYS_INLINE void impl(T&, T2&, T3&); -}; - -template -struct KerNeonXXs2NchwNchw44 { - static void impl( - const int8_t* src_ptr, const int8_t* weight_ptr, const int32_t* bias_ptr, - int8_t* dst_ptr, int ic, int ih, int iw, int ld_dst_oc, const Op& op) { - constexpr int stride = 1; - constexpr int filter_height = 1; - constexpr int filter_width = 4; - constexpr int oc_step = 4; - constexpr int loop_ic_step = 1; - constexpr int simd_len = 16; - constexpr int pack_iw_len = 16; - constexpr int src_reg = 8; - constexpr int weight_reg = 1; - - const int ic_stride = ih * iw * pack_iw_len; - const int ld_weight_oc = oc_step * filter_height * filter_width * ic; - constexpr int c_dim = OCHelper::val; - int32x4_t c[c_dim][8]; - init_ocx_ow8(c, bias_ptr, oc_step); - - for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { - const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride; - int8x16_t src[src_reg]; - int8x16_t dot4_weight[c_dim][weight_reg]; - int16x8_t temp_c[4]; - load_helper( - dot4_weight, weight_ptr, ld_weight_oc); - load_helper( - src, nchw_src_ptr + 0 * iw * pack_iw_len, 0); - cal_helper<0, 0, c_dim, stride>(c, src, dot4_weight, temp_c); - - weight_ptr += oc_step * filter_height * filter_width; - } - - store_ocx_ow8_remain_static_dt( - c, op, dst_ptr, ld_dst_oc); - } -}; - -template -struct KerNeonXXs2NchwNchw44 { - static void impl( - const int8_t* src_ptr, const int8_t* weight_ptr, const int32_t* bias_ptr, - int8_t* dst_ptr, int ic, int ih, int iw, int ld_dst_oc, const Op& op) { - constexpr int stride = 1; - constexpr int filter_height = 2; - constexpr int filter_width = 4; - constexpr int oc_step = 4; - constexpr int loop_ic_step = 1; - constexpr int simd_len = 16; - constexpr int pack_iw_len = 16; - constexpr int src_reg = 8; - constexpr int weight_reg = 1; - - const int ic_stride = ih * iw * pack_iw_len; - const int ld_weight_oc = oc_step * filter_height * filter_width * ic; - constexpr int c_dim = OCHelper::val; - int32x4_t c[c_dim][8]; - init_ocx_ow8(c, bias_ptr, oc_step); - - for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { - const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride; - int8x16_t src[src_reg]; - int8x16_t dot4_weight[c_dim][weight_reg]; - int16x8_t temp_c[4]; - load_helper( - dot4_weight, weight_ptr, ld_weight_oc); - load_helper( - src, nchw_src_ptr + 0 * iw * pack_iw_len, 0); - cal_helper<0, 0, c_dim, stride>(c, src, dot4_weight, temp_c); - - load_helper( - dot4_weight, weight_ptr + 1 * filter_width * oc_step, ld_weight_oc); - load_helper( - src, nchw_src_ptr + 1 * iw * pack_iw_len, 0); - cal_helper<0, 0, c_dim, stride>(c, src, dot4_weight, temp_c); - - weight_ptr += oc_step * filter_height * filter_width; - } - - store_ocx_ow8_remain_static_dt( - c, op, dst_ptr, ld_dst_oc); - } -}; - -template -struct KerNeonXXs2NchwNchw44 { - static void impl( - const int8_t* src_ptr, const int8_t* weight_ptr, const int32_t* bias_ptr, - int8_t* dst_ptr, int ic, int ih, int iw, int ld_dst_oc, const Op& op) { - constexpr int stride = 1; - constexpr int filter_height = 3; - constexpr int filter_width = 4; - constexpr int oc_step = 4; - constexpr int loop_ic_step = 1; - constexpr int simd_len = 16; - constexpr int pack_iw_len = 16; - constexpr int src_reg = 8; - constexpr int weight_reg = 1; - - const int ic_stride = ih * iw * pack_iw_len; - const int ld_weight_oc = oc_step * filter_height * filter_width * ic; - constexpr int c_dim = OCHelper::val; - int32x4_t c[c_dim][8]; - init_ocx_ow8(c, bias_ptr, oc_step); - - for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { - const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride; - int8x16_t src[src_reg]; - int8x16_t dot4_weight[c_dim][weight_reg]; - int16x8_t temp_c[4]; - load_helper( - dot4_weight, weight_ptr, ld_weight_oc); - - load_helper( - src, nchw_src_ptr + 0 * iw * pack_iw_len, 0); - cal_helper<0, 0, c_dim, stride>(c, src, dot4_weight, temp_c); - load_helper( - dot4_weight, weight_ptr + 1 * filter_width * oc_step, ld_weight_oc); - - load_helper( - src, nchw_src_ptr + 1 * iw * pack_iw_len, 0); - cal_helper<0, 0, c_dim, stride>(c, src, dot4_weight, temp_c); - - load_helper( - dot4_weight, weight_ptr + 2 * filter_width * oc_step, ld_weight_oc); - load_helper( - src, nchw_src_ptr + 2 * iw * pack_iw_len, 0); - cal_helper<0, 0, c_dim, stride>(c, src, dot4_weight, temp_c); - - weight_ptr += oc_step * filter_height * filter_width; - } - store_ocx_ow8_remain_static_dt( - c, op, dst_ptr, ld_dst_oc); - } -}; - -template -struct KerNeonXXs2NchwNchw44 { - static void impl( - const int8_t* src_ptr, const int8_t* weight_ptr, const int32_t* bias_ptr, - int8_t* dst_ptr, int ic, int ih, int iw, int ld_dst_oc, const Op& op) { - constexpr int stride = 1; - constexpr int filter_height = 5; - constexpr int filter_width = 8; - constexpr int oc_step = 4; - constexpr int loop_ic_step = 1; - constexpr int simd_len = 16; - constexpr int pack_iw_len = 16; - constexpr int src_reg = 8; - constexpr int weight_reg = 2; - - const int ic_stride = ih * iw * pack_iw_len; - const int ld_weight_oc = oc_step * filter_height * filter_width * ic; - constexpr int c_dim = OCHelper::val; - int32x4_t c[c_dim][8]; - init_ocx_ow8(c, bias_ptr, oc_step); - - for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { - const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride; - int8x16_t src[src_reg]; - int8x16_t dot4_weight[c_dim][weight_reg]; - int16x8_t temp_c[4]; -#define cb(step) \ - load_helper( \ - dot4_weight, weight_ptr + step * filter_width * oc_step, ld_weight_oc); \ - load_helper( \ - src, nchw_src_ptr + step * iw * pack_iw_len, 0); \ - cal_helper<0, 0, c_dim, stride>(c, src, dot4_weight, temp_c); \ - load_helper<4, 0, simd_len, 0, Vld1q_s8>( \ - src, nchw_src_ptr + step * iw * pack_iw_len + src_reg * pack_iw_len, 0); \ - cal_helper<4, 1, c_dim, stride>(c, src, dot4_weight, temp_c); - UNROLL_CALL_RAW(5, cb); -#undef cb - weight_ptr += oc_step * filter_height * filter_width; - } - store_ocx_ow8_remain_static_dt( - c, op, dst_ptr, ld_dst_oc); - } -}; - -template -struct KerNeonXXs2NchwNchw44 { - static void impl( - const int8_t* src_ptr, const int8_t* weight_ptr, const int32_t* bias_ptr, - int8_t* dst_ptr, int ic, int ih, int iw, int ld_dst_oc, const Op& op) { - constexpr int stride = 1; - constexpr int filter_height = 7; - constexpr int filter_width = 8; - constexpr int oc_step = 4; - constexpr int loop_ic_step = 1; - constexpr int simd_len = 16; - constexpr int pack_iw_len = 16; - constexpr int src_reg = 8; - constexpr int weight_reg = 2; - - const int ic_stride = ih * iw * pack_iw_len; - const int ld_weight_oc = oc_step * filter_height * filter_width * ic; - constexpr int c_dim = OCHelper::val; - int32x4_t c[c_dim][8]; - init_ocx_ow8(c, bias_ptr, oc_step); - - for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { - const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride; - int8x16_t src[src_reg]; - int8x16_t dot4_weight[c_dim][weight_reg]; - int16x8_t temp_c[4]; -#define cb(step) \ - load_helper( \ - dot4_weight, weight_ptr + step * filter_width * oc_step, ld_weight_oc); \ - load_helper( \ - src, nchw_src_ptr + step * iw * pack_iw_len, 0); \ - cal_helper<0, 0, c_dim, stride>(c, src, dot4_weight, temp_c); \ - load_helper<4, 0, simd_len, 0, Vld1q_s8>( \ - src, nchw_src_ptr + step * iw * pack_iw_len + src_reg * pack_iw_len, 0); \ - cal_helper<4, 1, c_dim, stride>(c, src, dot4_weight, temp_c); - - UNROLL_CALL_RAW(7, cb); -#undef cb - weight_ptr += oc_step * filter_height * filter_width; - } - store_ocx_ow8_remain_static_dt( - c, op, dst_ptr, ld_dst_oc); - } -}; -} // namespace - namespace int8_direct_nchw_nchw44 { /** * pack {oc / 4, fh, fw, ic, 4(oc)} to {oc / 4, ic, fh ,fw/4, 4(oc)*4(fw)} @@ -444,115 +117,9 @@ void pack_nchw_src_for_nchw44_conv<1>( } } -template -struct ConvDiectStrideInt8NchwNchw44 { - static void impl( - const int8_t* src, const int8_t* filter, const int32_t* bias, int32_t* temp, - int8_t* dst, const size_t oc, const size_t ic, const size_t ih, - const size_t iw, const size_t oh, const size_t ow, const Op& op) { - MEGDNN_MARK_USED_VAR(temp); - constexpr int stride = 1; - constexpr size_t fh = filter_size; - constexpr size_t fw = (filter_size + 3) / 4 * 4; - constexpr size_t ic_step = 1; - constexpr size_t big_oc_step = 8; - constexpr size_t oc_step = 4; - constexpr size_t ih_step = 1; - constexpr size_t oh_step = 1; - constexpr size_t ow_step = 8; - constexpr size_t stride_h = stride; - constexpr size_t stride_w = stride; - constexpr int pack_iw_len = 16; - - const size_t img_stride = oh * ow; - const size_t ow_end = ow / ow_step * ow_step; - const size_t ow_remain = ow - ow_end; - const size_t oc_end = oc / big_oc_step * big_oc_step; - const size_t oc_remain = oc - oc_end; - const int ld_dst_oc = oc_step * img_stride; - - using remain_fun = std::function; - remain_fun kern_big_oc_remain = nullptr; - remain_fun kern_small_oc_remain = nullptr; - switch (ow_remain) { -#define cb(step) \ - case step: \ - kern_big_oc_remain = KerNeonXXs2NchwNchw44< \ - bias_mode, Op, step, filter_size, big_oc_step, stride>::impl; \ - kern_small_oc_remain = KerNeonXXs2NchwNchw44< \ - bias_mode, Op, step, filter_size, oc_step, stride>::impl; \ - break; - - UNROLL_CALL_RAW(8, cb); - default: - megdnn_assert(0, "no remain %zu for kern", ow_remain); - } - - for (size_t oc_idx = 0; oc_idx < oc_end; oc_idx += big_oc_step) { - const size_t weight_offset = oc_idx * ic * fh * fw; - for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { - for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { - const size_t src_offset = - (oh_idx * stride_h * iw + ow_idx * stride_w * ih_step) * - ic_step * pack_iw_len; - const size_t dst_offset = - oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; - - KerNeonXXs2NchwNchw44< - bias_mode, Op, ow_step, filter_size, big_oc_step, stride>:: - impl(src + src_offset, filter + weight_offset, - bias + oc_idx, dst + dst_offset, ic, ih, iw, ld_dst_oc, - op); - } - if (ow_remain > 0) { - const size_t src_offset = - (oh_idx * stride_h * iw + ow_end * stride_w * ih_step) * - ic_step * pack_iw_len; - const size_t dst_offset = - oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; - kern_big_oc_remain( - src + src_offset, filter + weight_offset, bias + oc_idx, - dst + dst_offset, ic, ih, iw, ld_dst_oc, op); - } - } - } - - if (oc_remain > 0) { - size_t oc_idx = oc_end; - const size_t weight_offset = oc_idx * ic * fh * fw; - for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { - for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { - const size_t src_offset = - (oh_idx * stride_h * iw + ow_idx * stride_w * ih_step) * - ic_step * pack_iw_len; - const size_t dst_offset = - oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; - KerNeonXXs2NchwNchw44< - bias_mode, Op, ow_step, filter_size, oc_step, stride>:: - impl(src + src_offset, filter + weight_offset, - bias + oc_idx, dst + dst_offset, ic, ih, iw, ld_dst_oc, - op); - } - if (ow_remain > 0) { - const size_t src_offset = - (oh_idx * stride_h * iw + ow_end * stride_w * ih_step) * - ic_step * pack_iw_len; - const size_t dst_offset = - oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; - kern_small_oc_remain( - src + src_offset, filter + weight_offset, bias + oc_idx, - dst + dst_offset, ic, ih, iw, ld_dst_oc, op); - } - } - } - } -}; - #define INSTANCE_CONV_KERN_FUN(stride, filter_size, bias_mode, Op) \ - template struct ConvDiectStrideInt8NchwNchw44; + template struct megdnn::arm_common::int8_direct_nchw_nchw44:: \ + ConvDiectStrideInt8NchwNchw44; #define INSTANCE_OP_PARAM(stride, filter, bias_mode) \ INSTANCE_CONV_KERN_FUN( \ @@ -566,17 +133,10 @@ struct ConvDiectStrideInt8NchwNchw44 { INSTANCE_OP_PARAM(stride, filter, BiasMode::NO_BIAS) \ INSTANCE_OP_PARAM(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS) -#define INSTANCE_CONV_KERN(stride) \ - INSTANCE_BIAS_MODE_PARAM(stride, 1) \ - INSTANCE_BIAS_MODE_PARAM(stride, 2) \ - INSTANCE_BIAS_MODE_PARAM(stride, 3) \ - INSTANCE_BIAS_MODE_PARAM(stride, 5) \ - INSTANCE_BIAS_MODE_PARAM(stride, 7) - -INSTANCE_CONV_KERN(1); +#define INSTANCE_CONV_KERN(stride, filter) INSTANCE_BIAS_MODE_PARAM(stride, filter) } // namespace int8_direct_nchw_nchw44 } // namespace arm_common } // namespace megdnn -// vim: syntax=cpp.doxygen \ No newline at end of file +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1.h b/dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1.h new file mode 100644 index 00000000..4ae7f4f7 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1.h @@ -0,0 +1,481 @@ +/** + * \file + * dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_common.h" +#include "src/arm_common/conv_bias/int8/direct_nchw_nchw44_kern.h" + +namespace megdnn { +namespace arm_common { +namespace { +/** + * @brief core code for calculation patten + * + * @tparam src_idx is offset of src reg + * @tparam weight_idx is offset of weight reg + * @tparam c_dim is output channel + * @tparam Func mla operation funcion + * @tparam stride + * @tparam T outpur regs type + * @tparam T2 src regs type + * @tparam T3 weight regs type + * @tparam T4 temp regs type + */ + +template < + int src_idx, int weight_idx, int c_dim, int stride, typename T, typename T2, + typename T3, typename T4> +struct ShiftCalHelper { + static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight, T4& temp); + static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight); +}; +template < + int src_idx, int weight_idx, int c_dim, int stride, typename T, typename T2, + typename T3, typename T4> +MEGDNN_ALWAYS_INLINE void cal_helper(T& c, T2& src, T3& weight, T4& temp) { + ShiftCalHelper::impl( + c, src, weight, temp); +} +template < + int src_idx, int weight_idx, int c_dim, int stride, typename T, typename T2, + typename T3> +MEGDNN_ALWAYS_INLINE void cal_helper(T& c, T2& src, T3& weight) { + ShiftCalHelper::impl( + c, src, weight); +}; +template < + int src_idx, int weight_idx, typename T, typename T2, typename T3, typename T4> +struct ShiftCalHelper { + static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight, T4& temp) { + c[0][0] = vdotq_s32_h( + src[(0 + src_idx) % 8], weight[0][weight_idx], c[0][0], temp[0]); + c[1][0] = vdotq_s32_h( + src[(0 + src_idx) % 8], weight[1][weight_idx], c[1][0], temp[1]); + c[0][1] = vdotq_s32_h( + src[(1 + src_idx) % 8], weight[0][weight_idx], c[0][1], temp[2]); + c[1][1] = vdotq_s32_h( + src[(1 + src_idx) % 8], weight[1][weight_idx], c[1][1], temp[3]); + c[0][2] = vdotq_s32_h( + src[(2 + src_idx) % 8], weight[0][weight_idx], c[0][2], temp[0]); + c[1][2] = vdotq_s32_h( + src[(2 + src_idx) % 8], weight[1][weight_idx], c[1][2], temp[1]); + c[0][3] = vdotq_s32_h( + src[(3 + src_idx) % 8], weight[0][weight_idx], c[0][3], temp[2]); + c[1][3] = vdotq_s32_h( + src[(3 + src_idx) % 8], weight[1][weight_idx], c[1][3], temp[3]); + + c[0][4] = vdotq_s32_h( + src[(4 + src_idx) % 8], weight[0][weight_idx], c[0][4], temp[0]); + c[1][4] = vdotq_s32_h( + src[(4 + src_idx) % 8], weight[1][weight_idx], c[1][4], temp[1]); + c[0][5] = vdotq_s32_h( + src[(5 + src_idx) % 8], weight[0][weight_idx], c[0][5], temp[2]); + c[1][5] = vdotq_s32_h( + src[(5 + src_idx) % 8], weight[1][weight_idx], c[1][5], temp[3]); + c[0][6] = vdotq_s32_h( + src[(6 + src_idx) % 8], weight[0][weight_idx], c[0][6], temp[0]); + c[1][6] = vdotq_s32_h( + src[(6 + src_idx) % 8], weight[1][weight_idx], c[1][6], temp[1]); + c[0][7] = vdotq_s32_h( + src[(7 + src_idx) % 8], weight[0][weight_idx], c[0][7], temp[2]); + c[1][7] = vdotq_s32_h( + src[(7 + src_idx) % 8], weight[1][weight_idx], c[1][7], temp[3]); + } + static MEGDNN_ALWAYS_INLINE void impl(T&, T2&, T3&); +}; +template < + int src_idx, int weight_idx, typename T, typename T2, typename T3, typename T4> +struct ShiftCalHelper { + static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight, T4& temp) { + c[0][0] = vdotq_s32_h( + src[(0 + src_idx) % 8], weight[0][weight_idx], c[0][0], temp[0]); + c[0][1] = vdotq_s32_h( + src[(1 + src_idx) % 8], weight[0][weight_idx], c[0][1], temp[1]); + c[0][2] = vdotq_s32_h( + src[(2 + src_idx) % 8], weight[0][weight_idx], c[0][2], temp[2]); + c[0][3] = vdotq_s32_h( + src[(3 + src_idx) % 8], weight[0][weight_idx], c[0][3], temp[3]); + c[0][4] = vdotq_s32_h( + src[(4 + src_idx) % 8], weight[0][weight_idx], c[0][4], temp[0]); + c[0][5] = vdotq_s32_h( + src[(5 + src_idx) % 8], weight[0][weight_idx], c[0][5], temp[1]); + c[0][6] = vdotq_s32_h( + src[(6 + src_idx) % 8], weight[0][weight_idx], c[0][6], temp[2]); + c[0][7] = vdotq_s32_h( + src[(7 + src_idx) % 8], weight[0][weight_idx], c[0][7], temp[3]); + } + static MEGDNN_ALWAYS_INLINE void impl(T&, T2&, T3&); +}; + +template +struct KerNeonXXs2NchwNchw44 { + static void impl( + const int8_t* src_ptr, const int8_t* weight_ptr, const int32_t* bias_ptr, + int8_t* dst_ptr, int ic, int ih, int iw, int ld_dst_oc, const Op& op) { + constexpr int stride = 1; + constexpr int filter_height = 1; + constexpr int filter_width = 4; + constexpr int oc_step = 4; + constexpr int loop_ic_step = 1; + constexpr int simd_len = 16; + constexpr int pack_iw_len = 16; + constexpr int src_reg = 8; + constexpr int weight_reg = 1; + + const int ic_stride = ih * iw * pack_iw_len; + const int ld_weight_oc = oc_step * filter_height * filter_width * ic; + constexpr int c_dim = OCHelper::val; + int32x4_t c[c_dim][8]; + init_ocx_ow8(c, bias_ptr, oc_step); + + for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { + const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride; + int8x16_t src[src_reg]; + int8x16_t dot4_weight[c_dim][weight_reg]; + int16x8_t temp_c[4]; + load_helper( + dot4_weight, weight_ptr, ld_weight_oc); + load_helper( + src, nchw_src_ptr + 0 * iw * pack_iw_len, 0); + cal_helper<0, 0, c_dim, stride>(c, src, dot4_weight, temp_c); + + weight_ptr += oc_step * filter_height * filter_width; + } + + store_ocx_ow8_remain_static_dt( + c, op, dst_ptr, ld_dst_oc); + } +}; + +template +struct KerNeonXXs2NchwNchw44 { + static void impl( + const int8_t* src_ptr, const int8_t* weight_ptr, const int32_t* bias_ptr, + int8_t* dst_ptr, int ic, int ih, int iw, int ld_dst_oc, const Op& op) { + constexpr int stride = 1; + constexpr int filter_height = 2; + constexpr int filter_width = 4; + constexpr int oc_step = 4; + constexpr int loop_ic_step = 1; + constexpr int simd_len = 16; + constexpr int pack_iw_len = 16; + constexpr int src_reg = 8; + constexpr int weight_reg = 1; + + const int ic_stride = ih * iw * pack_iw_len; + const int ld_weight_oc = oc_step * filter_height * filter_width * ic; + constexpr int c_dim = OCHelper::val; + int32x4_t c[c_dim][8]; + init_ocx_ow8(c, bias_ptr, oc_step); + + for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { + const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride; + int8x16_t src[src_reg]; + int8x16_t dot4_weight[c_dim][weight_reg]; + int16x8_t temp_c[4]; + load_helper( + dot4_weight, weight_ptr, ld_weight_oc); + load_helper( + src, nchw_src_ptr + 0 * iw * pack_iw_len, 0); + cal_helper<0, 0, c_dim, stride>(c, src, dot4_weight, temp_c); + + load_helper( + dot4_weight, weight_ptr + 1 * filter_width * oc_step, ld_weight_oc); + load_helper( + src, nchw_src_ptr + 1 * iw * pack_iw_len, 0); + cal_helper<0, 0, c_dim, stride>(c, src, dot4_weight, temp_c); + + weight_ptr += oc_step * filter_height * filter_width; + } + + store_ocx_ow8_remain_static_dt( + c, op, dst_ptr, ld_dst_oc); + } +}; + +template +struct KerNeonXXs2NchwNchw44 { + static void impl( + const int8_t* src_ptr, const int8_t* weight_ptr, const int32_t* bias_ptr, + int8_t* dst_ptr, int ic, int ih, int iw, int ld_dst_oc, const Op& op) { + constexpr int stride = 1; + constexpr int filter_height = 3; + constexpr int filter_width = 4; + constexpr int oc_step = 4; + constexpr int loop_ic_step = 1; + constexpr int simd_len = 16; + constexpr int pack_iw_len = 16; + constexpr int src_reg = 8; + constexpr int weight_reg = 1; + + const int ic_stride = ih * iw * pack_iw_len; + const int ld_weight_oc = oc_step * filter_height * filter_width * ic; + constexpr int c_dim = OCHelper::val; + int32x4_t c[c_dim][8]; + init_ocx_ow8(c, bias_ptr, oc_step); + + for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { + const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride; + int8x16_t src[src_reg]; + int8x16_t dot4_weight[c_dim][weight_reg]; + int16x8_t temp_c[4]; + load_helper( + dot4_weight, weight_ptr, ld_weight_oc); + + load_helper( + src, nchw_src_ptr + 0 * iw * pack_iw_len, 0); + cal_helper<0, 0, c_dim, stride>(c, src, dot4_weight, temp_c); + load_helper( + dot4_weight, weight_ptr + 1 * filter_width * oc_step, ld_weight_oc); + + load_helper( + src, nchw_src_ptr + 1 * iw * pack_iw_len, 0); + cal_helper<0, 0, c_dim, stride>(c, src, dot4_weight, temp_c); + + load_helper( + dot4_weight, weight_ptr + 2 * filter_width * oc_step, ld_weight_oc); + load_helper( + src, nchw_src_ptr + 2 * iw * pack_iw_len, 0); + cal_helper<0, 0, c_dim, stride>(c, src, dot4_weight, temp_c); + + weight_ptr += oc_step * filter_height * filter_width; + } + store_ocx_ow8_remain_static_dt( + c, op, dst_ptr, ld_dst_oc); + } +}; + +template +struct KerNeonXXs2NchwNchw44 { + static void impl( + const int8_t* src_ptr, const int8_t* weight_ptr, const int32_t* bias_ptr, + int8_t* dst_ptr, int ic, int ih, int iw, int ld_dst_oc, const Op& op) { + constexpr int stride = 1; + constexpr int filter_height = 5; + constexpr int filter_width = 8; + constexpr int oc_step = 4; + constexpr int loop_ic_step = 1; + constexpr int simd_len = 16; + constexpr int pack_iw_len = 16; + constexpr int src_reg = 8; + constexpr int weight_reg = 2; + + const int ic_stride = ih * iw * pack_iw_len; + const int ld_weight_oc = oc_step * filter_height * filter_width * ic; + constexpr int c_dim = OCHelper::val; + int32x4_t c[c_dim][8]; + init_ocx_ow8(c, bias_ptr, oc_step); + + for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { + const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride; + int8x16_t src[src_reg]; + int8x16_t dot4_weight[c_dim][weight_reg]; + int16x8_t temp_c[4]; +#define cb(step) \ + load_helper( \ + dot4_weight, weight_ptr + step * filter_width * oc_step, ld_weight_oc); \ + load_helper( \ + src, nchw_src_ptr + step * iw * pack_iw_len, 0); \ + cal_helper<0, 0, c_dim, stride>(c, src, dot4_weight, temp_c); \ + load_helper<4, 0, simd_len, 0, Vld1q_s8>( \ + src, nchw_src_ptr + step * iw * pack_iw_len + src_reg * pack_iw_len, 0); \ + cal_helper<4, 1, c_dim, stride>(c, src, dot4_weight, temp_c); + UNROLL_CALL_RAW(5, cb); +#undef cb + weight_ptr += oc_step * filter_height * filter_width; + } + store_ocx_ow8_remain_static_dt( + c, op, dst_ptr, ld_dst_oc); + } +}; + +template +struct KerNeonXXs2NchwNchw44 { + static void impl( + const int8_t* src_ptr, const int8_t* weight_ptr, const int32_t* bias_ptr, + int8_t* dst_ptr, int ic, int ih, int iw, int ld_dst_oc, const Op& op) { + constexpr int stride = 1; + constexpr int filter_height = 7; + constexpr int filter_width = 8; + constexpr int oc_step = 4; + constexpr int loop_ic_step = 1; + constexpr int simd_len = 16; + constexpr int pack_iw_len = 16; + constexpr int src_reg = 8; + constexpr int weight_reg = 2; + + const int ic_stride = ih * iw * pack_iw_len; + const int ld_weight_oc = oc_step * filter_height * filter_width * ic; + constexpr int c_dim = OCHelper::val; + int32x4_t c[c_dim][8]; + init_ocx_ow8(c, bias_ptr, oc_step); + + for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { + const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride; + int8x16_t src[src_reg]; + int8x16_t dot4_weight[c_dim][weight_reg]; + int16x8_t temp_c[4]; +#define cb(step) \ + load_helper( \ + dot4_weight, weight_ptr + step * filter_width * oc_step, ld_weight_oc); \ + load_helper( \ + src, nchw_src_ptr + step * iw * pack_iw_len, 0); \ + cal_helper<0, 0, c_dim, stride>(c, src, dot4_weight, temp_c); \ + load_helper<4, 0, simd_len, 0, Vld1q_s8>( \ + src, nchw_src_ptr + step * iw * pack_iw_len + src_reg * pack_iw_len, 0); \ + cal_helper<4, 1, c_dim, stride>(c, src, dot4_weight, temp_c); + + UNROLL_CALL_RAW(7, cb); +#undef cb + weight_ptr += oc_step * filter_height * filter_width; + } + store_ocx_ow8_remain_static_dt( + c, op, dst_ptr, ld_dst_oc); + } +}; +} // namespace + +namespace int8_direct_nchw_nchw44 { +/** + * pack {oc / 4, fh, fw, ic, 4(oc)} to {oc / 4, ic, fh ,fw/4, 4(oc)*4(fw)} + * pack interleave two adjacent row in filter to one row + * */ +template +struct ConvDiectStrideInt8NchwNchw44 { + static void impl( + const int8_t* src, const int8_t* filter, const int32_t* bias, int32_t* temp, + int8_t* dst, const size_t oc, const size_t ic, const size_t ih, + const size_t iw, const size_t oh, const size_t ow, const Op& op) { + MEGDNN_MARK_USED_VAR(temp); + constexpr int stride = 1; + constexpr size_t fh = filter_size; + constexpr size_t fw = (filter_size + 3) / 4 * 4; + constexpr size_t ic_step = 1; + constexpr size_t big_oc_step = 8; + constexpr size_t oc_step = 4; + constexpr size_t ih_step = 1; + constexpr size_t oh_step = 1; + constexpr size_t ow_step = 8; + constexpr size_t stride_h = stride; + constexpr size_t stride_w = stride; + constexpr int pack_iw_len = 16; + + const size_t img_stride = oh * ow; + const size_t ow_end = ow / ow_step * ow_step; + const size_t ow_remain = ow - ow_end; + const size_t oc_end = oc / big_oc_step * big_oc_step; + const size_t oc_remain = oc - oc_end; + const int ld_dst_oc = oc_step * img_stride; + + using remain_fun = std::function; + remain_fun kern_big_oc_remain = nullptr; + remain_fun kern_small_oc_remain = nullptr; + switch (ow_remain) { +#define cb(step) \ + case step: \ + kern_big_oc_remain = KerNeonXXs2NchwNchw44< \ + bias_mode, Op, step, filter_size, big_oc_step, stride>::impl; \ + kern_small_oc_remain = KerNeonXXs2NchwNchw44< \ + bias_mode, Op, step, filter_size, oc_step, stride>::impl; \ + break; + + UNROLL_CALL_RAW(8, cb); + default: + megdnn_assert(0, "no remain %zu for kern", ow_remain); + } + + for (size_t oc_idx = 0; oc_idx < oc_end; oc_idx += big_oc_step) { + const size_t weight_offset = oc_idx * ic * fh * fw; + for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { + for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { + const size_t src_offset = + (oh_idx * stride_h * iw + ow_idx * stride_w * ih_step) * + ic_step * pack_iw_len; + const size_t dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; + + KerNeonXXs2NchwNchw44< + bias_mode, Op, ow_step, filter_size, big_oc_step, stride>:: + impl(src + src_offset, filter + weight_offset, + bias + oc_idx, dst + dst_offset, ic, ih, iw, ld_dst_oc, + op); + } + if (ow_remain > 0) { + const size_t src_offset = + (oh_idx * stride_h * iw + ow_end * stride_w * ih_step) * + ic_step * pack_iw_len; + const size_t dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; + kern_big_oc_remain( + src + src_offset, filter + weight_offset, bias + oc_idx, + dst + dst_offset, ic, ih, iw, ld_dst_oc, op); + } + } + } + + if (oc_remain > 0) { + size_t oc_idx = oc_end; + const size_t weight_offset = oc_idx * ic * fh * fw; + for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { + for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { + const size_t src_offset = + (oh_idx * stride_h * iw + ow_idx * stride_w * ih_step) * + ic_step * pack_iw_len; + const size_t dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; + KerNeonXXs2NchwNchw44< + bias_mode, Op, ow_step, filter_size, oc_step, stride>:: + impl(src + src_offset, filter + weight_offset, + bias + oc_idx, dst + dst_offset, ic, ih, iw, ld_dst_oc, + op); + } + if (ow_remain > 0) { + const size_t src_offset = + (oh_idx * stride_h * iw + ow_end * stride_w * ih_step) * + ic_step * pack_iw_len; + const size_t dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; + kern_small_oc_remain( + src + src_offset, filter + weight_offset, bias + oc_idx, + dst + dst_offset, ic, ih, iw, ld_dst_oc, op); + } + } + } + } +}; + +#define INSTANCE_CONV_KERN_FUN(stride, filter_size, bias_mode, Op) \ + template struct megdnn::arm_common::int8_direct_nchw_nchw44:: \ + ConvDiectStrideInt8NchwNchw44; + +#define INSTANCE_OP_PARAM(stride, filter, bias_mode) \ + INSTANCE_CONV_KERN_FUN( \ + stride, filter, bias_mode, TypeCvtOp) \ + INSTANCE_CONV_KERN_FUN( \ + stride, filter, bias_mode, ReluOp) \ + INSTANCE_CONV_KERN_FUN( \ + stride, filter, bias_mode, HSwishOp) + +#define INSTANCE_BIAS_MODE_PARAM(stride, filter) \ + INSTANCE_OP_PARAM(stride, filter, BiasMode::NO_BIAS) \ + INSTANCE_OP_PARAM(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS) + +#define INSTANCE_CONV_KERN(stride, filter) INSTANCE_BIAS_MODE_PARAM(stride, filter) + +} // namespace int8_direct_nchw_nchw44 +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1_1x1.cpp b/dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1_1x1.cpp new file mode 100644 index 00000000..14763c96 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1_1x1.cpp @@ -0,0 +1,19 @@ +/** + * \file + * dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1_1x1.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1.h" +using namespace megdnn; +using namespace arm_common; + +INSTANCE_CONV_KERN(1, 1); + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1_2x2.cpp b/dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1_2x2.cpp new file mode 100644 index 00000000..10d46268 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1_2x2.cpp @@ -0,0 +1,19 @@ +/** + * \file + * dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1_2x2.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1.h" +using namespace megdnn; +using namespace arm_common; + +INSTANCE_CONV_KERN(1, 2); + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1_3x3.cpp b/dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1_3x3.cpp new file mode 100644 index 00000000..87553278 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1_3x3.cpp @@ -0,0 +1,19 @@ +/** + * \file + * dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1_3x3.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1.h" +using namespace megdnn; +using namespace arm_common; + +INSTANCE_CONV_KERN(1, 3); + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1_5x5.cpp b/dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1_5x5.cpp new file mode 100644 index 00000000..d7deb345 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1_5x5.cpp @@ -0,0 +1,19 @@ +/** + * \file + * dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1_5x5.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1.h" +using namespace megdnn; +using namespace arm_common; + +INSTANCE_CONV_KERN(1, 5); + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1_7x7.cpp b/dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1_7x7.cpp new file mode 100644 index 00000000..37cb7679 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1_7x7.cpp @@ -0,0 +1,19 @@ +/** + * \file + * dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1_7x7.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1.h" +using namespace megdnn; +using namespace arm_common; + +INSTANCE_CONV_KERN(1, 7); + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl_s1.cpp b/dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl_2x2s1.cpp similarity index 83% rename from dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl_s1.cpp rename to dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl_2x2s1.cpp index 8b959f54..e6b8576d 100644 --- a/dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl_s1.cpp +++ b/dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl_2x2s1.cpp @@ -1,6 +1,6 @@ /** * \file - * dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl_s1.cpp + * dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl_2x2s1.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -12,8 +12,5 @@ */ #include "src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl.h" INSTANCE_CONV(2, 1); -INSTANCE_CONV(3, 1); -INSTANCE_CONV(5, 1); -INSTANCE_CONV(7, 1); -// vim: syntax=cpp.doxygen \ No newline at end of file +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl_s2.cpp b/dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl_2x2s2.cpp similarity index 83% rename from dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl_s2.cpp rename to dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl_2x2s2.cpp index 050a7df8..7f78d864 100644 --- a/dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl_s2.cpp +++ b/dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl_2x2s2.cpp @@ -1,6 +1,6 @@ /** * \file - * dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl_s2.cpp + * dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl_2x2s2.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -12,8 +12,5 @@ */ #include "src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl.h" INSTANCE_CONV(2, 2); -INSTANCE_CONV(3, 2); -INSTANCE_CONV(5, 2); -INSTANCE_CONV(7, 2); -// vim: syntax=cpp.doxygen \ No newline at end of file +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl_3x3s1.cpp b/dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl_3x3s1.cpp new file mode 100644 index 00000000..a970478d --- /dev/null +++ b/dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl_3x3s1.cpp @@ -0,0 +1,16 @@ +/** + * \file + * dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl_3x3s1.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl.h" +INSTANCE_CONV(3, 1); + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl_3x3s2.cpp b/dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl_3x3s2.cpp new file mode 100644 index 00000000..532351d8 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl_3x3s2.cpp @@ -0,0 +1,16 @@ +/** + * \file + * dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl_3x3s2.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl.h" +INSTANCE_CONV(3, 2); + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s1.cpp b/dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl_5x5s1.cpp similarity index 66% rename from dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s1.cpp rename to dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl_5x5s1.cpp index 2bd616d6..ffe5decc 100644 --- a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s1.cpp +++ b/dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl_5x5s1.cpp @@ -1,6 +1,6 @@ /** * \file - * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s1.cpp + * dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl_5x5s1.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -10,5 +10,7 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. */ -#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" +#include "src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl.h" INSTANCE_CONV(5, 1); + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s2.cpp b/dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl_5x5s2.cpp similarity index 66% rename from dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s2.cpp rename to dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl_5x5s2.cpp index 8433d0de..7a64dbe4 100644 --- a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s2.cpp +++ b/dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl_5x5s2.cpp @@ -1,6 +1,6 @@ /** * \file - * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s2.cpp + * dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl_5x5s2.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -10,5 +10,7 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. */ -#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" +#include "src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl.h" INSTANCE_CONV(5, 2); + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s1.cpp b/dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl_7x7s1.cpp similarity index 66% rename from dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s1.cpp rename to dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl_7x7s1.cpp index deb839a8..154903f9 100644 --- a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s1.cpp +++ b/dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl_7x7s1.cpp @@ -1,6 +1,6 @@ /** * \file - * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s1.cpp + * dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl_7x7s1.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -10,5 +10,7 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. */ -#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" +#include "src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl.h" INSTANCE_CONV(7, 1); + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s2.cpp b/dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl_7x7s2.cpp similarity index 66% rename from dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s2.cpp rename to dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl_7x7s2.cpp index c0a18167..83b9e21b 100644 --- a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s2.cpp +++ b/dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl_7x7s2.cpp @@ -1,6 +1,6 @@ /** * \file - * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s2.cpp + * dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl_7x7s2.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -10,5 +10,7 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. */ -#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" +#include "src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl.h" INSTANCE_CONV(7, 2); + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/fallback/elemwise/opr_binary_impl.cpp b/dnn/src/fallback/elemwise/opr_binary_impl.cpp new file mode 100644 index 00000000..9acda94e --- /dev/null +++ b/dnn/src/fallback/elemwise/opr_binary_impl.cpp @@ -0,0 +1,275 @@ +/** + * \file dnn/src/fallback/elemwise/opr_binary_impl.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "./opr_impl.h" + +#include "src/common/elemwise/kern_defs.cuh" +#include "src/common/utils.h" +#include "src/naive/handle.h" + +#include "midout.h" + +MIDOUT_DECL(megdnn_fallback_elemwise_binary) + +namespace megdnn { +namespace fallback { + +template +void ElemwiseImpl::binary_kern(const ElemwiseOpParamN<2>& param) { + using ctype = typename DTypeTrait::ctype; + using Kern = ElemwiseKern; + + MIDOUT_BEGIN(megdnn_fallback_elemwise_binary, ctype, midout_iv(mode)) { + if (param.max_ndim == 1) { + MIDOUT_BEGIN( + megdnn_fallback_elemwise_binary, ctype, midout_iv(mode), + midout_iv(1)) { + auto tot = param.size; + auto as = param[0].layout.stride[0], bs = param[1].layout.stride[0]; + auto src0 = param[0]; + auto src1 = param[1]; + auto dst_tensor = *m_dst; + + MEGDNN_DISPATCH_CPU_KERN_OPR({ + ctype* __restrict a = static_cast(src0.raw_ptr()); + ctype* __restrict b = static_cast(src1.raw_ptr()); + ctype* __restrict dst = static_cast(dst_tensor.raw_ptr()); + for (size_t i = 0; i < tot; ++i) { + dst[i] = Kern::apply(a[i * as], b[i * bs]); + } + }); + return; + } + MIDOUT_END(); + } + + if (std::min(param[0].layout.ndim, param[1].layout.ndim) > 1) { + return naive::ElemwiseForwardImpl::exec(*m_src, *m_dst); + } + + if (param.max_ndim == 2) { + if (param[0].layout.ndim == 1) { + MIDOUT_BEGIN( + megdnn_fallback_elemwise_binary, ctype, midout_iv(mode), + midout_iv(21)) { + auto as = param[0].layout.stride[0], + bs0 = param[1].layout.stride[0], + bs1 = param[1].layout.stride[1]; + auto n0 = param[1].layout.shape[0], n1 = param[1].layout.shape[1]; + auto src0 = param[0]; + auto src1 = param[1]; + auto dst_tensor = *m_dst; + + MEGDNN_DISPATCH_CPU_KERN_OPR({ + ctype* __restrict a = static_cast(src0.raw_ptr()); + ctype* __restrict b = static_cast(src1.raw_ptr()); + ctype* __restrict dst = + static_cast(dst_tensor.raw_ptr()); + ptrdiff_t toff = 0; + for (size_t i = 0; i < n0; ++i) { + for (size_t j = 0; j < n1; ++j) { + dst[toff] = + Kern::apply(a[as * toff], b[bs0 * i + bs1 * j]); + ++toff; + } + } + }); + return; + } + MIDOUT_END(); + } + + MIDOUT_BEGIN( + megdnn_fallback_elemwise_binary, ctype, midout_iv(mode), + midout_iv(22)) { + megdnn_assert(param[1].layout.ndim == 1); + auto bs = param[1].layout.stride[0], as0 = param[0].layout.stride[0], + as1 = param[0].layout.stride[1]; + auto n0 = param[0].layout.shape[0], n1 = param[0].layout.shape[1]; + auto src0 = param[0]; + auto src1 = param[1]; + auto dst_tensor = *m_dst; + + MEGDNN_DISPATCH_CPU_KERN_OPR({ + ctype* __restrict a = static_cast(src0.raw_ptr()); + ctype* __restrict b = static_cast(src1.raw_ptr()); + ctype* __restrict dst = static_cast(dst_tensor.raw_ptr()); + ptrdiff_t toff = 0; + for (size_t i = 0; i < n0; ++i) { + for (size_t j = 0; j < n1; ++j) { + dst[toff] = Kern::apply(a[as0 * i + as1 * j], b[toff * bs]); + ++toff; + } + } + }); + return; + } + MIDOUT_END(); + } + + if (param.max_ndim == 3) { + auto brd_101 = [](const TensorND& t) { + auto&& l = t.layout; + return l.ndim == 3 && l.stride[0] == 0 && l.stride[2] == 0; + }; + if (param[0].layout.ndim == 1 && brd_101(param[1])) { + MIDOUT_BEGIN( + megdnn_fallback_elemwise_binary, ctype, midout_iv(mode), + midout_iv(31)) { + auto as = param[0].layout.stride[0], bs = param[1].layout.stride[1]; + auto n0 = param[1].layout.shape[0], n1 = param[1].layout.shape[1], + n2 = param[1].layout.shape[2]; + auto src0 = param[0]; + auto src1 = param[1]; + auto dst_tensor = *m_dst; + + MEGDNN_DISPATCH_CPU_KERN_OPR({ + ctype* __restrict a = static_cast(src0.raw_ptr()); + ctype* __restrict b = static_cast(src1.raw_ptr()); + ctype* __restrict dst = + static_cast(dst_tensor.raw_ptr()); + size_t toff = 0; + for (size_t i = 0; i < n0; ++i) { + for (size_t j = 0; j < n1; ++j) { + for (size_t k = 0; k < n2; ++k) { + dst[toff] = Kern::apply(a[as * toff], b[bs * j]); + ++toff; + } + } + } + }); + return; + } + MIDOUT_END(); + } + if (param[1].layout.ndim == 1 && brd_101(param[0])) { + MIDOUT_BEGIN( + megdnn_fallback_elemwise_binary, ctype, midout_iv(mode), + midout_iv(32)) { + auto as = param[0].layout.stride[1], bs = param[1].layout.stride[0]; + auto n0 = param[0].layout.shape[0], n1 = param[0].layout.shape[1], + n2 = param[0].layout.shape[2]; + auto src0 = param[0]; + auto src1 = param[1]; + auto dst_tensor = *m_dst; + MEGDNN_DISPATCH_CPU_KERN_OPR({ + ctype* __restrict a = static_cast(src0.raw_ptr()); + ctype* __restrict b = static_cast(src1.raw_ptr()); + ctype* __restrict dst = + static_cast(dst_tensor.raw_ptr()); + size_t toff = 0; + for (size_t i = 0; i < n0; ++i) { + for (size_t j = 0; j < n1; ++j) { + for (size_t k = 0; k < n2; ++k) { + dst[toff] = Kern::apply(a[as * j], b[bs * toff]); + ++toff; + } + } + } + }); + return; + } + MIDOUT_END(); + } + } + + naive::ElemwiseForwardImpl::exec(*m_src, *m_dst); + } + MIDOUT_END(); +} + +#define SWITCH_DTYPE(_cat, _cb) \ + switch (m_dst->layout.dtype.enumv()) { \ + MEGDNN_FOREACH_COMPUTING_DTYPE_##_cat(_cb) default \ + : megdnn_throw("bad dtype"); \ + } + +template +void ElemwiseImpl::exec_BINARY_INT() { + auto param = make_elemwise_op_param<2>(); +#define cb(_dt) \ + case DTypeTrait<_dt>::enumv: \ + return binary_kern<_dt, mode>(param); + + SWITCH_DTYPE(INT, cb) + +#undef cb +} + +template +void ElemwiseImpl::exec_BINARY_FLOAT() { + auto param = make_elemwise_op_param<2>(); +#define cb(_dt) \ + case DTypeTrait<_dt>::enumv: \ + return binary_kern<_dt, mode>(param); + + SWITCH_DTYPE(FLOAT, cb) + +#undef cb +} + +#undef SWITCH_DTYPE + +#undef SWITCH_DTYPE +using Mode = param_enumv::Elemwise::Mode; +#define INST(mode) template void megdnn::fallback::ElemwiseImpl::exec_BINARY_INT() +INST(Mode::ABS_GRAD); +INST(Mode::ADD); +INST(Mode::FLOOR_DIV); +INST(Mode::MAX); +INST(Mode::MIN); +INST(Mode::MOD); +INST(Mode::MUL); +INST(Mode::SIGMOID_GRAD); +INST(Mode::SUB); +INST(Mode::SWITCH_GT0); +INST(Mode::TANH_GRAD); +INST(Mode::LT); +INST(Mode::LEQ); +INST(Mode::EQ); +INST(Mode::SHL); +INST(Mode::SHR); +INST(Mode::FUSE_ADD_RELU); +INST(Mode::RMULH); +#undef INST + +#define INST(mode) \ + template void megdnn::fallback::ElemwiseImpl::exec_BINARY_FLOAT() +INST(Mode::ABS_GRAD); +INST(Mode::ADD); +INST(Mode::FLOOR_DIV); +INST(Mode::MAX); +INST(Mode::MIN); +INST(Mode::MOD); +INST(Mode::MUL); +INST(Mode::POW); +INST(Mode::SIGMOID_GRAD); +INST(Mode::SUB); +INST(Mode::SWITCH_GT0); +INST(Mode::TANH_GRAD); +INST(Mode::TRUE_DIV); +INST(Mode::LOG_SUM_EXP); +INST(Mode::LT); +INST(Mode::LEQ); +INST(Mode::EQ); +INST(Mode::FUSE_ADD_RELU); +INST(Mode::FUSE_ADD_SIGMOID); +INST(Mode::FUSE_ADD_TANH); +INST(Mode::FAST_TANH_GRAD); +INST(Mode::ATAN2); +INST(Mode::H_SWISH_GRAD); +INST(Mode::FUSE_ADD_H_SWISH); +INST(Mode::SILU_GRAD); +INST(Mode::GELU_GRAD); +#undef INST +} // namespace fallback +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/fallback/elemwise/opr_impl.cpp b/dnn/src/fallback/elemwise/opr_impl.cpp index a9fd7815..eb4b2d9c 100644 --- a/dnn/src/fallback/elemwise/opr_impl.cpp +++ b/dnn/src/fallback/elemwise/opr_impl.cpp @@ -16,8 +16,6 @@ #include "midout.h" -MIDOUT_DECL(megdnn_fallback_elemwise_unary) -MIDOUT_DECL(megdnn_fallback_elemwise_binary) MIDOUT_DECL(megdnn_fallback_elemwise_exec_UNARY_INT) MIDOUT_DECL(megdnn_fallback_elemwise_exec_UNARY_FLOAT) MIDOUT_DECL(megdnn_fallback_elemwise_exec_BINARY_INT) @@ -26,200 +24,6 @@ MIDOUT_DECL(megdnn_fallback_elemwise_exec_BINARY_FLOAT) namespace megdnn { namespace fallback { -template -void ElemwiseImpl::unary_kern(const ElemwiseOpParamN<1>& param) { - using ctype = typename DTypeTrait::ctype; - using Kern = ElemwiseKern; - MIDOUT_BEGIN(megdnn_fallback_elemwise_unary, ctype, midout_iv(mode)) { - // only specialize for the most common 1-dim case - auto tot = param.size; - auto stride = param[0].layout.stride[0]; - auto src0 = param[0]; - auto dst_tensor = *m_dst; - if (param.max_ndim == 1) { - MIDOUT_BEGIN( - megdnn_fallback_elemwise_unary, ctype, midout_iv(mode), - midout_iv(1)) { - MEGDNN_DISPATCH_CPU_KERN_OPR({ - ctype* __restrict src = static_cast(src0.raw_ptr()); - ctype* __restrict dst = static_cast(dst_tensor.raw_ptr()); - for (size_t i = 0; i < tot; ++i) { - dst[i] = Kern::apply(src[i * stride]); - } - }); - return; - } - MIDOUT_END(); - } - naive::ElemwiseForwardImpl::exec(*m_src, *m_dst); - } - MIDOUT_END(); -} - -template -void ElemwiseImpl::binary_kern(const ElemwiseOpParamN<2>& param) { - using ctype = typename DTypeTrait::ctype; - using Kern = ElemwiseKern; - - MIDOUT_BEGIN(megdnn_fallback_elemwise_binary, ctype, midout_iv(mode)) { - if (param.max_ndim == 1) { - MIDOUT_BEGIN( - megdnn_fallback_elemwise_binary, ctype, midout_iv(mode), - midout_iv(1)) { - auto tot = param.size; - auto as = param[0].layout.stride[0], bs = param[1].layout.stride[0]; - auto src0 = param[0]; - auto src1 = param[1]; - auto dst_tensor = *m_dst; - - MEGDNN_DISPATCH_CPU_KERN_OPR({ - ctype* __restrict a = static_cast(src0.raw_ptr()); - ctype* __restrict b = static_cast(src1.raw_ptr()); - ctype* __restrict dst = static_cast(dst_tensor.raw_ptr()); - for (size_t i = 0; i < tot; ++i) { - dst[i] = Kern::apply(a[i * as], b[i * bs]); - } - }); - return; - } - MIDOUT_END(); - } - - if (std::min(param[0].layout.ndim, param[1].layout.ndim) > 1) { - return naive::ElemwiseForwardImpl::exec(*m_src, *m_dst); - } - - if (param.max_ndim == 2) { - if (param[0].layout.ndim == 1) { - MIDOUT_BEGIN( - megdnn_fallback_elemwise_binary, ctype, midout_iv(mode), - midout_iv(21)) { - auto as = param[0].layout.stride[0], - bs0 = param[1].layout.stride[0], - bs1 = param[1].layout.stride[1]; - auto n0 = param[1].layout.shape[0], n1 = param[1].layout.shape[1]; - auto src0 = param[0]; - auto src1 = param[1]; - auto dst_tensor = *m_dst; - - MEGDNN_DISPATCH_CPU_KERN_OPR({ - ctype* __restrict a = static_cast(src0.raw_ptr()); - ctype* __restrict b = static_cast(src1.raw_ptr()); - ctype* __restrict dst = - static_cast(dst_tensor.raw_ptr()); - ptrdiff_t toff = 0; - for (size_t i = 0; i < n0; ++i) { - for (size_t j = 0; j < n1; ++j) { - dst[toff] = - Kern::apply(a[as * toff], b[bs0 * i + bs1 * j]); - ++toff; - } - } - }); - return; - } - MIDOUT_END(); - } - - MIDOUT_BEGIN( - megdnn_fallback_elemwise_binary, ctype, midout_iv(mode), - midout_iv(22)) { - megdnn_assert(param[1].layout.ndim == 1); - auto bs = param[1].layout.stride[0], as0 = param[0].layout.stride[0], - as1 = param[0].layout.stride[1]; - auto n0 = param[0].layout.shape[0], n1 = param[0].layout.shape[1]; - auto src0 = param[0]; - auto src1 = param[1]; - auto dst_tensor = *m_dst; - - MEGDNN_DISPATCH_CPU_KERN_OPR({ - ctype* __restrict a = static_cast(src0.raw_ptr()); - ctype* __restrict b = static_cast(src1.raw_ptr()); - ctype* __restrict dst = static_cast(dst_tensor.raw_ptr()); - ptrdiff_t toff = 0; - for (size_t i = 0; i < n0; ++i) { - for (size_t j = 0; j < n1; ++j) { - dst[toff] = Kern::apply(a[as0 * i + as1 * j], b[toff * bs]); - ++toff; - } - } - }); - return; - } - MIDOUT_END(); - } - - if (param.max_ndim == 3) { - auto brd_101 = [](const TensorND& t) { - auto&& l = t.layout; - return l.ndim == 3 && l.stride[0] == 0 && l.stride[2] == 0; - }; - if (param[0].layout.ndim == 1 && brd_101(param[1])) { - MIDOUT_BEGIN( - megdnn_fallback_elemwise_binary, ctype, midout_iv(mode), - midout_iv(31)) { - auto as = param[0].layout.stride[0], bs = param[1].layout.stride[1]; - auto n0 = param[1].layout.shape[0], n1 = param[1].layout.shape[1], - n2 = param[1].layout.shape[2]; - auto src0 = param[0]; - auto src1 = param[1]; - auto dst_tensor = *m_dst; - - MEGDNN_DISPATCH_CPU_KERN_OPR({ - ctype* __restrict a = static_cast(src0.raw_ptr()); - ctype* __restrict b = static_cast(src1.raw_ptr()); - ctype* __restrict dst = - static_cast(dst_tensor.raw_ptr()); - size_t toff = 0; - for (size_t i = 0; i < n0; ++i) { - for (size_t j = 0; j < n1; ++j) { - for (size_t k = 0; k < n2; ++k) { - dst[toff] = Kern::apply(a[as * toff], b[bs * j]); - ++toff; - } - } - } - }); - return; - } - MIDOUT_END(); - } - if (param[1].layout.ndim == 1 && brd_101(param[0])) { - MIDOUT_BEGIN( - megdnn_fallback_elemwise_binary, ctype, midout_iv(mode), - midout_iv(32)) { - auto as = param[0].layout.stride[1], bs = param[1].layout.stride[0]; - auto n0 = param[0].layout.shape[0], n1 = param[0].layout.shape[1], - n2 = param[0].layout.shape[2]; - auto src0 = param[0]; - auto src1 = param[1]; - auto dst_tensor = *m_dst; - MEGDNN_DISPATCH_CPU_KERN_OPR({ - ctype* __restrict a = static_cast(src0.raw_ptr()); - ctype* __restrict b = static_cast(src1.raw_ptr()); - ctype* __restrict dst = - static_cast(dst_tensor.raw_ptr()); - size_t toff = 0; - for (size_t i = 0; i < n0; ++i) { - for (size_t j = 0; j < n1; ++j) { - for (size_t k = 0; k < n2; ++k) { - dst[toff] = Kern::apply(a[as * j], b[bs * toff]); - ++toff; - } - } - } - }); - return; - } - MIDOUT_END(); - } - } - - naive::ElemwiseForwardImpl::exec(*m_src, *m_dst); - } - MIDOUT_END(); -} - void ElemwiseImpl::exec(const TensorNDArray& srcs, _megdnn_tensor_out dst) { if (!dst.layout.is_contiguous()) { return naive::ElemwiseForwardImpl::exec(srcs, dst); @@ -278,62 +82,6 @@ void ElemwiseImpl::exec(const TensorNDArray& srcs, _megdnn_tensor_out dst) { naive::ElemwiseForwardImpl::exec(srcs, dst); } -#define SWITCH_DTYPE(_cat, _cb) \ - switch (m_dst->layout.dtype.enumv()) { \ - MEGDNN_FOREACH_COMPUTING_DTYPE_##_cat(_cb) default \ - : megdnn_throw("bad dtype"); \ - } - -template -void ElemwiseImpl::exec_UNARY_INT() { - auto param = make_elemwise_op_param<1>(); -#define cb(_dt) \ - case DTypeTrait<_dt>::enumv: \ - return unary_kern<_dt, mode>(param); - - SWITCH_DTYPE(INT, cb) - -#undef cb -} - -template -void ElemwiseImpl::exec_UNARY_FLOAT() { - auto param = make_elemwise_op_param<1>(); -#define cb(_dt) \ - case DTypeTrait<_dt>::enumv: \ - return unary_kern<_dt, mode>(param); - - SWITCH_DTYPE(FLOAT, cb) - -#undef cb -} - -template -void ElemwiseImpl::exec_BINARY_INT() { - auto param = make_elemwise_op_param<2>(); -#define cb(_dt) \ - case DTypeTrait<_dt>::enumv: \ - return binary_kern<_dt, mode>(param); - - SWITCH_DTYPE(INT, cb) - -#undef cb -} - -template -void ElemwiseImpl::exec_BINARY_FLOAT() { - auto param = make_elemwise_op_param<2>(); -#define cb(_dt) \ - case DTypeTrait<_dt>::enumv: \ - return binary_kern<_dt, mode>(param); - - SWITCH_DTYPE(FLOAT, cb) - -#undef cb -} - -#undef SWITCH_DTYPE - } // namespace fallback } // namespace megdnn diff --git a/dnn/src/fallback/elemwise/opr_unary_impl.cpp b/dnn/src/fallback/elemwise/opr_unary_impl.cpp new file mode 100644 index 00000000..af829358 --- /dev/null +++ b/dnn/src/fallback/elemwise/opr_unary_impl.cpp @@ -0,0 +1,122 @@ +/** + * \file dnn/src/fallback/elemwise/opr_unary_impl.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "./opr_impl.h" + +#include "src/common/elemwise/kern_defs.cuh" +#include "src/common/utils.h" +#include "src/naive/handle.h" + +#include "midout.h" + +MIDOUT_DECL(megdnn_fallback_elemwise_unary) + +namespace megdnn { +namespace fallback { + +template +void ElemwiseImpl::unary_kern(const ElemwiseOpParamN<1>& param) { + using ctype = typename DTypeTrait::ctype; + using Kern = ElemwiseKern; + MIDOUT_BEGIN(megdnn_fallback_elemwise_unary, ctype, midout_iv(mode)) { + // only specialize for the most common 1-dim case + auto tot = param.size; + auto stride = param[0].layout.stride[0]; + auto src0 = param[0]; + auto dst_tensor = *m_dst; + if (param.max_ndim == 1) { + MIDOUT_BEGIN( + megdnn_fallback_elemwise_unary, ctype, midout_iv(mode), + midout_iv(1)) { + MEGDNN_DISPATCH_CPU_KERN_OPR({ + ctype* __restrict src = static_cast(src0.raw_ptr()); + ctype* __restrict dst = static_cast(dst_tensor.raw_ptr()); + for (size_t i = 0; i < tot; ++i) { + dst[i] = Kern::apply(src[i * stride]); + } + }); + return; + } + MIDOUT_END(); + } + naive::ElemwiseForwardImpl::exec(*m_src, *m_dst); + } + MIDOUT_END(); +} + +#define SWITCH_DTYPE(_cat, _cb) \ + switch (m_dst->layout.dtype.enumv()) { \ + MEGDNN_FOREACH_COMPUTING_DTYPE_##_cat(_cb) default \ + : megdnn_throw("bad dtype"); \ + } + +template +void ElemwiseImpl::exec_UNARY_INT() { + auto param = make_elemwise_op_param<1>(); +#define cb(_dt) \ + case DTypeTrait<_dt>::enumv: \ + return unary_kern<_dt, mode>(param); + + SWITCH_DTYPE(INT, cb) + +#undef cb +} + +template +void ElemwiseImpl::exec_UNARY_FLOAT() { + auto param = make_elemwise_op_param<1>(); +#define cb(_dt) \ + case DTypeTrait<_dt>::enumv: \ + return unary_kern<_dt, mode>(param); + + SWITCH_DTYPE(FLOAT, cb) + +#undef cb +} + +#undef SWITCH_DTYPE +using Mode = param_enumv::Elemwise::Mode; +#define INST(mode) template void megdnn::fallback::ElemwiseImpl::exec_UNARY_INT(); +INST(Mode::RELU); +INST(Mode::ABS); +INST(Mode::NEGATE); +#undef INST + +#define INST(mode) \ + template void megdnn::fallback::ElemwiseImpl::exec_UNARY_FLOAT(); +INST(Mode::RELU); +INST(Mode::ABS); +INST(Mode::ACOS); +INST(Mode::ASIN); +INST(Mode::CEIL); +INST(Mode::COS); +INST(Mode::EXP); +INST(Mode::EXPM1); +INST(Mode::FLOOR); +INST(Mode::LOG); +INST(Mode::LOG1P); +INST(Mode::NEGATE); +INST(Mode::SIGMOID); +INST(Mode::SIN); +INST(Mode::TANH); +INST(Mode::FAST_TANH); +INST(Mode::ROUND); +INST(Mode::ERF); +INST(Mode::ERFINV); +INST(Mode::ERFC); +INST(Mode::ERFCINV); +INST(Mode::H_SWISH); +INST(Mode::SILU); +INST(Mode::GELU); +#undef INST +} // namespace fallback +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/naive/elemwise_multi_type/opr_impl_1.cpp b/dnn/src/naive/elemwise_multi_type/opr_impl_1.cpp new file mode 100644 index 00000000..1c135917 --- /dev/null +++ b/dnn/src/naive/elemwise_multi_type/opr_impl_1.cpp @@ -0,0 +1,138 @@ +/** + * \file dnn/src/naive/elemwise_multi_type/opr_impl_1.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "./opr_impl.h" +#include "megdnn/tensor_iter.h" +#include "src/common/elemwise/kern_defs.cuh" +#include "src/common/elemwise_multi_type/kern_defs.cuh" +#include "src/naive/handle.h" + +using namespace megdnn; +using namespace naive; + +void ElemwiseMultiTypeImpl::on_fuse_mul_add3_int16x32x32x32( + const ElemwiseOpParamN<3>& param, const TensorND& dst) { + auto size = param.size; + auto src0 = param[0]; + auto src1 = param[1]; + auto src2 = param[2]; + auto work = [src0, src1, src2, size, dst]() { + auto i0 = tensor_iter_valonly(src0).begin(); + auto i1 = tensor_iter_valonly(src1).begin(); + auto i2 = tensor_iter_valonly(src2).begin(); + auto dst_ptr = dst.ptr(); + for (size_t i = 0; i < size; ++i) { + dst_ptr[i] = (*i0) * (*i1) + (*i2); + ++i0; + ++i1; + ++i2; + } + }; + MEGDNN_DISPATCH_CPU_KERN_OPR(work()); +} + +void ElemwiseMultiTypeImpl::on_fuse_mul_add3_int16xf32xf32xf32( + const ElemwiseOpParamN<3>& param, const TensorND& dst) { + auto size = param.size; + auto src0 = param[0]; + auto src1 = param[1]; + auto src2 = param[2]; + auto work = [src0, src1, src2, size, dst]() { + auto i0 = tensor_iter_valonly(src0).begin(); + auto i1 = tensor_iter_valonly(src1).begin(); + auto i2 = tensor_iter_valonly(src2).begin(); + auto dst_ptr = dst.ptr(); + for (size_t i = 0; i < size; ++i) { + dst_ptr[i] = (*i0) * (*i1) + (*i2); + ++i0; + ++i1; + ++i2; + } + }; + MEGDNN_DISPATCH_CPU_KERN_OPR(work()); +} + +void ElemwiseMultiTypeImpl::on_fuse_mul_add3_uint8xf32xf32xf32( + const ElemwiseOpParamN<3>& param, const TensorND& dst) { + auto size = param.size; + auto src0 = param[0]; + auto src1 = param[1]; + auto src2 = param[2]; + auto work = [src0, src1, src2, size, dst]() { + auto i0 = tensor_iter_valonly(src0).begin(); + auto i1 = tensor_iter_valonly(src1).begin(); + auto i2 = tensor_iter_valonly(src2).begin(); + auto dst_ptr = dst.ptr(); + for (size_t i = 0; i < size; ++i) { + dst_ptr[i] = (*i0) * (*i1) + (*i2); + ++i0; + ++i1; + ++i2; + } + }; + MEGDNN_DISPATCH_CPU_KERN_OPR(work()); +} + +void ElemwiseMultiTypeImpl::on_mul_int16xf32xf32( + const ElemwiseOpParamN<2>& param, const TensorND& dst) { + auto size = param.size; + auto src0 = param[0]; + auto src1 = param[1]; + auto work = [src0, src1, size, dst]() { + auto i0 = tensor_iter_valonly(src0).begin(); + auto i1 = tensor_iter_valonly(src1).begin(); + auto dst_ptr = dst.ptr(); + for (size_t i = 0; i < size; ++i) { + dst_ptr[i] = (*i0) * (*i1); + ++i0; + ++i1; + } + }; + MEGDNN_DISPATCH_CPU_KERN_OPR(work()); +} + +void ElemwiseMultiTypeImpl::on_fuse_mul_add3_iXxf32xf32xi8( + const ElemwiseOpParamN<3>& param, const TensorND& dst) { + switch (param[0].layout.dtype.enumv()) { +#define cb(t) \ + case DTypeTrait::enumv: \ + return dispatch_fma3_iXxf32xf32xi8::ctype>(param, dst); + MEGDNN_FOREACH_COMPUTING_DTYPE_INT(cb) +#undef cb + default: + megdnn_throw("unsupported src dtype"); + } +} + +template +void ElemwiseMultiTypeImpl::dispatch_fma3_iXxf32xf32xi8( + const ElemwiseOpParamN<3>& param, const TensorND& dst) { + auto size = param.size; + auto src0 = param[0]; + auto src1 = param[1]; + auto src2 = param[2]; + auto work = [src0, src1, src2, size, dst]() { + elemwise_multi_type::Fma3iXxf32xf32xiYOp op; + auto i0 = tensor_iter_valonly(src0).begin(); + auto i1 = tensor_iter_valonly(src1).begin(); + auto i2 = tensor_iter_valonly(src2).begin(); + auto dst_ptr = dst.ptr(); + for (size_t i = 0; i < size; ++i) { + dst_ptr[i] = op(*i0, *i1, *i2); + ++i0; + ++i1; + ++i2; + } + }; + MEGDNN_DISPATCH_CPU_KERN_OPR(work()); +} + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/naive/elemwise_multi_type/opr_impl_2.cpp b/dnn/src/naive/elemwise_multi_type/opr_impl_2.cpp new file mode 100644 index 00000000..45861c78 --- /dev/null +++ b/dnn/src/naive/elemwise_multi_type/opr_impl_2.cpp @@ -0,0 +1,115 @@ +/** + * \file dnn/src/naive/elemwise_multi_type/opr_impl_2.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "./opr_impl.h" +#include "megdnn/tensor_iter.h" +#include "src/common/elemwise/kern_defs.cuh" +#include "src/common/elemwise_multi_type/kern_defs.cuh" +#include "src/naive/handle.h" + +using namespace megdnn; +using namespace naive; + +void ElemwiseMultiTypeImpl::on_round_shr_saturate_iXxi8xi8( + const ElemwiseOpParamN<2>& param, const TensorND& dst) { + switch (param[0].layout.dtype.enumv()) { +#define cb(t) \ + case DTypeTrait::enumv: \ + return dispatch_round_shr_saturate_iXxi8xiX::ctype, dt_int8>( \ + param, dst); + MEGDNN_FOREACH_COMPUTING_DTYPE_INT(cb) +#undef cb + default: + megdnn_throw("unsupported src dtype"); + } +} + +template +void ElemwiseMultiTypeImpl::dispatch_round_shr_saturate_iXxi8xiX( + const ElemwiseOpParamN<2>& param, const TensorND& dst) { + auto src0 = param[0]; + auto src1 = param[1]; + auto size = param.size; + auto work = [src0, src1, size, dst]() { + // This is needed as these iterators are captured as const value. + auto iA = tensor_iter_valonly(src0).begin(); + auto iB = tensor_iter_valonly(src1).begin(); + auto pD = dst.ptr(); + for (size_t i = 0; i < size; i++) { + *pD = elemwise_multi_type::round_shr_saturate(*iA, *iB); + ++iA; + ++iB; + ++pD; + } + }; + MEGDNN_DISPATCH_CPU_KERN_OPR(work()); +} +template +void ElemwiseMultiTypeImpl::dispatch_fuse_add_rmulh_round_shr_saturate( + const ElemwiseOpParamN<6>& param, const TensorND& dst) { + auto size = param.size; + auto src0 = param[0]; + auto src1 = param[1]; + auto src2 = param[2]; + auto src3 = param[3]; + auto src4 = param[4]; + auto src5 = param[5]; + auto work = [size, src0, src1, src2, src3, src4, src5, dst]() { + auto i0 = tensor_iter_valonly(src0).begin(); + auto i1 = tensor_iter_valonly(src1).begin(); + auto i2 = tensor_iter_valonly(src2).begin(); + auto ioff = tensor_iter_valonly(src3).begin(); + auto imin = tensor_iter_valonly(src4).begin(); + auto imax = tensor_iter_valonly(src5).begin(); + auto dst_ptr = dst.ptr(); + for (size_t i = 0; i < size; ++i) { + auto res = elemwise_multi_type::round_shr_saturate( + round_mulh_saturate(*i0 + *i1, *i2), *ioff); + res = std::min(res, *imax); + res = std::max(res, *imin); + dst_ptr[i] = res; + ++i0; + ++i1; + ++i2; + ++ioff; + ++imin; + ++imax; + } + }; + MEGDNN_DISPATCH_CPU_KERN_OPR(work()); +} + +void ElemwiseMultiTypeImpl::on_fuse_add_rmulh_round_shr_saturate_int16x16x16x8( + const ElemwiseOpParamN<6>& param, const TensorND& dst) { + dispatch_fuse_add_rmulh_round_shr_saturate(param, dst); +} + +void ElemwiseMultiTypeImpl::on_fuse_add_rmulh_round_shr_saturate_int32x32x32x8( + const ElemwiseOpParamN<6>& param, const TensorND& dst) { + dispatch_fuse_add_rmulh_round_shr_saturate(param, dst); +} + +void ElemwiseMultiTypeImpl::on_round_shr_saturate_iXxi8xi16( + const ElemwiseOpParamN<2>& param, const TensorND& dst) { + switch (param[0].layout.dtype.enumv()) { +#define cb(t) \ + case DTypeTrait::enumv: \ + return dispatch_round_shr_saturate_iXxi8xiX::ctype, dt_int16>( \ + param, dst); + cb(::megdnn::dtype::Int32); + cb(::megdnn::dtype::Int16); +#undef cb + default: + megdnn_throw("unsupported src dtype"); + } +} + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/naive/elemwise_multi_type/opr_impl.cpp b/dnn/src/naive/elemwise_multi_type/opr_impl_3.cpp similarity index 56% rename from dnn/src/naive/elemwise_multi_type/opr_impl.cpp rename to dnn/src/naive/elemwise_multi_type/opr_impl_3.cpp index 502cef74..67c21b12 100644 --- a/dnn/src/naive/elemwise_multi_type/opr_impl.cpp +++ b/dnn/src/naive/elemwise_multi_type/opr_impl_3.cpp @@ -1,5 +1,5 @@ /** - * \file dnn/src/naive/elemwise_multi_type/opr_impl.cpp + * \file dnn/src/naive/elemwise_multi_type/opr_impl_3.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -18,218 +18,6 @@ using namespace megdnn; using namespace naive; -void ElemwiseMultiTypeImpl::on_fuse_mul_add3_int16x32x32x32( - const ElemwiseOpParamN<3>& param, const TensorND& dst) { - auto size = param.size; - auto src0 = param[0]; - auto src1 = param[1]; - auto src2 = param[2]; - auto work = [src0, src1, src2, size, dst]() { - auto i0 = tensor_iter_valonly(src0).begin(); - auto i1 = tensor_iter_valonly(src1).begin(); - auto i2 = tensor_iter_valonly(src2).begin(); - auto dst_ptr = dst.ptr(); - for (size_t i = 0; i < size; ++i) { - dst_ptr[i] = (*i0) * (*i1) + (*i2); - ++i0; - ++i1; - ++i2; - } - }; - MEGDNN_DISPATCH_CPU_KERN_OPR(work()); -} - -void ElemwiseMultiTypeImpl::on_fuse_mul_add3_int16xf32xf32xf32( - const ElemwiseOpParamN<3>& param, const TensorND& dst) { - auto size = param.size; - auto src0 = param[0]; - auto src1 = param[1]; - auto src2 = param[2]; - auto work = [src0, src1, src2, size, dst]() { - auto i0 = tensor_iter_valonly(src0).begin(); - auto i1 = tensor_iter_valonly(src1).begin(); - auto i2 = tensor_iter_valonly(src2).begin(); - auto dst_ptr = dst.ptr(); - for (size_t i = 0; i < size; ++i) { - dst_ptr[i] = (*i0) * (*i1) + (*i2); - ++i0; - ++i1; - ++i2; - } - }; - MEGDNN_DISPATCH_CPU_KERN_OPR(work()); -} - -void ElemwiseMultiTypeImpl::on_fuse_mul_add3_uint8xf32xf32xf32( - const ElemwiseOpParamN<3>& param, const TensorND& dst) { - auto size = param.size; - auto src0 = param[0]; - auto src1 = param[1]; - auto src2 = param[2]; - auto work = [src0, src1, src2, size, dst]() { - auto i0 = tensor_iter_valonly(src0).begin(); - auto i1 = tensor_iter_valonly(src1).begin(); - auto i2 = tensor_iter_valonly(src2).begin(); - auto dst_ptr = dst.ptr(); - for (size_t i = 0; i < size; ++i) { - dst_ptr[i] = (*i0) * (*i1) + (*i2); - ++i0; - ++i1; - ++i2; - } - }; - MEGDNN_DISPATCH_CPU_KERN_OPR(work()); -} - -void ElemwiseMultiTypeImpl::on_mul_int16xf32xf32( - const ElemwiseOpParamN<2>& param, const TensorND& dst) { - auto size = param.size; - auto src0 = param[0]; - auto src1 = param[1]; - auto work = [src0, src1, size, dst]() { - auto i0 = tensor_iter_valonly(src0).begin(); - auto i1 = tensor_iter_valonly(src1).begin(); - auto dst_ptr = dst.ptr(); - for (size_t i = 0; i < size; ++i) { - dst_ptr[i] = (*i0) * (*i1); - ++i0; - ++i1; - } - }; - MEGDNN_DISPATCH_CPU_KERN_OPR(work()); -} - -void ElemwiseMultiTypeImpl::on_fuse_mul_add3_iXxf32xf32xi8( - const ElemwiseOpParamN<3>& param, const TensorND& dst) { - switch (param[0].layout.dtype.enumv()) { -#define cb(t) \ - case DTypeTrait::enumv: \ - return dispatch_fma3_iXxf32xf32xi8::ctype>(param, dst); - MEGDNN_FOREACH_COMPUTING_DTYPE_INT(cb) -#undef cb - default: - megdnn_throw("unsupported src dtype"); - } -} - -template -void ElemwiseMultiTypeImpl::dispatch_fma3_iXxf32xf32xi8( - const ElemwiseOpParamN<3>& param, const TensorND& dst) { - auto size = param.size; - auto src0 = param[0]; - auto src1 = param[1]; - auto src2 = param[2]; - auto work = [src0, src1, src2, size, dst]() { - elemwise_multi_type::Fma3iXxf32xf32xiYOp op; - auto i0 = tensor_iter_valonly(src0).begin(); - auto i1 = tensor_iter_valonly(src1).begin(); - auto i2 = tensor_iter_valonly(src2).begin(); - auto dst_ptr = dst.ptr(); - for (size_t i = 0; i < size; ++i) { - dst_ptr[i] = op(*i0, *i1, *i2); - ++i0; - ++i1; - ++i2; - } - }; - MEGDNN_DISPATCH_CPU_KERN_OPR(work()); -} - -void ElemwiseMultiTypeImpl::on_round_shr_saturate_iXxi8xi8( - const ElemwiseOpParamN<2>& param, const TensorND& dst) { - switch (param[0].layout.dtype.enumv()) { -#define cb(t) \ - case DTypeTrait::enumv: \ - return dispatch_round_shr_saturate_iXxi8xiX::ctype, dt_int8>( \ - param, dst); - MEGDNN_FOREACH_COMPUTING_DTYPE_INT(cb) -#undef cb - default: - megdnn_throw("unsupported src dtype"); - } -} - -template -void ElemwiseMultiTypeImpl::dispatch_round_shr_saturate_iXxi8xiX( - const ElemwiseOpParamN<2>& param, const TensorND& dst) { - auto src0 = param[0]; - auto src1 = param[1]; - auto size = param.size; - auto work = [src0, src1, size, dst]() { - // This is needed as these iterators are captured as const value. - auto iA = tensor_iter_valonly(src0).begin(); - auto iB = tensor_iter_valonly(src1).begin(); - auto pD = dst.ptr(); - for (size_t i = 0; i < size; i++) { - *pD = elemwise_multi_type::round_shr_saturate(*iA, *iB); - ++iA; - ++iB; - ++pD; - } - }; - MEGDNN_DISPATCH_CPU_KERN_OPR(work()); -} - -template -void ElemwiseMultiTypeImpl::dispatch_fuse_add_rmulh_round_shr_saturate( - const ElemwiseOpParamN<6>& param, const TensorND& dst) { - auto size = param.size; - auto src0 = param[0]; - auto src1 = param[1]; - auto src2 = param[2]; - auto src3 = param[3]; - auto src4 = param[4]; - auto src5 = param[5]; - auto work = [size, src0, src1, src2, src3, src4, src5, dst]() { - auto i0 = tensor_iter_valonly(src0).begin(); - auto i1 = tensor_iter_valonly(src1).begin(); - auto i2 = tensor_iter_valonly(src2).begin(); - auto ioff = tensor_iter_valonly(src3).begin(); - auto imin = tensor_iter_valonly(src4).begin(); - auto imax = tensor_iter_valonly(src5).begin(); - auto dst_ptr = dst.ptr(); - for (size_t i = 0; i < size; ++i) { - auto res = elemwise_multi_type::round_shr_saturate( - round_mulh_saturate(*i0 + *i1, *i2), *ioff); - res = std::min(res, *imax); - res = std::max(res, *imin); - dst_ptr[i] = res; - ++i0; - ++i1; - ++i2; - ++ioff; - ++imin; - ++imax; - } - }; - MEGDNN_DISPATCH_CPU_KERN_OPR(work()); -} - -void ElemwiseMultiTypeImpl::on_fuse_add_rmulh_round_shr_saturate_int16x16x16x8( - const ElemwiseOpParamN<6>& param, const TensorND& dst) { - dispatch_fuse_add_rmulh_round_shr_saturate(param, dst); -} - -void ElemwiseMultiTypeImpl::on_fuse_add_rmulh_round_shr_saturate_int32x32x32x8( - const ElemwiseOpParamN<6>& param, const TensorND& dst) { - dispatch_fuse_add_rmulh_round_shr_saturate(param, dst); -} - -void ElemwiseMultiTypeImpl::on_round_shr_saturate_iXxi8xi16( - const ElemwiseOpParamN<2>& param, const TensorND& dst) { - switch (param[0].layout.dtype.enumv()) { -#define cb(t) \ - case DTypeTrait::enumv: \ - return dispatch_round_shr_saturate_iXxi8xiX::ctype, dt_int16>( \ - param, dst); - cb(::megdnn::dtype::Int32); - cb(::megdnn::dtype::Int16); -#undef cb - default: - megdnn_throw("unsupported src dtype"); - } -} - template void ElemwiseMultiTypeImpl::dispatch_add_qint_op( const ElemwiseOpParamN<1>& param, const TensorND& dst_tensor) {