From ff6a3bb819ec90ddfb51d7920ea06a20acd95f12 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 9 Mar 2022 18:31:34 +0800 Subject: [PATCH] fix(fallback): delete the repeat opcaller in fallback and arm_common GitOrigin-RevId: 87046b81977acc8f12f2d46e8361956b9e050e20 --- dnn/src/aarch64/conv_bias/int8/algos.cpp | 2 +- dnn/src/aarch64/conv_bias/quint8/algos.cpp | 2 +- .../f16/channel_wise_3x3_s1p1_nchw88_kern.cpp | 2 +- .../conv_bias/f16/channel_wise_nchw88_algo.cpp | 2 +- .../conv_bias/f16/channel_wise_nchw88_kern.cpp | 2 +- .../conv_bias/f16/direct_nchw88_algo.cpp | 2 +- .../conv_bias/f16/direct_nchw88_kern.cpp | 2 +- .../fp32/channel_wise_3x3_s1p1_nchw44_kern.cpp | 2 +- .../fp32/channel_wise_5x5_s1p2_nchw44_kern.cpp | 2 +- .../conv_bias/fp32/channel_wise_nchw44_algo.cpp | 2 +- .../conv_bias/fp32/channel_wise_nchw44_kern.cpp | 2 +- .../f32_direct_nchw44_kern_common_s1.h | 2 +- .../f32_direct_nchw44_kern_common_s2.h | 2 +- .../f32_direct_nchw_nchw44_kern_common.h | 2 +- .../conv_bias/fp32/f32_direct_nchw44_algo.cpp | 2 +- .../conv_bias/fp32/f32_direct_nchw_nchw44_algo.cpp | 2 +- .../conv_bias/fp32/f32_direct_nchw_nchw44_kern.h | 2 +- dnn/src/arm_common/conv_bias/int8/algos.cpp | 2 +- .../conv_bias/int8/channel_wise_kernel.cpp | 2 +- .../conv_bias/int8/channel_wise_nchw44.cpp | 2 +- dnn/src/arm_common/conv_bias/int8/direct.cpp | 2 +- .../arm_common/conv_bias/int8/direct_dotprod.cpp | 2 +- .../conv_bias/int8/direct_dotprod_nchw44_algo.cpp | 2 +- .../int8/direct_kernels/dot_direct_nchw44_common.h | 2 +- .../int8/direct_kernels/int8_direct_nchw44_s1.cpp | 2 +- .../int8/direct_kernels/int8_direct_nchw44_s2.cpp | 2 +- .../conv_bias/int8/direct_nchw44_algo.cpp | 2 +- .../arm_common/conv_bias/int8/direct_nchw44_kern.h | 2 +- .../conv_bias/int8/direct_nchw_nchw44_algo.cpp | 2 +- .../conv_bias/int8/direct_nchw_nchw44_kern.h | 2 +- .../conv_bias/int8/dot_direct_nchw_nchw44_algo.cpp | 2 +- .../conv_bias/int8/dot_direct_nchw_nchw44_kern.h | 2 +- dnn/src/arm_common/conv_bias/int8/stride1.cpp | 2 +- .../arm_common/conv_bias/int8/stride1_dotprod.cpp | 2 +- dnn/src/arm_common/conv_bias/int8/stride2.cpp | 2 +- .../arm_common/conv_bias/int8/stride2_dotprod.cpp | 2 +- .../channel_wise_kernel_int8x8x16_nchw44.cpp | 2 +- .../int8x8x16/direct_nchw_nchw44_algo.cpp | 2 +- .../conv_bias/int8x8x16/direct_nchw_nchw44_kern.h | 2 +- .../kernel/direct_nchw_nchw44_kern_impl.h | 2 +- dnn/src/arm_common/conv_bias/matmul_postprocess.h | 2 +- dnn/src/arm_common/conv_bias/postprocess_helper.h | 120 +- dnn/src/arm_common/conv_bias/quint8/algos.cpp | 2 +- dnn/src/arm_common/conv_bias/quint8/direct.cpp | 2 +- .../arm_common/conv_bias/quint8/direct_dotprod.cpp | 2 +- dnn/src/arm_common/conv_bias/quint8/stride1.cpp | 2 +- .../conv_bias/quint8/stride1_dotprod.cpp | 2 +- dnn/src/arm_common/conv_bias/quint8/stride2.cpp | 2 +- .../conv_bias/quint8/stride2_dotprod.cpp | 2 +- dnn/src/arm_common/elemwise/binary/algo.cpp | 7 +- dnn/src/arm_common/elemwise/opr_impl.cpp | 2 +- dnn/src/arm_common/elemwise/opr_impl.h | 2 +- dnn/src/arm_common/elemwise/ternary/algo.cpp | 3 +- dnn/src/arm_common/elemwise/unary/algo.cpp | 3 +- dnn/src/arm_common/elemwise_helper/elemwise_op.h | 151 ++ dnn/src/arm_common/elemwise_helper/kimpl/pow.h | 36 - dnn/src/arm_common/elemwise_helper/op_binary.h | 1 - .../arm_common/elemwise_multi_type/opr_impl.cpp | 4 +- dnn/src/arm_common/elemwise_op.h | 1537 -------------------- dnn/src/fallback/elemwise/gi_impl/binary/algo.cpp | 3 +- dnn/src/fallback/elemwise/gi_impl/ternary/algo.cpp | 3 +- dnn/src/fallback/elemwise/gi_impl/unary/algo.cpp | 3 +- dnn/src/fallback/elemwise/opr_impl.cpp | 1 + dnn/src/fallback/elemwise/opr_impl.h | 4 +- dnn/src/fallback/elemwise_helper/elemwise_op.h | 72 + dnn/src/fallback/elemwise_helper/kimpl/abs.h | 2 +- .../elemwise_helper/kimpl/fuse_add_h_swish.h | 2 +- dnn/src/fallback/elemwise_helper/kimpl/hswish.h | 2 +- dnn/src/fallback/elemwise_helper/kimpl/max.h | 2 +- dnn/src/fallback/elemwise_helper/kimpl/min.h | 2 +- dnn/src/fallback/elemwise_helper/kimpl/mul.h | 2 +- dnn/src/fallback/elemwise_helper/kimpl/none.h | 8 +- dnn/src/fallback/elemwise_helper/kimpl/relu.h | 20 +- dnn/src/fallback/elemwise_helper/kimpl/sub.h | 2 +- dnn/src/fallback/elemwise_helper/op_common.h | 1370 +++++++++++++++++ dnn/src/fallback/elemwise_op.h | 1432 ------------------ dnn/src/fallback/general_intrinsic/gi_common.h | 13 + dnn/src/fallback/general_intrinsic/gi_float.h | 52 +- dnn/src/fallback/general_intrinsic/gi_int.h | 63 +- dnn/test/fallback/elemwise.cpp | 1 - 80 files changed, 1810 insertions(+), 3211 deletions(-) create mode 100644 dnn/src/arm_common/elemwise_helper/elemwise_op.h delete mode 100644 dnn/src/arm_common/elemwise_helper/kimpl/pow.h delete mode 100644 dnn/src/arm_common/elemwise_op.h create mode 100644 dnn/src/fallback/elemwise_helper/elemwise_op.h delete mode 100644 dnn/src/fallback/elemwise_op.h diff --git a/dnn/src/aarch64/conv_bias/int8/algos.cpp b/dnn/src/aarch64/conv_bias/int8/algos.cpp index 79b6b0d3..b5e97c01 100644 --- a/dnn/src/aarch64/conv_bias/int8/algos.cpp +++ b/dnn/src/aarch64/conv_bias/int8/algos.cpp @@ -12,7 +12,7 @@ #include "src/aarch64/conv_bias/int8/algos.h" #include "src/aarch64/conv_bias/int8/strategy.h" #include "src/arm_common/convolution/img2col_helper.h" -#include "src/arm_common/elemwise_op.h" +#include "src/arm_common/elemwise_helper/elemwise_op.h" #include "src/common/opr_delegate.h" #include "src/fallback/conv_bias/common.h" #include "src/fallback/matrix_mul/gemm_impl.h" diff --git a/dnn/src/aarch64/conv_bias/quint8/algos.cpp b/dnn/src/aarch64/conv_bias/quint8/algos.cpp index 11596fdc..809cd3ef 100644 --- a/dnn/src/aarch64/conv_bias/quint8/algos.cpp +++ b/dnn/src/aarch64/conv_bias/quint8/algos.cpp @@ -14,7 +14,7 @@ #include "src/aarch64/matrix_mul/quint8_dot/gemv.h" #include "src/aarch64/matrix_mul/quint8_dot/strategy.h" #include "src/arm_common/convolution/img2col_helper.h" -#include "src/arm_common/elemwise_op.h" +#include "src/arm_common/elemwise_helper/elemwise_op.h" #include "src/common/opr_delegate.h" #include "src/fallback/conv_bias/common.h" #include "src/fallback/matrix_mul/gemm_impl.h" diff --git a/dnn/src/arm_common/conv_bias/f16/channel_wise_3x3_s1p1_nchw88_kern.cpp b/dnn/src/arm_common/conv_bias/f16/channel_wise_3x3_s1p1_nchw88_kern.cpp index aa9a072d..1e1a797f 100644 --- a/dnn/src/arm_common/conv_bias/f16/channel_wise_3x3_s1p1_nchw88_kern.cpp +++ b/dnn/src/arm_common/conv_bias/f16/channel_wise_3x3_s1p1_nchw88_kern.cpp @@ -11,7 +11,7 @@ */ #include "src/arm_common/conv_bias/f16/channel_wise_3x3_s1p1_nchw88_kern.h" -#include "src/arm_common/elemwise_op.h" +#include "src/arm_common/elemwise_helper/elemwise_op.h" #include "src/arm_common/simd_macro/marm_neon.h" #include "src/arm_common/utils.h" #include "src/common/utils.h" diff --git a/dnn/src/arm_common/conv_bias/f16/channel_wise_nchw88_algo.cpp b/dnn/src/arm_common/conv_bias/f16/channel_wise_nchw88_algo.cpp index f532a683..cfa427b8 100644 --- a/dnn/src/arm_common/conv_bias/f16/channel_wise_nchw88_algo.cpp +++ b/dnn/src/arm_common/conv_bias/f16/channel_wise_nchw88_algo.cpp @@ -12,7 +12,7 @@ #include "src/arm_common/conv_bias/f16/algos.h" #include "src/arm_common/conv_bias/f16/channel_wise_nchw88_kern.h" -#include "src/arm_common/elemwise_op.h" +#include "src/arm_common/elemwise_helper/elemwise_op.h" #include "midout.h" diff --git a/dnn/src/arm_common/conv_bias/f16/channel_wise_nchw88_kern.cpp b/dnn/src/arm_common/conv_bias/f16/channel_wise_nchw88_kern.cpp index 1dd0934d..b0ddafe0 100644 --- a/dnn/src/arm_common/conv_bias/f16/channel_wise_nchw88_kern.cpp +++ b/dnn/src/arm_common/conv_bias/f16/channel_wise_nchw88_kern.cpp @@ -12,7 +12,7 @@ #include "src/arm_common/conv_bias/f16/channel_wise_nchw88_kern.h" #include "src/arm_common/conv_bias/f16/channel_wise_3x3_s1p1_nchw88_kern.h" -#include "src/arm_common/elemwise_op.h" +#include "src/arm_common/elemwise_helper/elemwise_op.h" #include "src/arm_common/simd_macro/marm_neon.h" #include "src/arm_common/utils.h" #include "src/common/unroll_macro.h" diff --git a/dnn/src/arm_common/conv_bias/f16/direct_nchw88_algo.cpp b/dnn/src/arm_common/conv_bias/f16/direct_nchw88_algo.cpp index 12d10f82..42de168e 100644 --- a/dnn/src/arm_common/conv_bias/f16/direct_nchw88_algo.cpp +++ b/dnn/src/arm_common/conv_bias/f16/direct_nchw88_algo.cpp @@ -15,7 +15,7 @@ #include "src/arm_common/conv_bias/f16/algos.h" #include "src/arm_common/conv_bias/f16/direct_nchw88_kern.h" -#include "src/arm_common/elemwise_op.h" +#include "src/arm_common/elemwise_helper/elemwise_op.h" #include "midout.h" diff --git a/dnn/src/arm_common/conv_bias/f16/direct_nchw88_kern.cpp b/dnn/src/arm_common/conv_bias/f16/direct_nchw88_kern.cpp index 7c1f75f2..b00d4b18 100644 --- a/dnn/src/arm_common/conv_bias/f16/direct_nchw88_kern.cpp +++ b/dnn/src/arm_common/conv_bias/f16/direct_nchw88_kern.cpp @@ -13,7 +13,7 @@ #include "src/arm_common/conv_bias/f16/direct_nchw88_kern.h" #include "src/arm_common/conv_bias/opr_impl.h" -#include "src/arm_common/elemwise_op.h" +#include "src/arm_common/elemwise_helper/elemwise_op.h" #include "src/arm_common/simd_macro/marm_neon.h" #include "src/fallback/conv_bias/common.h" diff --git a/dnn/src/arm_common/conv_bias/fp32/channel_wise_3x3_s1p1_nchw44_kern.cpp b/dnn/src/arm_common/conv_bias/fp32/channel_wise_3x3_s1p1_nchw44_kern.cpp index 5d9f1328..2da295c8 100644 --- a/dnn/src/arm_common/conv_bias/fp32/channel_wise_3x3_s1p1_nchw44_kern.cpp +++ b/dnn/src/arm_common/conv_bias/fp32/channel_wise_3x3_s1p1_nchw44_kern.cpp @@ -11,7 +11,7 @@ */ #include "src/arm_common/conv_bias/fp32/channel_wise_3x3_s1p1_nchw44_kern.h" -#include "src/arm_common/elemwise_op.h" +#include "src/arm_common/elemwise_helper/elemwise_op.h" #include "src/arm_common/simd_macro/marm_neon.h" #include "src/arm_common/utils.h" #include "src/common/utils.h" diff --git a/dnn/src/arm_common/conv_bias/fp32/channel_wise_5x5_s1p2_nchw44_kern.cpp b/dnn/src/arm_common/conv_bias/fp32/channel_wise_5x5_s1p2_nchw44_kern.cpp index 48b65f8a..46700ba4 100644 --- a/dnn/src/arm_common/conv_bias/fp32/channel_wise_5x5_s1p2_nchw44_kern.cpp +++ b/dnn/src/arm_common/conv_bias/fp32/channel_wise_5x5_s1p2_nchw44_kern.cpp @@ -11,7 +11,7 @@ */ #include "src/arm_common/conv_bias/fp32/channel_wise_5x5_s1p2_nchw44_kern.h" -#include "src/arm_common/elemwise_op.h" +#include "src/arm_common/elemwise_helper/elemwise_op.h" #include "src/arm_common/simd_macro/marm_neon.h" #include "src/arm_common/utils.h" #include "src/common/utils.h" diff --git a/dnn/src/arm_common/conv_bias/fp32/channel_wise_nchw44_algo.cpp b/dnn/src/arm_common/conv_bias/fp32/channel_wise_nchw44_algo.cpp index 6f717cba..60f2fdb2 100644 --- a/dnn/src/arm_common/conv_bias/fp32/channel_wise_nchw44_algo.cpp +++ b/dnn/src/arm_common/conv_bias/fp32/channel_wise_nchw44_algo.cpp @@ -12,7 +12,7 @@ #include "src/arm_common/conv_bias/fp32/algos.h" #include "src/arm_common/conv_bias/fp32/channel_wise_nchw44_kern.h" -#include "src/arm_common/elemwise_op.h" +#include "src/arm_common/elemwise_helper/elemwise_op.h" #include "midout.h" diff --git a/dnn/src/arm_common/conv_bias/fp32/channel_wise_nchw44_kern.cpp b/dnn/src/arm_common/conv_bias/fp32/channel_wise_nchw44_kern.cpp index 4acda6fc..e7d403e2 100644 --- a/dnn/src/arm_common/conv_bias/fp32/channel_wise_nchw44_kern.cpp +++ b/dnn/src/arm_common/conv_bias/fp32/channel_wise_nchw44_kern.cpp @@ -13,7 +13,7 @@ #include "src/arm_common/conv_bias/fp32/channel_wise_nchw44_kern.h" #include "src/arm_common/conv_bias/fp32/channel_wise_3x3_s1p1_nchw44_kern.h" #include "src/arm_common/conv_bias/fp32/channel_wise_5x5_s1p2_nchw44_kern.h" -#include "src/arm_common/elemwise_op.h" +#include "src/arm_common/elemwise_helper/elemwise_op.h" #include "src/arm_common/simd_macro/marm_neon.h" #include "src/arm_common/utils.h" #include "src/common/unroll_macro.h" diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h index 3915caea..37b893a2 100644 --- a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h +++ b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h @@ -14,7 +14,7 @@ #include "megdnn/arch.h" #include "src/arm_common/conv_bias/fp32/f32_direct_nchw44_kern.h" #include "src/arm_common/conv_bias/intrinsic_helper.h" -#include "src/arm_common/elemwise_op.h" +#include "src/arm_common/elemwise_helper/elemwise_op.h" #include "src/arm_common/simd_macro/marm_neon.h" #include "src/common/unroll_macro.h" #include "src/common/utils.h" diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h index cbbf047a..e52654a0 100644 --- a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h +++ b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h @@ -14,7 +14,7 @@ #include "megdnn/arch.h" #include "src/arm_common/conv_bias/fp32/f32_direct_nchw44_kern.h" #include "src/arm_common/conv_bias/intrinsic_helper.h" -#include "src/arm_common/elemwise_op.h" +#include "src/arm_common/elemwise_helper/elemwise_op.h" #include "src/arm_common/simd_macro/marm_neon.h" #include "src/common/unroll_macro.h" #include "src/common/utils.h" diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h index 1869f2ff..07372da1 100644 --- a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h +++ b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h @@ -14,7 +14,7 @@ #include "src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_kern.h" #include "src/arm_common/conv_bias/intrinsic_helper.h" #include "src/arm_common/conv_bias/opr_impl.h" -#include "src/arm_common/elemwise_op.h" +#include "src/arm_common/elemwise_helper/elemwise_op.h" #include "src/arm_common/simd_macro/marm_neon.h" #include "src/common/unroll_macro.h" #include "src/common/utils.h" diff --git a/dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw44_algo.cpp b/dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw44_algo.cpp index f454da13..508a0a21 100644 --- a/dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw44_algo.cpp +++ b/dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw44_algo.cpp @@ -15,7 +15,7 @@ #include "src/arm_common/conv_bias/fp32/algos.h" #include "src/arm_common/conv_bias/fp32/f32_direct_nchw44_kern.h" -#include "src/arm_common/elemwise_op.h" +#include "src/arm_common/elemwise_helper/elemwise_op.h" #include "midout.h" diff --git a/dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_algo.cpp b/dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_algo.cpp index fbfa91fa..9efeff10 100644 --- a/dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_algo.cpp +++ b/dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_algo.cpp @@ -15,7 +15,7 @@ #include "src/arm_common/conv_bias/fp32/algos.h" #include "src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_kern.h" #include "src/arm_common/conv_bias/fp32/strategy.h" -#include "src/arm_common/elemwise_op.h" +#include "src/arm_common/elemwise_helper/elemwise_op.h" #include "src/common/nchw_nchwxx_valid.h" #include "src/common/opr_delegate.h" diff --git a/dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_kern.h b/dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_kern.h index dc26275c..51477e02 100644 --- a/dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_kern.h +++ b/dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_kern.h @@ -13,7 +13,7 @@ #include "megdnn/arch.h" #include "src/arm_common/conv_bias/intrinsic_helper.h" #include "src/arm_common/conv_bias/opr_impl.h" -#include "src/arm_common/elemwise_op.h" +#include "src/arm_common/elemwise_helper/elemwise_op.h" #include "src/arm_common/simd_macro/marm_neon.h" #include "src/common/unroll_macro.h" #include "src/common/utils.h" diff --git a/dnn/src/arm_common/conv_bias/int8/algos.cpp b/dnn/src/arm_common/conv_bias/int8/algos.cpp index a74ef427..06146081 100644 --- a/dnn/src/arm_common/conv_bias/int8/algos.cpp +++ b/dnn/src/arm_common/conv_bias/int8/algos.cpp @@ -17,7 +17,7 @@ #include "src/arm_common/conv_bias/int8/stride1_dotprod.h" #include "src/arm_common/conv_bias/int8/stride2.h" #include "src/arm_common/conv_bias/int8/stride2_dotprod.h" -#include "src/arm_common/elemwise_op.h" +#include "src/arm_common/elemwise_helper/elemwise_op.h" #include "src/fallback/conv_bias/common.h" #include "midout.h" diff --git a/dnn/src/arm_common/conv_bias/int8/channel_wise_kernel.cpp b/dnn/src/arm_common/conv_bias/int8/channel_wise_kernel.cpp index 3c98e3a3..7521c6de 100644 --- a/dnn/src/arm_common/conv_bias/int8/channel_wise_kernel.cpp +++ b/dnn/src/arm_common/conv_bias/int8/channel_wise_kernel.cpp @@ -11,7 +11,7 @@ */ #include "src/arm_common/conv_bias/int8/channel_wise_kernel.h" -#include "src/arm_common/elemwise_op.h" +#include "src/arm_common/elemwise_helper/elemwise_op.h" #include "src/arm_common/simd_macro/marm_neon.h" #include "src/common/unroll_macro.h" #include "src/common/utils.h" diff --git a/dnn/src/arm_common/conv_bias/int8/channel_wise_nchw44.cpp b/dnn/src/arm_common/conv_bias/int8/channel_wise_nchw44.cpp index 85006e64..e862a3e8 100644 --- a/dnn/src/arm_common/conv_bias/int8/channel_wise_nchw44.cpp +++ b/dnn/src/arm_common/conv_bias/int8/channel_wise_nchw44.cpp @@ -12,7 +12,7 @@ #include "src/arm_common/conv_bias/int8/channel_wise_nchw44.h" #include "megdnn/oprs.h" #include "src/arm_common/conv_bias/int8/channel_wise_kernel.h" -#include "src/arm_common/elemwise_op.h" +#include "src/arm_common/elemwise_helper/elemwise_op.h" #include "src/common/opr_delegate.h" #include "midout.h" diff --git a/dnn/src/arm_common/conv_bias/int8/direct.cpp b/dnn/src/arm_common/conv_bias/int8/direct.cpp index f62610fa..dc52ac46 100644 --- a/dnn/src/arm_common/conv_bias/int8/direct.cpp +++ b/dnn/src/arm_common/conv_bias/int8/direct.cpp @@ -10,7 +10,7 @@ */ #include "src/arm_common/conv_bias/int8/direct.h" -#include "src/arm_common/elemwise_op.h" +#include "src/arm_common/elemwise_helper/elemwise_op.h" #include "src/arm_common/simd_macro/marm_neon.h" #include "src/common/utils.h" #include "src/fallback/conv_bias/common.h" diff --git a/dnn/src/arm_common/conv_bias/int8/direct_dotprod.cpp b/dnn/src/arm_common/conv_bias/int8/direct_dotprod.cpp index d9b7421b..6d1d3ed3 100644 --- a/dnn/src/arm_common/conv_bias/int8/direct_dotprod.cpp +++ b/dnn/src/arm_common/conv_bias/int8/direct_dotprod.cpp @@ -11,7 +11,7 @@ #include "src/arm_common/conv_bias/int8/direct_dotprod.h" #if MGB_ENABLE_DOT -#include "src/arm_common/elemwise_op.h" +#include "src/arm_common/elemwise_helper/elemwise_op.h" #include "src/arm_common/simd_macro/marm_neon.h" #include "src/common/utils.h" #include "src/fallback/conv_bias/common.h" diff --git a/dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44_algo.cpp b/dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44_algo.cpp index c124ee36..558f5db4 100644 --- a/dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44_algo.cpp +++ b/dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44_algo.cpp @@ -14,7 +14,7 @@ #if MGB_ENABLE_DOT #include "src/arm_common/conv_bias/int8/algos.h" #include "src/arm_common/conv_bias/int8/direct_dotprod_nchw44.h" -#include "src/arm_common/elemwise_op.h" +#include "src/arm_common/elemwise_helper/elemwise_op.h" #include "midout.h" diff --git a/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_common.h b/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_common.h index 1050da31..f7a2a4ed 100644 --- a/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_common.h +++ b/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_common.h @@ -14,7 +14,7 @@ #include "megdnn/arch.h" #include "src/arm_common/conv_bias/intrinsic_helper.h" #if MGB_ENABLE_DOT -#include "src/arm_common/elemwise_op.h" +#include "src/arm_common/elemwise_helper/elemwise_op.h" #include "src/arm_common/intrinsic_helper.h" #include "src/arm_common/neon_struct.h" #include "src/common/unroll_macro.h" diff --git a/dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw44_s1.cpp b/dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw44_s1.cpp index 43e4839b..937c5cb9 100644 --- a/dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw44_s1.cpp +++ b/dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw44_s1.cpp @@ -13,7 +13,7 @@ #include "src/arm_common/conv_bias/int8/direct.h" #include "src/arm_common/conv_bias/int8/direct_nchw44_kern.h" #include "src/arm_common/conv_bias/intrinsic_helper.h" -#include "src/arm_common/elemwise_op.h" +#include "src/arm_common/elemwise_helper/elemwise_op.h" #include "src/arm_common/simd_macro/marm_neon.h" #include "src/common/utils.h" #include "src/fallback/conv_bias/common.h" diff --git a/dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw44_s2.cpp b/dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw44_s2.cpp index b2a3cff5..a55f2e5d 100644 --- a/dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw44_s2.cpp +++ b/dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw44_s2.cpp @@ -14,7 +14,7 @@ #include "src/arm_common/conv_bias/int8/direct.h" #include "src/arm_common/conv_bias/int8/direct_nchw44_kern.h" #include "src/arm_common/conv_bias/intrinsic_helper.h" -#include "src/arm_common/elemwise_op.h" +#include "src/arm_common/elemwise_helper/elemwise_op.h" #include "src/arm_common/simd_macro/marm_neon.h" #include "src/common/utils.h" #include "src/fallback/conv_bias/common.h" diff --git a/dnn/src/arm_common/conv_bias/int8/direct_nchw44_algo.cpp b/dnn/src/arm_common/conv_bias/int8/direct_nchw44_algo.cpp index 27d3cae4..e7dcf989 100644 --- a/dnn/src/arm_common/conv_bias/int8/direct_nchw44_algo.cpp +++ b/dnn/src/arm_common/conv_bias/int8/direct_nchw44_algo.cpp @@ -14,7 +14,7 @@ #include "src/arm_common/conv_bias/int8/algos.h" #include "src/arm_common/conv_bias/int8/direct.h" #include "src/arm_common/conv_bias/int8/direct_nchw44_kern.h" -#include "src/arm_common/elemwise_op.h" +#include "src/arm_common/elemwise_helper/elemwise_op.h" #include "src/common/opr_delegate.h" #include "midout.h" diff --git a/dnn/src/arm_common/conv_bias/int8/direct_nchw44_kern.h b/dnn/src/arm_common/conv_bias/int8/direct_nchw44_kern.h index 3e5cea4d..176e5e61 100644 --- a/dnn/src/arm_common/conv_bias/int8/direct_nchw44_kern.h +++ b/dnn/src/arm_common/conv_bias/int8/direct_nchw44_kern.h @@ -12,7 +12,7 @@ #pragma once #include "src/arm_common/conv_bias/int8/direct.h" #include "src/arm_common/conv_bias/intrinsic_helper.h" -#include "src/arm_common/elemwise_op.h" +#include "src/arm_common/elemwise_helper/elemwise_op.h" #include "src/arm_common/simd_macro/marm_neon.h" #include "src/common/utils.h" #include "src/fallback/conv_bias/common.h" diff --git a/dnn/src/arm_common/conv_bias/int8/direct_nchw_nchw44_algo.cpp b/dnn/src/arm_common/conv_bias/int8/direct_nchw_nchw44_algo.cpp index fb82f210..82c79497 100644 --- a/dnn/src/arm_common/conv_bias/int8/direct_nchw_nchw44_algo.cpp +++ b/dnn/src/arm_common/conv_bias/int8/direct_nchw_nchw44_algo.cpp @@ -14,7 +14,7 @@ #include "src/arm_common/conv_bias/int8/algos.h" #include "src/arm_common/conv_bias/int8/direct_nchw_nchw44_kern.h" #include "src/arm_common/conv_bias/int8/strategy.h" -#include "src/arm_common/elemwise_op.h" +#include "src/arm_common/elemwise_helper/elemwise_op.h" #include "src/common/nchw_nchwxx_valid.h" #include "src/common/opr_delegate.h" diff --git a/dnn/src/arm_common/conv_bias/int8/direct_nchw_nchw44_kern.h b/dnn/src/arm_common/conv_bias/int8/direct_nchw_nchw44_kern.h index 403968b0..fe3104b1 100644 --- a/dnn/src/arm_common/conv_bias/int8/direct_nchw_nchw44_kern.h +++ b/dnn/src/arm_common/conv_bias/int8/direct_nchw_nchw44_kern.h @@ -13,7 +13,7 @@ #include "megdnn/arch.h" #include "src/arm_common/conv_bias/intrinsic_helper.h" #include "src/arm_common/conv_bias/opr_impl.h" -#include "src/arm_common/elemwise_op.h" +#include "src/arm_common/elemwise_helper/elemwise_op.h" #include "src/arm_common/simd_macro/marm_neon.h" #include "src/common/unroll_macro.h" #include "src/common/utils.h" diff --git a/dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_algo.cpp b/dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_algo.cpp index 8224d006..5f7ff1c7 100644 --- a/dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_algo.cpp +++ b/dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_algo.cpp @@ -15,7 +15,7 @@ #include "src/arm_common/conv_bias/block_helper.h" #include "src/arm_common/conv_bias/int8/algos.h" #include "src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_kern.h" -#include "src/arm_common/elemwise_op.h" +#include "src/arm_common/elemwise_helper/elemwise_op.h" #include "src/common/nchw_nchwxx_valid.h" #include "midout.h" diff --git a/dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_kern.h b/dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_kern.h index ba9394ec..fbbee3e0 100644 --- a/dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_kern.h +++ b/dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_kern.h @@ -14,7 +14,7 @@ #include "src/arm_common/conv_bias/intrinsic_helper.h" #if MGB_ENABLE_DOT -#include "src/arm_common/elemwise_op.h" +#include "src/arm_common/elemwise_helper/elemwise_op.h" #include "src/arm_common/simd_macro/marm_neon.h" #include "src/common/unroll_macro.h" #include "src/common/utils.h" diff --git a/dnn/src/arm_common/conv_bias/int8/stride1.cpp b/dnn/src/arm_common/conv_bias/int8/stride1.cpp index ce61f4fe..b688e5e5 100644 --- a/dnn/src/arm_common/conv_bias/int8/stride1.cpp +++ b/dnn/src/arm_common/conv_bias/int8/stride1.cpp @@ -14,7 +14,7 @@ #include "megdnn/oprs.h" #include "src/arm_common/conv_bias/int8/direct.h" #include "src/arm_common/conv_bias/int8/strategy.h" -#include "src/arm_common/elemwise_op.h" +#include "src/arm_common/elemwise_helper/elemwise_op.h" #include "src/common/opr_delegate.h" using namespace megdnn; diff --git a/dnn/src/arm_common/conv_bias/int8/stride1_dotprod.cpp b/dnn/src/arm_common/conv_bias/int8/stride1_dotprod.cpp index 6916dc64..4acad913 100644 --- a/dnn/src/arm_common/conv_bias/int8/stride1_dotprod.cpp +++ b/dnn/src/arm_common/conv_bias/int8/stride1_dotprod.cpp @@ -14,7 +14,7 @@ #include "megdnn/oprs.h" #include "src/arm_common/conv_bias/int8/direct_dotprod.h" #include "src/arm_common/conv_bias/int8/strategy.h" -#include "src/arm_common/elemwise_op.h" +#include "src/arm_common/elemwise_helper/elemwise_op.h" #include "src/common/opr_delegate.h" using namespace megdnn; diff --git a/dnn/src/arm_common/conv_bias/int8/stride2.cpp b/dnn/src/arm_common/conv_bias/int8/stride2.cpp index a91e52aa..017ccf38 100644 --- a/dnn/src/arm_common/conv_bias/int8/stride2.cpp +++ b/dnn/src/arm_common/conv_bias/int8/stride2.cpp @@ -14,7 +14,7 @@ #include "megdnn/oprs.h" #include "src/arm_common/conv_bias/int8/direct.h" #include "src/arm_common/conv_bias/int8/strategy.h" -#include "src/arm_common/elemwise_op.h" +#include "src/arm_common/elemwise_helper/elemwise_op.h" #include "src/common/opr_delegate.h" using namespace megdnn; diff --git a/dnn/src/arm_common/conv_bias/int8/stride2_dotprod.cpp b/dnn/src/arm_common/conv_bias/int8/stride2_dotprod.cpp index f51d2c8f..33964cfa 100644 --- a/dnn/src/arm_common/conv_bias/int8/stride2_dotprod.cpp +++ b/dnn/src/arm_common/conv_bias/int8/stride2_dotprod.cpp @@ -14,7 +14,7 @@ #include "megdnn/oprs.h" #include "src/arm_common/conv_bias/int8/direct_dotprod.h" #include "src/arm_common/conv_bias/int8/strategy.h" -#include "src/arm_common/elemwise_op.h" +#include "src/arm_common/elemwise_helper/elemwise_op.h" #include "src/common/opr_delegate.h" using namespace megdnn; diff --git a/dnn/src/arm_common/conv_bias/int8x8x16/channel_wise_kernel_int8x8x16_nchw44.cpp b/dnn/src/arm_common/conv_bias/int8x8x16/channel_wise_kernel_int8x8x16_nchw44.cpp index 65c986e9..c51bb933 100644 --- a/dnn/src/arm_common/conv_bias/int8x8x16/channel_wise_kernel_int8x8x16_nchw44.cpp +++ b/dnn/src/arm_common/conv_bias/int8x8x16/channel_wise_kernel_int8x8x16_nchw44.cpp @@ -11,7 +11,7 @@ */ #include "src/arm_common/conv_bias/int8x8x16/channel_wise_kernel.h" -#include "src/arm_common/elemwise_op.h" +#include "src/arm_common/elemwise_helper/elemwise_op.h" #include "src/arm_common/simd_macro/marm_neon.h" #include "src/common/unroll_macro.h" #include "src/common/utils.h" diff --git a/dnn/src/arm_common/conv_bias/int8x8x16/direct_nchw_nchw44_algo.cpp b/dnn/src/arm_common/conv_bias/int8x8x16/direct_nchw_nchw44_algo.cpp index bf3279aa..7b8a3818 100644 --- a/dnn/src/arm_common/conv_bias/int8x8x16/direct_nchw_nchw44_algo.cpp +++ b/dnn/src/arm_common/conv_bias/int8x8x16/direct_nchw_nchw44_algo.cpp @@ -15,7 +15,7 @@ #include "src/arm_common/conv_bias/block_helper.h" #include "src/arm_common/conv_bias/int8x8x16/algos.h" #include "src/arm_common/conv_bias/int8x8x16/direct_nchw_nchw44_kern.h" -#include "src/arm_common/elemwise_op.h" +#include "src/arm_common/elemwise_helper/elemwise_op.h" #include "src/common/nchw_nchwxx_valid.h" #include "src/common/opr_delegate.h" diff --git a/dnn/src/arm_common/conv_bias/int8x8x16/direct_nchw_nchw44_kern.h b/dnn/src/arm_common/conv_bias/int8x8x16/direct_nchw_nchw44_kern.h index d00cba5b..283f4442 100644 --- a/dnn/src/arm_common/conv_bias/int8x8x16/direct_nchw_nchw44_kern.h +++ b/dnn/src/arm_common/conv_bias/int8x8x16/direct_nchw_nchw44_kern.h @@ -13,7 +13,7 @@ #include "megdnn/arch.h" #include "src/arm_common/conv_bias/intrinsic_helper.h" #include "src/arm_common/conv_bias/opr_impl.h" -#include "src/arm_common/elemwise_op.h" +#include "src/arm_common/elemwise_helper/elemwise_op.h" #include "src/arm_common/simd_macro/marm_neon.h" #include "src/common/unroll_macro.h" #include "src/common/utils.h" diff --git a/dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl.h b/dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl.h index 16ad7de7..1b0d1504 100644 --- a/dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl.h +++ b/dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl.h @@ -16,7 +16,7 @@ #include "src/arm_common/conv_bias/int8x8x16/direct_nchw_nchw44_kern.h" #include "src/arm_common/conv_bias/intrinsic_helper.h" #include "src/arm_common/conv_bias/opr_impl.h" -#include "src/arm_common/elemwise_op.h" +#include "src/arm_common/elemwise_helper/elemwise_op.h" #include "src/arm_common/simd_macro/marm_neon.h" #include "src/common/unroll_macro.h" #include "src/common/utils.h" diff --git a/dnn/src/arm_common/conv_bias/matmul_postprocess.h b/dnn/src/arm_common/conv_bias/matmul_postprocess.h index 00c75f8b..15cb0bc6 100644 --- a/dnn/src/arm_common/conv_bias/matmul_postprocess.h +++ b/dnn/src/arm_common/conv_bias/matmul_postprocess.h @@ -12,7 +12,7 @@ #include "megdnn/dtype.h" #include "megdnn/oprs.h" -#include "src/arm_common/elemwise_op.h" +#include "src/arm_common/elemwise_helper/elemwise_op.h" #include "src/arm_common/simd_macro/marm_neon.h" #include "src/common/utils.h" #include "src/fallback/conv_bias/common.h" diff --git a/dnn/src/arm_common/conv_bias/postprocess_helper.h b/dnn/src/arm_common/conv_bias/postprocess_helper.h index f2999c98..7e845d2e 100644 --- a/dnn/src/arm_common/conv_bias/postprocess_helper.h +++ b/dnn/src/arm_common/conv_bias/postprocess_helper.h @@ -13,8 +13,8 @@ #pragma once #include "megdnn/basic_types.h" +#include "src/arm_common/elemwise_helper/elemwise_op.h" #include "src/arm_common/elemwise_helper/kimpl/op_base.h" -#include "src/arm_common/elemwise_op.h" #include "src/fallback/conv_bias/opr_impl.h" #include "midout.h" @@ -44,29 +44,29 @@ namespace { break; #define FOR_NONLINEAR_UNARY(_op) \ - megdnn::arm_common::OpCallerUnary<_op, megdnn::VEC>::run( \ + megdnn::elemwise::OpCallerUnary<_op, megdnn::elemwise::VEC>::run( \ static_cast(conv_dst_ptr), reinterpret_cast(dst_ptr), \ bias_type, dst_type, N* OC* OH* OW* pack_oc_size); -#define FOR_NONLINEAR_BINARY_BROADCAST(_op) \ - megdnn::arm_common::OpCallerBinary<_op, megdnn::VEC_BCAST101>::run( \ - static_cast(conv_dst_ptr), \ - reinterpret_cast(bias_ptr), \ - reinterpret_cast(dst_ptr), bias_type, bias_type, dst_type, N, OC, \ +#define FOR_NONLINEAR_BINARY_BROADCAST(_op) \ + megdnn::elemwise::OpCallerBinary<_op, megdnn::elemwise::VEC_BCAST101>::run( \ + static_cast(conv_dst_ptr), \ + reinterpret_cast(bias_ptr), \ + reinterpret_cast(dst_ptr), bias_type, bias_type, dst_type, N, OC, \ OH* OW); #define FOR_NONLINEAR_BINARY_BROADCAST_NCHWXX(_op) \ - megdnn::arm_common::OpCallerBinary<_op, megdnn::VEC_BCAST101xX>::run( \ - static_cast(conv_dst_ptr), \ - reinterpret_cast(bias_ptr), \ - reinterpret_cast(dst_ptr), bias_type, bias_type, dst_type, N, OC, \ - OH* OW, pack_oc_size); - -#define FOR_NONLINEAR_BINARY(_op) \ - megdnn::arm_common::OpCallerBinary<_op, megdnn::VEC_VEC>::run( \ - static_cast(conv_dst_ptr), \ - reinterpret_cast(bias_ptr), \ - reinterpret_cast(dst_ptr), bias_type, bias_type, dst_type, \ + megdnn::elemwise::OpCallerBinary<_op, megdnn::elemwise::VEC_BCAST101xX>:: \ + run(static_cast(conv_dst_ptr), \ + reinterpret_cast(bias_ptr), \ + reinterpret_cast(dst_ptr), bias_type, bias_type, dst_type, N, \ + OC, OH* OW, pack_oc_size); + +#define FOR_NONLINEAR_BINARY(_op) \ + megdnn::elemwise::OpCallerBinary<_op, megdnn::elemwise::VEC_VEC>::run( \ + static_cast(conv_dst_ptr), \ + reinterpret_cast(bias_ptr), \ + reinterpret_cast(dst_ptr), bias_type, bias_type, dst_type, \ N* OC* OH* OW* pack_oc_size); #define FOR_BIAS(_mode) \ @@ -167,33 +167,35 @@ struct PostProcess { #undef FOR_BIAS #undef HANDLE_IDENTITY -#define FOR_NONLINEAR_UNARY(_op) \ - megdnn::arm_common::OpCallerUnary<_op, megdnn::VEC>::run( \ - static_cast(conv_dst_ptr), reinterpret_cast(dst_ptr), \ - bias_type, dst_type, N* OC* OH* OW* pack_oc_size); - -#define FOR_NONLINEAR_BINARY_BROADCAST(_op) \ - megdnn::arm_common::OpCallerBinary<_op, megdnn::VEC_BCAST101>:: \ - run(static_cast(conv_dst_ptr), \ - reinterpret_cast(bias_ptr), \ - reinterpret_cast(dst_ptr), bias_type, bias_type, dst_type, \ +#define FOR_NONLINEAR_UNARY(_op) \ + megdnn::elemwise::OpCallerUnary<_op, megdnn::elemwise::VEC>:: \ + run(static_cast(conv_dst_ptr), \ + reinterpret_cast(dst_ptr), bias_type, dst_type, \ + N* OC* OH* OW* pack_oc_size); + +#define FOR_NONLINEAR_BINARY_BROADCAST(_op) \ + megdnn::elemwise::OpCallerBinary< \ + _op, megdnn::elemwise::VEC_BCAST101>:: \ + run(static_cast(conv_dst_ptr), \ + reinterpret_cast(bias_ptr), \ + reinterpret_cast(dst_ptr), bias_type, bias_type, dst_type, \ N, OC, OH* OW); -#define FOR_NONLINEAR_BINARY_BROADCAST_NCHWXX(_op) \ - megdnn::arm_common:: \ - OpCallerBinary<_op, megdnn::VEC_BCAST101xX>::run( \ - static_cast(conv_dst_ptr), \ - reinterpret_cast(bias_ptr), \ - reinterpret_cast(dst_ptr), bias_type, bias_type, \ - dst_type, N, OC, OH* OW, pack_oc_size); - -#define FOR_NONLINEAR_BINARY_BROADCAST_NCHW88(_op) \ - megdnn::arm_common:: \ - OpCallerBinary<_op, megdnn::VEC_BCAST101xX>::run( \ - static_cast(conv_dst_ptr), \ - reinterpret_cast(bias_ptr), \ - reinterpret_cast(dst_ptr), bias_type, bias_type, \ - dst_type, N, OC, OH* OW, pack_oc_size); +#define FOR_NONLINEAR_BINARY_BROADCAST_NCHWXX(_op) \ + megdnn::elemwise::OpCallerBinary< \ + _op, megdnn::elemwise::VEC_BCAST101xX>:: \ + run(static_cast(conv_dst_ptr), \ + reinterpret_cast(bias_ptr), \ + reinterpret_cast(dst_ptr), bias_type, bias_type, dst_type, \ + N, OC, OH* OW, pack_oc_size); + +#define FOR_NONLINEAR_BINARY_BROADCAST_NCHW88(_op) \ + megdnn::elemwise::OpCallerBinary< \ + _op, megdnn::elemwise::VEC_BCAST101xX>:: \ + run(static_cast(conv_dst_ptr), \ + reinterpret_cast(bias_ptr), \ + reinterpret_cast(dst_ptr), bias_type, bias_type, dst_type, \ + N, OC, OH* OW, pack_oc_size); #define HANDLE_IDENTITY(_caller, _op) \ case megdnn::NonlineMode::IDENTITY: \ @@ -267,25 +269,25 @@ struct PostProcess { #undef FOR_NONLINEAR #undef FOR_BIAS -#define FOR_BINARY_BROADCAST(_op) \ - megdnn::arm_common::OpCallerBinary<_op, megdnn::VEC_BCAST101>::run( \ - static_cast(conv_dst_ptr), \ - reinterpret_cast(bias_ptr), \ - reinterpret_cast(dst_ptr), bias_type, bias_type, dst_type, N, OC, \ +#define FOR_BINARY_BROADCAST(_op) \ + megdnn::elemwise::OpCallerBinary<_op, megdnn::elemwise::VEC_BCAST101>::run( \ + static_cast(conv_dst_ptr), \ + reinterpret_cast(bias_ptr), \ + reinterpret_cast(dst_ptr), bias_type, bias_type, dst_type, N, OC, \ OH* OW); #define FOR_BINARY_BROADCAST_NCHWXX(_op) \ - megdnn::arm_common::OpCallerBinary<_op, megdnn::VEC_BCAST101xX>::run( \ - static_cast(conv_dst_ptr), \ - reinterpret_cast(bias_ptr), \ - reinterpret_cast(dst_ptr), bias_type, bias_type, dst_type, N, OC, \ - OH* OW, pack_oc_size); - -#define FOR_BINARY(_op) \ - megdnn::arm_common::OpCallerBinary<_op, megdnn::VEC_VEC>::run( \ - static_cast(conv_dst_ptr), \ - reinterpret_cast(bias_ptr), \ - reinterpret_cast(dst_ptr), bias_type, bias_type, dst_type, \ + megdnn::elemwise::OpCallerBinary<_op, megdnn::elemwise::VEC_BCAST101xX>:: \ + run(static_cast(conv_dst_ptr), \ + reinterpret_cast(bias_ptr), \ + reinterpret_cast(dst_ptr), bias_type, bias_type, dst_type, N, \ + OC, OH* OW, pack_oc_size); + +#define FOR_BINARY(_op) \ + megdnn::elemwise::OpCallerBinary<_op, megdnn::elemwise::VEC_VEC>::run( \ + static_cast(conv_dst_ptr), \ + reinterpret_cast(bias_ptr), \ + reinterpret_cast(dst_ptr), bias_type, bias_type, dst_type, \ N* OC* OH* OW* pack_oc_size); #define FOR_BIAS(_bias_mode, OH, OW) \ diff --git a/dnn/src/arm_common/conv_bias/quint8/algos.cpp b/dnn/src/arm_common/conv_bias/quint8/algos.cpp index 9c7992e0..49462c9d 100644 --- a/dnn/src/arm_common/conv_bias/quint8/algos.cpp +++ b/dnn/src/arm_common/conv_bias/quint8/algos.cpp @@ -15,7 +15,7 @@ #include "src/arm_common/conv_bias/quint8/stride1_dotprod.h" #include "src/arm_common/conv_bias/quint8/stride2.h" #include "src/arm_common/conv_bias/quint8/stride2_dotprod.h" -#include "src/arm_common/elemwise_op.h" +#include "src/arm_common/elemwise_helper/elemwise_op.h" #include "src/fallback/conv_bias/common.h" #include "midout.h" diff --git a/dnn/src/arm_common/conv_bias/quint8/direct.cpp b/dnn/src/arm_common/conv_bias/quint8/direct.cpp index bd337e83..7a07e946 100644 --- a/dnn/src/arm_common/conv_bias/quint8/direct.cpp +++ b/dnn/src/arm_common/conv_bias/quint8/direct.cpp @@ -10,7 +10,7 @@ */ #include "src/arm_common/conv_bias/quint8/direct.h" -#include "src/arm_common/elemwise_op.h" +#include "src/arm_common/elemwise_helper/elemwise_op.h" #include "src/arm_common/simd_macro/marm_neon.h" #include "src/common/utils.h" #include "src/fallback/conv_bias/common.h" diff --git a/dnn/src/arm_common/conv_bias/quint8/direct_dotprod.cpp b/dnn/src/arm_common/conv_bias/quint8/direct_dotprod.cpp index 925b1eea..085c81d3 100644 --- a/dnn/src/arm_common/conv_bias/quint8/direct_dotprod.cpp +++ b/dnn/src/arm_common/conv_bias/quint8/direct_dotprod.cpp @@ -11,7 +11,7 @@ #include "src/arm_common/conv_bias/quint8/direct_dotprod.h" #if MGB_ENABLE_DOT -#include "src/arm_common/elemwise_op.h" +#include "src/arm_common/elemwise_helper/elemwise_op.h" #include "src/arm_common/simd_macro/marm_neon.h" #include "src/common/utils.h" #include "src/fallback/conv_bias/common.h" diff --git a/dnn/src/arm_common/conv_bias/quint8/stride1.cpp b/dnn/src/arm_common/conv_bias/quint8/stride1.cpp index 37099e79..f08f2fbf 100644 --- a/dnn/src/arm_common/conv_bias/quint8/stride1.cpp +++ b/dnn/src/arm_common/conv_bias/quint8/stride1.cpp @@ -12,7 +12,7 @@ #include "src/arm_common/conv_bias/quint8/stride1.h" #include "megdnn/oprs.h" #include "src/arm_common/conv_bias/quint8/direct.h" -#include "src/arm_common/elemwise_op.h" +#include "src/arm_common/elemwise_helper/elemwise_op.h" #include "src/common/opr_delegate.h" using namespace megdnn; diff --git a/dnn/src/arm_common/conv_bias/quint8/stride1_dotprod.cpp b/dnn/src/arm_common/conv_bias/quint8/stride1_dotprod.cpp index 30dbf6a0..a54a1b0f 100644 --- a/dnn/src/arm_common/conv_bias/quint8/stride1_dotprod.cpp +++ b/dnn/src/arm_common/conv_bias/quint8/stride1_dotprod.cpp @@ -12,7 +12,7 @@ #if MGB_ENABLE_DOT #include "megdnn/oprs.h" #include "src/arm_common/conv_bias/quint8/direct_dotprod.h" -#include "src/arm_common/elemwise_op.h" +#include "src/arm_common/elemwise_helper/elemwise_op.h" #include "src/common/opr_delegate.h" using namespace megdnn; diff --git a/dnn/src/arm_common/conv_bias/quint8/stride2.cpp b/dnn/src/arm_common/conv_bias/quint8/stride2.cpp index f3ec300f..8c3fc55e 100644 --- a/dnn/src/arm_common/conv_bias/quint8/stride2.cpp +++ b/dnn/src/arm_common/conv_bias/quint8/stride2.cpp @@ -12,7 +12,7 @@ #include "src/arm_common/conv_bias/quint8/stride2.h" #include "megdnn/oprs.h" #include "src/arm_common/conv_bias/quint8/direct.h" -#include "src/arm_common/elemwise_op.h" +#include "src/arm_common/elemwise_helper/elemwise_op.h" #include "src/common/opr_delegate.h" using namespace megdnn; diff --git a/dnn/src/arm_common/conv_bias/quint8/stride2_dotprod.cpp b/dnn/src/arm_common/conv_bias/quint8/stride2_dotprod.cpp index 839b0b30..3f60422b 100644 --- a/dnn/src/arm_common/conv_bias/quint8/stride2_dotprod.cpp +++ b/dnn/src/arm_common/conv_bias/quint8/stride2_dotprod.cpp @@ -12,7 +12,7 @@ #if MGB_ENABLE_DOT #include "megdnn/oprs.h" #include "src/arm_common/conv_bias/quint8/direct_dotprod.h" -#include "src/arm_common/elemwise_op.h" +#include "src/arm_common/elemwise_helper/elemwise_op.h" #include "src/common/opr_delegate.h" using namespace megdnn; diff --git a/dnn/src/arm_common/elemwise/binary/algo.cpp b/dnn/src/arm_common/elemwise/binary/algo.cpp index 515d1ae0..e74d4e0d 100644 --- a/dnn/src/arm_common/elemwise/binary/algo.cpp +++ b/dnn/src/arm_common/elemwise/binary/algo.cpp @@ -10,7 +10,7 @@ * implied. */ #include "src/arm_common/elemwise/binary/algo.h" -#include "src/arm_common/elemwise_op.h" +#include "src/arm_common/elemwise_helper/elemwise_op.h" #include "src/common/utils.h" #include "src/naive/handle.h" @@ -20,6 +20,7 @@ MIDOUT_DECL(megdnn_arm_common_elemwise_binary) using namespace megdnn; +using namespace elemwise; using namespace arm_common; namespace { @@ -160,7 +161,7 @@ bool ElemwiseImpl::AlgoBinaryVecBcast101xX::is_available( DISPATCH_BINARY(ADD, _case, _type, _type_midout_id, AddOp); \ DISPATCH_BINARY(SUB, _case, _type, _type_midout_id, SubOp); \ DISPATCH_BINARY(MUL, _case, _type, _type_midout_id, MulOp); \ - DISPATCH_BINARY(POW, _case, _type, _type_midout_id, PowOp); \ + DISPATCH_BINARY(POW, _case, _type, _type_midout_id, fallback::PowOp); \ DISPATCH_BINARY(TRUE_DIV, _case, _type, _type_midout_id, TrueDivOp); \ DISPATCH_BINARY(FUSE_ADD_RELU, _case, _type, _type_midout_id, FuseAddReluOp); \ DISPATCH_BINARY( \ @@ -178,7 +179,7 @@ bool ElemwiseImpl::AlgoBinaryVecBcast101xX::is_available( DISPATCH_BINARY(ADD, _case, _type, _type_midout_id, AddOp); \ DISPATCH_BINARY(SUB, _case, _type, _type_midout_id, SubOp); \ DISPATCH_BINARY(MUL, _case, _type, _type_midout_id, MulOp); \ - DISPATCH_BINARY(POW, _case, _type, _type_midout_id, PowOp); \ + DISPATCH_BINARY(POW, _case, _type, _type_midout_id, fallback::PowOp); \ DISPATCH_BINARY(FUSE_ADD_RELU, _case, _type, _type_midout_id, FuseAddReluOp); \ DISPATCH_BINARY( \ FUSE_ADD_H_SWISH, _case, _type, _type_midout_id, FuseAddHSwishOp); \ diff --git a/dnn/src/arm_common/elemwise/opr_impl.cpp b/dnn/src/arm_common/elemwise/opr_impl.cpp index c4800d49..4a8ca5ae 100644 --- a/dnn/src/arm_common/elemwise/opr_impl.cpp +++ b/dnn/src/arm_common/elemwise/opr_impl.cpp @@ -13,7 +13,7 @@ #include "src/arm_common/elemwise/binary/algo.h" #include "src/arm_common/elemwise/ternary/algo.h" #include "src/arm_common/elemwise/unary/algo.h" -#include "src/arm_common/elemwise_op.h" +#include "src/arm_common/elemwise_helper/elemwise_op.h" #include "src/common/metahelper.h" #include "src/common/utils.h" diff --git a/dnn/src/arm_common/elemwise/opr_impl.h b/dnn/src/arm_common/elemwise/opr_impl.h index 7d0cc9cf..5cfb8363 100644 --- a/dnn/src/arm_common/elemwise/opr_impl.h +++ b/dnn/src/arm_common/elemwise/opr_impl.h @@ -12,7 +12,7 @@ #pragma once #include "src/fallback/elemwise/opr_impl.h" -#include "src/arm_common/elemwise_op.h" +#include "src/arm_common/elemwise_helper/elemwise_op.h" namespace megdnn { namespace arm_common { diff --git a/dnn/src/arm_common/elemwise/ternary/algo.cpp b/dnn/src/arm_common/elemwise/ternary/algo.cpp index db658be1..5c8e675a 100644 --- a/dnn/src/arm_common/elemwise/ternary/algo.cpp +++ b/dnn/src/arm_common/elemwise/ternary/algo.cpp @@ -10,7 +10,7 @@ * implied. */ #include "src/arm_common/elemwise/ternary/algo.h" -#include "src/arm_common/elemwise_op.h" +#include "src/arm_common/elemwise_helper/elemwise_op.h" #include "src/common/utils.h" #include "src/naive/handle.h" @@ -20,6 +20,7 @@ MIDOUT_DECL(megdnn_arm_common_elemwise_ternary) using namespace megdnn; +using namespace elemwise; using namespace arm_common; #define DISPATCH_MODE_FLOAT(_case, _type, _type_midout_id) \ diff --git a/dnn/src/arm_common/elemwise/unary/algo.cpp b/dnn/src/arm_common/elemwise/unary/algo.cpp index f4d48f93..16bbac60 100644 --- a/dnn/src/arm_common/elemwise/unary/algo.cpp +++ b/dnn/src/arm_common/elemwise/unary/algo.cpp @@ -10,7 +10,7 @@ * implied. */ #include "src/arm_common/elemwise/unary/algo.h" -#include "src/arm_common/elemwise_op.h" +#include "src/arm_common/elemwise_helper/elemwise_op.h" #include "src/common/utils.h" #include "src/naive/handle.h" @@ -20,6 +20,7 @@ MIDOUT_DECL(megdnn_arm_common_elemwise_unary) using namespace megdnn; +using namespace elemwise; using namespace arm_common; bool ElemwiseImpl::AlgoUnary::is_available(const KernParam& kern_param) const { diff --git a/dnn/src/arm_common/elemwise_helper/elemwise_op.h b/dnn/src/arm_common/elemwise_helper/elemwise_op.h new file mode 100644 index 00000000..b62a2987 --- /dev/null +++ b/dnn/src/arm_common/elemwise_helper/elemwise_op.h @@ -0,0 +1,151 @@ +/** + * \file dnn/src/arm_common/elemwise_helper/elemwise_op.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#pragma once + +#include "src/arm_common/elemwise_helper/op_binary.h" +#include "src/arm_common/elemwise_helper/op_ternary.h" +#include "src/arm_common/elemwise_helper/op_unary.h" +#include "src/fallback/elemwise_helper/elemwise_op.h" + +namespace megdnn { +namespace elemwise { + +using BcastType = megdnn::elemwise::BcastType; + +///////////////////////////////// ParamElemVistor /////////////////////////// + +#define cb(_ctype, _inner_ctype, _neon_type, _fun_suffix) \ + template <> \ + struct ParamElemVisitor<_ctype> { \ + _neon_type operator()(const _ctype* src) const { \ + return vld1q_##_fun_suffix(reinterpret_cast(src)); \ + } \ + }; \ + template <> \ + struct ParamElemVisitorDup<_ctype> { \ + _neon_type operator()(const _ctype* src) const { \ + return vdupq_n_##_fun_suffix(*reinterpret_cast(src)); \ + } \ + } +cb(dt_quint8, uint8_t, uint8x16_t, u8); +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +cb(__fp16, __fp16, float16x8_t, f16); +#endif +cb(dt_int16, int16_t, int16x8_t, s16); +#undef cb + +template +struct ParamElemVisitorBcast101x4; +#define cb(_ctype, _inner_ctype, _neon_type, _fun_suffix, rel_suffix) \ + template <> \ + struct ParamElemVisitorBcast101x4<_ctype> { \ + _neon_type operator()(const _ctype* src) const { \ + return vreinterpretq_##_fun_suffix##_##rel_suffix(vld1q_dup_##rel_suffix( \ + reinterpret_cast(src))); \ + } \ + } + +cb(dt_quint8, uint32_t, uint8x16_t, u8, u32); +cb(dt_int16, int64_t, int16x8_t, s16, s64); +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +cb(__fp16, uint64_t, float16x8_t, f16, u64); +#endif +#undef cb + +template +struct ParamElemVisitorBcast101x8; +#define cb(_ctype, _inner_ctype, _neon_type, _fun_suffix) \ + template <> \ + struct ParamElemVisitorBcast101x8<_ctype> { \ + _neon_type operator()(const _ctype* src) const { \ + return vld1q_##_fun_suffix(reinterpret_cast(src)); \ + } \ + } +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +cb(__fp16, __fp16, float16x8_t, f16); +#endif +#undef cb + +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +template <> +struct OpCallerBinaryBcast101xXVec<__fp16, 8> { + using src_ctype = __fp16; + + template + static void run( + const src_ctype* src0, const src_ctype* src1, typename Op::dst_ctype* dst, + const Op& op, size_t batch, size_t nr_channel_blocks, + size_t channel_stride) { + ParamElemVisitorBcast101x8 vis0; + ParamElemVisitor vis1; + OpCallerBinaryBcast101xDVec::run( + src0, src1, dst, op, vis0, vis1, batch, nr_channel_blocks, + channel_stride); + } +}; + +template <> +struct OpCallerBinaryVecBcast101xX<__fp16, 8> { + using src_ctype = __fp16; + template + static void run( + const src_ctype* src0, const src_ctype* src1, typename Op::dst_ctype* dst, + const Op& op, size_t batch, size_t nr_channel_blocks, + size_t channel_stride) { + ParamElemVisitor vis0; + ParamElemVisitorBcast101x8 vis1; + OpCallerBinaryVecBcast101xD::run( + src0, src1, dst, op, vis0, vis1, batch, nr_channel_blocks, + channel_stride); + } +}; + +template <> +struct OpCallerTernaryBcast101xXVecBcast101xX<__fp16, 8> { + using src_ctype = __fp16; + template + static void run( + const src_ctype* src0, const src_ctype* src1, const src_ctype* src2, + typename Op::dst_ctype* dst, const Op& op, size_t batch, + size_t nr_channel_blocks, size_t channel_stride) { + ParamElemVisitorBcast101x8 vis0; + ParamElemVisitor vis1; + ParamElemVisitorBcast101x8 vis2; + OpCallerTernaryBcast101xDVecBcast101xD::run( + src0, src1, src2, dst, op, vis0, vis1, vis2, batch, nr_channel_blocks, + channel_stride); + } +}; + +template <> +struct OpCallerTernaryVecBcast101xXVec<__fp16, 8> { + using src_ctype = __fp16; + template + static void run( + const src_ctype* src0, const src_ctype* src1, const src_ctype* src2, + typename Op::dst_ctype* dst, const Op& op, size_t batch, + size_t nr_channel_blocks, size_t channel_stride) { + ParamElemVisitor vis0; + ParamElemVisitorBcast101x8 vis1; + ParamElemVisitor vis2; + OpCallerTernaryVecBcast101xDVec::run( + src0, src1, src2, dst, op, vis0, vis1, vis2, batch, nr_channel_blocks, + channel_stride); + } +}; +#endif + +} // namespace elemwise +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/elemwise_helper/kimpl/pow.h b/dnn/src/arm_common/elemwise_helper/kimpl/pow.h deleted file mode 100644 index a8154598..00000000 --- a/dnn/src/arm_common/elemwise_helper/kimpl/pow.h +++ /dev/null @@ -1,36 +0,0 @@ -/** - * \file dnn/src/arm_common/elemwise_helper/kimpl/pow.h - * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") - * - * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - */ -#pragma once - -#include "src/arm_common/elemwise_helper/kimpl/op_base.h" - -namespace megdnn { -namespace arm_common { - -// when __fp16 is avaliable POW is very slow, so add there -/////////////////////// POW float only //////////////////////////// -template -struct PowOp : BinaryOpBase { - using BinaryOpBase::BinaryOpBase; - constexpr static size_t SIMD_WIDTH = 1; - void operator()( - const src_ctype& src0, const src_ctype& src1, dst_ctype* dst) const { - *dst = operator()(src0, src1); - } - dst_ctype operator()(const src_ctype& src0, const src_ctype& src1) const { - return powf(src0, src1); - } -}; - -} // namespace arm_common -} // namespace megdnn - -// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/elemwise_helper/op_binary.h b/dnn/src/arm_common/elemwise_helper/op_binary.h index ee8aeff3..3a281503 100644 --- a/dnn/src/arm_common/elemwise_helper/op_binary.h +++ b/dnn/src/arm_common/elemwise_helper/op_binary.h @@ -18,7 +18,6 @@ #include "src/arm_common/elemwise_helper/kimpl/max.h" #include "src/arm_common/elemwise_helper/kimpl/min.h" #include "src/arm_common/elemwise_helper/kimpl/mul.h" -#include "src/arm_common/elemwise_helper/kimpl/pow.h" #include "src/arm_common/elemwise_helper/kimpl/rmulh.h" #include "src/arm_common/elemwise_helper/kimpl/sub.h" #include "src/arm_common/elemwise_helper/kimpl/true_div.h" diff --git a/dnn/src/arm_common/elemwise_multi_type/opr_impl.cpp b/dnn/src/arm_common/elemwise_multi_type/opr_impl.cpp index df6cbd76..d56785dc 100644 --- a/dnn/src/arm_common/elemwise_multi_type/opr_impl.cpp +++ b/dnn/src/arm_common/elemwise_multi_type/opr_impl.cpp @@ -15,7 +15,7 @@ #include "src/common/elemwise_multi_type/kern_defs.cuh" #include "src/naive/handle.h" -#include "src/arm_common/elemwise_op.h" +#include "src/arm_common/elemwise_helper/elemwise_op.h" #include "src/arm_common/simd_macro/marm_neon.h" namespace { @@ -46,6 +46,8 @@ void neon_round_shr_saturate_int16_static_k( } // namespace +using namespace elemwise; + namespace megdnn { namespace arm_common { diff --git a/dnn/src/arm_common/elemwise_op.h b/dnn/src/arm_common/elemwise_op.h deleted file mode 100644 index 92ebcc29..00000000 --- a/dnn/src/arm_common/elemwise_op.h +++ /dev/null @@ -1,1537 +0,0 @@ -/** - * \file dnn/src/arm_common/elemwise_op.h - * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") - * - * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or - * implied. - */ - -#pragma once - -#include "src/arm_common/elemwise_helper/op_binary.h" -#include "src/arm_common/elemwise_helper/op_ternary.h" -#include "src/arm_common/elemwise_helper/op_unary.h" -#include "src/fallback/elemwise_helper/op_common.h" - -namespace megdnn { -namespace arm_common { - -using BcastType = megdnn::BcastType; - -///////////////////////////////// ParamElemVistor /////////////////////////// -template -struct ParamElemVisitor; - -//! visitor single elemwise, and dup to vector -template -struct ParamElemVisitorDup; - -#define cb(_ctype, _inner_ctype, _neon_type, _fun_suffix) \ - template <> \ - struct ParamElemVisitor<_ctype> { \ - _neon_type operator()(const _ctype* src) const { \ - return vld1q_##_fun_suffix(reinterpret_cast(src)); \ - } \ - }; \ - template <> \ - struct ParamElemVisitorDup<_ctype> { \ - _neon_type operator()(const _ctype* src) const { \ - return vdupq_n_##_fun_suffix(*reinterpret_cast(src)); \ - } \ - } -cb(dt_qint32, int32_t, int32x4_t, s32); -cb(dt_qint8, int8_t, int8x16_t, s8); -cb(dt_quint8, uint8_t, uint8x16_t, u8); - -cb(dt_float32, float32_t, float32x4_t, f32); -#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC -cb(__fp16, __fp16, float16x8_t, f16); -#endif -cb(dt_int32, int32_t, int32x4_t, s32); -cb(dt_int16, int16_t, int16x8_t, s16); -cb(dt_int8, int8_t, int8x16_t, s8); -#undef cb - -template -struct ParamElemVisitorBcast101x4; -#define cb(_ctype, _inner_ctype, _neon_type, _fun_suffix, rel_suffix) \ - template <> \ - struct ParamElemVisitorBcast101x4<_ctype> { \ - _neon_type operator()(const _ctype* src) const { \ - return vreinterpretq_##_fun_suffix##_##rel_suffix(vld1q_dup_##rel_suffix( \ - reinterpret_cast(src))); \ - } \ - } - -cb(dt_qint8, int32_t, int8x16_t, s8, s32); -cb(dt_quint8, uint32_t, uint8x16_t, u8, u32); -cb(dt_int8, int32_t, int8x16_t, s8, s32); -cb(dt_int16, int64_t, int16x8_t, s16, s64); -#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC -cb(__fp16, uint64_t, float16x8_t, f16, u64); -#endif -#undef cb -#define cb(_ctype, _inner_ctype, _neon_type, _fun_suffix) \ - template <> \ - struct ParamElemVisitorBcast101x4<_ctype> { \ - _neon_type operator()(const _ctype* src) const { \ - return vld1q_##_fun_suffix(reinterpret_cast(src)); \ - } \ - } - -cb(dt_qint32, int32_t, int32x4_t, s32); -cb(dt_float32, float32_t, float32x4_t, f32); -cb(dt_int32, int32_t, int32x4_t, s32); -#undef cb - -template -struct ParamElemVisitorBcast101x8; -#define cb(_ctype, _inner_ctype, _neon_type, _fun_suffix) \ - template <> \ - struct ParamElemVisitorBcast101x8<_ctype> { \ - _neon_type operator()(const _ctype* src) const { \ - return vld1q_##_fun_suffix(reinterpret_cast(src)); \ - } \ - } -#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC -cb(__fp16, __fp16, float16x8_t, f16); -#endif -#undef cb - -///////////////////////////////// OpCaller ///////////////////////////// -template -struct OpCallerUnary; - -template -struct OpCallerUnary { - static void run( - const typename Op::src_ctype* src, typename Op::dst_ctype* dst, - DType src_dtype, DType dst_dtype, size_t nr_elems) { - Op op(src_dtype, dst_dtype); - ParamElemVisitor vis; - size_t i = 0; - for (; i + Op::SIMD_WIDTH * 2 <= nr_elems; i += Op::SIMD_WIDTH * 2) { - op({{vis(src), vis(src + Op::SIMD_WIDTH)}}, dst); - src += Op::SIMD_WIDTH * 2; - dst += Op::SIMD_WIDTH * 2; - } -#if MEGDNN_FIX_AARCH32_BUG -// FIXME: as llvm may cause cannot select error if enable vectorize -#pragma clang loop vectorize(disable) -#endif - for (; i < nr_elems; i++) { - op(*src, dst); - src++; - dst++; - } - } -}; - -template -struct OpCallerBinary; - -///////////////////////// Pow //////////////////////////////// -template -struct OpCallerBinary, VEC_VEC> { - using Op = PowOp; - static void run( - const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, - typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, - DType dst_dtype, size_t nr_elems) { - Op op(src0_dtype, src1_dtype, dst_dtype); - size_t i = 0; -#if MEGDNN_FIX_AARCH32_BUG -// FIXME: as llvm may cause cannot select error if enable vectorize -#pragma clang loop vectorize(disable) -#endif - for (; i < nr_elems; i++) { - op(*src0, *src1, dst); - src0++; - src1++; - dst++; - } - } -}; - -template -struct OpCallerBinary, VEC_SCALAR> { - using Op = PowOp; - static void run( - const typename Op::src_ctype* src0, const typename Op::src_ctype src1, - typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, - DType dst_dtype, size_t nr_elems) { - Op op(src0_dtype, src1_dtype, dst_dtype); - size_t i = 0; -#if MEGDNN_FIX_AARCH32_BUG -// FIXME: as llvm may cause cannot select error if enable vectorize -#pragma clang loop vectorize(disable) -#endif - for (; i < nr_elems; i++) { - op(*src0, src1, dst); - src0++; - dst++; - } - } -}; - -template -struct OpCallerBinary, VEC_BCAST101> { - using Op = PowOp; - static void run( - const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, - typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, - DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) { - Op op(src0_dtype, src1_dtype, dst_dtype); - for (size_t b = 0; b < batch; b++) { - const typename Op::src_ctype* src1_ptr = src1; - for (size_t c = 0; c < channel; c++) { - size_t i = 0; -#if MEGDNN_FIX_AARCH32_BUG -// FIXME: as llvm may cause cannot select error if enable vectorize -#pragma clang loop vectorize(disable) -#endif - for (; i < channel_stride; i++) { - op(*src0, *src1_ptr, dst); - src0++; - dst++; - } - src1_ptr++; - } - } - } -}; - -template -struct OpCallerBinary, VEC_BCASTX0X> { - using Op = PowOp; - static void run( - const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, - typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, - DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) { - Op op(src0_dtype, src1_dtype, dst_dtype); - for (size_t b = 0; b < batch; b++) { - const typename Op::src_ctype* src1_ptr_base = src1 + b * channel_stride; - for (size_t c = 0; c < channel; c++) { - size_t i = 0; - auto src1_ptr = src1_ptr_base; -#if MEGDNN_FIX_AARCH32_BUG -// FIXME: as llvm may cause cannot select error if enable vectorize -#pragma clang loop vectorize(disable) -#endif - for (; i < channel_stride; i++) { - op(*src0, *src1_ptr, dst); - src0++; - src1_ptr++; - dst++; - } - } - } - } -}; - -template -struct OpCallerBinary, VEC_BCAST111C> { - using Op = PowOp; - static void run( - const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, - typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, - DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) { - Op op(src0_dtype, src1_dtype, dst_dtype); - for (size_t b = 0; b < batch; b++) { - for (size_t c = 0; c < channel; c++) { - size_t i = 0; - const typename Op::src_ctype* src1_ptr = src1; -#if MEGDNN_FIX_AARCH32_BUG -// FIXME: as llvm may cause cannot select error if enable vectorize -#pragma clang loop vectorize(disable) -#endif - for (; i < channel_stride; i++) { - op(*src0, *src1_ptr, dst); - src0++; - src1_ptr++; - dst++; - } - } - } - } -}; - -template -struct OpCallerBinary, BCAST111C_VEC> { - using Op = PowOp; - static void run( - const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, - typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, - DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) { - Op op(src0_dtype, src1_dtype, dst_dtype); - for (size_t b = 0; b < batch; b++) { - for (size_t c = 0; c < channel; c++) { - size_t i = 0; - const typename Op::src_ctype* src0_ptr = src0; -#if MEGDNN_FIX_AARCH32_BUG -// FIXME: as llvm may cause cannot select error if enable vectorize -#pragma clang loop vectorize(disable) -#endif - for (; i < channel_stride; i++) { - op(*src0_ptr, *src1, dst); - src0_ptr++; - src1++; - dst++; - } - } - } - } -}; - -template -struct OpCallerBinary, SCALAR_VEC> { - using Op = PowOp; - static void run( - const typename Op::src_ctype src0, const typename Op::src_ctype* src1, - typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, - DType dst_dtype, size_t nr_elems) { - Op op(src0_dtype, src1_dtype, dst_dtype); - size_t i = 0; -#if MEGDNN_FIX_AARCH32_BUG -// FIXME: as llvm may cause cannot select error if enable vectorize -#pragma clang loop vectorize(disable) -#endif - for (; i < nr_elems; i++) { - op(src0, *src1, dst); - src1++; - dst++; - } - } -}; - -template -struct OpCallerBinary, BCAST101_VEC> { - using Op = PowOp; - static void run( - const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, - typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, - DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) { - Op op(src0_dtype, src1_dtype, dst_dtype); - for (size_t b = 0; b < batch; b++) { - auto src0_ptr = src0; - for (size_t c = 0; c < channel; c++) { - size_t i = 0; -#if MEGDNN_FIX_AARCH32_BUG -// FIXME: as llvm may cause cannot select error if enable vectorize -#pragma clang loop vectorize(disable) -#endif - for (; i < channel_stride; i++) { - op(*src0_ptr, *src1, dst); - src1++; - dst++; - } - src0_ptr++; - } - } - } -}; - -template -struct OpCallerBinary, BCASTX0X_VEC> { - using Op = PowOp; - static void run( - const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, - typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, - DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) { - Op op(src0_dtype, src1_dtype, dst_dtype); - for (size_t b = 0; b < batch; b++) { - auto src0_ptr_base = src0 + b * channel_stride; - for (size_t c = 0; c < channel; c++) { - size_t i = 0; - auto src0_ptr = src0_ptr_base; -#if MEGDNN_FIX_AARCH32_BUG -// FIXME: as llvm may cause cannot select error if enable vectorize -#pragma clang loop vectorize(disable) -#endif - for (; i < channel_stride; i++) { - op(*src0_ptr, *src1, dst); - src0_ptr++; - src1++; - dst++; - } - } - } - } -}; - -template -struct OpCallerBinary { - static void run( - const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, - typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, - DType dst_dtype, size_t nr_elems) { - Op op(src0_dtype, src1_dtype, dst_dtype); - ParamElemVisitor vis0; - ParamElemVisitor vis1; - size_t i = 0; - for (; i + Op::SIMD_WIDTH * 2 <= nr_elems; i += Op::SIMD_WIDTH * 2) { - op({{vis0(src0), vis0(src0 + Op::SIMD_WIDTH)}}, - {{vis1(src1), vis1(src1 + Op::SIMD_WIDTH)}}, dst); - src0 += Op::SIMD_WIDTH * 2; - src1 += Op::SIMD_WIDTH * 2; - dst += Op::SIMD_WIDTH * 2; - } -#if MEGDNN_FIX_AARCH32_BUG -// FIXME: as llvm may cause cannot select error if enable vectorize -#pragma clang loop vectorize(disable) -#endif - for (; i < nr_elems; i++) { - op(*src0, *src1, dst); - src0++; - src1++; - dst++; - } - } -}; - -template -struct OpCallerBinary { - static void run( - const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, - typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, - DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) { - Op op(src0_dtype, src1_dtype, dst_dtype); - ParamElemVisitor vis0; - ParamElemVisitorDup vis1; - for (size_t b = 0; b < batch; b++) { - const typename Op::src_ctype* src1_ptr = src1; - for (size_t c = 0; c < channel; c++) { - size_t i = 0; - auto src1_neon = vis1(src1_ptr); - for (; i + Op::SIMD_WIDTH * 2 <= channel_stride; - i += Op::SIMD_WIDTH * 2) { - op({{vis0(src0), vis0(src0 + Op::SIMD_WIDTH)}}, - {{src1_neon, src1_neon}}, dst); - src0 += Op::SIMD_WIDTH * 2; - dst += Op::SIMD_WIDTH * 2; - } -#if MEGDNN_FIX_AARCH32_BUG -// FIXME: as llvm may cause cannot select error if enable vectorize -#pragma clang loop vectorize(disable) -#endif - for (; i < channel_stride; i++) { - op(*src0, *src1_ptr, dst); - src0++; - dst++; - } - src1_ptr++; - } - } - } -}; - -template -struct OpCallerBinary { - static void run( - const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, - typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, - DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) { - Op op(src0_dtype, src1_dtype, dst_dtype); - ParamElemVisitor vis; - for (size_t b = 0; b < batch; b++) { - const typename Op::src_ctype* src1_ptr_base = src1 + b * channel_stride; - for (size_t c = 0; c < channel; c++) { - size_t i = 0; - auto src1_ptr = src1_ptr_base; - for (; i + Op::SIMD_WIDTH * 2 <= channel_stride; - i += Op::SIMD_WIDTH * 2) { - auto src0_neon0 = vis(src0); - auto src0_neon1 = vis(src0 + Op::SIMD_WIDTH); - auto src1_neon0 = vis(src1_ptr); - auto src1_neon1 = vis(src1_ptr + Op::SIMD_WIDTH); - op({{src0_neon0, src0_neon1}}, {{src1_neon0, src1_neon1}}, dst); - src0 += Op::SIMD_WIDTH * 2; - src1_ptr += Op::SIMD_WIDTH * 2; - dst += Op::SIMD_WIDTH * 2; - } -#if MEGDNN_FIX_AARCH32_BUG -// FIXME: as llvm may cause cannot select error if enable vectorize -#pragma clang loop vectorize(disable) -#endif - for (; i < channel_stride; i++) { - op(*src0, *src1_ptr, dst); - src0++; - src1_ptr++; - dst++; - } - } - } - } -}; - -template -struct OpCallerBinary { - static void run( - const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, - typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, - DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) { - Op op(src0_dtype, src1_dtype, dst_dtype); - ParamElemVisitor vis; - for (size_t b = 0; b < batch; b++) { - for (size_t c = 0; c < channel; c++) { - size_t rest = channel_stride; - const typename Op::src_ctype* src1_ptr = src1; - while (rest >= Op::SIMD_WIDTH * 2) { - auto src0_neon0 = vis(src0); - auto src0_neon1 = vis(src0 + Op::SIMD_WIDTH); - auto src1_neon0 = vis(src1_ptr); - auto src1_neon1 = vis(src1_ptr + Op::SIMD_WIDTH); - src0 += Op::SIMD_WIDTH * 2; - src1_ptr += Op::SIMD_WIDTH * 2; - op({{src0_neon0, src0_neon1}}, {{src1_neon0, src1_neon1}}, dst); - dst += Op::SIMD_WIDTH * 2; - rest -= Op::SIMD_WIDTH * 2; - } -#if MEGDNN_FIX_AARCH32_BUG -// FIXME: as llvm may cause cannot select error if enable vectorize -#pragma clang loop vectorize(disable) -#endif - while (rest > 0) { - op(*src0, *src1_ptr, dst); - dst++; - src0++; - src1_ptr++; - rest--; - } - } - } - } -}; - -template -struct OpCallerBinary { - static void run( - const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, - typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, - DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) { - Op op(src0_dtype, src1_dtype, dst_dtype); - ParamElemVisitor vis; - for (size_t b = 0; b < batch; b++) { - for (size_t c = 0; c < channel; c++) { - size_t rest = channel_stride; - const typename Op::src_ctype* src0_ptr = src0; - while (rest >= Op::SIMD_WIDTH * 2) { - auto src0_neon0 = vis(src0_ptr); - auto src0_neon1 = vis(src0_ptr + Op::SIMD_WIDTH); - auto src1_neon0 = vis(src1); - auto src1_neon1 = vis(src1 + Op::SIMD_WIDTH); - src0_ptr += Op::SIMD_WIDTH * 2; - src1 += Op::SIMD_WIDTH * 2; - op({{src0_neon0, src0_neon1}}, {{src1_neon0, src1_neon1}}, dst); - dst += Op::SIMD_WIDTH * 2; - rest -= Op::SIMD_WIDTH * 2; - } -#if MEGDNN_FIX_AARCH32_BUG -// FIXME: as llvm may cause cannot select error if enable vectorize -#pragma clang loop vectorize(disable) -#endif - while (rest > 0) { - op(*src0_ptr, *src1, dst); - dst++; - src0_ptr++; - src1++; - rest--; - } - } - } - } -}; - -template -struct OpCallerBinary, BCAST101xX_VEC> { - using Op = PowOp; - static void run( - const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, - typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, - DType dst_dtype, size_t batch, size_t nr_channel_blocks, - size_t channel_stride, size_t channel_block_dim) { - Op op(src0_dtype, src1_dtype, dst_dtype); - for (size_t b = 0; b < batch; b++) { - auto src0_ptr = src0; - for (size_t cb = 0; cb < nr_channel_blocks; cb++) { - auto src0_block_ptr = src0_ptr + cb * channel_block_dim; - for (size_t i = 0; i < channel_stride; i++) { - for (size_t c_iter = 0; c_iter < channel_block_dim; c_iter++) { - op(*(src0_block_ptr + c_iter), *src1, dst); - src1++; - dst++; - } - } - } - } - } -}; - -template -struct OpCallerBinaryBcast101xXVec { - template - static void run( - const src_ctype* src0, const src_ctype* src1, typename Op::dst_ctype* dst, - const Op& op, size_t batch, size_t nr_channel_blocks, - size_t channel_stride) { - for (size_t b = 0; b < batch; b++) { - auto src0_ptr = src0; - for (size_t cb = 0; cb < nr_channel_blocks; cb++) { - auto src0_block_ptr = src0_ptr + cb * channel_block_dim; - for (size_t img_index = 0; img_index < channel_stride; img_index++) { - for (size_t c_iter = 0; c_iter < channel_block_dim; c_iter++) { - op(*(src0_block_ptr + c_iter), *src1, dst); - src1++; - dst++; - } - } - } - } - } -}; - -template -struct OpCallerBinaryBcast101xDVec { - template - static void run( - const src_ctype* src0, const src_ctype* src1, typename Op::dst_ctype* dst, - const Op& op, const Vis0& vis0, const Vis1& vis1, size_t batch, - size_t nr_channel_blocks, size_t channel_stride) { - for (size_t b = 0; b < batch; b++) { - auto src0_ptr = src0; - for (size_t cb = 0; cb < nr_channel_blocks; cb++) { - auto src0_block_ptr = src0_ptr + cb * channel_block_dim; - auto channel_block_vec = vis0(src0_block_ptr); - size_t img_index = 0; - auto src1_offset = Op::SIMD_WIDTH / channel_block_dim; - for (; img_index + 2 * src1_offset <= channel_stride; - img_index += 2 * src1_offset) { - op({{channel_block_vec, channel_block_vec}}, - {{vis1(src1), vis1(src1 + Op::SIMD_WIDTH)}}, dst); - src1 += Op::SIMD_WIDTH * 2; - dst += Op::SIMD_WIDTH * 2; - } - // TODO:all elemwise_multi_type op imp one simd mode - for (; img_index < channel_stride; img_index++) { - for (size_t c_iter = 0; c_iter < channel_block_dim; c_iter++) { - op(*(src0_block_ptr + c_iter), *src1, dst); - src1++; - dst++; - } - } - } - } - } -}; - -template -struct OpCallerBinaryBcast101xXVec { - template - static void run( - const src_ctype* src0, const src_ctype* src1, typename Op::dst_ctype* dst, - const Op& op, size_t batch, size_t nr_channel_blocks, - size_t channel_stride) { - ParamElemVisitorBcast101x4 vis0; - ParamElemVisitor vis1; - OpCallerBinaryBcast101xDVec::run( - src0, src1, dst, op, vis0, vis1, batch, nr_channel_blocks, - channel_stride); - } -}; - -#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC -template <> -struct OpCallerBinaryBcast101xXVec<__fp16, 8> { - using src_ctype = __fp16; - - template - static void run( - const src_ctype* src0, const src_ctype* src1, typename Op::dst_ctype* dst, - const Op& op, size_t batch, size_t nr_channel_blocks, - size_t channel_stride) { - ParamElemVisitorBcast101x8 vis0; - ParamElemVisitor vis1; - OpCallerBinaryBcast101xDVec::run( - src0, src1, dst, op, vis0, vis1, batch, nr_channel_blocks, - channel_stride); - } -}; -#endif - -template -struct OpCallerBinary { - static void run( - const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, - typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, - DType dst_dtype, size_t batch, size_t nr_channel_blocks, - size_t channel_stride, size_t channel_block_dim) { - megdnn_assert( - channel_block_dim == 4 || channel_block_dim == 8, - "only imp for nchw44/nchw88"); - Op op(src0_dtype, src1_dtype, dst_dtype); - if (channel_block_dim == 4) { - OpCallerBinaryBcast101xXVec::run( - src0, src1, dst, op, batch, nr_channel_blocks, channel_stride); - } else { - OpCallerBinaryBcast101xXVec::run( - src0, src1, dst, op, batch, nr_channel_blocks, channel_stride); - } - } -}; - -template -struct OpCallerBinary, VEC_BCAST101xX> { - using Op = PowOp; - static void run( - const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, - typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, - DType dst_dtype, size_t batch, size_t nr_channel_blocks, - size_t channel_stride, size_t channel_block_dim) { - Op op(src0_dtype, src1_dtype, dst_dtype); - for (size_t b = 0; b < batch; b++) { - auto src1_ptr = src1; - for (size_t cb = 0; cb < nr_channel_blocks; cb++) { - auto src1_block_ptr = src1_ptr + cb * channel_block_dim; - for (size_t i = 0; i < channel_stride; i++) { - for (size_t c_iter = 0; c_iter < channel_block_dim; c_iter++) { - op(*(src0), *(src1_block_ptr + c_iter), dst); - src0++; - dst++; - } - } - } - } - } -}; - -template -struct OpCallerBinaryVecBcast101xX { - template - static void run( - const src_ctype* src0, const src_ctype* src1, typename Op::dst_ctype* dst, - const Op& op, size_t batch, size_t nr_channel_blocks, - size_t channel_stride) { - for (size_t b = 0; b < batch; b++) { - auto src1_ptr = src1; - for (size_t cb = 0; cb < nr_channel_blocks; cb++) { - auto src1_block_ptr = src1_ptr + cb * channel_block_dim; - for (size_t img_index = 0; img_index < channel_stride; img_index++) { - for (size_t c_iter = 0; c_iter < channel_block_dim; c_iter++) { - op(*src0, *(src1_block_ptr + c_iter), dst); - src0++; - dst++; - } - } - } - } - } -}; - -template -struct OpCallerBinaryVecBcast101xD { - template - static void run( - const src_ctype* src0, const src_ctype* src1, typename Op::dst_ctype* dst, - const Op& op, const Vis0& vis0, const Vis1& vis1, size_t batch, - size_t nr_channel_blocks, size_t channel_stride) { - for (size_t b = 0; b < batch; b++) { - auto src1_ptr = src1; - for (size_t cb = 0; cb < nr_channel_blocks; cb++) { - auto src1_block_ptr = src1_ptr + cb * channel_block_dim; - auto channel_block_vec = vis1(src1_block_ptr); - size_t img_index = 0; - auto src0_offset = Op::SIMD_WIDTH / channel_block_dim; - for (; img_index + 2 * src0_offset <= channel_stride; - img_index += 2 * src0_offset) { - op({{vis0(src0), vis0(src0 + Op::SIMD_WIDTH)}}, - {{channel_block_vec, channel_block_vec}}, dst); - src0 += Op::SIMD_WIDTH * 2; - dst += Op::SIMD_WIDTH * 2; - } - // TODO:all elemwise_multi_type op imp one simd mode - for (; img_index < channel_stride; img_index++) { - for (size_t c_iter = 0; c_iter < channel_block_dim; c_iter++) { - op(*src0, *(src1_block_ptr + c_iter), dst); - src0++; - dst++; - } - } - } - } - } -}; - -template -struct OpCallerBinaryVecBcast101xX { - template - static void run( - const src_ctype* src0, const src_ctype* src1, typename Op::dst_ctype* dst, - const Op& op, size_t batch, size_t nr_channel_blocks, - size_t channel_stride) { - ParamElemVisitor vis0; - ParamElemVisitorBcast101x4 vis1; - OpCallerBinaryVecBcast101xD::run( - src0, src1, dst, op, vis0, vis1, batch, nr_channel_blocks, - channel_stride); - } -}; - -#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC -template <> -struct OpCallerBinaryVecBcast101xX<__fp16, 8> { - using src_ctype = __fp16; - template - static void run( - const src_ctype* src0, const src_ctype* src1, typename Op::dst_ctype* dst, - const Op& op, size_t batch, size_t nr_channel_blocks, - size_t channel_stride) { - ParamElemVisitor vis0; - ParamElemVisitorBcast101x8 vis1; - OpCallerBinaryVecBcast101xD::run( - src0, src1, dst, op, vis0, vis1, batch, nr_channel_blocks, - channel_stride); - } -}; -#endif - -template -struct OpCallerBinary { - static void run( - const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, - typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, - DType dst_dtype, size_t batch, size_t nr_channel_blocks, - size_t channel_stride, size_t channel_block_dim) { - megdnn_assert( - channel_block_dim == 4 || channel_block_dim == 8, - "only imp for nchw44/nchw88"); - Op op(src0_dtype, src1_dtype, dst_dtype); - if (channel_block_dim == 4) { - OpCallerBinaryVecBcast101xX::run( - src0, src1, dst, op, batch, nr_channel_blocks, channel_stride); - } else { - OpCallerBinaryVecBcast101xX::run( - src0, src1, dst, op, batch, nr_channel_blocks, channel_stride); - } - } -}; - -template -struct OpCallerBinary { - static void run( - const typename Op::src_ctype* src0, const typename Op::src_ctype src1, - typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, - DType dst_dtype, size_t nr_elems) { - Op op(src0_dtype, src1_dtype, dst_dtype); - ParamElemVisitor vis0; - ParamElemVisitorDup vis1; - auto vis1_neon = vis1(&src1); - size_t i = 0; - for (; i + Op::SIMD_WIDTH * 2 <= nr_elems; i += Op::SIMD_WIDTH * 2) { - op({{vis0(src0), vis0(src0 + Op::SIMD_WIDTH)}}, {{vis1_neon, vis1_neon}}, - dst); - src0 += Op::SIMD_WIDTH * 2; - dst += Op::SIMD_WIDTH * 2; - } -#if MEGDNN_FIX_AARCH32_BUG -// FIXME: as llvm may cause cannot select error if enable vectorize -#pragma clang loop vectorize(disable) -#endif - for (; i < nr_elems; i++) { - op(*src0, src1, dst); - src0++; - dst++; - } - } -}; - -//! this only for nonswap op, like SUB and DIV -template -struct OpCallerBinary { - static void run( - const typename Op::src_ctype src0, const typename Op::src_ctype* src1, - typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, - DType dst_dtype, size_t nr_elems) { - Op op(src0_dtype, src1_dtype, dst_dtype); - ParamElemVisitorDup vis0; - ParamElemVisitor vis1; - auto vis0_neon = vis0(&src0); - size_t i = 0; - for (; i + Op::SIMD_WIDTH * 2 <= nr_elems; i += Op::SIMD_WIDTH * 2) { - op({{vis0_neon, vis0_neon}}, {{vis1(src1), vis1(src1 + Op::SIMD_WIDTH)}}, - dst); - src1 += Op::SIMD_WIDTH * 2; - dst += Op::SIMD_WIDTH * 2; - } -#if MEGDNN_FIX_AARCH32_BUG -// FIXME: as llvm may cause cannot select error if enable vectorize -#pragma clang loop vectorize(disable) -#endif - for (; i < nr_elems; i++) { - op(src0, *src1, dst); - src1++; - dst++; - } - } -}; - -template -struct OpCallerBinary { - static void run( - const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, - typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, - DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) { - Op op(src0_dtype, src1_dtype, dst_dtype); - ParamElemVisitorDup vis0; - ParamElemVisitor vis1; - for (size_t b = 0; b < batch; b++) { - auto src0_ptr = src0; - for (size_t c = 0; c < channel; c++) { - auto vis0_neon = vis0(src0_ptr); - size_t i = 0; - for (; i + Op::SIMD_WIDTH * 2 <= channel_stride; - i += Op::SIMD_WIDTH * 2) { - op({{vis0_neon, vis0_neon}}, - {{vis1(src1), vis1(src1 + Op::SIMD_WIDTH)}}, dst); - src1 += Op::SIMD_WIDTH * 2; - dst += Op::SIMD_WIDTH * 2; - } -#if MEGDNN_FIX_AARCH32_BUG -// FIXME: as llvm may cause cannot select error if enable vectorize -#pragma clang loop vectorize(disable) -#endif - for (; i < channel_stride; i++) { - op(*src0_ptr, *src1, dst); - src1++; - dst++; - } - src0_ptr++; - } - } - } -}; - -template -struct OpCallerBinary { - static void run( - const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, - typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, - DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) { - Op op(src0_dtype, src1_dtype, dst_dtype); - ParamElemVisitor vis; - for (size_t b = 0; b < batch; b++) { - auto src0_ptr_base = src0 + b * channel_stride; - for (size_t c = 0; c < channel; c++) { - auto src0_ptr = src0_ptr_base; - size_t i = 0; - for (; i + Op::SIMD_WIDTH * 2 <= channel_stride; - i += Op::SIMD_WIDTH * 2) { - auto src0_neon0 = vis(src0_ptr); - auto src0_neon1 = vis(src0_ptr + Op::SIMD_WIDTH); - auto src1_neon0 = vis(src1); - auto src1_neon1 = vis(src1 + Op::SIMD_WIDTH); - op({{src0_neon0, src0_neon1}}, {{src1_neon0, src1_neon1}}, dst); - src0_ptr += Op::SIMD_WIDTH * 2; - src1 += Op::SIMD_WIDTH * 2; - dst += Op::SIMD_WIDTH * 2; - } -#if MEGDNN_FIX_AARCH32_BUG -// FIXME: as llvm may cause cannot select error if enable vectorize -#pragma clang loop vectorize(disable) -#endif - for (; i < channel_stride; i++) { - op(*src0_ptr, *src1, dst); - src0_ptr++; - src1++; - dst++; - } - } - } - } -}; - -template -struct OpCallerTernary; - -template -struct OpCallerTernary { - static void run( - const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, - const typename Op::src_ctype* src2, typename Op::dst_ctype* dst, - DType src0_dtype, DType src1_dtype, DType src2_dtype, DType dst_dtype, - size_t nr_elems) { - Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype); - ParamElemVisitor vis0; - ParamElemVisitor vis1; - ParamElemVisitor vis2; - size_t i = 0; - for (; i + Op::SIMD_WIDTH * 2 <= nr_elems; i += Op::SIMD_WIDTH * 2) { - op({{vis0(src0), vis0(src0 + Op::SIMD_WIDTH)}}, - {{vis1(src1), vis1(src1 + Op::SIMD_WIDTH)}}, - {{vis2(src2), vis2(src2 + Op::SIMD_WIDTH)}}, dst); - src0 += Op::SIMD_WIDTH * 2; - src1 += Op::SIMD_WIDTH * 2; - src2 += Op::SIMD_WIDTH * 2; - dst += Op::SIMD_WIDTH * 2; - } -#if MEGDNN_FIX_AARCH32_BUG -// FIXME: as llvm may cause cannot select error if enable vectorize -#pragma clang loop vectorize(disable) -#endif - for (; i < nr_elems; i++) { - op(*src0, *src1, *src2, dst); - src0++; - src1++; - src2++; - dst++; - } - } -}; - -//! src0: vector, src1: vector, src2: scalar -template -struct OpCallerTernary { - static void run( - const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, - const typename Op::src_ctype src2, typename Op::dst_ctype* dst, - DType src0_dtype, DType src1_dtype, DType src2_dtype, DType dst_dtype, - size_t nr_elems) { - Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype); - ParamElemVisitor vis0; - ParamElemVisitor vis1; - ParamElemVisitorDup vis2; - auto vis2_neon = vis2(&src2); - size_t i = 0; - for (; i + Op::SIMD_WIDTH * 2 <= nr_elems; i += Op::SIMD_WIDTH * 2) { - op({{vis0(src0), vis0(src0 + Op::SIMD_WIDTH)}}, - {{vis1(src1), vis1(src1 + Op::SIMD_WIDTH)}}, {{vis2_neon, vis2_neon}}, - dst); - src0 += Op::SIMD_WIDTH * 2; - src1 += Op::SIMD_WIDTH * 2; - dst += Op::SIMD_WIDTH * 2; - } -#if MEGDNN_FIX_AARCH32_BUG -// FIXME: as llvm may cause cannot select error if enable vectorize -#pragma clang loop vectorize(disable) -#endif - for (; i < nr_elems; i++) { - op(*src0, *src1, src2, dst); - src0++; - src1++; - dst++; - } - } -}; - -//! src0: 1C11, src1: vector, src2: 1C11 -template -struct OpCallerTernary { - static void run( - const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, - const typename Op::src_ctype* src2, typename Op::dst_ctype* dst, - DType src0_dtype, DType src1_dtype, DType src2_dtype, DType dst_dtype, - size_t batch_size, size_t channel_size, size_t channel_stride, - size_t batch_offset) { - Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype); - ParamElemVisitor vis1; - ParamElemVisitorDup vis0; - ParamElemVisitorDup vis2; - for (size_t batch = 0; batch < batch_size; batch++) { - auto src0_ptr = src0; - auto src2_ptr = src2; - auto b_offset = batch_offset; - for (size_t channel = 0; channel < channel_size; channel++) { - size_t i = 0; - auto src0_neon = vis0(src0_ptr); - auto src2_neon = vis2(src2_ptr); - for (; i + Op::SIMD_WIDTH * 2 <= channel_stride; - i += Op::SIMD_WIDTH * 2) { - op({{src0_neon, src0_neon}}, - {{vis1(src1), vis1(src1 + Op::SIMD_WIDTH)}}, - {{src2_neon, src2_neon}}, dst); - src1 += Op::SIMD_WIDTH * 2; - dst += Op::SIMD_WIDTH * 2; - b_offset -= Op::SIMD_WIDTH * 2; - } -#if MEGDNN_FIX_AARCH32_BUG -// FIXME: as llvm may cause cannot select error if enable vectorize -#pragma clang loop vectorize(disable) -#endif - for (; i < channel_stride; i++) { - op(*src0_ptr, *src1, *src2_ptr, dst); - src1++; - dst++; - b_offset--; - } - src0_ptr++; - src2_ptr++; - } - src1 += b_offset; - dst += b_offset; - } - } -}; - -//! src0: 111C, src1: vector, src2: 111C, src1 may not be contig -template -struct OpCallerTernary { - static void run( - const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, - size_t src1_offset, const typename Op::src_ctype* src2, - typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, - DType src2_dtype, DType dst_dtype, size_t batch_size, size_t channel_size, - size_t channel_stride, size_t batch_offset) { - Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype); - ParamElemVisitor vis; - for (size_t batch = 0; batch < batch_size; batch++) { - auto b_offset = batch_offset; - for (size_t channel = 0; channel < channel_size; channel++) { - auto src0_ptr = src0; - auto src2_ptr = src2; - size_t i = 0; - for (; i + Op::SIMD_WIDTH * 2 <= channel_stride; - i += Op::SIMD_WIDTH * 2) { - auto src0_neon0 = vis(src0_ptr); - auto src0_neon1 = vis(src0_ptr + Op::SIMD_WIDTH); - auto src1_neon0 = vis(src1); - auto src1_neon1 = vis(src1 + Op::SIMD_WIDTH); - auto src2_neon0 = vis(src2_ptr); - auto src2_neon1 = vis(src2_ptr + Op::SIMD_WIDTH); - op({{src0_neon0, src0_neon1}}, {{src1_neon0, src1_neon1}}, - {{src2_neon0, src2_neon1}}, dst); - src0_ptr += Op::SIMD_WIDTH * 2; - src1 += Op::SIMD_WIDTH * 2; - src2_ptr += Op::SIMD_WIDTH * 2; - dst += Op::SIMD_WIDTH * 2; - b_offset -= Op::SIMD_WIDTH * 2; - } -#if MEGDNN_FIX_AARCH32_BUG -// FIXME: as llvm may cause cannot select error if enable vectorize -#pragma clang loop vectorize(disable) -#endif - for (; i < channel_stride; i++) { - op(*src0_ptr, *src1, *src2_ptr, dst); - src0_ptr++; - src1++; - src2_ptr++; - dst++; - b_offset--; - } - src1 += src1_offset; - } - src1 += b_offset; - dst += b_offset; - } - } -}; - -template -struct OpCallerTernaryBcast101xXVecBcast101xX { - template - static void run( - const src_ctype* src0, const src_ctype* src1, const src_ctype* src2, - typename Op::dst_ctype* dst, const Op& op, size_t batch, - size_t nr_channel_blocks, size_t channel_stride) { - for (size_t b = 0; b < batch; b++) { - auto src0_ptr = src0; - auto src2_ptr = src2; - for (size_t cb = 0; cb < nr_channel_blocks; cb++) { - auto src0_block_ptr = src0_ptr + cb * channel_block_dim; - auto src2_block_ptr = src2_ptr + cb * channel_block_dim; - for (size_t img_index = 0; img_index < channel_stride; img_index++) { - for (size_t c_iter = 0; c_iter < channel_block_dim; c_iter++) { - op(*(src0_block_ptr + c_iter), *src1, - *(src2_block_ptr + c_iter), dst); - src1++; - dst++; - } - } - } - } - } -}; - -template -struct OpCallerTernaryBcast101xDVecBcast101xD { - template - static void run( - const src_ctype* src0, const src_ctype* src1, const src_ctype* src2, - typename Op::dst_ctype* dst, const Op& op, const Vis0& vis0, - const Vis1& vis1, const Vis2& vis2, size_t batch, size_t nr_channel_blocks, - size_t channel_stride) { - for (size_t b = 0; b < batch; b++) { - auto src0_ptr = src0; - auto src2_ptr = src2; - for (size_t cb = 0; cb < nr_channel_blocks; cb++) { - auto src0_block_ptr = src0_ptr + cb * channel_block_dim; - auto src2_block_ptr = src2_ptr + cb * channel_block_dim; - auto channel_block_vec0 = vis0(src0_block_ptr); - auto channel_block_vec2 = vis2(src2_block_ptr); - size_t img_index = 0; - auto src1_offset = Op::SIMD_WIDTH / channel_block_dim; - for (; img_index + 2 * src1_offset <= channel_stride; - img_index += 2 * src1_offset) { - op({{channel_block_vec0, channel_block_vec0}}, - {{vis1(src1), vis1(src1 + Op::SIMD_WIDTH)}}, - {{channel_block_vec2, channel_block_vec2}}, dst); - src1 += Op::SIMD_WIDTH * 2; - dst += Op::SIMD_WIDTH * 2; - } - // TODO:all elemwise_multi_type op imp one simd mode - for (; img_index < channel_stride; img_index++) { - for (size_t c_iter = 0; c_iter < channel_block_dim; c_iter++) { - op(*(src0_block_ptr + c_iter), *src1, - *(src2_block_ptr + c_iter), dst); - src1++; - dst++; - } - } - } - } - } -}; - -//! src0: CHW44, src1: vector, src2: CHW44 -template -struct OpCallerTernaryBcast101xXVecBcast101xX { - template - static void run( - const src_ctype* src0, const src_ctype* src1, const src_ctype* src2, - typename Op::dst_ctype* dst, const Op& op, size_t batch, - size_t nr_channel_blocks, size_t channel_stride) { - ParamElemVisitorBcast101x4 vis0; - ParamElemVisitor vis1; - ParamElemVisitorBcast101x4 vis2; - OpCallerTernaryBcast101xDVecBcast101xD::run( - src0, src1, src2, dst, op, vis0, vis1, vis2, batch, nr_channel_blocks, - channel_stride); - } -}; - -#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC -template <> -struct OpCallerTernaryBcast101xXVecBcast101xX<__fp16, 8> { - using src_ctype = __fp16; - template - static void run( - const src_ctype* src0, const src_ctype* src1, const src_ctype* src2, - typename Op::dst_ctype* dst, const Op& op, size_t batch, - size_t nr_channel_blocks, size_t channel_stride) { - ParamElemVisitorBcast101x8 vis0; - ParamElemVisitor vis1; - ParamElemVisitorBcast101x8 vis2; - OpCallerTernaryBcast101xDVecBcast101xD::run( - src0, src1, src2, dst, op, vis0, vis1, vis2, batch, nr_channel_blocks, - channel_stride); - } -}; -#endif - -template -struct OpCallerTernary { - static void run( - const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, - const typename Op::src_ctype* src2, typename Op::dst_ctype* dst, - DType src0_dtype, DType src1_dtype, DType src2_dtype, DType dst_dtype, - size_t batch, size_t nr_channel_blocks, size_t channel_stride, - size_t channel_block_dim) { - megdnn_assert( - channel_block_dim == 4 || channel_block_dim == 8, - "only imp for nchw44/nchw88"); - Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype); - if (channel_block_dim == 4) { - OpCallerTernaryBcast101xXVecBcast101xX::run( - src0, src1, src2, dst, op, batch, nr_channel_blocks, - channel_stride); - } else { - OpCallerTernaryBcast101xXVecBcast101xX::run( - src0, src1, src2, dst, op, batch, nr_channel_blocks, - channel_stride); - } - } -}; - -//! src1: 1C11, src0 and src2 are contig -template -struct OpCallerTernary { - static void run( - const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, - const typename Op::src_ctype* src2, typename Op::dst_ctype* dst, - DType src0_dtype, DType src1_dtype, DType src2_dtype, DType dst_dtype, - size_t batch_size, size_t channel_size, size_t channel_stride) { - Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype); - ParamElemVisitor vis0; - ParamElemVisitorDup vis1; - ParamElemVisitor vis2; - for (size_t batch = 0; batch < batch_size; batch++) { - auto src1_ptr = src1; - for (size_t channel = 0; channel < channel_size; channel++) { - size_t i = 0; - auto src1_neon = vis1(src1_ptr); - for (; i + Op::SIMD_WIDTH * 2 <= channel_stride; - i += Op::SIMD_WIDTH * 2) { - op({{vis0(src0), vis0(src0 + Op::SIMD_WIDTH)}}, - {{src1_neon, src1_neon}}, - {{vis2(src2), vis2(src2 + Op::SIMD_WIDTH)}}, dst); - src0 += Op::SIMD_WIDTH * 2; - src2 += Op::SIMD_WIDTH * 2; - dst += Op::SIMD_WIDTH * 2; - } -#if MEGDNN_FIX_AARCH32_BUG -// FIXME: as llvm may cause cannot select error if enable vectorize -#pragma clang loop vectorize(disable) -#endif - for (; i < channel_stride; i++) { - op(*src0, *src1_ptr, *src2, dst); - src0++; - src2++; - dst++; - } - src1_ptr++; - } - } - } -}; - -//! src1: 111C, src0 and src2 may not be contig -template -struct OpCallerTernary { - static void run( - const typename Op::src_ctype* src0, size_t src0_offset, - const typename Op::src_ctype* src1, const typename Op::src_ctype* src2, - size_t src2_offset, typename Op::dst_ctype* dst, DType src0_dtype, - DType src1_dtype, DType src2_dtype, DType dst_dtype, size_t batch_size, - size_t channel_size, size_t channel_stride) { - Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype); - ParamElemVisitor vis0; - ParamElemVisitor vis1; - ParamElemVisitor vis2; - for (size_t batch = 0; batch < batch_size; batch++) { - for (size_t channel = 0; channel < channel_size; channel++) { - auto src1_ptr = src1; - size_t i = 0; - for (; i + Op::SIMD_WIDTH * 2 <= channel_stride; - i += Op::SIMD_WIDTH * 2) { - op({{vis0(src0), vis0(src0 + Op::SIMD_WIDTH)}}, - {{vis1(src1_ptr), vis1(src1_ptr + Op::SIMD_WIDTH)}}, - {{vis2(src2), vis2(src2 + Op::SIMD_WIDTH)}}, dst); - src0 += Op::SIMD_WIDTH * 2; - src1_ptr += Op::SIMD_WIDTH * 2; - src2 += Op::SIMD_WIDTH * 2; - dst += Op::SIMD_WIDTH * 2; - } -#if MEGDNN_FIX_AARCH32_BUG -// FIXME: as llvm may cause cannot select error if enable vectorize -#pragma clang loop vectorize(disable) -#endif - for (; i < channel_stride; i++) { - op(*src0, *src1_ptr, *src2, dst); - src0++; - src1_ptr++; - src2++; - dst++; - } - src0 += src0_offset; - src2 += src2_offset; - } - } - } -}; - -template -struct OpCallerTernaryVecBcast101xXVec { - template - static void run( - const src_ctype* src0, const src_ctype* src1, const src_ctype* src2, - typename Op::dst_ctype* dst, const Op& op, size_t batch, - size_t nr_channel_blocks, size_t channel_stride) { - for (size_t b = 0; b < batch; b++) { - auto src1_ptr = src1; - for (size_t cb = 0; cb < nr_channel_blocks; cb++) { - auto src1_block_ptr = src1_ptr + cb * channel_block_dim; - for (size_t img_index = 0; img_index < channel_stride; img_index++) { - for (size_t c_iter = 0; c_iter < channel_block_dim; c_iter++) { - op(*src0, *(src1_block_ptr + c_iter), *src2, dst); - src0++; - src2++; - dst++; - } - } - } - } - } -}; - -//! src1: CHW44, src0 and src2 are contig -template -struct OpCallerTernaryVecBcast101xDVec { - template - static void run( - const src_ctype* src0, const src_ctype* src1, const src_ctype* src2, - typename Op::dst_ctype* dst, const Op& op, const Vis0& vis0, - const Vis1& vis1, const Vis2& vis2, size_t batch, size_t nr_channel_blocks, - size_t channel_stride) { - for (size_t b = 0; b < batch; b++) { - auto src1_ptr = src1; - for (size_t cb = 0; cb < nr_channel_blocks; cb++) { - auto src1_block_ptr = src1_ptr + cb * channel_block_dim; - auto channel_block_vec = vis1(src1_block_ptr); - size_t img_index = 0; - auto offset = Op::SIMD_WIDTH / channel_block_dim; - for (; img_index + 2 * offset <= channel_stride; - img_index += 2 * offset) { - op({{vis0(src0), vis0(src0 + Op::SIMD_WIDTH)}}, - {{channel_block_vec, channel_block_vec}}, - {{vis2(src2), vis2(src2 + Op::SIMD_WIDTH)}}, dst); - src0 += Op::SIMD_WIDTH * 2; - src2 += Op::SIMD_WIDTH * 2; - dst += Op::SIMD_WIDTH * 2; - } - // TODO:all elemwise_multi_type op imp one simd mode - for (; img_index < channel_stride; img_index++) { - for (size_t c_iter = 0; c_iter < channel_block_dim; c_iter++) { - op(*src0, *(src1_block_ptr + c_iter), *src2, dst); - src0++; - src2++; - dst++; - } - } - } - } - } -}; - -template -struct OpCallerTernaryVecBcast101xXVec { - template - static void run( - const src_ctype* src0, const src_ctype* src1, const src_ctype* src2, - typename Op::dst_ctype* dst, const Op& op, size_t batch, - size_t nr_channel_blocks, size_t channel_stride) { - ParamElemVisitor vis0; - ParamElemVisitorBcast101x4 vis1; - ParamElemVisitor vis2; - OpCallerTernaryVecBcast101xDVec::run( - src0, src1, src2, dst, op, vis0, vis1, vis2, batch, nr_channel_blocks, - channel_stride); - } -}; - -#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC -template <> -struct OpCallerTernaryVecBcast101xXVec<__fp16, 8> { - using src_ctype = __fp16; - template - static void run( - const src_ctype* src0, const src_ctype* src1, const src_ctype* src2, - typename Op::dst_ctype* dst, const Op& op, size_t batch, - size_t nr_channel_blocks, size_t channel_stride) { - ParamElemVisitor vis0; - ParamElemVisitorBcast101x8 vis1; - ParamElemVisitor vis2; - OpCallerTernaryVecBcast101xDVec::run( - src0, src1, src2, dst, op, vis0, vis1, vis2, batch, nr_channel_blocks, - channel_stride); - } -}; -#endif - -template -struct OpCallerTernary { - static void run( - const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, - const typename Op::src_ctype* src2, typename Op::dst_ctype* dst, - DType src0_dtype, DType src1_dtype, DType src2_dtype, DType dst_dtype, - size_t batch, size_t nr_channel_blocks, size_t channel_stride, - size_t channel_block_dim) { - megdnn_assert( - channel_block_dim == 4 || channel_block_dim == 8, - "only imp for nchw44/nchw88"); - - Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype); - if (channel_block_dim == 4) { - OpCallerTernaryVecBcast101xXVec::run( - src0, src1, src2, dst, op, batch, nr_channel_blocks, - channel_stride); - } else { - OpCallerTernaryVecBcast101xXVec::run( - src0, src1, src2, dst, op, batch, nr_channel_blocks, - channel_stride); - } - } -}; - -//! src1: scalar, src0 and src2 has the same shape -template -struct OpCallerTernary { - static void run( - const typename Op::src_ctype* src0, const typename Op::src_ctype src1, - const typename Op::src_ctype* src2, typename Op::dst_ctype* dst, - DType src0_dtype, DType src1_dtype, DType src2_dtype, DType dst_dtype, - size_t nr_elems) { - Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype); - ParamElemVisitor vis0; - ParamElemVisitorDup vis1; - ParamElemVisitor vis2; - auto vis1_neon = vis1(&src1); - size_t i = 0; - for (; i + Op::SIMD_WIDTH * 2 <= nr_elems; i += Op::SIMD_WIDTH * 2) { - op({{vis0(src0), vis0(src0 + Op::SIMD_WIDTH)}}, {{vis1_neon, vis1_neon}}, - {{vis2(src2), vis2(src2 + Op::SIMD_WIDTH)}}, dst); - src0 += Op::SIMD_WIDTH * 2; - src2 += Op::SIMD_WIDTH * 2; - dst += Op::SIMD_WIDTH * 2; - } -#if MEGDNN_FIX_AARCH32_BUG -// FIXME: as llvm may cause cannot select error if enable vectorize -#pragma clang loop vectorize(disable) -#endif - for (; i < nr_elems; i++) { - op(*src0, src1, *src2, dst); - src0++; - src2++; - dst++; - } - } -}; - -//! src1, src2: scalar, src0 is vector -template -struct OpCallerTernary { - static void run( - const typename Op::src_ctype* src0, const typename Op::src_ctype src1, - const typename Op::src_ctype src2, typename Op::dst_ctype* dst, - DType src0_dtype, DType src1_dtype, DType src2_dtype, DType dst_dtype, - size_t nr_elems) { - Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype); - ParamElemVisitor vis0; - ParamElemVisitorDup vis1; - ParamElemVisitorDup vis2; - auto vis1_neon = vis1(&src1); - auto vis2_neon = vis2(&src2); - size_t i = 0; - for (; i + Op::SIMD_WIDTH * 2 <= nr_elems; i += Op::SIMD_WIDTH * 2) { - op({{vis0(src0), vis0(src0 + Op::SIMD_WIDTH)}}, {{vis1_neon, vis1_neon}}, - {{vis2_neon, vis2_neon}}, dst); - src0 += Op::SIMD_WIDTH * 2; - dst += Op::SIMD_WIDTH * 2; - } -#if MEGDNN_FIX_AARCH32_BUG -// FIXME: as llvm may cause cannot select error if enable vectorize -#pragma clang loop vectorize(disable) -#endif - for (; i < nr_elems; i++) { - op(*src0, src1, src2, dst); - src0++; - dst++; - } - } -}; - -} // namespace arm_common -} // namespace megdnn - -// vim: syntax=cpp.doxygen diff --git a/dnn/src/fallback/elemwise/gi_impl/binary/algo.cpp b/dnn/src/fallback/elemwise/gi_impl/binary/algo.cpp index 341b3809..4d050b14 100644 --- a/dnn/src/fallback/elemwise/gi_impl/binary/algo.cpp +++ b/dnn/src/fallback/elemwise/gi_impl/binary/algo.cpp @@ -2,7 +2,7 @@ * \file dnn/src/fallback/elemwise/gi_impl/binary/algo.cpp */ #include "src/fallback/elemwise/gi_impl/binary/algo.h" -#include "src/fallback/elemwise_op.h" +#include "src/fallback/elemwise_helper/elemwise_op.h" #include "src/common/utils.h" #include "src/naive/handle.h" @@ -12,6 +12,7 @@ MIDOUT_DECL(megdnn_fallback_elemwise_binary) using namespace megdnn; +using namespace elemwise; using namespace fallback; namespace { diff --git a/dnn/src/fallback/elemwise/gi_impl/ternary/algo.cpp b/dnn/src/fallback/elemwise/gi_impl/ternary/algo.cpp index b4e383c2..d51f3b98 100644 --- a/dnn/src/fallback/elemwise/gi_impl/ternary/algo.cpp +++ b/dnn/src/fallback/elemwise/gi_impl/ternary/algo.cpp @@ -3,7 +3,7 @@ */ #include "src/fallback/elemwise/gi_impl/ternary/algo.h" -#include "src/fallback/elemwise_op.h" +#include "src/fallback/elemwise_helper/elemwise_op.h" #include "src/common/utils.h" #include "src/naive/handle.h" @@ -13,6 +13,7 @@ MIDOUT_DECL(megdnn_fallback_elemwise_ternary) using namespace megdnn; +using namespace elemwise; using namespace fallback; #define DISPATCH_MODE_FLOAT(_case, _type, _type_midout_id) \ diff --git a/dnn/src/fallback/elemwise/gi_impl/unary/algo.cpp b/dnn/src/fallback/elemwise/gi_impl/unary/algo.cpp index ffbac861..fb888d8a 100644 --- a/dnn/src/fallback/elemwise/gi_impl/unary/algo.cpp +++ b/dnn/src/fallback/elemwise/gi_impl/unary/algo.cpp @@ -2,7 +2,7 @@ * \file dnn/src/fallback/elemwise/gi_impl/unary/algo.cpp */ #include "src/fallback/elemwise/gi_impl/unary/algo.h" -#include "src/fallback/elemwise_op.h" +#include "src/fallback/elemwise_helper/elemwise_op.h" #include "src/common/utils.h" #include "src/naive/handle.h" @@ -12,6 +12,7 @@ MIDOUT_DECL(megdnn_fallback_elemwise_unary) using namespace megdnn; +using namespace elemwise; using namespace fallback; bool ElemwiseImpl::AlgoUnary::is_available(const KernParam& kern_param) const { diff --git a/dnn/src/fallback/elemwise/opr_impl.cpp b/dnn/src/fallback/elemwise/opr_impl.cpp index 98e22374..68eea732 100644 --- a/dnn/src/fallback/elemwise/opr_impl.cpp +++ b/dnn/src/fallback/elemwise/opr_impl.cpp @@ -25,6 +25,7 @@ MIDOUT_DECL(megdnn_fallback_elemwise_exec_BINARY_INT) MIDOUT_DECL(megdnn_fallback_elemwise_exec_BINARY_FLOAT) using namespace megdnn; +using namespace elemwise; using namespace fallback; void ElemwiseImpl::exec(const TensorNDArray& srcs, _megdnn_tensor_out dst) { diff --git a/dnn/src/fallback/elemwise/opr_impl.h b/dnn/src/fallback/elemwise/opr_impl.h index c7285ecb..ea073f0e 100644 --- a/dnn/src/fallback/elemwise/opr_impl.h +++ b/dnn/src/fallback/elemwise/opr_impl.h @@ -9,7 +9,7 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #pragma once -#include "src/fallback/elemwise_op.h" +#include "src/fallback/elemwise_helper/elemwise_op.h" #include "src/naive/elemwise/opr_impl.h" namespace megdnn { @@ -60,7 +60,7 @@ private: public: class AlgoBase; struct KernParam { - BcastType broad_cast_type; + elemwise::BcastType broad_cast_type; Mode mode; const TensorND* m_dst; Handle* handle; diff --git a/dnn/src/fallback/elemwise_helper/elemwise_op.h b/dnn/src/fallback/elemwise_helper/elemwise_op.h new file mode 100644 index 00000000..f7b935cf --- /dev/null +++ b/dnn/src/fallback/elemwise_helper/elemwise_op.h @@ -0,0 +1,72 @@ +/** + * \file dnn/src/fallback/elemwise_helper/elemwise_op.h + */ + +#pragma once + +#include "src/fallback/elemwise_helper/op_binary.h" +#include "src/fallback/elemwise_helper/op_common.h" +#include "src/fallback/elemwise_helper/op_ternary.h" +#include "src/fallback/elemwise_helper/op_unary.h" + +#include "src/fallback/general_intrinsic/gi_float.h" +#include "src/fallback/general_intrinsic/gi_int.h" + +namespace megdnn { +namespace elemwise { + +///////////////////////////////// ParamElemVistor /////////////////////////// + +#define cb(_ctype, _inner_ctype, _simd_type, _fun_suffix) \ + template <> \ + struct ParamElemVisitor<_ctype> { \ + _simd_type operator()(const _ctype* src) const { \ + return GiLoad##_fun_suffix(src); \ + } \ + }; \ + template <> \ + struct ParamElemVisitorDup<_ctype> { \ + _simd_type operator()(const _ctype* src) const { \ + return GiBroadcast##_fun_suffix( \ + *reinterpret_cast(src)); \ + } \ + } +cb(dt_qint32, int32_t, GI_INT32_t, Int32); +cb(dt_qint8, int8_t, GI_INT8_t, Int8); + +cb(dt_float32, float, GI_FLOAT32_t, Float32); +cb(dt_int32, int32_t, GI_INT32_t, Int32); +cb(dt_int8, int8_t, GI_INT8_t, Int8); +#undef cb + +template +struct ParamElemVisitorBcast101x4; +#define cb(_ctype, _inner_ctype, _simd_type, _fun_suffix, rel_suffix) \ + template <> \ + struct ParamElemVisitorBcast101x4<_ctype> { \ + _simd_type operator()(const _ctype* src) const { \ + return GiReinter##rel_suffix##To##_fun_suffix(GiBroadcast##rel_suffix( \ + *reinterpret_cast(src))); \ + } \ + } + +cb(dt_qint8, int32_t, GI_INT8_t, Int8, Int32); +cb(dt_int8, int32_t, GI_INT8_t, Int8, Int32); +#undef cb +#define cb(_ctype, _inner_ctype, _simd_type, _fun_suffix) \ + template <> \ + struct ParamElemVisitorBcast101x4<_ctype> { \ + _simd_type operator()(const _ctype* src) const { \ + return GiLoad##_fun_suffix(src); \ + } \ + } + +cb(dt_qint32, int32_t, GI_INT32_t, Int32); +cb(dt_float32, float, GI_FLOAT32_t, Float32); +cb(dt_int32, int32_t, GI_INT32_t, Int32); +#undef cb + +} // namespace elemwise +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/fallback/elemwise_helper/kimpl/abs.h b/dnn/src/fallback/elemwise_helper/kimpl/abs.h index 8bec33b0..5fe0263c 100644 --- a/dnn/src/fallback/elemwise_helper/kimpl/abs.h +++ b/dnn/src/fallback/elemwise_helper/kimpl/abs.h @@ -58,7 +58,7 @@ struct AbsOpBase : UnaryOpBase { template <> struct AbsOp : AbsOpBase { using AbsOpBase::AbsOpBase; - constexpr static size_t SIMD_WIDTH = 16; + constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int8_t); using AbsOpBase::operator(); void operator()(const GI_INT8_V2_t& vsrc, dt_qint8* dst) const { OPERATOR_UNARY_QINT8_FALLBACK; diff --git a/dnn/src/fallback/elemwise_helper/kimpl/fuse_add_h_swish.h b/dnn/src/fallback/elemwise_helper/kimpl/fuse_add_h_swish.h index 70a7360a..72abc243 100644 --- a/dnn/src/fallback/elemwise_helper/kimpl/fuse_add_h_swish.h +++ b/dnn/src/fallback/elemwise_helper/kimpl/fuse_add_h_swish.h @@ -87,7 +87,7 @@ template <> struct FuseAddHSwishOp : FuseAddHSwishOpBase { using FuseAddHSwishOpBase::FuseAddHSwishOpBase; using FuseAddHSwishOpBase::operator(); - constexpr static size_t SIMD_WIDTH = 4; + constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int8_t); void operator()( const GI_INT32_V2_t& vsrc0, const GI_INT32_V2_t& vsrc1, dt_qint8* dst) const { diff --git a/dnn/src/fallback/elemwise_helper/kimpl/hswish.h b/dnn/src/fallback/elemwise_helper/kimpl/hswish.h index 2dfcfc2b..c6e8663f 100644 --- a/dnn/src/fallback/elemwise_helper/kimpl/hswish.h +++ b/dnn/src/fallback/elemwise_helper/kimpl/hswish.h @@ -83,7 +83,7 @@ template <> struct HSwishOp : HSwishOpBase { using HSwishOpBase::HSwishOpBase; using HSwishOpBase::operator(); - constexpr static size_t SIMD_WIDTH = 4; + constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int32_t); void operator()(const GI_INT32_V2_t& vsrc, dt_qint8* dst) const { GiStoreLowInt8(reinterpret_cast(dst), operator()(vsrc)); diff --git a/dnn/src/fallback/elemwise_helper/kimpl/max.h b/dnn/src/fallback/elemwise_helper/kimpl/max.h index 31b38641..025c08b6 100644 --- a/dnn/src/fallback/elemwise_helper/kimpl/max.h +++ b/dnn/src/fallback/elemwise_helper/kimpl/max.h @@ -77,7 +77,7 @@ struct MaxOpBase : BinaryOpBase { template <> struct MaxOp : MaxOpBase { using MaxOpBase::MaxOpBase; - constexpr static size_t SIMD_WIDTH = 16; + constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int8_t); using MaxOpBase::operator(); void operator()( diff --git a/dnn/src/fallback/elemwise_helper/kimpl/min.h b/dnn/src/fallback/elemwise_helper/kimpl/min.h index 598fce33..edac0104 100644 --- a/dnn/src/fallback/elemwise_helper/kimpl/min.h +++ b/dnn/src/fallback/elemwise_helper/kimpl/min.h @@ -74,7 +74,7 @@ struct MinOpBase : BinaryOpBase { template <> struct MinOp : MinOpBase { using MinOpBase::MinOpBase; - constexpr static size_t SIMD_WIDTH = 16; + constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int8_t); using MinOpBase::operator(); void operator()( diff --git a/dnn/src/fallback/elemwise_helper/kimpl/mul.h b/dnn/src/fallback/elemwise_helper/kimpl/mul.h index 24da646a..dc58f5ac 100644 --- a/dnn/src/fallback/elemwise_helper/kimpl/mul.h +++ b/dnn/src/fallback/elemwise_helper/kimpl/mul.h @@ -73,7 +73,7 @@ struct MulOpBase : BinaryOpBase { template <> struct MulOp : MulOpBase { using MulOpBase::MulOpBase; - constexpr static size_t SIMD_WIDTH = 16; + constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int8_t); using MulOpBase::operator(); void operator()( diff --git a/dnn/src/fallback/elemwise_helper/kimpl/none.h b/dnn/src/fallback/elemwise_helper/kimpl/none.h index 8ece32a3..9c20e510 100644 --- a/dnn/src/fallback/elemwise_helper/kimpl/none.h +++ b/dnn/src/fallback/elemwise_helper/kimpl/none.h @@ -54,8 +54,6 @@ struct NoneOpBase : UnaryOpBase { } }; -#pragma GCC diagnostic ignored "-Waddress-of-packed-member" - template <> struct NoneOp : NoneOpBase { using NoneOpBase::NoneOpBase; @@ -63,11 +61,11 @@ struct NoneOp : NoneOpBase { constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int32_t); void operator()(const GI_INT32_V2_t& vsrc, dt_qint8* dst) const { - GiStoreInt32(reinterpret_cast(dst), vsrc.val[0]); - GiStoreInt32(reinterpret_cast(dst + 16), vsrc.val[1]); + GiStoreInt32(dst, vsrc.val[0]); + GiStoreInt32(dst + 16, vsrc.val[1]); } void operator()(const GI_INT32_t& src, dt_qint8* dst) const { - GiStoreInt32(reinterpret_cast(dst), src); + GiStoreInt32(dst, src); } }; diff --git a/dnn/src/fallback/elemwise_helper/kimpl/relu.h b/dnn/src/fallback/elemwise_helper/kimpl/relu.h index 7b8365a1..6614fe64 100644 --- a/dnn/src/fallback/elemwise_helper/kimpl/relu.h +++ b/dnn/src/fallback/elemwise_helper/kimpl/relu.h @@ -112,36 +112,38 @@ struct ReluOp : ReluOpBase, FixupBase : ReluOpBase(src_scale, dst_scale), FixupBase(scale) {} void operator()(const int32x4x2_t& vsrc, dt_qint8* dst) const { - vst1_s8(reinterpret_cast(dst), operator()(vsrc)); + vst1_s8(reinterpret_cast(dst), vget_low_s8(operator()(vsrc))); } - int8x8_t operator()(const int32x4x2_t& vsrc) const { + int8x16_t operator()(const int32x4x2_t& vsrc) const { int32x4_t vitem0 = vqrdmulhq_s32(vsrc.val[0], vmultiplier); int32x4_t vitem1 = vqrdmulhq_s32(vsrc.val[1], vmultiplier); vitem0 = vmaxq_s32(vitem0, QConverterBase::vzero()); vitem1 = vmaxq_s32(vitem1, QConverterBase::vzero()); - return vqmovn_s16(vcombine_s16( + auto tmp = vqmovn_s16(vcombine_s16( vqmovn_s32(vrshlq_s32(vitem0, vshift)), vqmovn_s32(vrshlq_s32(vitem1, vshift)))); + return vcombine_s8(tmp, tmp); } - int8x8_t operator()(const float32x4_t& vsrc) const { + int8x16_t operator()(const float32x4_t& vsrc) const { int32x4_t vitem0 = vqrdmulhq_s32(vcvtq_s32_f32(vsrc), vmultiplier); vitem0 = vmaxq_s32(vitem0, QConverterBase::vzero()); vitem0 = vrshlq_s32(vitem0, vshift); int16x4_t vitem = vqmovn_s32(vitem0); - return vqmovn_s16(vcombine_s16(vitem, vitem)); + auto tmp = vqmovn_s16(vcombine_s16(vitem, vitem)); + return vcombine_s8(tmp, tmp); } void operator()(const int32x4_t& src, dt_qint8* dst) const { auto vitem0 = vmulq_f32(vcvtq_f32_s32(src), this->vscale); vitem0 = vmaxq_f32(vitem0, QConverterBase::vfzero()); - auto result = QConverter::convert(vitem0); - vst1_lane_s32(reinterpret_cast(dst), (int32x2_t)result, 0); + auto result = QConverter::convert(vitem0); + vst1q_lane_s32(reinterpret_cast(dst), (int32x4_t)result, 0); } void operator()(const float32x4_t& src, dt_qint8* dst) const { auto vitem0 = vmulq_f32(src, this->vscale); vitem0 = vmaxq_f32(vitem0, QConverterBase::vfzero()); - auto result = QConverter::convert(vitem0); - vst1_lane_s32(reinterpret_cast(dst), (int32x2_t)result, 0); + auto result = QConverter::convert(vitem0); + vst1q_lane_s32(reinterpret_cast(dst), (int32x4_t)result, 0); } }; diff --git a/dnn/src/fallback/elemwise_helper/kimpl/sub.h b/dnn/src/fallback/elemwise_helper/kimpl/sub.h index 89ba3546..c898225b 100644 --- a/dnn/src/fallback/elemwise_helper/kimpl/sub.h +++ b/dnn/src/fallback/elemwise_helper/kimpl/sub.h @@ -73,7 +73,7 @@ struct SubOpBase : BinaryOpBase { template <> struct SubOp : SubOpBase { using SubOpBase::SubOpBase; - constexpr static size_t SIMD_WIDTH = 16; + constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int8_t); using SubOpBase::operator(); void operator()( diff --git a/dnn/src/fallback/elemwise_helper/op_common.h b/dnn/src/fallback/elemwise_helper/op_common.h index 6b24ad24..11b9d3d4 100644 --- a/dnn/src/fallback/elemwise_helper/op_common.h +++ b/dnn/src/fallback/elemwise_helper/op_common.h @@ -3,7 +3,11 @@ */ #pragma once +#include "megdnn/dtype.h" +#include "src/fallback/elemwise_helper/kimpl/pow.h" + namespace megdnn { +namespace elemwise { /*! * \brief broadcast type * BCAST_x[0]x[1]...: x[i] == !stride[i] @@ -34,6 +38,1372 @@ enum BcastType { UNKNOWN_BCAST_TYPE }; +///////////////////////////////// ParamElemVistor /////////////////////////// +template +struct ParamElemVisitor; + +//! visitor single elemwise, and dup to vector +template +struct ParamElemVisitorDup; + +template +struct ParamElemVisitorBcast101x4; + +///////////////////////////////// OpCaller ///////////////////////////// +template +struct OpCallerUnary; + +template +struct OpCallerUnary { + static void run( + const typename Op::src_ctype* src, typename Op::dst_ctype* dst, + DType src_dtype, DType dst_dtype, size_t nr_elems) { + Op op(src_dtype, dst_dtype); + ParamElemVisitor vis; + size_t i = 0; + for (; i + Op::SIMD_WIDTH * 2 <= nr_elems; i += Op::SIMD_WIDTH * 2) { + op({{vis(src), vis(src + Op::SIMD_WIDTH)}}, dst); + src += Op::SIMD_WIDTH * 2; + dst += Op::SIMD_WIDTH * 2; + } +#if MEGDNN_FIX_AARCH32_BUG +// FIXME: as llvm may cause cannot select error if enable vectorize +#pragma clang loop vectorize(disable) +#endif + for (; i < nr_elems; i++) { + op(*src, dst); + src++; + dst++; + } + } +}; + +template +struct OpCallerBinary; + +///////////////////////// Pow //////////////////////////////// +template +struct OpCallerBinary, VEC_VEC> { + using Op = fallback::PowOp; + static void run( + const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, + typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, + DType dst_dtype, size_t nr_elems) { + Op op(src0_dtype, src1_dtype, dst_dtype); + size_t i = 0; +#if MEGDNN_FIX_AARCH32_BUG +// FIXME: as llvm may cause cannot select error if enable vectorize +#pragma clang loop vectorize(disable) +#endif + for (; i < nr_elems; i++) { + op(*src0, *src1, dst); + src0++; + src1++; + dst++; + } + } +}; + +template +struct OpCallerBinary, VEC_SCALAR> { + using Op = fallback::PowOp; + static void run( + const typename Op::src_ctype* src0, const typename Op::src_ctype src1, + typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, + DType dst_dtype, size_t nr_elems) { + Op op(src0_dtype, src1_dtype, dst_dtype); + size_t i = 0; +#if MEGDNN_FIX_AARCH32_BUG +// FIXME: as llvm may cause cannot select error if enable vectorize +#pragma clang loop vectorize(disable) +#endif + for (; i < nr_elems; i++) { + op(*src0, src1, dst); + src0++; + dst++; + } + } +}; + +template +struct OpCallerBinary, VEC_BCAST101> { + using Op = fallback::PowOp; + static void run( + const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, + typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, + DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) { + Op op(src0_dtype, src1_dtype, dst_dtype); + for (size_t b = 0; b < batch; b++) { + const typename Op::src_ctype* src1_ptr = src1; + for (size_t c = 0; c < channel; c++) { + size_t i = 0; +#if MEGDNN_FIX_AARCH32_BUG +// FIXME: as llvm may cause cannot select error if enable vectorize +#pragma clang loop vectorize(disable) +#endif + for (; i < channel_stride; i++) { + op(*src0, *src1_ptr, dst); + src0++; + dst++; + } + src1_ptr++; + } + } + } +}; + +template +struct OpCallerBinary, VEC_BCASTX0X> { + using Op = fallback::PowOp; + static void run( + const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, + typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, + DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) { + Op op(src0_dtype, src1_dtype, dst_dtype); + for (size_t b = 0; b < batch; b++) { + const typename Op::src_ctype* src1_ptr_base = src1 + b * channel_stride; + for (size_t c = 0; c < channel; c++) { + size_t i = 0; + auto src1_ptr = src1_ptr_base; +#if MEGDNN_FIX_AARCH32_BUG +// FIXME: as llvm may cause cannot select error if enable vectorize +#pragma clang loop vectorize(disable) +#endif + for (; i < channel_stride; i++) { + op(*src0, *src1_ptr, dst); + src0++; + src1_ptr++; + dst++; + } + } + } + } +}; + +template +struct OpCallerBinary, VEC_BCAST111C> { + using Op = fallback::PowOp; + static void run( + const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, + typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, + DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) { + Op op(src0_dtype, src1_dtype, dst_dtype); + for (size_t b = 0; b < batch; b++) { + for (size_t c = 0; c < channel; c++) { + size_t i = 0; + const typename Op::src_ctype* src1_ptr = src1; +#if MEGDNN_FIX_AARCH32_BUG +// FIXME: as llvm may cause cannot select error if enable vectorize +#pragma clang loop vectorize(disable) +#endif + for (; i < channel_stride; i++) { + op(*src0, *src1_ptr, dst); + src0++; + src1_ptr++; + dst++; + } + } + } + } +}; + +template +struct OpCallerBinary, BCAST111C_VEC> { + using Op = fallback::PowOp; + static void run( + const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, + typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, + DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) { + Op op(src0_dtype, src1_dtype, dst_dtype); + for (size_t b = 0; b < batch; b++) { + for (size_t c = 0; c < channel; c++) { + size_t i = 0; + const typename Op::src_ctype* src0_ptr = src0; +#if MEGDNN_FIX_AARCH32_BUG +// FIXME: as llvm may cause cannot select error if enable vectorize +#pragma clang loop vectorize(disable) +#endif + for (; i < channel_stride; i++) { + op(*src0_ptr, *src1, dst); + src0_ptr++; + src1++; + dst++; + } + } + } + } +}; + +template +struct OpCallerBinary, SCALAR_VEC> { + using Op = fallback::PowOp; + static void run( + const typename Op::src_ctype src0, const typename Op::src_ctype* src1, + typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, + DType dst_dtype, size_t nr_elems) { + Op op(src0_dtype, src1_dtype, dst_dtype); + size_t i = 0; +#if MEGDNN_FIX_AARCH32_BUG +// FIXME: as llvm may cause cannot select error if enable vectorize +#pragma clang loop vectorize(disable) +#endif + for (; i < nr_elems; i++) { + op(src0, *src1, dst); + src1++; + dst++; + } + } +}; + +template +struct OpCallerBinary, BCAST101_VEC> { + using Op = fallback::PowOp; + static void run( + const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, + typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, + DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) { + Op op(src0_dtype, src1_dtype, dst_dtype); + for (size_t b = 0; b < batch; b++) { + auto src0_ptr = src0; + for (size_t c = 0; c < channel; c++) { + size_t i = 0; +#if MEGDNN_FIX_AARCH32_BUG +// FIXME: as llvm may cause cannot select error if enable vectorize +#pragma clang loop vectorize(disable) +#endif + for (; i < channel_stride; i++) { + op(*src0_ptr, *src1, dst); + src1++; + dst++; + } + src0_ptr++; + } + } + } +}; + +template +struct OpCallerBinary, BCASTX0X_VEC> { + using Op = fallback::PowOp; + static void run( + const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, + typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, + DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) { + Op op(src0_dtype, src1_dtype, dst_dtype); + for (size_t b = 0; b < batch; b++) { + auto src0_ptr_base = src0 + b * channel_stride; + for (size_t c = 0; c < channel; c++) { + size_t i = 0; + auto src0_ptr = src0_ptr_base; +#if MEGDNN_FIX_AARCH32_BUG +// FIXME: as llvm may cause cannot select error if enable vectorize +#pragma clang loop vectorize(disable) +#endif + for (; i < channel_stride; i++) { + op(*src0_ptr, *src1, dst); + src0_ptr++; + src1++; + dst++; + } + } + } + } +}; + +template +struct OpCallerBinary { + static void run( + const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, + typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, + DType dst_dtype, size_t nr_elems) { + Op op(src0_dtype, src1_dtype, dst_dtype); + ParamElemVisitor vis0; + ParamElemVisitor vis1; + size_t i = 0; + for (; i + Op::SIMD_WIDTH * 2 <= nr_elems; i += Op::SIMD_WIDTH * 2) { + op({{vis0(src0), vis0(src0 + Op::SIMD_WIDTH)}}, + {{vis1(src1), vis1(src1 + Op::SIMD_WIDTH)}}, dst); + src0 += Op::SIMD_WIDTH * 2; + src1 += Op::SIMD_WIDTH * 2; + dst += Op::SIMD_WIDTH * 2; + } +#if MEGDNN_FIX_AARCH32_BUG +// FIXME: as llvm may cause cannot select error if enable vectorize +#pragma clang loop vectorize(disable) +#endif + for (; i < nr_elems; i++) { + op(*src0, *src1, dst); + src0++; + src1++; + dst++; + } + } +}; + +template +struct OpCallerBinary { + static void run( + const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, + typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, + DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) { + Op op(src0_dtype, src1_dtype, dst_dtype); + ParamElemVisitor vis0; + ParamElemVisitorDup vis1; + for (size_t b = 0; b < batch; b++) { + const typename Op::src_ctype* src1_ptr = src1; + for (size_t c = 0; c < channel; c++) { + size_t i = 0; + auto src1_simd = vis1(src1_ptr); + for (; i + Op::SIMD_WIDTH * 2 <= channel_stride; + i += Op::SIMD_WIDTH * 2) { + op({{vis0(src0), vis0(src0 + Op::SIMD_WIDTH)}}, + {{src1_simd, src1_simd}}, dst); + src0 += Op::SIMD_WIDTH * 2; + dst += Op::SIMD_WIDTH * 2; + } +#if MEGDNN_FIX_AARCH32_BUG +// FIXME: as llvm may cause cannot select error if enable vectorize +#pragma clang loop vectorize(disable) +#endif + for (; i < channel_stride; i++) { + op(*src0, *src1_ptr, dst); + src0++; + dst++; + } + src1_ptr++; + } + } + } +}; + +template +struct OpCallerBinary { + static void run( + const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, + typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, + DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) { + Op op(src0_dtype, src1_dtype, dst_dtype); + ParamElemVisitor vis; + for (size_t b = 0; b < batch; b++) { + const typename Op::src_ctype* src1_ptr_base = src1 + b * channel_stride; + for (size_t c = 0; c < channel; c++) { + size_t i = 0; + auto src1_ptr = src1_ptr_base; + for (; i + Op::SIMD_WIDTH * 2 <= channel_stride; + i += Op::SIMD_WIDTH * 2) { + auto src0_simd0 = vis(src0); + auto src0_simd1 = vis(src0 + Op::SIMD_WIDTH); + auto src1_simd0 = vis(src1_ptr); + auto src1_simd1 = vis(src1_ptr + Op::SIMD_WIDTH); + op({{src0_simd0, src0_simd1}}, {{src1_simd0, src1_simd1}}, dst); + src0 += Op::SIMD_WIDTH * 2; + src1_ptr += Op::SIMD_WIDTH * 2; + dst += Op::SIMD_WIDTH * 2; + } +#if MEGDNN_FIX_AARCH32_BUG +// FIXME: as llvm may cause cannot select error if enable vectorize +#pragma clang loop vectorize(disable) +#endif + for (; i < channel_stride; i++) { + op(*src0, *src1_ptr, dst); + src0++; + src1_ptr++; + dst++; + } + } + } + } +}; + +template +struct OpCallerBinary { + static void run( + const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, + typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, + DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) { + Op op(src0_dtype, src1_dtype, dst_dtype); + ParamElemVisitor vis; + for (size_t b = 0; b < batch; b++) { + for (size_t c = 0; c < channel; c++) { + size_t rest = channel_stride; + const typename Op::src_ctype* src1_ptr = src1; + while (rest >= Op::SIMD_WIDTH * 2) { + auto src0_simd0 = vis(src0); + auto src0_simd1 = vis(src0 + Op::SIMD_WIDTH); + auto src1_simd0 = vis(src1_ptr); + auto src1_simd1 = vis(src1_ptr + Op::SIMD_WIDTH); + src0 += Op::SIMD_WIDTH * 2; + src1_ptr += Op::SIMD_WIDTH * 2; + op({{src0_simd0, src0_simd1}}, {{src1_simd0, src1_simd1}}, dst); + dst += Op::SIMD_WIDTH * 2; + rest -= Op::SIMD_WIDTH * 2; + } +#if MEGDNN_FIX_AARCH32_BUG +// FIXME: as llvm may cause cannot select error if enable vectorize +#pragma clang loop vectorize(disable) +#endif + while (rest > 0) { + op(*src0, *src1_ptr, dst); + dst++; + src0++; + src1_ptr++; + rest--; + } + } + } + } +}; + +template +struct OpCallerBinary { + static void run( + const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, + typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, + DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) { + Op op(src0_dtype, src1_dtype, dst_dtype); + ParamElemVisitor vis; + for (size_t b = 0; b < batch; b++) { + for (size_t c = 0; c < channel; c++) { + size_t rest = channel_stride; + const typename Op::src_ctype* src0_ptr = src0; + while (rest >= Op::SIMD_WIDTH * 2) { + auto src0_simd0 = vis(src0_ptr); + auto src0_simd1 = vis(src0_ptr + Op::SIMD_WIDTH); + auto src1_simd0 = vis(src1); + auto src1_simd1 = vis(src1 + Op::SIMD_WIDTH); + src0_ptr += Op::SIMD_WIDTH * 2; + src1 += Op::SIMD_WIDTH * 2; + op({{src0_simd0, src0_simd1}}, {{src1_simd0, src1_simd1}}, dst); + dst += Op::SIMD_WIDTH * 2; + rest -= Op::SIMD_WIDTH * 2; + } +#if MEGDNN_FIX_AARCH32_BUG +// FIXME: as llvm may cause cannot select error if enable vectorize +#pragma clang loop vectorize(disable) +#endif + while (rest > 0) { + op(*src0_ptr, *src1, dst); + dst++; + src0_ptr++; + src1++; + rest--; + } + } + } + } +}; + +template +struct OpCallerBinary, BCAST101xX_VEC> { + using Op = fallback::PowOp; + static void run( + const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, + typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, + DType dst_dtype, size_t batch, size_t nr_channel_blocks, + size_t channel_stride, size_t channel_block_dim) { + Op op(src0_dtype, src1_dtype, dst_dtype); + for (size_t b = 0; b < batch; b++) { + auto src0_ptr = src0; + for (size_t cb = 0; cb < nr_channel_blocks; cb++) { + auto src0_block_ptr = src0_ptr + cb * channel_block_dim; + for (size_t i = 0; i < channel_stride; i++) { + for (size_t c_iter = 0; c_iter < channel_block_dim; c_iter++) { + op(*(src0_block_ptr + c_iter), *src1, dst); + src1++; + dst++; + } + } + } + } + } +}; + +template +struct OpCallerBinaryBcast101xXVec { + template + static void run( + const src_ctype* src0, const src_ctype* src1, typename Op::dst_ctype* dst, + const Op& op, size_t batch, size_t nr_channel_blocks, + size_t channel_stride) { + for (size_t b = 0; b < batch; b++) { + auto src0_ptr = src0; + for (size_t cb = 0; cb < nr_channel_blocks; cb++) { + auto src0_block_ptr = src0_ptr + cb * channel_block_dim; + for (size_t img_index = 0; img_index < channel_stride; img_index++) { + for (size_t c_iter = 0; c_iter < channel_block_dim; c_iter++) { + op(*(src0_block_ptr + c_iter), *src1, dst); + src1++; + dst++; + } + } + } + } + } +}; + +template +struct OpCallerBinaryBcast101xDVec { + template + static void run( + const src_ctype* src0, const src_ctype* src1, typename Op::dst_ctype* dst, + const Op& op, const Vis0& vis0, const Vis1& vis1, size_t batch, + size_t nr_channel_blocks, size_t channel_stride) { + for (size_t b = 0; b < batch; b++) { + auto src0_ptr = src0; + for (size_t cb = 0; cb < nr_channel_blocks; cb++) { + auto src0_block_ptr = src0_ptr + cb * channel_block_dim; + auto channel_block_vec = vis0(src0_block_ptr); + size_t img_index = 0; + auto src1_offset = Op::SIMD_WIDTH / channel_block_dim; + for (; img_index + 2 * src1_offset <= channel_stride; + img_index += 2 * src1_offset) { + op({{channel_block_vec, channel_block_vec}}, + {{vis1(src1), vis1(src1 + Op::SIMD_WIDTH)}}, dst); + src1 += Op::SIMD_WIDTH * 2; + dst += Op::SIMD_WIDTH * 2; + } + // TODO:all elemwise_multi_type op imp one simd mode + for (; img_index < channel_stride; img_index++) { + for (size_t c_iter = 0; c_iter < channel_block_dim; c_iter++) { + op(*(src0_block_ptr + c_iter), *src1, dst); + src1++; + dst++; + } + } + } + } + } +}; + +template +struct OpCallerBinaryBcast101xXVec { + template + static void run( + const src_ctype* src0, const src_ctype* src1, typename Op::dst_ctype* dst, + const Op& op, size_t batch, size_t nr_channel_blocks, + size_t channel_stride) { + ParamElemVisitorBcast101x4 vis0; + ParamElemVisitor vis1; + OpCallerBinaryBcast101xDVec::run( + src0, src1, dst, op, vis0, vis1, batch, nr_channel_blocks, + channel_stride); + } +}; + +template +struct OpCallerBinary { + static void run( + const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, + typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, + DType dst_dtype, size_t batch, size_t nr_channel_blocks, + size_t channel_stride, size_t channel_block_dim) { + megdnn_assert( + channel_block_dim == 4 || channel_block_dim == 8, + "only imp for nchw44/nchw88"); + Op op(src0_dtype, src1_dtype, dst_dtype); + if (channel_block_dim == 4) { + OpCallerBinaryBcast101xXVec::run( + src0, src1, dst, op, batch, nr_channel_blocks, channel_stride); + } else { + OpCallerBinaryBcast101xXVec::run( + src0, src1, dst, op, batch, nr_channel_blocks, channel_stride); + } + } +}; + +template +struct OpCallerBinary, VEC_BCAST101xX> { + using Op = fallback::PowOp; + static void run( + const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, + typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, + DType dst_dtype, size_t batch, size_t nr_channel_blocks, + size_t channel_stride, size_t channel_block_dim) { + Op op(src0_dtype, src1_dtype, dst_dtype); + for (size_t b = 0; b < batch; b++) { + auto src1_ptr = src1; + for (size_t cb = 0; cb < nr_channel_blocks; cb++) { + auto src1_block_ptr = src1_ptr + cb * channel_block_dim; + for (size_t i = 0; i < channel_stride; i++) { + for (size_t c_iter = 0; c_iter < channel_block_dim; c_iter++) { + op(*(src0), *(src1_block_ptr + c_iter), dst); + src0++; + dst++; + } + } + } + } + } +}; + +template +struct OpCallerBinaryVecBcast101xX { + template + static void run( + const src_ctype* src0, const src_ctype* src1, typename Op::dst_ctype* dst, + const Op& op, size_t batch, size_t nr_channel_blocks, + size_t channel_stride) { + for (size_t b = 0; b < batch; b++) { + auto src1_ptr = src1; + for (size_t cb = 0; cb < nr_channel_blocks; cb++) { + auto src1_block_ptr = src1_ptr + cb * channel_block_dim; + for (size_t img_index = 0; img_index < channel_stride; img_index++) { + for (size_t c_iter = 0; c_iter < channel_block_dim; c_iter++) { + op(*src0, *(src1_block_ptr + c_iter), dst); + src0++; + dst++; + } + } + } + } + } +}; + +template +struct OpCallerBinaryVecBcast101xD { + template + static void run( + const src_ctype* src0, const src_ctype* src1, typename Op::dst_ctype* dst, + const Op& op, const Vis0& vis0, const Vis1& vis1, size_t batch, + size_t nr_channel_blocks, size_t channel_stride) { + for (size_t b = 0; b < batch; b++) { + auto src1_ptr = src1; + for (size_t cb = 0; cb < nr_channel_blocks; cb++) { + auto src1_block_ptr = src1_ptr + cb * channel_block_dim; + auto channel_block_vec = vis1(src1_block_ptr); + size_t img_index = 0; + auto src0_offset = Op::SIMD_WIDTH / channel_block_dim; + for (; img_index + 2 * src0_offset <= channel_stride; + img_index += 2 * src0_offset) { + op({{vis0(src0), vis0(src0 + Op::SIMD_WIDTH)}}, + {{channel_block_vec, channel_block_vec}}, dst); + src0 += Op::SIMD_WIDTH * 2; + dst += Op::SIMD_WIDTH * 2; + } + // TODO:all elemwise_multi_type op imp one simd mode + for (; img_index < channel_stride; img_index++) { + for (size_t c_iter = 0; c_iter < channel_block_dim; c_iter++) { + op(*src0, *(src1_block_ptr + c_iter), dst); + src0++; + dst++; + } + } + } + } + } +}; + +template +struct OpCallerBinaryVecBcast101xX { + template + static void run( + const src_ctype* src0, const src_ctype* src1, typename Op::dst_ctype* dst, + const Op& op, size_t batch, size_t nr_channel_blocks, + size_t channel_stride) { + ParamElemVisitor vis0; + ParamElemVisitorBcast101x4 vis1; + OpCallerBinaryVecBcast101xD::run( + src0, src1, dst, op, vis0, vis1, batch, nr_channel_blocks, + channel_stride); + } +}; + +template +struct OpCallerBinary { + static void run( + const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, + typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, + DType dst_dtype, size_t batch, size_t nr_channel_blocks, + size_t channel_stride, size_t channel_block_dim) { + megdnn_assert( + channel_block_dim == 4 || channel_block_dim == 8, + "only imp for nchw44/nchw88"); + Op op(src0_dtype, src1_dtype, dst_dtype); + if (channel_block_dim == 4) { + OpCallerBinaryVecBcast101xX::run( + src0, src1, dst, op, batch, nr_channel_blocks, channel_stride); + } else { + OpCallerBinaryVecBcast101xX::run( + src0, src1, dst, op, batch, nr_channel_blocks, channel_stride); + } + } +}; + +template +struct OpCallerBinary { + static void run( + const typename Op::src_ctype* src0, const typename Op::src_ctype src1, + typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, + DType dst_dtype, size_t nr_elems) { + Op op(src0_dtype, src1_dtype, dst_dtype); + ParamElemVisitor vis0; + ParamElemVisitorDup vis1; + auto vis1_simd = vis1(&src1); + size_t i = 0; + for (; i + Op::SIMD_WIDTH * 2 <= nr_elems; i += Op::SIMD_WIDTH * 2) { + op({{vis0(src0), vis0(src0 + Op::SIMD_WIDTH)}}, {{vis1_simd, vis1_simd}}, + dst); + src0 += Op::SIMD_WIDTH * 2; + dst += Op::SIMD_WIDTH * 2; + } +#if MEGDNN_FIX_AARCH32_BUG +// FIXME: as llvm may cause cannot select error if enable vectorize +#pragma clang loop vectorize(disable) +#endif + for (; i < nr_elems; i++) { + op(*src0, src1, dst); + src0++; + dst++; + } + } +}; + +//! this only for nonswap op, like SUB and DIV +template +struct OpCallerBinary { + static void run( + const typename Op::src_ctype src0, const typename Op::src_ctype* src1, + typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, + DType dst_dtype, size_t nr_elems) { + Op op(src0_dtype, src1_dtype, dst_dtype); + ParamElemVisitorDup vis0; + ParamElemVisitor vis1; + auto vis0_simd = vis0(&src0); + size_t i = 0; + for (; i + Op::SIMD_WIDTH * 2 <= nr_elems; i += Op::SIMD_WIDTH * 2) { + op({{vis0_simd, vis0_simd}}, {{vis1(src1), vis1(src1 + Op::SIMD_WIDTH)}}, + dst); + src1 += Op::SIMD_WIDTH * 2; + dst += Op::SIMD_WIDTH * 2; + } +#if MEGDNN_FIX_AARCH32_BUG +// FIXME: as llvm may cause cannot select error if enable vectorize +#pragma clang loop vectorize(disable) +#endif + for (; i < nr_elems; i++) { + op(src0, *src1, dst); + src1++; + dst++; + } + } +}; + +template +struct OpCallerBinary { + static void run( + const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, + typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, + DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) { + Op op(src0_dtype, src1_dtype, dst_dtype); + ParamElemVisitorDup vis0; + ParamElemVisitor vis1; + for (size_t b = 0; b < batch; b++) { + auto src0_ptr = src0; + for (size_t c = 0; c < channel; c++) { + auto vis0_simd = vis0(src0_ptr); + size_t i = 0; + for (; i + Op::SIMD_WIDTH * 2 <= channel_stride; + i += Op::SIMD_WIDTH * 2) { + op({{vis0_simd, vis0_simd}}, + {{vis1(src1), vis1(src1 + Op::SIMD_WIDTH)}}, dst); + src1 += Op::SIMD_WIDTH * 2; + dst += Op::SIMD_WIDTH * 2; + } +#if MEGDNN_FIX_AARCH32_BUG +// FIXME: as llvm may cause cannot select error if enable vectorize +#pragma clang loop vectorize(disable) +#endif + for (; i < channel_stride; i++) { + op(*src0_ptr, *src1, dst); + src1++; + dst++; + } + src0_ptr++; + } + } + } +}; + +template +struct OpCallerBinary { + static void run( + const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, + typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, + DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) { + Op op(src0_dtype, src1_dtype, dst_dtype); + ParamElemVisitor vis; + for (size_t b = 0; b < batch; b++) { + auto src0_ptr_base = src0 + b * channel_stride; + for (size_t c = 0; c < channel; c++) { + auto src0_ptr = src0_ptr_base; + size_t i = 0; + for (; i + Op::SIMD_WIDTH * 2 <= channel_stride; + i += Op::SIMD_WIDTH * 2) { + auto src0_simd0 = vis(src0_ptr); + auto src0_simd1 = vis(src0_ptr + Op::SIMD_WIDTH); + auto src1_simd0 = vis(src1); + auto src1_simd1 = vis(src1 + Op::SIMD_WIDTH); + op({{src0_simd0, src0_simd1}}, {{src1_simd0, src1_simd1}}, dst); + src0_ptr += Op::SIMD_WIDTH * 2; + src1 += Op::SIMD_WIDTH * 2; + dst += Op::SIMD_WIDTH * 2; + } +#if MEGDNN_FIX_AARCH32_BUG +// FIXME: as llvm may cause cannot select error if enable vectorize +#pragma clang loop vectorize(disable) +#endif + for (; i < channel_stride; i++) { + op(*src0_ptr, *src1, dst); + src0_ptr++; + src1++; + dst++; + } + } + } + } +}; + +template +struct OpCallerTernary; + +template +struct OpCallerTernary { + static void run( + const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, + const typename Op::src_ctype* src2, typename Op::dst_ctype* dst, + DType src0_dtype, DType src1_dtype, DType src2_dtype, DType dst_dtype, + size_t nr_elems) { + Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype); + ParamElemVisitor vis0; + ParamElemVisitor vis1; + ParamElemVisitor vis2; + size_t i = 0; + for (; i + Op::SIMD_WIDTH * 2 <= nr_elems; i += Op::SIMD_WIDTH * 2) { + op({{vis0(src0), vis0(src0 + Op::SIMD_WIDTH)}}, + {{vis1(src1), vis1(src1 + Op::SIMD_WIDTH)}}, + {{vis2(src2), vis2(src2 + Op::SIMD_WIDTH)}}, dst); + src0 += Op::SIMD_WIDTH * 2; + src1 += Op::SIMD_WIDTH * 2; + src2 += Op::SIMD_WIDTH * 2; + dst += Op::SIMD_WIDTH * 2; + } +#if MEGDNN_FIX_AARCH32_BUG +// FIXME: as llvm may cause cannot select error if enable vectorize +#pragma clang loop vectorize(disable) +#endif + for (; i < nr_elems; i++) { + op(*src0, *src1, *src2, dst); + src0++; + src1++; + src2++; + dst++; + } + } +}; + +//! src0: vector, src1: vector, src2: scalar +template +struct OpCallerTernary { + static void run( + const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, + const typename Op::src_ctype src2, typename Op::dst_ctype* dst, + DType src0_dtype, DType src1_dtype, DType src2_dtype, DType dst_dtype, + size_t nr_elems) { + Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype); + ParamElemVisitor vis0; + ParamElemVisitor vis1; + ParamElemVisitorDup vis2; + auto vis2_simd = vis2(&src2); + size_t i = 0; + for (; i + Op::SIMD_WIDTH * 2 <= nr_elems; i += Op::SIMD_WIDTH * 2) { + op({{vis0(src0), vis0(src0 + Op::SIMD_WIDTH)}}, + {{vis1(src1), vis1(src1 + Op::SIMD_WIDTH)}}, {{vis2_simd, vis2_simd}}, + dst); + src0 += Op::SIMD_WIDTH * 2; + src1 += Op::SIMD_WIDTH * 2; + dst += Op::SIMD_WIDTH * 2; + } +#if MEGDNN_FIX_AARCH32_BUG +// FIXME: as llvm may cause cannot select error if enable vectorize +#pragma clang loop vectorize(disable) +#endif + for (; i < nr_elems; i++) { + op(*src0, *src1, src2, dst); + src0++; + src1++; + dst++; + } + } +}; + +//! src0: 1C11, src1: vector, src2: 1C11 +template +struct OpCallerTernary { + static void run( + const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, + const typename Op::src_ctype* src2, typename Op::dst_ctype* dst, + DType src0_dtype, DType src1_dtype, DType src2_dtype, DType dst_dtype, + size_t batch_size, size_t channel_size, size_t channel_stride, + size_t batch_offset) { + Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype); + ParamElemVisitor vis1; + ParamElemVisitorDup vis0; + ParamElemVisitorDup vis2; + for (size_t batch = 0; batch < batch_size; batch++) { + auto src0_ptr = src0; + auto src2_ptr = src2; + auto b_offset = batch_offset; + for (size_t channel = 0; channel < channel_size; channel++) { + size_t i = 0; + auto src0_simd = vis0(src0_ptr); + auto src2_simd = vis2(src2_ptr); + for (; i + Op::SIMD_WIDTH * 2 <= channel_stride; + i += Op::SIMD_WIDTH * 2) { + op({{src0_simd, src0_simd}}, + {{vis1(src1), vis1(src1 + Op::SIMD_WIDTH)}}, + {{src2_simd, src2_simd}}, dst); + src1 += Op::SIMD_WIDTH * 2; + dst += Op::SIMD_WIDTH * 2; + b_offset -= Op::SIMD_WIDTH * 2; + } +#if MEGDNN_FIX_AARCH32_BUG +// FIXME: as llvm may cause cannot select error if enable vectorize +#pragma clang loop vectorize(disable) +#endif + for (; i < channel_stride; i++) { + op(*src0_ptr, *src1, *src2_ptr, dst); + src1++; + dst++; + b_offset--; + } + src0_ptr++; + src2_ptr++; + } + src1 += b_offset; + dst += b_offset; + } + } +}; + +//! src0: 111C, src1: vector, src2: 111C, src1 may not be contig +template +struct OpCallerTernary { + static void run( + const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, + size_t src1_offset, const typename Op::src_ctype* src2, + typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, + DType src2_dtype, DType dst_dtype, size_t batch_size, size_t channel_size, + size_t channel_stride, size_t batch_offset) { + Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype); + ParamElemVisitor vis; + for (size_t batch = 0; batch < batch_size; batch++) { + auto b_offset = batch_offset; + for (size_t channel = 0; channel < channel_size; channel++) { + auto src0_ptr = src0; + auto src2_ptr = src2; + size_t i = 0; + for (; i + Op::SIMD_WIDTH * 2 <= channel_stride; + i += Op::SIMD_WIDTH * 2) { + auto src0_simd0 = vis(src0_ptr); + auto src0_simd1 = vis(src0_ptr + Op::SIMD_WIDTH); + auto src1_simd0 = vis(src1); + auto src1_simd1 = vis(src1 + Op::SIMD_WIDTH); + auto src2_simd0 = vis(src2_ptr); + auto src2_simd1 = vis(src2_ptr + Op::SIMD_WIDTH); + op({{src0_simd0, src0_simd1}}, {{src1_simd0, src1_simd1}}, + {{src2_simd0, src2_simd1}}, dst); + src0_ptr += Op::SIMD_WIDTH * 2; + src1 += Op::SIMD_WIDTH * 2; + src2_ptr += Op::SIMD_WIDTH * 2; + dst += Op::SIMD_WIDTH * 2; + b_offset -= Op::SIMD_WIDTH * 2; + } +#if MEGDNN_FIX_AARCH32_BUG +// FIXME: as llvm may cause cannot select error if enable vectorize +#pragma clang loop vectorize(disable) +#endif + for (; i < channel_stride; i++) { + op(*src0_ptr, *src1, *src2_ptr, dst); + src0_ptr++; + src1++; + src2_ptr++; + dst++; + b_offset--; + } + src1 += src1_offset; + } + src1 += b_offset; + dst += b_offset; + } + } +}; + +template +struct OpCallerTernaryBcast101xXVecBcast101xX { + template + static void run( + const src_ctype* src0, const src_ctype* src1, const src_ctype* src2, + typename Op::dst_ctype* dst, const Op& op, size_t batch, + size_t nr_channel_blocks, size_t channel_stride) { + for (size_t b = 0; b < batch; b++) { + auto src0_ptr = src0; + auto src2_ptr = src2; + for (size_t cb = 0; cb < nr_channel_blocks; cb++) { + auto src0_block_ptr = src0_ptr + cb * channel_block_dim; + auto src2_block_ptr = src2_ptr + cb * channel_block_dim; + for (size_t img_index = 0; img_index < channel_stride; img_index++) { + for (size_t c_iter = 0; c_iter < channel_block_dim; c_iter++) { + op(*(src0_block_ptr + c_iter), *src1, + *(src2_block_ptr + c_iter), dst); + src1++; + dst++; + } + } + } + } + } +}; + +template +struct OpCallerTernaryBcast101xDVecBcast101xD { + template + static void run( + const src_ctype* src0, const src_ctype* src1, const src_ctype* src2, + typename Op::dst_ctype* dst, const Op& op, const Vis0& vis0, + const Vis1& vis1, const Vis2& vis2, size_t batch, size_t nr_channel_blocks, + size_t channel_stride) { + for (size_t b = 0; b < batch; b++) { + auto src0_ptr = src0; + auto src2_ptr = src2; + for (size_t cb = 0; cb < nr_channel_blocks; cb++) { + auto src0_block_ptr = src0_ptr + cb * channel_block_dim; + auto src2_block_ptr = src2_ptr + cb * channel_block_dim; + auto channel_block_vec0 = vis0(src0_block_ptr); + auto channel_block_vec2 = vis2(src2_block_ptr); + size_t img_index = 0; + auto src1_offset = Op::SIMD_WIDTH / channel_block_dim; + for (; img_index + 2 * src1_offset <= channel_stride; + img_index += 2 * src1_offset) { + op({{channel_block_vec0, channel_block_vec0}}, + {{vis1(src1), vis1(src1 + Op::SIMD_WIDTH)}}, + {{channel_block_vec2, channel_block_vec2}}, dst); + src1 += Op::SIMD_WIDTH * 2; + dst += Op::SIMD_WIDTH * 2; + } + // TODO:all elemwise_multi_type op imp one simd mode + for (; img_index < channel_stride; img_index++) { + for (size_t c_iter = 0; c_iter < channel_block_dim; c_iter++) { + op(*(src0_block_ptr + c_iter), *src1, + *(src2_block_ptr + c_iter), dst); + src1++; + dst++; + } + } + } + } + } +}; + +//! src0: CHW44, src1: vector, src2: CHW44 +template +struct OpCallerTernaryBcast101xXVecBcast101xX { + template + static void run( + const src_ctype* src0, const src_ctype* src1, const src_ctype* src2, + typename Op::dst_ctype* dst, const Op& op, size_t batch, + size_t nr_channel_blocks, size_t channel_stride) { + ParamElemVisitorBcast101x4 vis0; + ParamElemVisitor vis1; + ParamElemVisitorBcast101x4 vis2; + OpCallerTernaryBcast101xDVecBcast101xD::run( + src0, src1, src2, dst, op, vis0, vis1, vis2, batch, nr_channel_blocks, + channel_stride); + } +}; + +template +struct OpCallerTernary { + static void run( + const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, + const typename Op::src_ctype* src2, typename Op::dst_ctype* dst, + DType src0_dtype, DType src1_dtype, DType src2_dtype, DType dst_dtype, + size_t batch, size_t nr_channel_blocks, size_t channel_stride, + size_t channel_block_dim) { + megdnn_assert( + channel_block_dim == 4 || channel_block_dim == 8, + "only imp for nchw44/nchw88"); + Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype); + if (channel_block_dim == 4) { + OpCallerTernaryBcast101xXVecBcast101xX::run( + src0, src1, src2, dst, op, batch, nr_channel_blocks, + channel_stride); + } else { + OpCallerTernaryBcast101xXVecBcast101xX::run( + src0, src1, src2, dst, op, batch, nr_channel_blocks, + channel_stride); + } + } +}; + +//! src1: 1C11, src0 and src2 are contig +template +struct OpCallerTernary { + static void run( + const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, + const typename Op::src_ctype* src2, typename Op::dst_ctype* dst, + DType src0_dtype, DType src1_dtype, DType src2_dtype, DType dst_dtype, + size_t batch_size, size_t channel_size, size_t channel_stride) { + Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype); + ParamElemVisitor vis0; + ParamElemVisitorDup vis1; + ParamElemVisitor vis2; + for (size_t batch = 0; batch < batch_size; batch++) { + auto src1_ptr = src1; + for (size_t channel = 0; channel < channel_size; channel++) { + size_t i = 0; + auto src1_simd = vis1(src1_ptr); + for (; i + Op::SIMD_WIDTH * 2 <= channel_stride; + i += Op::SIMD_WIDTH * 2) { + op({{vis0(src0), vis0(src0 + Op::SIMD_WIDTH)}}, + {{src1_simd, src1_simd}}, + {{vis2(src2), vis2(src2 + Op::SIMD_WIDTH)}}, dst); + src0 += Op::SIMD_WIDTH * 2; + src2 += Op::SIMD_WIDTH * 2; + dst += Op::SIMD_WIDTH * 2; + } +#if MEGDNN_FIX_AARCH32_BUG +// FIXME: as llvm may cause cannot select error if enable vectorize +#pragma clang loop vectorize(disable) +#endif + for (; i < channel_stride; i++) { + op(*src0, *src1_ptr, *src2, dst); + src0++; + src2++; + dst++; + } + src1_ptr++; + } + } + } +}; + +//! src1: 111C, src0 and src2 may not be contig +template +struct OpCallerTernary { + static void run( + const typename Op::src_ctype* src0, size_t src0_offset, + const typename Op::src_ctype* src1, const typename Op::src_ctype* src2, + size_t src2_offset, typename Op::dst_ctype* dst, DType src0_dtype, + DType src1_dtype, DType src2_dtype, DType dst_dtype, size_t batch_size, + size_t channel_size, size_t channel_stride) { + Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype); + ParamElemVisitor vis0; + ParamElemVisitor vis1; + ParamElemVisitor vis2; + for (size_t batch = 0; batch < batch_size; batch++) { + for (size_t channel = 0; channel < channel_size; channel++) { + auto src1_ptr = src1; + size_t i = 0; + for (; i + Op::SIMD_WIDTH * 2 <= channel_stride; + i += Op::SIMD_WIDTH * 2) { + op({{vis0(src0), vis0(src0 + Op::SIMD_WIDTH)}}, + {{vis1(src1_ptr), vis1(src1_ptr + Op::SIMD_WIDTH)}}, + {{vis2(src2), vis2(src2 + Op::SIMD_WIDTH)}}, dst); + src0 += Op::SIMD_WIDTH * 2; + src1_ptr += Op::SIMD_WIDTH * 2; + src2 += Op::SIMD_WIDTH * 2; + dst += Op::SIMD_WIDTH * 2; + } +#if MEGDNN_FIX_AARCH32_BUG +// FIXME: as llvm may cause cannot select error if enable vectorize +#pragma clang loop vectorize(disable) +#endif + for (; i < channel_stride; i++) { + op(*src0, *src1_ptr, *src2, dst); + src0++; + src1_ptr++; + src2++; + dst++; + } + src0 += src0_offset; + src2 += src2_offset; + } + } + } +}; + +template +struct OpCallerTernaryVecBcast101xXVec { + template + static void run( + const src_ctype* src0, const src_ctype* src1, const src_ctype* src2, + typename Op::dst_ctype* dst, const Op& op, size_t batch, + size_t nr_channel_blocks, size_t channel_stride) { + for (size_t b = 0; b < batch; b++) { + auto src1_ptr = src1; + for (size_t cb = 0; cb < nr_channel_blocks; cb++) { + auto src1_block_ptr = src1_ptr + cb * channel_block_dim; + for (size_t img_index = 0; img_index < channel_stride; img_index++) { + for (size_t c_iter = 0; c_iter < channel_block_dim; c_iter++) { + op(*src0, *(src1_block_ptr + c_iter), *src2, dst); + src0++; + src2++; + dst++; + } + } + } + } + } +}; + +//! src1: CHW44, src0 and src2 are contig +template +struct OpCallerTernaryVecBcast101xDVec { + template + static void run( + const src_ctype* src0, const src_ctype* src1, const src_ctype* src2, + typename Op::dst_ctype* dst, const Op& op, const Vis0& vis0, + const Vis1& vis1, const Vis2& vis2, size_t batch, size_t nr_channel_blocks, + size_t channel_stride) { + for (size_t b = 0; b < batch; b++) { + auto src1_ptr = src1; + for (size_t cb = 0; cb < nr_channel_blocks; cb++) { + auto src1_block_ptr = src1_ptr + cb * channel_block_dim; + auto channel_block_vec = vis1(src1_block_ptr); + size_t img_index = 0; + auto offset = Op::SIMD_WIDTH / channel_block_dim; + for (; img_index + 2 * offset <= channel_stride; + img_index += 2 * offset) { + op({{vis0(src0), vis0(src0 + Op::SIMD_WIDTH)}}, + {{channel_block_vec, channel_block_vec}}, + {{vis2(src2), vis2(src2 + Op::SIMD_WIDTH)}}, dst); + src0 += Op::SIMD_WIDTH * 2; + src2 += Op::SIMD_WIDTH * 2; + dst += Op::SIMD_WIDTH * 2; + } + // TODO:all elemwise_multi_type op imp one simd mode + for (; img_index < channel_stride; img_index++) { + for (size_t c_iter = 0; c_iter < channel_block_dim; c_iter++) { + op(*src0, *(src1_block_ptr + c_iter), *src2, dst); + src0++; + src2++; + dst++; + } + } + } + } + } +}; + +template +struct OpCallerTernaryVecBcast101xXVec { + template + static void run( + const src_ctype* src0, const src_ctype* src1, const src_ctype* src2, + typename Op::dst_ctype* dst, const Op& op, size_t batch, + size_t nr_channel_blocks, size_t channel_stride) { + ParamElemVisitor vis0; + ParamElemVisitorBcast101x4 vis1; + ParamElemVisitor vis2; + OpCallerTernaryVecBcast101xDVec::run( + src0, src1, src2, dst, op, vis0, vis1, vis2, batch, nr_channel_blocks, + channel_stride); + } +}; + +template +struct OpCallerTernary { + static void run( + const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, + const typename Op::src_ctype* src2, typename Op::dst_ctype* dst, + DType src0_dtype, DType src1_dtype, DType src2_dtype, DType dst_dtype, + size_t batch, size_t nr_channel_blocks, size_t channel_stride, + size_t channel_block_dim) { + megdnn_assert( + channel_block_dim == 4 || channel_block_dim == 8, + "only imp for nchw44/nchw88"); + + Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype); + if (channel_block_dim == 4) { + OpCallerTernaryVecBcast101xXVec::run( + src0, src1, src2, dst, op, batch, nr_channel_blocks, + channel_stride); + } else { + OpCallerTernaryVecBcast101xXVec::run( + src0, src1, src2, dst, op, batch, nr_channel_blocks, + channel_stride); + } + } +}; + +//! src1: scalar, src0 and src2 has the same shape +template +struct OpCallerTernary { + static void run( + const typename Op::src_ctype* src0, const typename Op::src_ctype src1, + const typename Op::src_ctype* src2, typename Op::dst_ctype* dst, + DType src0_dtype, DType src1_dtype, DType src2_dtype, DType dst_dtype, + size_t nr_elems) { + Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype); + ParamElemVisitor vis0; + ParamElemVisitorDup vis1; + ParamElemVisitor vis2; + auto vis1_simd = vis1(&src1); + size_t i = 0; + for (; i + Op::SIMD_WIDTH * 2 <= nr_elems; i += Op::SIMD_WIDTH * 2) { + op({{vis0(src0), vis0(src0 + Op::SIMD_WIDTH)}}, {{vis1_simd, vis1_simd}}, + {{vis2(src2), vis2(src2 + Op::SIMD_WIDTH)}}, dst); + src0 += Op::SIMD_WIDTH * 2; + src2 += Op::SIMD_WIDTH * 2; + dst += Op::SIMD_WIDTH * 2; + } +#if MEGDNN_FIX_AARCH32_BUG +// FIXME: as llvm may cause cannot select error if enable vectorize +#pragma clang loop vectorize(disable) +#endif + for (; i < nr_elems; i++) { + op(*src0, src1, *src2, dst); + src0++; + src2++; + dst++; + } + } +}; + +//! src1, src2: scalar, src0 is vector +template +struct OpCallerTernary { + static void run( + const typename Op::src_ctype* src0, const typename Op::src_ctype src1, + const typename Op::src_ctype src2, typename Op::dst_ctype* dst, + DType src0_dtype, DType src1_dtype, DType src2_dtype, DType dst_dtype, + size_t nr_elems) { + Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype); + ParamElemVisitor vis0; + ParamElemVisitorDup vis1; + ParamElemVisitorDup vis2; + auto vis1_simd = vis1(&src1); + auto vis2_simd = vis2(&src2); + size_t i = 0; + for (; i + Op::SIMD_WIDTH * 2 <= nr_elems; i += Op::SIMD_WIDTH * 2) { + op({{vis0(src0), vis0(src0 + Op::SIMD_WIDTH)}}, {{vis1_simd, vis1_simd}}, + {{vis2_simd, vis2_simd}}, dst); + src0 += Op::SIMD_WIDTH * 2; + dst += Op::SIMD_WIDTH * 2; + } +#if MEGDNN_FIX_AARCH32_BUG +// FIXME: as llvm may cause cannot select error if enable vectorize +#pragma clang loop vectorize(disable) +#endif + for (; i < nr_elems; i++) { + op(*src0, src1, src2, dst); + src0++; + dst++; + } + } +}; + +} // namespace elemwise } // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/fallback/elemwise_op.h b/dnn/src/fallback/elemwise_op.h deleted file mode 100644 index 2bd961aa..00000000 --- a/dnn/src/fallback/elemwise_op.h +++ /dev/null @@ -1,1432 +0,0 @@ -/** - * \file dnn/src/fallback/elemwise_op.h - */ - -#pragma once - -#include "src/fallback/elemwise_helper/op_binary.h" -#include "src/fallback/elemwise_helper/op_common.h" -#include "src/fallback/elemwise_helper/op_ternary.h" -#include "src/fallback/elemwise_helper/op_unary.h" - -#include "src/fallback/general_intrinsic/gi_float.h" -#include "src/fallback/general_intrinsic/gi_int.h" - -namespace megdnn { -namespace fallback { - -///////////////////////////////// ParamElemVistor /////////////////////////// -template -struct ParamElemVisitor; - -//! visitor single elemwise, and dup to vector -template -struct ParamElemVisitorDup; - -#define cb(_ctype, _inner_ctype, _simd_type, _fun_suffix) \ - template <> \ - struct ParamElemVisitor<_ctype> { \ - _simd_type operator()(const _ctype* src) const { \ - return GiLoad##_fun_suffix(reinterpret_cast(src)); \ - } \ - }; \ - template <> \ - struct ParamElemVisitorDup<_ctype> { \ - _simd_type operator()(const _ctype* src) const { \ - return GiBroadcast##_fun_suffix( \ - *reinterpret_cast(src)); \ - } \ - } -cb(dt_qint32, int32_t, GI_INT32_t, Int32); -cb(dt_qint8, int8_t, GI_INT8_t, Int8); - -cb(dt_float32, float, GI_FLOAT32_t, Float32); -cb(dt_int32, int32_t, GI_INT32_t, Int32); -cb(dt_int8, int8_t, GI_INT8_t, Int8); -#undef cb - -template -struct ParamElemVisitorBcast101x4; -#define cb(_ctype, _inner_ctype, _simd_type, _fun_suffix, rel_suffix) \ - template <> \ - struct ParamElemVisitorBcast101x4<_ctype> { \ - _simd_type operator()(const _ctype* src) const { \ - return GiReinter##rel_suffix##To##_fun_suffix(GiBroadcast##rel_suffix( \ - *reinterpret_cast(src))); \ - } \ - } - -cb(dt_qint8, int32_t, GI_INT8_t, Int8, Int32); -cb(dt_int8, int32_t, GI_INT8_t, Int8, Int32); -#undef cb -#define cb(_ctype, _inner_ctype, _simd_type, _fun_suffix) \ - template <> \ - struct ParamElemVisitorBcast101x4<_ctype> { \ - _simd_type operator()(const _ctype* src) const { \ - return GiLoad##_fun_suffix(reinterpret_cast(src)); \ - } \ - } - -cb(dt_qint32, int32_t, GI_INT32_t, Int32); -cb(dt_float32, float, GI_FLOAT32_t, Float32); -cb(dt_int32, int32_t, GI_INT32_t, Int32); -#undef cb - -///////////////////////////////// OpCaller ///////////////////////////// -template -struct OpCallerUnary; - -template -struct OpCallerUnary { - static void run( - const typename Op::src_ctype* src, typename Op::dst_ctype* dst, - DType src_dtype, DType dst_dtype, size_t nr_elems) { - Op op(src_dtype, dst_dtype); - ParamElemVisitor vis; - size_t i = 0; - for (; i + Op::SIMD_WIDTH * 2 <= nr_elems; i += Op::SIMD_WIDTH * 2) { - op({{vis(src), vis(src + Op::SIMD_WIDTH)}}, dst); - src += Op::SIMD_WIDTH * 2; - dst += Op::SIMD_WIDTH * 2; - } -#if MEGDNN_FIX_AARCH32_BUG -// FIXME: as llvm may cause cannot select error if enable vectorize -#pragma clang loop vectorize(disable) -#endif - for (; i < nr_elems; i++) { - op(*src, dst); - src++; - dst++; - } - } -}; - -template -struct OpCallerBinary; - -///////////////////////// Pow //////////////////////////////// -template -struct OpCallerBinary, VEC_VEC> { - using Op = PowOp; - static void run( - const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, - typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, - DType dst_dtype, size_t nr_elems) { - Op op(src0_dtype, src1_dtype, dst_dtype); - size_t i = 0; -#if MEGDNN_FIX_AARCH32_BUG -// FIXME: as llvm may cause cannot select error if enable vectorize -#pragma clang loop vectorize(disable) -#endif - for (; i < nr_elems; i++) { - op(*src0, *src1, dst); - src0++; - src1++; - dst++; - } - } -}; - -template -struct OpCallerBinary, VEC_SCALAR> { - using Op = PowOp; - static void run( - const typename Op::src_ctype* src0, const typename Op::src_ctype src1, - typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, - DType dst_dtype, size_t nr_elems) { - Op op(src0_dtype, src1_dtype, dst_dtype); - size_t i = 0; -#if MEGDNN_FIX_AARCH32_BUG -// FIXME: as llvm may cause cannot select error if enable vectorize -#pragma clang loop vectorize(disable) -#endif - for (; i < nr_elems; i++) { - op(*src0, src1, dst); - src0++; - dst++; - } - } -}; - -template -struct OpCallerBinary, VEC_BCAST101> { - using Op = PowOp; - static void run( - const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, - typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, - DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) { - Op op(src0_dtype, src1_dtype, dst_dtype); - for (size_t b = 0; b < batch; b++) { - const typename Op::src_ctype* src1_ptr = src1; - for (size_t c = 0; c < channel; c++) { - size_t i = 0; -#if MEGDNN_FIX_AARCH32_BUG -// FIXME: as llvm may cause cannot select error if enable vectorize -#pragma clang loop vectorize(disable) -#endif - for (; i < channel_stride; i++) { - op(*src0, *src1_ptr, dst); - src0++; - dst++; - } - src1_ptr++; - } - } - } -}; - -template -struct OpCallerBinary, VEC_BCASTX0X> { - using Op = PowOp; - static void run( - const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, - typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, - DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) { - Op op(src0_dtype, src1_dtype, dst_dtype); - for (size_t b = 0; b < batch; b++) { - const typename Op::src_ctype* src1_ptr_base = src1 + b * channel_stride; - for (size_t c = 0; c < channel; c++) { - size_t i = 0; - auto src1_ptr = src1_ptr_base; -#if MEGDNN_FIX_AARCH32_BUG -// FIXME: as llvm may cause cannot select error if enable vectorize -#pragma clang loop vectorize(disable) -#endif - for (; i < channel_stride; i++) { - op(*src0, *src1_ptr, dst); - src0++; - src1_ptr++; - dst++; - } - } - } - } -}; - -template -struct OpCallerBinary, VEC_BCAST111C> { - using Op = PowOp; - static void run( - const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, - typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, - DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) { - Op op(src0_dtype, src1_dtype, dst_dtype); - for (size_t b = 0; b < batch; b++) { - for (size_t c = 0; c < channel; c++) { - size_t i = 0; - const typename Op::src_ctype* src1_ptr = src1; -#if MEGDNN_FIX_AARCH32_BUG -// FIXME: as llvm may cause cannot select error if enable vectorize -#pragma clang loop vectorize(disable) -#endif - for (; i < channel_stride; i++) { - op(*src0, *src1_ptr, dst); - src0++; - src1_ptr++; - dst++; - } - } - } - } -}; - -template -struct OpCallerBinary, BCAST111C_VEC> { - using Op = PowOp; - static void run( - const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, - typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, - DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) { - Op op(src0_dtype, src1_dtype, dst_dtype); - for (size_t b = 0; b < batch; b++) { - for (size_t c = 0; c < channel; c++) { - size_t i = 0; - const typename Op::src_ctype* src0_ptr = src0; -#if MEGDNN_FIX_AARCH32_BUG -// FIXME: as llvm may cause cannot select error if enable vectorize -#pragma clang loop vectorize(disable) -#endif - for (; i < channel_stride; i++) { - op(*src0_ptr, *src1, dst); - src0_ptr++; - src1++; - dst++; - } - } - } - } -}; - -template -struct OpCallerBinary, SCALAR_VEC> { - using Op = PowOp; - static void run( - const typename Op::src_ctype src0, const typename Op::src_ctype* src1, - typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, - DType dst_dtype, size_t nr_elems) { - Op op(src0_dtype, src1_dtype, dst_dtype); - size_t i = 0; -#if MEGDNN_FIX_AARCH32_BUG -// FIXME: as llvm may cause cannot select error if enable vectorize -#pragma clang loop vectorize(disable) -#endif - for (; i < nr_elems; i++) { - op(src0, *src1, dst); - src1++; - dst++; - } - } -}; - -template -struct OpCallerBinary, BCAST101_VEC> { - using Op = PowOp; - static void run( - const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, - typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, - DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) { - Op op(src0_dtype, src1_dtype, dst_dtype); - for (size_t b = 0; b < batch; b++) { - auto src0_ptr = src0; - for (size_t c = 0; c < channel; c++) { - size_t i = 0; -#if MEGDNN_FIX_AARCH32_BUG -// FIXME: as llvm may cause cannot select error if enable vectorize -#pragma clang loop vectorize(disable) -#endif - for (; i < channel_stride; i++) { - op(*src0_ptr, *src1, dst); - src1++; - dst++; - } - src0_ptr++; - } - } - } -}; - -template -struct OpCallerBinary, BCASTX0X_VEC> { - using Op = PowOp; - static void run( - const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, - typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, - DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) { - Op op(src0_dtype, src1_dtype, dst_dtype); - for (size_t b = 0; b < batch; b++) { - auto src0_ptr_base = src0 + b * channel_stride; - for (size_t c = 0; c < channel; c++) { - size_t i = 0; - auto src0_ptr = src0_ptr_base; -#if MEGDNN_FIX_AARCH32_BUG -// FIXME: as llvm may cause cannot select error if enable vectorize -#pragma clang loop vectorize(disable) -#endif - for (; i < channel_stride; i++) { - op(*src0_ptr, *src1, dst); - src0_ptr++; - src1++; - dst++; - } - } - } - } -}; - -template -struct OpCallerBinary { - static void run( - const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, - typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, - DType dst_dtype, size_t nr_elems) { - Op op(src0_dtype, src1_dtype, dst_dtype); - ParamElemVisitor vis0; - ParamElemVisitor vis1; - size_t i = 0; - for (; i + Op::SIMD_WIDTH * 2 <= nr_elems; i += Op::SIMD_WIDTH * 2) { - op({{vis0(src0), vis0(src0 + Op::SIMD_WIDTH)}}, - {{vis1(src1), vis1(src1 + Op::SIMD_WIDTH)}}, dst); - src0 += Op::SIMD_WIDTH * 2; - src1 += Op::SIMD_WIDTH * 2; - dst += Op::SIMD_WIDTH * 2; - } -#if MEGDNN_FIX_AARCH32_BUG -// FIXME: as llvm may cause cannot select error if enable vectorize -#pragma clang loop vectorize(disable) -#endif - for (; i < nr_elems; i++) { - op(*src0, *src1, dst); - src0++; - src1++; - dst++; - } - } -}; - -template -struct OpCallerBinary { - static void run( - const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, - typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, - DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) { - Op op(src0_dtype, src1_dtype, dst_dtype); - ParamElemVisitor vis0; - ParamElemVisitorDup vis1; - for (size_t b = 0; b < batch; b++) { - const typename Op::src_ctype* src1_ptr = src1; - for (size_t c = 0; c < channel; c++) { - size_t i = 0; - auto src1_simd = vis1(src1_ptr); - for (; i + Op::SIMD_WIDTH * 2 <= channel_stride; - i += Op::SIMD_WIDTH * 2) { - op({{vis0(src0), vis0(src0 + Op::SIMD_WIDTH)}}, - {{src1_simd, src1_simd}}, dst); - src0 += Op::SIMD_WIDTH * 2; - dst += Op::SIMD_WIDTH * 2; - } -#if MEGDNN_FIX_AARCH32_BUG -// FIXME: as llvm may cause cannot select error if enable vectorize -#pragma clang loop vectorize(disable) -#endif - for (; i < channel_stride; i++) { - op(*src0, *src1_ptr, dst); - src0++; - dst++; - } - src1_ptr++; - } - } - } -}; - -template -struct OpCallerBinary { - static void run( - const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, - typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, - DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) { - Op op(src0_dtype, src1_dtype, dst_dtype); - ParamElemVisitor vis; - for (size_t b = 0; b < batch; b++) { - const typename Op::src_ctype* src1_ptr_base = src1 + b * channel_stride; - for (size_t c = 0; c < channel; c++) { - size_t i = 0; - auto src1_ptr = src1_ptr_base; - for (; i + Op::SIMD_WIDTH * 2 <= channel_stride; - i += Op::SIMD_WIDTH * 2) { - auto src0_simd0 = vis(src0); - auto src0_simd1 = vis(src0 + Op::SIMD_WIDTH); - auto src1_simd0 = vis(src1_ptr); - auto src1_simd1 = vis(src1_ptr + Op::SIMD_WIDTH); - op({{src0_simd0, src0_simd1}}, {{src1_simd0, src1_simd1}}, dst); - src0 += Op::SIMD_WIDTH * 2; - src1_ptr += Op::SIMD_WIDTH * 2; - dst += Op::SIMD_WIDTH * 2; - } -#if MEGDNN_FIX_AARCH32_BUG -// FIXME: as llvm may cause cannot select error if enable vectorize -#pragma clang loop vectorize(disable) -#endif - for (; i < channel_stride; i++) { - op(*src0, *src1_ptr, dst); - src0++; - src1_ptr++; - dst++; - } - } - } - } -}; - -template -struct OpCallerBinary { - static void run( - const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, - typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, - DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) { - Op op(src0_dtype, src1_dtype, dst_dtype); - ParamElemVisitor vis; - for (size_t b = 0; b < batch; b++) { - for (size_t c = 0; c < channel; c++) { - size_t rest = channel_stride; - const typename Op::src_ctype* src1_ptr = src1; - while (rest >= Op::SIMD_WIDTH * 2) { - auto src0_simd0 = vis(src0); - auto src0_simd1 = vis(src0 + Op::SIMD_WIDTH); - auto src1_simd0 = vis(src1_ptr); - auto src1_simd1 = vis(src1_ptr + Op::SIMD_WIDTH); - src0 += Op::SIMD_WIDTH * 2; - src1_ptr += Op::SIMD_WIDTH * 2; - op({{src0_simd0, src0_simd1}}, {{src1_simd0, src1_simd1}}, dst); - dst += Op::SIMD_WIDTH * 2; - rest -= Op::SIMD_WIDTH * 2; - } -#if MEGDNN_FIX_AARCH32_BUG -// FIXME: as llvm may cause cannot select error if enable vectorize -#pragma clang loop vectorize(disable) -#endif - while (rest > 0) { - op(*src0, *src1_ptr, dst); - dst++; - src0++; - src1_ptr++; - rest--; - } - } - } - } -}; - -template -struct OpCallerBinary { - static void run( - const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, - typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, - DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) { - Op op(src0_dtype, src1_dtype, dst_dtype); - ParamElemVisitor vis; - for (size_t b = 0; b < batch; b++) { - for (size_t c = 0; c < channel; c++) { - size_t rest = channel_stride; - const typename Op::src_ctype* src0_ptr = src0; - while (rest >= Op::SIMD_WIDTH * 2) { - auto src0_simd0 = vis(src0_ptr); - auto src0_simd1 = vis(src0_ptr + Op::SIMD_WIDTH); - auto src1_simd0 = vis(src1); - auto src1_simd1 = vis(src1 + Op::SIMD_WIDTH); - src0_ptr += Op::SIMD_WIDTH * 2; - src1 += Op::SIMD_WIDTH * 2; - op({{src0_simd0, src0_simd1}}, {{src1_simd0, src1_simd1}}, dst); - dst += Op::SIMD_WIDTH * 2; - rest -= Op::SIMD_WIDTH * 2; - } -#if MEGDNN_FIX_AARCH32_BUG -// FIXME: as llvm may cause cannot select error if enable vectorize -#pragma clang loop vectorize(disable) -#endif - while (rest > 0) { - op(*src0_ptr, *src1, dst); - dst++; - src0_ptr++; - src1++; - rest--; - } - } - } - } -}; - -template -struct OpCallerBinary, BCAST101xX_VEC> { - using Op = PowOp; - static void run( - const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, - typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, - DType dst_dtype, size_t batch, size_t nr_channel_blocks, - size_t channel_stride, size_t channel_block_dim) { - Op op(src0_dtype, src1_dtype, dst_dtype); - for (size_t b = 0; b < batch; b++) { - auto src0_ptr = src0; - for (size_t cb = 0; cb < nr_channel_blocks; cb++) { - auto src0_block_ptr = src0_ptr + cb * channel_block_dim; - for (size_t i = 0; i < channel_stride; i++) { - for (size_t c_iter = 0; c_iter < channel_block_dim; c_iter++) { - op(*(src0_block_ptr + c_iter), *src1, dst); - src1++; - dst++; - } - } - } - } - } -}; - -template -struct OpCallerBinaryBcast101xXVec { - template - static void run( - const src_ctype* src0, const src_ctype* src1, typename Op::dst_ctype* dst, - const Op& op, size_t batch, size_t nr_channel_blocks, - size_t channel_stride) { - for (size_t b = 0; b < batch; b++) { - auto src0_ptr = src0; - for (size_t cb = 0; cb < nr_channel_blocks; cb++) { - auto src0_block_ptr = src0_ptr + cb * channel_block_dim; - for (size_t img_index = 0; img_index < channel_stride; img_index++) { - for (size_t c_iter = 0; c_iter < channel_block_dim; c_iter++) { - op(*(src0_block_ptr + c_iter), *src1, dst); - src1++; - dst++; - } - } - } - } - } -}; - -template -struct OpCallerBinaryBcast101xDVec { - template - static void run( - const src_ctype* src0, const src_ctype* src1, typename Op::dst_ctype* dst, - const Op& op, const Vis0& vis0, const Vis1& vis1, size_t batch, - size_t nr_channel_blocks, size_t channel_stride) { - for (size_t b = 0; b < batch; b++) { - auto src0_ptr = src0; - for (size_t cb = 0; cb < nr_channel_blocks; cb++) { - auto src0_block_ptr = src0_ptr + cb * channel_block_dim; - auto channel_block_vec = vis0(src0_block_ptr); - size_t img_index = 0; - auto src1_offset = Op::SIMD_WIDTH / channel_block_dim; - for (; img_index + 2 * src1_offset <= channel_stride; - img_index += 2 * src1_offset) { - op({{channel_block_vec, channel_block_vec}}, - {{vis1(src1), vis1(src1 + Op::SIMD_WIDTH)}}, dst); - src1 += Op::SIMD_WIDTH * 2; - dst += Op::SIMD_WIDTH * 2; - } - // TODO:all elemwise_multi_type op imp one simd mode - for (; img_index < channel_stride; img_index++) { - for (size_t c_iter = 0; c_iter < channel_block_dim; c_iter++) { - op(*(src0_block_ptr + c_iter), *src1, dst); - src1++; - dst++; - } - } - } - } - } -}; - -template -struct OpCallerBinaryBcast101xXVec { - template - static void run( - const src_ctype* src0, const src_ctype* src1, typename Op::dst_ctype* dst, - const Op& op, size_t batch, size_t nr_channel_blocks, - size_t channel_stride) { - ParamElemVisitorBcast101x4 vis0; - ParamElemVisitor vis1; - OpCallerBinaryBcast101xDVec::run( - src0, src1, dst, op, vis0, vis1, batch, nr_channel_blocks, - channel_stride); - } -}; - -template -struct OpCallerBinary { - static void run( - const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, - typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, - DType dst_dtype, size_t batch, size_t nr_channel_blocks, - size_t channel_stride, size_t channel_block_dim) { - megdnn_assert( - channel_block_dim == 4 || channel_block_dim == 8, - "only imp for nchw44/nchw88"); - Op op(src0_dtype, src1_dtype, dst_dtype); - if (channel_block_dim == 4) { - OpCallerBinaryBcast101xXVec::run( - src0, src1, dst, op, batch, nr_channel_blocks, channel_stride); - } else { - OpCallerBinaryBcast101xXVec::run( - src0, src1, dst, op, batch, nr_channel_blocks, channel_stride); - } - } -}; - -template -struct OpCallerBinary, VEC_BCAST101xX> { - using Op = PowOp; - static void run( - const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, - typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, - DType dst_dtype, size_t batch, size_t nr_channel_blocks, - size_t channel_stride, size_t channel_block_dim) { - Op op(src0_dtype, src1_dtype, dst_dtype); - for (size_t b = 0; b < batch; b++) { - auto src1_ptr = src1; - for (size_t cb = 0; cb < nr_channel_blocks; cb++) { - auto src1_block_ptr = src1_ptr + cb * channel_block_dim; - for (size_t i = 0; i < channel_stride; i++) { - for (size_t c_iter = 0; c_iter < channel_block_dim; c_iter++) { - op(*(src0), *(src1_block_ptr + c_iter), dst); - src0++; - dst++; - } - } - } - } - } -}; - -template -struct OpCallerBinaryVecBcast101xX { - template - static void run( - const src_ctype* src0, const src_ctype* src1, typename Op::dst_ctype* dst, - const Op& op, size_t batch, size_t nr_channel_blocks, - size_t channel_stride) { - for (size_t b = 0; b < batch; b++) { - auto src1_ptr = src1; - for (size_t cb = 0; cb < nr_channel_blocks; cb++) { - auto src1_block_ptr = src1_ptr + cb * channel_block_dim; - for (size_t img_index = 0; img_index < channel_stride; img_index++) { - for (size_t c_iter = 0; c_iter < channel_block_dim; c_iter++) { - op(*src0, *(src1_block_ptr + c_iter), dst); - src0++; - dst++; - } - } - } - } - } -}; - -template -struct OpCallerBinaryVecBcast101xD { - template - static void run( - const src_ctype* src0, const src_ctype* src1, typename Op::dst_ctype* dst, - const Op& op, const Vis0& vis0, const Vis1& vis1, size_t batch, - size_t nr_channel_blocks, size_t channel_stride) { - for (size_t b = 0; b < batch; b++) { - auto src1_ptr = src1; - for (size_t cb = 0; cb < nr_channel_blocks; cb++) { - auto src1_block_ptr = src1_ptr + cb * channel_block_dim; - auto channel_block_vec = vis1(src1_block_ptr); - size_t img_index = 0; - auto src0_offset = Op::SIMD_WIDTH / channel_block_dim; - for (; img_index + 2 * src0_offset <= channel_stride; - img_index += 2 * src0_offset) { - op({{vis0(src0), vis0(src0 + Op::SIMD_WIDTH)}}, - {{channel_block_vec, channel_block_vec}}, dst); - src0 += Op::SIMD_WIDTH * 2; - dst += Op::SIMD_WIDTH * 2; - } - // TODO:all elemwise_multi_type op imp one simd mode - for (; img_index < channel_stride; img_index++) { - for (size_t c_iter = 0; c_iter < channel_block_dim; c_iter++) { - op(*src0, *(src1_block_ptr + c_iter), dst); - src0++; - dst++; - } - } - } - } - } -}; - -template -struct OpCallerBinaryVecBcast101xX { - template - static void run( - const src_ctype* src0, const src_ctype* src1, typename Op::dst_ctype* dst, - const Op& op, size_t batch, size_t nr_channel_blocks, - size_t channel_stride) { - ParamElemVisitor vis0; - ParamElemVisitorBcast101x4 vis1; - OpCallerBinaryVecBcast101xD::run( - src0, src1, dst, op, vis0, vis1, batch, nr_channel_blocks, - channel_stride); - } -}; - -template -struct OpCallerBinary { - static void run( - const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, - typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, - DType dst_dtype, size_t batch, size_t nr_channel_blocks, - size_t channel_stride, size_t channel_block_dim) { - megdnn_assert( - channel_block_dim == 4 || channel_block_dim == 8, - "only imp for nchw44/nchw88"); - Op op(src0_dtype, src1_dtype, dst_dtype); - if (channel_block_dim == 4) { - OpCallerBinaryVecBcast101xX::run( - src0, src1, dst, op, batch, nr_channel_blocks, channel_stride); - } else { - OpCallerBinaryVecBcast101xX::run( - src0, src1, dst, op, batch, nr_channel_blocks, channel_stride); - } - } -}; - -template -struct OpCallerBinary { - static void run( - const typename Op::src_ctype* src0, const typename Op::src_ctype src1, - typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, - DType dst_dtype, size_t nr_elems) { - Op op(src0_dtype, src1_dtype, dst_dtype); - ParamElemVisitor vis0; - ParamElemVisitorDup vis1; - auto vis1_simd = vis1(&src1); - size_t i = 0; - for (; i + Op::SIMD_WIDTH * 2 <= nr_elems; i += Op::SIMD_WIDTH * 2) { - op({{vis0(src0), vis0(src0 + Op::SIMD_WIDTH)}}, {{vis1_simd, vis1_simd}}, - dst); - src0 += Op::SIMD_WIDTH * 2; - dst += Op::SIMD_WIDTH * 2; - } -#if MEGDNN_FIX_AARCH32_BUG -// FIXME: as llvm may cause cannot select error if enable vectorize -#pragma clang loop vectorize(disable) -#endif - for (; i < nr_elems; i++) { - op(*src0, src1, dst); - src0++; - dst++; - } - } -}; - -//! this only for nonswap op, like SUB and DIV -template -struct OpCallerBinary { - static void run( - const typename Op::src_ctype src0, const typename Op::src_ctype* src1, - typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, - DType dst_dtype, size_t nr_elems) { - Op op(src0_dtype, src1_dtype, dst_dtype); - ParamElemVisitorDup vis0; - ParamElemVisitor vis1; - auto vis0_simd = vis0(&src0); - size_t i = 0; - for (; i + Op::SIMD_WIDTH * 2 <= nr_elems; i += Op::SIMD_WIDTH * 2) { - op({{vis0_simd, vis0_simd}}, {{vis1(src1), vis1(src1 + Op::SIMD_WIDTH)}}, - dst); - src1 += Op::SIMD_WIDTH * 2; - dst += Op::SIMD_WIDTH * 2; - } -#if MEGDNN_FIX_AARCH32_BUG -// FIXME: as llvm may cause cannot select error if enable vectorize -#pragma clang loop vectorize(disable) -#endif - for (; i < nr_elems; i++) { - op(src0, *src1, dst); - src1++; - dst++; - } - } -}; - -template -struct OpCallerBinary { - static void run( - const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, - typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, - DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) { - Op op(src0_dtype, src1_dtype, dst_dtype); - ParamElemVisitorDup vis0; - ParamElemVisitor vis1; - for (size_t b = 0; b < batch; b++) { - auto src0_ptr = src0; - for (size_t c = 0; c < channel; c++) { - auto vis0_simd = vis0(src0_ptr); - size_t i = 0; - for (; i + Op::SIMD_WIDTH * 2 <= channel_stride; - i += Op::SIMD_WIDTH * 2) { - op({{vis0_simd, vis0_simd}}, - {{vis1(src1), vis1(src1 + Op::SIMD_WIDTH)}}, dst); - src1 += Op::SIMD_WIDTH * 2; - dst += Op::SIMD_WIDTH * 2; - } -#if MEGDNN_FIX_AARCH32_BUG -// FIXME: as llvm may cause cannot select error if enable vectorize -#pragma clang loop vectorize(disable) -#endif - for (; i < channel_stride; i++) { - op(*src0_ptr, *src1, dst); - src1++; - dst++; - } - src0_ptr++; - } - } - } -}; - -template -struct OpCallerBinary { - static void run( - const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, - typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, - DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) { - Op op(src0_dtype, src1_dtype, dst_dtype); - ParamElemVisitor vis; - for (size_t b = 0; b < batch; b++) { - auto src0_ptr_base = src0 + b * channel_stride; - for (size_t c = 0; c < channel; c++) { - auto src0_ptr = src0_ptr_base; - size_t i = 0; - for (; i + Op::SIMD_WIDTH * 2 <= channel_stride; - i += Op::SIMD_WIDTH * 2) { - auto src0_simd0 = vis(src0_ptr); - auto src0_simd1 = vis(src0_ptr + Op::SIMD_WIDTH); - auto src1_simd0 = vis(src1); - auto src1_simd1 = vis(src1 + Op::SIMD_WIDTH); - op({{src0_simd0, src0_simd1}}, {{src1_simd0, src1_simd1}}, dst); - src0_ptr += Op::SIMD_WIDTH * 2; - src1 += Op::SIMD_WIDTH * 2; - dst += Op::SIMD_WIDTH * 2; - } -#if MEGDNN_FIX_AARCH32_BUG -// FIXME: as llvm may cause cannot select error if enable vectorize -#pragma clang loop vectorize(disable) -#endif - for (; i < channel_stride; i++) { - op(*src0_ptr, *src1, dst); - src0_ptr++; - src1++; - dst++; - } - } - } - } -}; - -template -struct OpCallerTernary; - -template -struct OpCallerTernary { - static void run( - const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, - const typename Op::src_ctype* src2, typename Op::dst_ctype* dst, - DType src0_dtype, DType src1_dtype, DType src2_dtype, DType dst_dtype, - size_t nr_elems) { - Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype); - ParamElemVisitor vis0; - ParamElemVisitor vis1; - ParamElemVisitor vis2; - size_t i = 0; - for (; i + Op::SIMD_WIDTH * 2 <= nr_elems; i += Op::SIMD_WIDTH * 2) { - op({{vis0(src0), vis0(src0 + Op::SIMD_WIDTH)}}, - {{vis1(src1), vis1(src1 + Op::SIMD_WIDTH)}}, - {{vis2(src2), vis2(src2 + Op::SIMD_WIDTH)}}, dst); - src0 += Op::SIMD_WIDTH * 2; - src1 += Op::SIMD_WIDTH * 2; - src2 += Op::SIMD_WIDTH * 2; - dst += Op::SIMD_WIDTH * 2; - } -#if MEGDNN_FIX_AARCH32_BUG -// FIXME: as llvm may cause cannot select error if enable vectorize -#pragma clang loop vectorize(disable) -#endif - for (; i < nr_elems; i++) { - op(*src0, *src1, *src2, dst); - src0++; - src1++; - src2++; - dst++; - } - } -}; - -//! src0: vector, src1: vector, src2: scalar -template -struct OpCallerTernary { - static void run( - const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, - const typename Op::src_ctype src2, typename Op::dst_ctype* dst, - DType src0_dtype, DType src1_dtype, DType src2_dtype, DType dst_dtype, - size_t nr_elems) { - Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype); - ParamElemVisitor vis0; - ParamElemVisitor vis1; - ParamElemVisitorDup vis2; - auto vis2_simd = vis2(&src2); - size_t i = 0; - for (; i + Op::SIMD_WIDTH * 2 <= nr_elems; i += Op::SIMD_WIDTH * 2) { - op({{vis0(src0), vis0(src0 + Op::SIMD_WIDTH)}}, - {{vis1(src1), vis1(src1 + Op::SIMD_WIDTH)}}, {{vis2_simd, vis2_simd}}, - dst); - src0 += Op::SIMD_WIDTH * 2; - src1 += Op::SIMD_WIDTH * 2; - dst += Op::SIMD_WIDTH * 2; - } -#if MEGDNN_FIX_AARCH32_BUG -// FIXME: as llvm may cause cannot select error if enable vectorize -#pragma clang loop vectorize(disable) -#endif - for (; i < nr_elems; i++) { - op(*src0, *src1, src2, dst); - src0++; - src1++; - dst++; - } - } -}; - -//! src0: 1C11, src1: vector, src2: 1C11 -template -struct OpCallerTernary { - static void run( - const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, - const typename Op::src_ctype* src2, typename Op::dst_ctype* dst, - DType src0_dtype, DType src1_dtype, DType src2_dtype, DType dst_dtype, - size_t batch_size, size_t channel_size, size_t channel_stride, - size_t batch_offset) { - Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype); - ParamElemVisitor vis1; - ParamElemVisitorDup vis0; - ParamElemVisitorDup vis2; - for (size_t batch = 0; batch < batch_size; batch++) { - auto src0_ptr = src0; - auto src2_ptr = src2; - auto b_offset = batch_offset; - for (size_t channel = 0; channel < channel_size; channel++) { - size_t i = 0; - auto src0_simd = vis0(src0_ptr); - auto src2_simd = vis2(src2_ptr); - for (; i + Op::SIMD_WIDTH * 2 <= channel_stride; - i += Op::SIMD_WIDTH * 2) { - op({{src0_simd, src0_simd}}, - {{vis1(src1), vis1(src1 + Op::SIMD_WIDTH)}}, - {{src2_simd, src2_simd}}, dst); - src1 += Op::SIMD_WIDTH * 2; - dst += Op::SIMD_WIDTH * 2; - b_offset -= Op::SIMD_WIDTH * 2; - } -#if MEGDNN_FIX_AARCH32_BUG -// FIXME: as llvm may cause cannot select error if enable vectorize -#pragma clang loop vectorize(disable) -#endif - for (; i < channel_stride; i++) { - op(*src0_ptr, *src1, *src2_ptr, dst); - src1++; - dst++; - b_offset--; - } - src0_ptr++; - src2_ptr++; - } - src1 += b_offset; - dst += b_offset; - } - } -}; - -//! src0: 111C, src1: vector, src2: 111C, src1 may not be contig -template -struct OpCallerTernary { - static void run( - const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, - size_t src1_offset, const typename Op::src_ctype* src2, - typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, - DType src2_dtype, DType dst_dtype, size_t batch_size, size_t channel_size, - size_t channel_stride, size_t batch_offset) { - Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype); - ParamElemVisitor vis; - for (size_t batch = 0; batch < batch_size; batch++) { - auto b_offset = batch_offset; - for (size_t channel = 0; channel < channel_size; channel++) { - auto src0_ptr = src0; - auto src2_ptr = src2; - size_t i = 0; - for (; i + Op::SIMD_WIDTH * 2 <= channel_stride; - i += Op::SIMD_WIDTH * 2) { - auto src0_simd0 = vis(src0_ptr); - auto src0_simd1 = vis(src0_ptr + Op::SIMD_WIDTH); - auto src1_simd0 = vis(src1); - auto src1_simd1 = vis(src1 + Op::SIMD_WIDTH); - auto src2_simd0 = vis(src2_ptr); - auto src2_simd1 = vis(src2_ptr + Op::SIMD_WIDTH); - op({{src0_simd0, src0_simd1}}, {{src1_simd0, src1_simd1}}, - {{src2_simd0, src2_simd1}}, dst); - src0_ptr += Op::SIMD_WIDTH * 2; - src1 += Op::SIMD_WIDTH * 2; - src2_ptr += Op::SIMD_WIDTH * 2; - dst += Op::SIMD_WIDTH * 2; - b_offset -= Op::SIMD_WIDTH * 2; - } -#if MEGDNN_FIX_AARCH32_BUG -// FIXME: as llvm may cause cannot select error if enable vectorize -#pragma clang loop vectorize(disable) -#endif - for (; i < channel_stride; i++) { - op(*src0_ptr, *src1, *src2_ptr, dst); - src0_ptr++; - src1++; - src2_ptr++; - dst++; - b_offset--; - } - src1 += src1_offset; - } - src1 += b_offset; - dst += b_offset; - } - } -}; - -template -struct OpCallerTernaryBcast101xXVecBcast101xX { - template - static void run( - const src_ctype* src0, const src_ctype* src1, const src_ctype* src2, - typename Op::dst_ctype* dst, const Op& op, size_t batch, - size_t nr_channel_blocks, size_t channel_stride) { - for (size_t b = 0; b < batch; b++) { - auto src0_ptr = src0; - auto src2_ptr = src2; - for (size_t cb = 0; cb < nr_channel_blocks; cb++) { - auto src0_block_ptr = src0_ptr + cb * channel_block_dim; - auto src2_block_ptr = src2_ptr + cb * channel_block_dim; - for (size_t img_index = 0; img_index < channel_stride; img_index++) { - for (size_t c_iter = 0; c_iter < channel_block_dim; c_iter++) { - op(*(src0_block_ptr + c_iter), *src1, - *(src2_block_ptr + c_iter), dst); - src1++; - dst++; - } - } - } - } - } -}; - -template -struct OpCallerTernaryBcast101xDVecBcast101xD { - template - static void run( - const src_ctype* src0, const src_ctype* src1, const src_ctype* src2, - typename Op::dst_ctype* dst, const Op& op, const Vis0& vis0, - const Vis1& vis1, const Vis2& vis2, size_t batch, size_t nr_channel_blocks, - size_t channel_stride) { - for (size_t b = 0; b < batch; b++) { - auto src0_ptr = src0; - auto src2_ptr = src2; - for (size_t cb = 0; cb < nr_channel_blocks; cb++) { - auto src0_block_ptr = src0_ptr + cb * channel_block_dim; - auto src2_block_ptr = src2_ptr + cb * channel_block_dim; - auto channel_block_vec0 = vis0(src0_block_ptr); - auto channel_block_vec2 = vis2(src2_block_ptr); - size_t img_index = 0; - auto src1_offset = Op::SIMD_WIDTH / channel_block_dim; - for (; img_index + 2 * src1_offset <= channel_stride; - img_index += 2 * src1_offset) { - op({{channel_block_vec0, channel_block_vec0}}, - {{vis1(src1), vis1(src1 + Op::SIMD_WIDTH)}}, - {{channel_block_vec2, channel_block_vec2}}, dst); - src1 += Op::SIMD_WIDTH * 2; - dst += Op::SIMD_WIDTH * 2; - } - // TODO:all elemwise_multi_type op imp one simd mode - for (; img_index < channel_stride; img_index++) { - for (size_t c_iter = 0; c_iter < channel_block_dim; c_iter++) { - op(*(src0_block_ptr + c_iter), *src1, - *(src2_block_ptr + c_iter), dst); - src1++; - dst++; - } - } - } - } - } -}; - -//! src0: CHW44, src1: vector, src2: CHW44 -template -struct OpCallerTernaryBcast101xXVecBcast101xX { - template - static void run( - const src_ctype* src0, const src_ctype* src1, const src_ctype* src2, - typename Op::dst_ctype* dst, const Op& op, size_t batch, - size_t nr_channel_blocks, size_t channel_stride) { - ParamElemVisitorBcast101x4 vis0; - ParamElemVisitor vis1; - ParamElemVisitorBcast101x4 vis2; - OpCallerTernaryBcast101xDVecBcast101xD::run( - src0, src1, src2, dst, op, vis0, vis1, vis2, batch, nr_channel_blocks, - channel_stride); - } -}; - -template -struct OpCallerTernary { - static void run( - const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, - const typename Op::src_ctype* src2, typename Op::dst_ctype* dst, - DType src0_dtype, DType src1_dtype, DType src2_dtype, DType dst_dtype, - size_t batch, size_t nr_channel_blocks, size_t channel_stride, - size_t channel_block_dim) { - megdnn_assert( - channel_block_dim == 4 || channel_block_dim == 8, - "only imp for nchw44/nchw88"); - Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype); - if (channel_block_dim == 4) { - OpCallerTernaryBcast101xXVecBcast101xX::run( - src0, src1, src2, dst, op, batch, nr_channel_blocks, - channel_stride); - } else { - OpCallerTernaryBcast101xXVecBcast101xX::run( - src0, src1, src2, dst, op, batch, nr_channel_blocks, - channel_stride); - } - } -}; - -//! src1: 1C11, src0 and src2 are contig -template -struct OpCallerTernary { - static void run( - const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, - const typename Op::src_ctype* src2, typename Op::dst_ctype* dst, - DType src0_dtype, DType src1_dtype, DType src2_dtype, DType dst_dtype, - size_t batch_size, size_t channel_size, size_t channel_stride) { - Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype); - ParamElemVisitor vis0; - ParamElemVisitorDup vis1; - ParamElemVisitor vis2; - for (size_t batch = 0; batch < batch_size; batch++) { - auto src1_ptr = src1; - for (size_t channel = 0; channel < channel_size; channel++) { - size_t i = 0; - auto src1_simd = vis1(src1_ptr); - for (; i + Op::SIMD_WIDTH * 2 <= channel_stride; - i += Op::SIMD_WIDTH * 2) { - op({{vis0(src0), vis0(src0 + Op::SIMD_WIDTH)}}, - {{src1_simd, src1_simd}}, - {{vis2(src2), vis2(src2 + Op::SIMD_WIDTH)}}, dst); - src0 += Op::SIMD_WIDTH * 2; - src2 += Op::SIMD_WIDTH * 2; - dst += Op::SIMD_WIDTH * 2; - } -#if MEGDNN_FIX_AARCH32_BUG -// FIXME: as llvm may cause cannot select error if enable vectorize -#pragma clang loop vectorize(disable) -#endif - for (; i < channel_stride; i++) { - op(*src0, *src1_ptr, *src2, dst); - src0++; - src2++; - dst++; - } - src1_ptr++; - } - } - } -}; - -//! src1: 111C, src0 and src2 may not be contig -template -struct OpCallerTernary { - static void run( - const typename Op::src_ctype* src0, size_t src0_offset, - const typename Op::src_ctype* src1, const typename Op::src_ctype* src2, - size_t src2_offset, typename Op::dst_ctype* dst, DType src0_dtype, - DType src1_dtype, DType src2_dtype, DType dst_dtype, size_t batch_size, - size_t channel_size, size_t channel_stride) { - Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype); - ParamElemVisitor vis0; - ParamElemVisitor vis1; - ParamElemVisitor vis2; - for (size_t batch = 0; batch < batch_size; batch++) { - for (size_t channel = 0; channel < channel_size; channel++) { - auto src1_ptr = src1; - size_t i = 0; - for (; i + Op::SIMD_WIDTH * 2 <= channel_stride; - i += Op::SIMD_WIDTH * 2) { - op({{vis0(src0), vis0(src0 + Op::SIMD_WIDTH)}}, - {{vis1(src1_ptr), vis1(src1_ptr + Op::SIMD_WIDTH)}}, - {{vis2(src2), vis2(src2 + Op::SIMD_WIDTH)}}, dst); - src0 += Op::SIMD_WIDTH * 2; - src1_ptr += Op::SIMD_WIDTH * 2; - src2 += Op::SIMD_WIDTH * 2; - dst += Op::SIMD_WIDTH * 2; - } -#if MEGDNN_FIX_AARCH32_BUG -// FIXME: as llvm may cause cannot select error if enable vectorize -#pragma clang loop vectorize(disable) -#endif - for (; i < channel_stride; i++) { - op(*src0, *src1_ptr, *src2, dst); - src0++; - src1_ptr++; - src2++; - dst++; - } - src0 += src0_offset; - src2 += src2_offset; - } - } - } -}; - -template -struct OpCallerTernaryVecBcast101xXVec { - template - static void run( - const src_ctype* src0, const src_ctype* src1, const src_ctype* src2, - typename Op::dst_ctype* dst, const Op& op, size_t batch, - size_t nr_channel_blocks, size_t channel_stride) { - for (size_t b = 0; b < batch; b++) { - auto src1_ptr = src1; - for (size_t cb = 0; cb < nr_channel_blocks; cb++) { - auto src1_block_ptr = src1_ptr + cb * channel_block_dim; - for (size_t img_index = 0; img_index < channel_stride; img_index++) { - for (size_t c_iter = 0; c_iter < channel_block_dim; c_iter++) { - op(*src0, *(src1_block_ptr + c_iter), *src2, dst); - src0++; - src2++; - dst++; - } - } - } - } - } -}; - -//! src1: CHW44, src0 and src2 are contig -template -struct OpCallerTernaryVecBcast101xDVec { - template - static void run( - const src_ctype* src0, const src_ctype* src1, const src_ctype* src2, - typename Op::dst_ctype* dst, const Op& op, const Vis0& vis0, - const Vis1& vis1, const Vis2& vis2, size_t batch, size_t nr_channel_blocks, - size_t channel_stride) { - for (size_t b = 0; b < batch; b++) { - auto src1_ptr = src1; - for (size_t cb = 0; cb < nr_channel_blocks; cb++) { - auto src1_block_ptr = src1_ptr + cb * channel_block_dim; - auto channel_block_vec = vis1(src1_block_ptr); - size_t img_index = 0; - auto offset = Op::SIMD_WIDTH / channel_block_dim; - for (; img_index + 2 * offset <= channel_stride; - img_index += 2 * offset) { - op({{vis0(src0), vis0(src0 + Op::SIMD_WIDTH)}}, - {{channel_block_vec, channel_block_vec}}, - {{vis2(src2), vis2(src2 + Op::SIMD_WIDTH)}}, dst); - src0 += Op::SIMD_WIDTH * 2; - src2 += Op::SIMD_WIDTH * 2; - dst += Op::SIMD_WIDTH * 2; - } - // TODO:all elemwise_multi_type op imp one simd mode - for (; img_index < channel_stride; img_index++) { - for (size_t c_iter = 0; c_iter < channel_block_dim; c_iter++) { - op(*src0, *(src1_block_ptr + c_iter), *src2, dst); - src0++; - src2++; - dst++; - } - } - } - } - } -}; - -template -struct OpCallerTernaryVecBcast101xXVec { - template - static void run( - const src_ctype* src0, const src_ctype* src1, const src_ctype* src2, - typename Op::dst_ctype* dst, const Op& op, size_t batch, - size_t nr_channel_blocks, size_t channel_stride) { - ParamElemVisitor vis0; - ParamElemVisitorBcast101x4 vis1; - ParamElemVisitor vis2; - OpCallerTernaryVecBcast101xDVec::run( - src0, src1, src2, dst, op, vis0, vis1, vis2, batch, nr_channel_blocks, - channel_stride); - } -}; - -template -struct OpCallerTernary { - static void run( - const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, - const typename Op::src_ctype* src2, typename Op::dst_ctype* dst, - DType src0_dtype, DType src1_dtype, DType src2_dtype, DType dst_dtype, - size_t batch, size_t nr_channel_blocks, size_t channel_stride, - size_t channel_block_dim) { - megdnn_assert( - channel_block_dim == 4 || channel_block_dim == 8, - "only imp for nchw44/nchw88"); - - Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype); - if (channel_block_dim == 4) { - OpCallerTernaryVecBcast101xXVec::run( - src0, src1, src2, dst, op, batch, nr_channel_blocks, - channel_stride); - } else { - OpCallerTernaryVecBcast101xXVec::run( - src0, src1, src2, dst, op, batch, nr_channel_blocks, - channel_stride); - } - } -}; - -//! src1: scalar, src0 and src2 has the same shape -template -struct OpCallerTernary { - static void run( - const typename Op::src_ctype* src0, const typename Op::src_ctype src1, - const typename Op::src_ctype* src2, typename Op::dst_ctype* dst, - DType src0_dtype, DType src1_dtype, DType src2_dtype, DType dst_dtype, - size_t nr_elems) { - Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype); - ParamElemVisitor vis0; - ParamElemVisitorDup vis1; - ParamElemVisitor vis2; - auto vis1_simd = vis1(&src1); - size_t i = 0; - for (; i + Op::SIMD_WIDTH * 2 <= nr_elems; i += Op::SIMD_WIDTH * 2) { - op({{vis0(src0), vis0(src0 + Op::SIMD_WIDTH)}}, {{vis1_simd, vis1_simd}}, - {{vis2(src2), vis2(src2 + Op::SIMD_WIDTH)}}, dst); - src0 += Op::SIMD_WIDTH * 2; - src2 += Op::SIMD_WIDTH * 2; - dst += Op::SIMD_WIDTH * 2; - } -#if MEGDNN_FIX_AARCH32_BUG -// FIXME: as llvm may cause cannot select error if enable vectorize -#pragma clang loop vectorize(disable) -#endif - for (; i < nr_elems; i++) { - op(*src0, src1, *src2, dst); - src0++; - src2++; - dst++; - } - } -}; - -//! src1, src2: scalar, src0 is vector -template -struct OpCallerTernary { - static void run( - const typename Op::src_ctype* src0, const typename Op::src_ctype src1, - const typename Op::src_ctype src2, typename Op::dst_ctype* dst, - DType src0_dtype, DType src1_dtype, DType src2_dtype, DType dst_dtype, - size_t nr_elems) { - Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype); - ParamElemVisitor vis0; - ParamElemVisitorDup vis1; - ParamElemVisitorDup vis2; - auto vis1_simd = vis1(&src1); - auto vis2_simd = vis2(&src2); - size_t i = 0; - for (; i + Op::SIMD_WIDTH * 2 <= nr_elems; i += Op::SIMD_WIDTH * 2) { - op({{vis0(src0), vis0(src0 + Op::SIMD_WIDTH)}}, {{vis1_simd, vis1_simd}}, - {{vis2_simd, vis2_simd}}, dst); - src0 += Op::SIMD_WIDTH * 2; - dst += Op::SIMD_WIDTH * 2; - } -#if MEGDNN_FIX_AARCH32_BUG -// FIXME: as llvm may cause cannot select error if enable vectorize -#pragma clang loop vectorize(disable) -#endif - for (; i < nr_elems; i++) { - op(*src0, src1, src2, dst); - src0++; - dst++; - } - } -}; - -} // namespace fallback -} // namespace megdnn - -// vim: syntax=cpp.doxygen diff --git a/dnn/src/fallback/general_intrinsic/gi_common.h b/dnn/src/fallback/general_intrinsic/gi_common.h index 8c3ff8cb..8be13e05 100644 --- a/dnn/src/fallback/general_intrinsic/gi_common.h +++ b/dnn/src/fallback/general_intrinsic/gi_common.h @@ -13,6 +13,7 @@ #include "math.h" #include "stdint.h" +#include "string.h" #if defined(_WIN32) #include @@ -132,6 +133,18 @@ typedef uint32_t GI_UINT32_t __attribute__((vector_size(16))); #define Max(a, b) (a) > (b) ? (a) : (b) #define Min(a, b) (a) < (b) ? (a) : (b) +#if defined(GI_NEON_INTRINSICS) +#if defined(__ARM_FEATURE_FMA) && defined(GI_NEON64_INTRINSICS) +#define v_fma_ps_f32(c, b, a) vfmaq_f32((c), (b), (a)) +#define v_fma_n_f32(c, b, a) vfmaq_n_f32((c), (b), (a)) +#define v_fma_lane_f32(c, b, a, lane) vfmaq_lane_f32((c), (b), (a), (lane)) +#else +#define v_fma_ps_f32(c, b, a) vmlaq_f32((c), (b), (a)) +#define v_fma_n_f32(c, b, a) vmlaq_n_f32((c), (b), (a)) +#define v_fma_lane_f32(c, b, a, lane) vmlaq_lane_f32((c), (b), (a), (lane)) +#endif +#endif + typedef struct { GI_INT32_t val[2]; } GI_INT32_V2_t; diff --git a/dnn/src/fallback/general_intrinsic/gi_float.h b/dnn/src/fallback/general_intrinsic/gi_float.h index dc42fb1d..a6b1ac1c 100644 --- a/dnn/src/fallback/general_intrinsic/gi_float.h +++ b/dnn/src/fallback/general_intrinsic/gi_float.h @@ -20,7 +20,9 @@ GI_INT32_t GiReinterpretAsInt32(GI_FLOAT32_t In) { #elif defined(GI_SSE2_INTRINSICS) return _mm_castps_si128(In); #else - return *(GI_INT32_t*)(&In); + GI_INT32_t ret; + memcpy(&ret, &In, GI_SIMD_LEN_BYTE); + return ret; #endif } @@ -31,7 +33,9 @@ GI_UINT32_t GiReinterpretAsUint32(GI_FLOAT32_t In) { #elif defined(GI_SSE2_INTRINSICS) return _mm_castps_si128(In); #else - return *(GI_UINT32_t*)(&In); + GI_UINT32_t ret; + memcpy(&ret, &In, GI_SIMD_LEN_BYTE); + return ret; #endif } @@ -42,7 +46,9 @@ GI_FLOAT32_t GiReintInt32ToFloat32(GI_INT32_t Vector) { #elif defined(GI_SSE2_INTRINSICS) return _mm_castsi128_ps(Vector); #else - return *(GI_FLOAT32_t*)(&Vector); + GI_FLOAT32_t ret; + memcpy(&ret, &Vector, GI_SIMD_LEN_BYTE); + return ret; #endif } @@ -53,7 +59,9 @@ GI_FLOAT32_t GiReintUint32ToFloat32(GI_UINT32_t Vector) { #elif defined(GI_SSE2_INTRINSICS) return _mm_castsi128_ps(Vector); #else - return *(GI_FLOAT32_t*)(&Vector); + GI_FLOAT32_t ret; + memcpy(&ret, &Vector, GI_SIMD_LEN_BYTE); + return ret; #endif } @@ -69,7 +77,7 @@ GI_INT32_t GiRoundAsInt32(GI_FLOAT32_t Vector) { float32x4_t vinc0 = vbslq_f32(vcgeq_f32(Vector, vzero), vfhalf, vfneg_half); return vcvtq_s32_f32(vaddq_f32(Vector, vinc0)); #endif -#elif defined(GI_SSE2_INTRINSICS) +#elif defined(GI_SSE42_INTRINSICS) __m128 vfzero = _mm_set1_ps(0.f); __m128 vfhalf = _mm_set1_ps(0.5f); __m128 vfneg_half = _mm_set1_ps(-0.5f); @@ -322,11 +330,7 @@ GI_FORCEINLINE GI_FLOAT32_t GiMultiplyAddFloat32( GI_FLOAT32_t VectorSum, GI_FLOAT32_t Vector1, GI_FLOAT32_t Vector2) { #if defined(GI_NEON_INTRINSICS) -#if defined(__ARM_FEATURE_FMA) - return vfmaq_f32(VectorSum, Vector1, Vector2); -#else - return vmlaq_f32(VectorSum, Vector1, Vector2); -#endif + return v_fma_ps_f32(VectorSum, Vector1, Vector2); #elif defined(GI_FMA3_INTRINSICS) return _mm_fmadd_ps(Vector1, Vector2, VectorSum); #elif defined(GI_SSE2_INTRINSICS) @@ -352,11 +356,7 @@ GI_FORCEINLINE GI_FLOAT32_t GiMultiplyAddScalarFloat32( GI_FLOAT32_t VectorSum, GI_FLOAT32_t Vector, float Scalar) { #if defined(GI_NEON_INTRINSICS) -#if defined(__ARM_FEATURE_FMA) - return vfmaq_n_f32(VectorSum, Vector, Scalar); -#else - return vfmla_n_f32(VectorSum, Vector, Scalar); -#endif + return v_fma_n_f32(VectorSum, Vector, Scalar); #elif defined(GI_SSE2_INTRINSICS) return GiMultiplyAddFloat32(VectorSum, GiBroadcastFloat32(Scalar), Vector); #else @@ -365,27 +365,10 @@ GI_FLOAT32_t GiMultiplyAddScalarFloat32( } #if defined(GI_NEON_INTRINSICS) -#if defined(__ARM_FEATURE_FMA) -#define GIMULTIPLYADDLANFLOAT32(i) \ - GI_FORCEINLINE GI_FLOAT32_t GiMultiplyAddLan##i##Float32( \ - GI_FLOAT32_t VectorSum, GI_FLOAT32_t Vector1, GI_FLOAT32_t Vector2) { \ - return vfmaq_lane_f32(VectorSum, Vector1, vget_low_f32(Vector2), i); \ - } -GIMULTIPLYADDLANFLOAT32(0) -GIMULTIPLYADDLANFLOAT32(1) -#undef GIMULTIPLYADDLANFLOAT32 -#define GIMULTIPLYADDLANFLOAT32(i) \ - GI_FORCEINLINE GI_FLOAT32_t GiMultiplyAddLan##i##Float32( \ - GI_FLOAT32_t VectorSum, GI_FLOAT32_t Vector1, GI_FLOAT32_t Vector2) { \ - return vfmaq_lane_f32(VectorSum, Vector1, vget_high_f32(Vector2), i - 2); \ - } -GIMULTIPLYADDLANFLOAT32(2) -GIMULTIPLYADDLANFLOAT32(3) -#else #define GIMULTIPLYADDLANFLOAT32(i) \ GI_FORCEINLINE GI_FLOAT32_t GiMultiplyAddLan##i##Float32( \ GI_FLOAT32_t VectorSum, GI_FLOAT32_t Vector1, GI_FLOAT32_t Vector2) { \ - return vmlaq_lane_f32(VectorSum, Vector1, vget_low_f32(Vector2), i); \ + return v_fma_lane_f32(VectorSum, Vector1, vget_low_f32(Vector2), i); \ } GIMULTIPLYADDLANFLOAT32(0) GIMULTIPLYADDLANFLOAT32(1) @@ -393,11 +376,10 @@ GIMULTIPLYADDLANFLOAT32(1) #define GIMULTIPLYADDLANFLOAT32(i) \ GI_FORCEINLINE GI_FLOAT32_t GiMultiplyAddLan##i##Float32( \ GI_FLOAT32_t VectorSum, GI_FLOAT32_t Vector1, GI_FLOAT32_t Vector2) { \ - return vmlaq_lane_f32(VectorSum, Vector1, vget_high_f32(Vector2), i - 2); \ + return v_fma_lane_f32(VectorSum, Vector1, vget_high_f32(Vector2), i - 2); \ } GIMULTIPLYADDLANFLOAT32(2) GIMULTIPLYADDLANFLOAT32(3) -#endif #undef GIMULTIPLYADDLANFLOAT32 #elif defined(GI_SSE2_INTRINSICS) diff --git a/dnn/src/fallback/general_intrinsic/gi_int.h b/dnn/src/fallback/general_intrinsic/gi_int.h index 2da8283a..abb4e2b1 100644 --- a/dnn/src/fallback/general_intrinsic/gi_int.h +++ b/dnn/src/fallback/general_intrinsic/gi_int.h @@ -59,66 +59,69 @@ GI_INT8_t GiBroadcastInt8(int8_t Value) { } GI_FORCEINLINE -GI_INT32_t GiLoadInt32(const int32_t* Buffer) { +GI_INT32_t GiLoadInt32(const void* Buffer) { #if defined(GI_NEON_INTRINSICS) - return vld1q_s32(Buffer); + return vld1q_s32((int32_t*)Buffer); #elif defined(GI_SSE2_INTRINSICS) return _mm_loadu_si128((const __m128i*)Buffer); #else GI_INT32_t ret; + const int32_t* ptr = (int32_t*)Buffer; for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(int32_t); i++) { - ret[i] = Buffer[i]; + ret[i] = ptr[i]; } return ret; #endif } GI_FORCEINLINE -GI_INT8_t GiLoadInt8(const int8_t* Buffer) { +GI_INT8_t GiLoadInt8(const void* Buffer) { #if defined(GI_NEON_INTRINSICS) - return vld1q_s8(Buffer); + return vld1q_s8((int8_t*)Buffer); #elif defined(GI_SSE2_INTRINSICS) return _mm_loadu_si128((const __m128i*)Buffer); #else GI_INT8_t ret; + const int8_t* ptr = (int8_t*)Buffer; for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(int8_t); i++) { - ret[i] = Buffer[i]; + ret[i] = ptr[i]; } return ret; #endif } GI_FORCEINLINE -void GiStoreInt32(int32_t* Buffer, GI_INT32_t Vector) { +void GiStoreInt32(void* Buffer, GI_INT32_t Vector) { #if defined(GI_NEON_INTRINSICS) - vst1q_s32(Buffer, Vector); + vst1q_s32((int32_t*)Buffer, Vector); #elif defined(GI_SSE2_INTRINSICS) _mm_storeu_si128((__m128i*)Buffer, Vector); #else + int32_t* ptr = (int32_t*)Buffer; for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(int32_t); i++) { - Buffer[i] = Vector[i]; + ptr[i] = Vector[i]; } #endif } #if defined(GI_NEON_INTRINSICS) -#define GISTORELANEINT32(i) \ - GI_FORCEINLINE void GiStoreLane##i##Int32(int32_t* Buffer, GI_INT32_t Vector) { \ - vst1q_lane_s32(Buffer, Vector, i); \ +#define GISTORELANEINT32(i) \ + GI_FORCEINLINE void GiStoreLane##i##Int32(void* Buffer, GI_INT32_t Vector) { \ + vst1q_lane_s32((int32_t*)Buffer, Vector, i); \ } #elif defined(GI_SSE2_INTRINSICS) #define GISTORELANEINT32(i) \ - GI_FORCEINLINE void GiStoreLane##i##Int32(int32_t* Buffer, GI_INT32_t Vector) { \ + GI_FORCEINLINE void GiStoreLane##i##Int32(void* Buffer, GI_INT32_t Vector) { \ GI_FLOAT32_t tmp = _mm_castsi128_ps(Vector); \ _mm_store_ss( \ (float*)Buffer, _mm_shuffle_ps(tmp, tmp, _MM_SHUFFLE(i, i, i, i))); \ } #else -#define GISTORELANEINT32(i) \ - GI_FORCEINLINE void GiStoreLane##i##Int32(int32_t* Buffer, GI_INT32_t Vector) { \ - *Buffer = Vector[i]; \ +#define GISTORELANEINT32(i) \ + GI_FORCEINLINE void GiStoreLane##i##Int32(void* Buffer, GI_INT32_t Vector) { \ + *((int32_t*)Buffer) = Vector[i]; \ } #endif @@ -141,53 +144,57 @@ GI_INT8_t GiReinterInt32ToInt8(GI_INT32_t Vector) { } GI_FORCEINLINE -void GiStoreInt16(int16_t* Buffer, GI_INT16_t Vector) { +void GiStoreInt16(void* Buffer, GI_INT16_t Vector) { #if defined(GI_NEON_INTRINSICS) - vst1q_s16(Buffer, Vector); + vst1q_s16((int16_t*)Buffer, Vector); #elif defined(GI_SSE2_INTRINSICS) _mm_storeu_si128((__m128i*)Buffer, Vector); #else + int16_t* ptr = (int16_t*)Buffer; for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(int16_t); i++) { - Buffer[i] = Vector[i]; + ptr[i] = Vector[i]; } #endif } GI_FORCEINLINE -void GiStoreInt8(int8_t* Buffer, GI_INT8_t Vector) { +void GiStoreInt8(void* Buffer, GI_INT8_t Vector) { #if defined(GI_NEON_INTRINSICS) - vst1q_s8(Buffer, Vector); + vst1q_s8((int8_t*)Buffer, Vector); #elif defined(GI_SSE2_INTRINSICS) _mm_storeu_si128((__m128i*)Buffer, Vector); #else + int8_t* ptr = (int8_t*)Buffer; for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(int8_t); i++) { - Buffer[i] = Vector[i]; + ptr[i] = Vector[i]; } #endif } GI_FORCEINLINE -void GiStoreLowInt8(int8_t* Buffer, GI_INT8_t Vector) { +void GiStoreLowInt8(void* Buffer, GI_INT8_t Vector) { #if defined(GI_NEON_INTRINSICS) - vst1_s8(Buffer, vget_low_s8(Vector)); + vst1_s8((int8_t*)Buffer, vget_low_s8(Vector)); #elif defined(GI_SSE2_INTRINSICS) _mm_storel_epi64((__m128i*)Buffer, Vector); #else + int8_t* ptr = (int8_t*)Buffer; for (size_t i = 0; i < GI_SIMD_LEN_BYTE / 2 / sizeof(int8_t); i++) { - Buffer[i] = Vector[i]; + ptr[i] = Vector[i]; } #endif } GI_FORCEINLINE -void GiStoreHihgInt8(int8_t* Buffer, GI_INT8_t Vector) { +void GiStoreHihgInt8(void* Buffer, GI_INT8_t Vector) { #if defined(GI_NEON_INTRINSICS) - vst1_s8(Buffer, vget_high_s8(Vector)); + vst1_s8((int8_t*)Buffer, vget_high_s8(Vector)); #elif defined(GI_SSE2_INTRINSICS) _mm_storel_epi64((__m128i*)Buffer, _mm_unpackhi_epi64(Vector, Vector)); #else + int8_t* ptr = (int8_t*)Buffer; for (size_t i = 0; i < GI_SIMD_LEN_BYTE / 2 / sizeof(int8_t); i++) { - Buffer[i] = Vector[GI_SIMD_LEN_BYTE / 2 + i]; + ptr[i] = Vector[GI_SIMD_LEN_BYTE / 2 + i]; } #endif } diff --git a/dnn/test/fallback/elemwise.cpp b/dnn/test/fallback/elemwise.cpp index ef9487ea..8d58a692 100644 --- a/dnn/test/fallback/elemwise.cpp +++ b/dnn/test/fallback/elemwise.cpp @@ -39,7 +39,6 @@ TEST_F(FALLBACK, ELEMWISE_RECORD) { checker.execs({{10, 10, 32}, {10, 10, 32}, {}}); } - TEST_F(FALLBACK, ELEMWISE_FORWARD_TERNARY) { using Mode = ElemwiseForward::Param::Mode; Checker checker(handle());