GitOrigin-RevId: 87046b8197
release-1.10
@@ -12,7 +12,7 @@ | |||||
#include "src/aarch64/conv_bias/int8/algos.h" | #include "src/aarch64/conv_bias/int8/algos.h" | ||||
#include "src/aarch64/conv_bias/int8/strategy.h" | #include "src/aarch64/conv_bias/int8/strategy.h" | ||||
#include "src/arm_common/convolution/img2col_helper.h" | #include "src/arm_common/convolution/img2col_helper.h" | ||||
#include "src/arm_common/elemwise_op.h" | |||||
#include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
#include "src/common/opr_delegate.h" | #include "src/common/opr_delegate.h" | ||||
#include "src/fallback/conv_bias/common.h" | #include "src/fallback/conv_bias/common.h" | ||||
#include "src/fallback/matrix_mul/gemm_impl.h" | #include "src/fallback/matrix_mul/gemm_impl.h" | ||||
@@ -14,7 +14,7 @@ | |||||
#include "src/aarch64/matrix_mul/quint8_dot/gemv.h" | #include "src/aarch64/matrix_mul/quint8_dot/gemv.h" | ||||
#include "src/aarch64/matrix_mul/quint8_dot/strategy.h" | #include "src/aarch64/matrix_mul/quint8_dot/strategy.h" | ||||
#include "src/arm_common/convolution/img2col_helper.h" | #include "src/arm_common/convolution/img2col_helper.h" | ||||
#include "src/arm_common/elemwise_op.h" | |||||
#include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
#include "src/common/opr_delegate.h" | #include "src/common/opr_delegate.h" | ||||
#include "src/fallback/conv_bias/common.h" | #include "src/fallback/conv_bias/common.h" | ||||
#include "src/fallback/matrix_mul/gemm_impl.h" | #include "src/fallback/matrix_mul/gemm_impl.h" | ||||
@@ -11,7 +11,7 @@ | |||||
*/ | */ | ||||
#include "src/arm_common/conv_bias/f16/channel_wise_3x3_s1p1_nchw88_kern.h" | #include "src/arm_common/conv_bias/f16/channel_wise_3x3_s1p1_nchw88_kern.h" | ||||
#include "src/arm_common/elemwise_op.h" | |||||
#include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | #include "src/arm_common/simd_macro/marm_neon.h" | ||||
#include "src/arm_common/utils.h" | #include "src/arm_common/utils.h" | ||||
#include "src/common/utils.h" | #include "src/common/utils.h" | ||||
@@ -12,7 +12,7 @@ | |||||
#include "src/arm_common/conv_bias/f16/algos.h" | #include "src/arm_common/conv_bias/f16/algos.h" | ||||
#include "src/arm_common/conv_bias/f16/channel_wise_nchw88_kern.h" | #include "src/arm_common/conv_bias/f16/channel_wise_nchw88_kern.h" | ||||
#include "src/arm_common/elemwise_op.h" | |||||
#include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
#include "midout.h" | #include "midout.h" | ||||
@@ -12,7 +12,7 @@ | |||||
#include "src/arm_common/conv_bias/f16/channel_wise_nchw88_kern.h" | #include "src/arm_common/conv_bias/f16/channel_wise_nchw88_kern.h" | ||||
#include "src/arm_common/conv_bias/f16/channel_wise_3x3_s1p1_nchw88_kern.h" | #include "src/arm_common/conv_bias/f16/channel_wise_3x3_s1p1_nchw88_kern.h" | ||||
#include "src/arm_common/elemwise_op.h" | |||||
#include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | #include "src/arm_common/simd_macro/marm_neon.h" | ||||
#include "src/arm_common/utils.h" | #include "src/arm_common/utils.h" | ||||
#include "src/common/unroll_macro.h" | #include "src/common/unroll_macro.h" | ||||
@@ -15,7 +15,7 @@ | |||||
#include "src/arm_common/conv_bias/f16/algos.h" | #include "src/arm_common/conv_bias/f16/algos.h" | ||||
#include "src/arm_common/conv_bias/f16/direct_nchw88_kern.h" | #include "src/arm_common/conv_bias/f16/direct_nchw88_kern.h" | ||||
#include "src/arm_common/elemwise_op.h" | |||||
#include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
#include "midout.h" | #include "midout.h" | ||||
@@ -13,7 +13,7 @@ | |||||
#include "src/arm_common/conv_bias/f16/direct_nchw88_kern.h" | #include "src/arm_common/conv_bias/f16/direct_nchw88_kern.h" | ||||
#include "src/arm_common/conv_bias/opr_impl.h" | #include "src/arm_common/conv_bias/opr_impl.h" | ||||
#include "src/arm_common/elemwise_op.h" | |||||
#include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | #include "src/arm_common/simd_macro/marm_neon.h" | ||||
#include "src/fallback/conv_bias/common.h" | #include "src/fallback/conv_bias/common.h" | ||||
@@ -11,7 +11,7 @@ | |||||
*/ | */ | ||||
#include "src/arm_common/conv_bias/fp32/channel_wise_3x3_s1p1_nchw44_kern.h" | #include "src/arm_common/conv_bias/fp32/channel_wise_3x3_s1p1_nchw44_kern.h" | ||||
#include "src/arm_common/elemwise_op.h" | |||||
#include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | #include "src/arm_common/simd_macro/marm_neon.h" | ||||
#include "src/arm_common/utils.h" | #include "src/arm_common/utils.h" | ||||
#include "src/common/utils.h" | #include "src/common/utils.h" | ||||
@@ -11,7 +11,7 @@ | |||||
*/ | */ | ||||
#include "src/arm_common/conv_bias/fp32/channel_wise_5x5_s1p2_nchw44_kern.h" | #include "src/arm_common/conv_bias/fp32/channel_wise_5x5_s1p2_nchw44_kern.h" | ||||
#include "src/arm_common/elemwise_op.h" | |||||
#include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | #include "src/arm_common/simd_macro/marm_neon.h" | ||||
#include "src/arm_common/utils.h" | #include "src/arm_common/utils.h" | ||||
#include "src/common/utils.h" | #include "src/common/utils.h" | ||||
@@ -12,7 +12,7 @@ | |||||
#include "src/arm_common/conv_bias/fp32/algos.h" | #include "src/arm_common/conv_bias/fp32/algos.h" | ||||
#include "src/arm_common/conv_bias/fp32/channel_wise_nchw44_kern.h" | #include "src/arm_common/conv_bias/fp32/channel_wise_nchw44_kern.h" | ||||
#include "src/arm_common/elemwise_op.h" | |||||
#include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
#include "midout.h" | #include "midout.h" | ||||
@@ -13,7 +13,7 @@ | |||||
#include "src/arm_common/conv_bias/fp32/channel_wise_nchw44_kern.h" | #include "src/arm_common/conv_bias/fp32/channel_wise_nchw44_kern.h" | ||||
#include "src/arm_common/conv_bias/fp32/channel_wise_3x3_s1p1_nchw44_kern.h" | #include "src/arm_common/conv_bias/fp32/channel_wise_3x3_s1p1_nchw44_kern.h" | ||||
#include "src/arm_common/conv_bias/fp32/channel_wise_5x5_s1p2_nchw44_kern.h" | #include "src/arm_common/conv_bias/fp32/channel_wise_5x5_s1p2_nchw44_kern.h" | ||||
#include "src/arm_common/elemwise_op.h" | |||||
#include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | #include "src/arm_common/simd_macro/marm_neon.h" | ||||
#include "src/arm_common/utils.h" | #include "src/arm_common/utils.h" | ||||
#include "src/common/unroll_macro.h" | #include "src/common/unroll_macro.h" | ||||
@@ -14,7 +14,7 @@ | |||||
#include "megdnn/arch.h" | #include "megdnn/arch.h" | ||||
#include "src/arm_common/conv_bias/fp32/f32_direct_nchw44_kern.h" | #include "src/arm_common/conv_bias/fp32/f32_direct_nchw44_kern.h" | ||||
#include "src/arm_common/conv_bias/intrinsic_helper.h" | #include "src/arm_common/conv_bias/intrinsic_helper.h" | ||||
#include "src/arm_common/elemwise_op.h" | |||||
#include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | #include "src/arm_common/simd_macro/marm_neon.h" | ||||
#include "src/common/unroll_macro.h" | #include "src/common/unroll_macro.h" | ||||
#include "src/common/utils.h" | #include "src/common/utils.h" | ||||
@@ -14,7 +14,7 @@ | |||||
#include "megdnn/arch.h" | #include "megdnn/arch.h" | ||||
#include "src/arm_common/conv_bias/fp32/f32_direct_nchw44_kern.h" | #include "src/arm_common/conv_bias/fp32/f32_direct_nchw44_kern.h" | ||||
#include "src/arm_common/conv_bias/intrinsic_helper.h" | #include "src/arm_common/conv_bias/intrinsic_helper.h" | ||||
#include "src/arm_common/elemwise_op.h" | |||||
#include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | #include "src/arm_common/simd_macro/marm_neon.h" | ||||
#include "src/common/unroll_macro.h" | #include "src/common/unroll_macro.h" | ||||
#include "src/common/utils.h" | #include "src/common/utils.h" | ||||
@@ -14,7 +14,7 @@ | |||||
#include "src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_kern.h" | #include "src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_kern.h" | ||||
#include "src/arm_common/conv_bias/intrinsic_helper.h" | #include "src/arm_common/conv_bias/intrinsic_helper.h" | ||||
#include "src/arm_common/conv_bias/opr_impl.h" | #include "src/arm_common/conv_bias/opr_impl.h" | ||||
#include "src/arm_common/elemwise_op.h" | |||||
#include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | #include "src/arm_common/simd_macro/marm_neon.h" | ||||
#include "src/common/unroll_macro.h" | #include "src/common/unroll_macro.h" | ||||
#include "src/common/utils.h" | #include "src/common/utils.h" | ||||
@@ -15,7 +15,7 @@ | |||||
#include "src/arm_common/conv_bias/fp32/algos.h" | #include "src/arm_common/conv_bias/fp32/algos.h" | ||||
#include "src/arm_common/conv_bias/fp32/f32_direct_nchw44_kern.h" | #include "src/arm_common/conv_bias/fp32/f32_direct_nchw44_kern.h" | ||||
#include "src/arm_common/elemwise_op.h" | |||||
#include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
#include "midout.h" | #include "midout.h" | ||||
@@ -15,7 +15,7 @@ | |||||
#include "src/arm_common/conv_bias/fp32/algos.h" | #include "src/arm_common/conv_bias/fp32/algos.h" | ||||
#include "src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_kern.h" | #include "src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_kern.h" | ||||
#include "src/arm_common/conv_bias/fp32/strategy.h" | #include "src/arm_common/conv_bias/fp32/strategy.h" | ||||
#include "src/arm_common/elemwise_op.h" | |||||
#include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
#include "src/common/nchw_nchwxx_valid.h" | #include "src/common/nchw_nchwxx_valid.h" | ||||
#include "src/common/opr_delegate.h" | #include "src/common/opr_delegate.h" | ||||
@@ -13,7 +13,7 @@ | |||||
#include "megdnn/arch.h" | #include "megdnn/arch.h" | ||||
#include "src/arm_common/conv_bias/intrinsic_helper.h" | #include "src/arm_common/conv_bias/intrinsic_helper.h" | ||||
#include "src/arm_common/conv_bias/opr_impl.h" | #include "src/arm_common/conv_bias/opr_impl.h" | ||||
#include "src/arm_common/elemwise_op.h" | |||||
#include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | #include "src/arm_common/simd_macro/marm_neon.h" | ||||
#include "src/common/unroll_macro.h" | #include "src/common/unroll_macro.h" | ||||
#include "src/common/utils.h" | #include "src/common/utils.h" | ||||
@@ -17,7 +17,7 @@ | |||||
#include "src/arm_common/conv_bias/int8/stride1_dotprod.h" | #include "src/arm_common/conv_bias/int8/stride1_dotprod.h" | ||||
#include "src/arm_common/conv_bias/int8/stride2.h" | #include "src/arm_common/conv_bias/int8/stride2.h" | ||||
#include "src/arm_common/conv_bias/int8/stride2_dotprod.h" | #include "src/arm_common/conv_bias/int8/stride2_dotprod.h" | ||||
#include "src/arm_common/elemwise_op.h" | |||||
#include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
#include "src/fallback/conv_bias/common.h" | #include "src/fallback/conv_bias/common.h" | ||||
#include "midout.h" | #include "midout.h" | ||||
@@ -11,7 +11,7 @@ | |||||
*/ | */ | ||||
#include "src/arm_common/conv_bias/int8/channel_wise_kernel.h" | #include "src/arm_common/conv_bias/int8/channel_wise_kernel.h" | ||||
#include "src/arm_common/elemwise_op.h" | |||||
#include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | #include "src/arm_common/simd_macro/marm_neon.h" | ||||
#include "src/common/unroll_macro.h" | #include "src/common/unroll_macro.h" | ||||
#include "src/common/utils.h" | #include "src/common/utils.h" | ||||
@@ -12,7 +12,7 @@ | |||||
#include "src/arm_common/conv_bias/int8/channel_wise_nchw44.h" | #include "src/arm_common/conv_bias/int8/channel_wise_nchw44.h" | ||||
#include "megdnn/oprs.h" | #include "megdnn/oprs.h" | ||||
#include "src/arm_common/conv_bias/int8/channel_wise_kernel.h" | #include "src/arm_common/conv_bias/int8/channel_wise_kernel.h" | ||||
#include "src/arm_common/elemwise_op.h" | |||||
#include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
#include "src/common/opr_delegate.h" | #include "src/common/opr_delegate.h" | ||||
#include "midout.h" | #include "midout.h" | ||||
@@ -10,7 +10,7 @@ | |||||
*/ | */ | ||||
#include "src/arm_common/conv_bias/int8/direct.h" | #include "src/arm_common/conv_bias/int8/direct.h" | ||||
#include "src/arm_common/elemwise_op.h" | |||||
#include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | #include "src/arm_common/simd_macro/marm_neon.h" | ||||
#include "src/common/utils.h" | #include "src/common/utils.h" | ||||
#include "src/fallback/conv_bias/common.h" | #include "src/fallback/conv_bias/common.h" | ||||
@@ -11,7 +11,7 @@ | |||||
#include "src/arm_common/conv_bias/int8/direct_dotprod.h" | #include "src/arm_common/conv_bias/int8/direct_dotprod.h" | ||||
#if MGB_ENABLE_DOT | #if MGB_ENABLE_DOT | ||||
#include "src/arm_common/elemwise_op.h" | |||||
#include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | #include "src/arm_common/simd_macro/marm_neon.h" | ||||
#include "src/common/utils.h" | #include "src/common/utils.h" | ||||
#include "src/fallback/conv_bias/common.h" | #include "src/fallback/conv_bias/common.h" | ||||
@@ -14,7 +14,7 @@ | |||||
#if MGB_ENABLE_DOT | #if MGB_ENABLE_DOT | ||||
#include "src/arm_common/conv_bias/int8/algos.h" | #include "src/arm_common/conv_bias/int8/algos.h" | ||||
#include "src/arm_common/conv_bias/int8/direct_dotprod_nchw44.h" | #include "src/arm_common/conv_bias/int8/direct_dotprod_nchw44.h" | ||||
#include "src/arm_common/elemwise_op.h" | |||||
#include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
#include "midout.h" | #include "midout.h" | ||||
@@ -14,7 +14,7 @@ | |||||
#include "megdnn/arch.h" | #include "megdnn/arch.h" | ||||
#include "src/arm_common/conv_bias/intrinsic_helper.h" | #include "src/arm_common/conv_bias/intrinsic_helper.h" | ||||
#if MGB_ENABLE_DOT | #if MGB_ENABLE_DOT | ||||
#include "src/arm_common/elemwise_op.h" | |||||
#include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
#include "src/arm_common/intrinsic_helper.h" | #include "src/arm_common/intrinsic_helper.h" | ||||
#include "src/arm_common/neon_struct.h" | #include "src/arm_common/neon_struct.h" | ||||
#include "src/common/unroll_macro.h" | #include "src/common/unroll_macro.h" | ||||
@@ -13,7 +13,7 @@ | |||||
#include "src/arm_common/conv_bias/int8/direct.h" | #include "src/arm_common/conv_bias/int8/direct.h" | ||||
#include "src/arm_common/conv_bias/int8/direct_nchw44_kern.h" | #include "src/arm_common/conv_bias/int8/direct_nchw44_kern.h" | ||||
#include "src/arm_common/conv_bias/intrinsic_helper.h" | #include "src/arm_common/conv_bias/intrinsic_helper.h" | ||||
#include "src/arm_common/elemwise_op.h" | |||||
#include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | #include "src/arm_common/simd_macro/marm_neon.h" | ||||
#include "src/common/utils.h" | #include "src/common/utils.h" | ||||
#include "src/fallback/conv_bias/common.h" | #include "src/fallback/conv_bias/common.h" | ||||
@@ -14,7 +14,7 @@ | |||||
#include "src/arm_common/conv_bias/int8/direct.h" | #include "src/arm_common/conv_bias/int8/direct.h" | ||||
#include "src/arm_common/conv_bias/int8/direct_nchw44_kern.h" | #include "src/arm_common/conv_bias/int8/direct_nchw44_kern.h" | ||||
#include "src/arm_common/conv_bias/intrinsic_helper.h" | #include "src/arm_common/conv_bias/intrinsic_helper.h" | ||||
#include "src/arm_common/elemwise_op.h" | |||||
#include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | #include "src/arm_common/simd_macro/marm_neon.h" | ||||
#include "src/common/utils.h" | #include "src/common/utils.h" | ||||
#include "src/fallback/conv_bias/common.h" | #include "src/fallback/conv_bias/common.h" | ||||
@@ -14,7 +14,7 @@ | |||||
#include "src/arm_common/conv_bias/int8/algos.h" | #include "src/arm_common/conv_bias/int8/algos.h" | ||||
#include "src/arm_common/conv_bias/int8/direct.h" | #include "src/arm_common/conv_bias/int8/direct.h" | ||||
#include "src/arm_common/conv_bias/int8/direct_nchw44_kern.h" | #include "src/arm_common/conv_bias/int8/direct_nchw44_kern.h" | ||||
#include "src/arm_common/elemwise_op.h" | |||||
#include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
#include "src/common/opr_delegate.h" | #include "src/common/opr_delegate.h" | ||||
#include "midout.h" | #include "midout.h" | ||||
@@ -12,7 +12,7 @@ | |||||
#pragma once | #pragma once | ||||
#include "src/arm_common/conv_bias/int8/direct.h" | #include "src/arm_common/conv_bias/int8/direct.h" | ||||
#include "src/arm_common/conv_bias/intrinsic_helper.h" | #include "src/arm_common/conv_bias/intrinsic_helper.h" | ||||
#include "src/arm_common/elemwise_op.h" | |||||
#include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | #include "src/arm_common/simd_macro/marm_neon.h" | ||||
#include "src/common/utils.h" | #include "src/common/utils.h" | ||||
#include "src/fallback/conv_bias/common.h" | #include "src/fallback/conv_bias/common.h" | ||||
@@ -14,7 +14,7 @@ | |||||
#include "src/arm_common/conv_bias/int8/algos.h" | #include "src/arm_common/conv_bias/int8/algos.h" | ||||
#include "src/arm_common/conv_bias/int8/direct_nchw_nchw44_kern.h" | #include "src/arm_common/conv_bias/int8/direct_nchw_nchw44_kern.h" | ||||
#include "src/arm_common/conv_bias/int8/strategy.h" | #include "src/arm_common/conv_bias/int8/strategy.h" | ||||
#include "src/arm_common/elemwise_op.h" | |||||
#include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
#include "src/common/nchw_nchwxx_valid.h" | #include "src/common/nchw_nchwxx_valid.h" | ||||
#include "src/common/opr_delegate.h" | #include "src/common/opr_delegate.h" | ||||
@@ -13,7 +13,7 @@ | |||||
#include "megdnn/arch.h" | #include "megdnn/arch.h" | ||||
#include "src/arm_common/conv_bias/intrinsic_helper.h" | #include "src/arm_common/conv_bias/intrinsic_helper.h" | ||||
#include "src/arm_common/conv_bias/opr_impl.h" | #include "src/arm_common/conv_bias/opr_impl.h" | ||||
#include "src/arm_common/elemwise_op.h" | |||||
#include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | #include "src/arm_common/simd_macro/marm_neon.h" | ||||
#include "src/common/unroll_macro.h" | #include "src/common/unroll_macro.h" | ||||
#include "src/common/utils.h" | #include "src/common/utils.h" | ||||
@@ -15,7 +15,7 @@ | |||||
#include "src/arm_common/conv_bias/block_helper.h" | #include "src/arm_common/conv_bias/block_helper.h" | ||||
#include "src/arm_common/conv_bias/int8/algos.h" | #include "src/arm_common/conv_bias/int8/algos.h" | ||||
#include "src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_kern.h" | #include "src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_kern.h" | ||||
#include "src/arm_common/elemwise_op.h" | |||||
#include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
#include "src/common/nchw_nchwxx_valid.h" | #include "src/common/nchw_nchwxx_valid.h" | ||||
#include "midout.h" | #include "midout.h" | ||||
@@ -14,7 +14,7 @@ | |||||
#include "src/arm_common/conv_bias/intrinsic_helper.h" | #include "src/arm_common/conv_bias/intrinsic_helper.h" | ||||
#if MGB_ENABLE_DOT | #if MGB_ENABLE_DOT | ||||
#include "src/arm_common/elemwise_op.h" | |||||
#include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | #include "src/arm_common/simd_macro/marm_neon.h" | ||||
#include "src/common/unroll_macro.h" | #include "src/common/unroll_macro.h" | ||||
#include "src/common/utils.h" | #include "src/common/utils.h" | ||||
@@ -14,7 +14,7 @@ | |||||
#include "megdnn/oprs.h" | #include "megdnn/oprs.h" | ||||
#include "src/arm_common/conv_bias/int8/direct.h" | #include "src/arm_common/conv_bias/int8/direct.h" | ||||
#include "src/arm_common/conv_bias/int8/strategy.h" | #include "src/arm_common/conv_bias/int8/strategy.h" | ||||
#include "src/arm_common/elemwise_op.h" | |||||
#include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
#include "src/common/opr_delegate.h" | #include "src/common/opr_delegate.h" | ||||
using namespace megdnn; | using namespace megdnn; | ||||
@@ -14,7 +14,7 @@ | |||||
#include "megdnn/oprs.h" | #include "megdnn/oprs.h" | ||||
#include "src/arm_common/conv_bias/int8/direct_dotprod.h" | #include "src/arm_common/conv_bias/int8/direct_dotprod.h" | ||||
#include "src/arm_common/conv_bias/int8/strategy.h" | #include "src/arm_common/conv_bias/int8/strategy.h" | ||||
#include "src/arm_common/elemwise_op.h" | |||||
#include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
#include "src/common/opr_delegate.h" | #include "src/common/opr_delegate.h" | ||||
using namespace megdnn; | using namespace megdnn; | ||||
@@ -14,7 +14,7 @@ | |||||
#include "megdnn/oprs.h" | #include "megdnn/oprs.h" | ||||
#include "src/arm_common/conv_bias/int8/direct.h" | #include "src/arm_common/conv_bias/int8/direct.h" | ||||
#include "src/arm_common/conv_bias/int8/strategy.h" | #include "src/arm_common/conv_bias/int8/strategy.h" | ||||
#include "src/arm_common/elemwise_op.h" | |||||
#include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
#include "src/common/opr_delegate.h" | #include "src/common/opr_delegate.h" | ||||
using namespace megdnn; | using namespace megdnn; | ||||
@@ -14,7 +14,7 @@ | |||||
#include "megdnn/oprs.h" | #include "megdnn/oprs.h" | ||||
#include "src/arm_common/conv_bias/int8/direct_dotprod.h" | #include "src/arm_common/conv_bias/int8/direct_dotprod.h" | ||||
#include "src/arm_common/conv_bias/int8/strategy.h" | #include "src/arm_common/conv_bias/int8/strategy.h" | ||||
#include "src/arm_common/elemwise_op.h" | |||||
#include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
#include "src/common/opr_delegate.h" | #include "src/common/opr_delegate.h" | ||||
using namespace megdnn; | using namespace megdnn; | ||||
@@ -11,7 +11,7 @@ | |||||
*/ | */ | ||||
#include "src/arm_common/conv_bias/int8x8x16/channel_wise_kernel.h" | #include "src/arm_common/conv_bias/int8x8x16/channel_wise_kernel.h" | ||||
#include "src/arm_common/elemwise_op.h" | |||||
#include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | #include "src/arm_common/simd_macro/marm_neon.h" | ||||
#include "src/common/unroll_macro.h" | #include "src/common/unroll_macro.h" | ||||
#include "src/common/utils.h" | #include "src/common/utils.h" | ||||
@@ -15,7 +15,7 @@ | |||||
#include "src/arm_common/conv_bias/block_helper.h" | #include "src/arm_common/conv_bias/block_helper.h" | ||||
#include "src/arm_common/conv_bias/int8x8x16/algos.h" | #include "src/arm_common/conv_bias/int8x8x16/algos.h" | ||||
#include "src/arm_common/conv_bias/int8x8x16/direct_nchw_nchw44_kern.h" | #include "src/arm_common/conv_bias/int8x8x16/direct_nchw_nchw44_kern.h" | ||||
#include "src/arm_common/elemwise_op.h" | |||||
#include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
#include "src/common/nchw_nchwxx_valid.h" | #include "src/common/nchw_nchwxx_valid.h" | ||||
#include "src/common/opr_delegate.h" | #include "src/common/opr_delegate.h" | ||||
@@ -13,7 +13,7 @@ | |||||
#include "megdnn/arch.h" | #include "megdnn/arch.h" | ||||
#include "src/arm_common/conv_bias/intrinsic_helper.h" | #include "src/arm_common/conv_bias/intrinsic_helper.h" | ||||
#include "src/arm_common/conv_bias/opr_impl.h" | #include "src/arm_common/conv_bias/opr_impl.h" | ||||
#include "src/arm_common/elemwise_op.h" | |||||
#include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | #include "src/arm_common/simd_macro/marm_neon.h" | ||||
#include "src/common/unroll_macro.h" | #include "src/common/unroll_macro.h" | ||||
#include "src/common/utils.h" | #include "src/common/utils.h" | ||||
@@ -16,7 +16,7 @@ | |||||
#include "src/arm_common/conv_bias/int8x8x16/direct_nchw_nchw44_kern.h" | #include "src/arm_common/conv_bias/int8x8x16/direct_nchw_nchw44_kern.h" | ||||
#include "src/arm_common/conv_bias/intrinsic_helper.h" | #include "src/arm_common/conv_bias/intrinsic_helper.h" | ||||
#include "src/arm_common/conv_bias/opr_impl.h" | #include "src/arm_common/conv_bias/opr_impl.h" | ||||
#include "src/arm_common/elemwise_op.h" | |||||
#include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | #include "src/arm_common/simd_macro/marm_neon.h" | ||||
#include "src/common/unroll_macro.h" | #include "src/common/unroll_macro.h" | ||||
#include "src/common/utils.h" | #include "src/common/utils.h" | ||||
@@ -12,7 +12,7 @@ | |||||
#include "megdnn/dtype.h" | #include "megdnn/dtype.h" | ||||
#include "megdnn/oprs.h" | #include "megdnn/oprs.h" | ||||
#include "src/arm_common/elemwise_op.h" | |||||
#include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | #include "src/arm_common/simd_macro/marm_neon.h" | ||||
#include "src/common/utils.h" | #include "src/common/utils.h" | ||||
#include "src/fallback/conv_bias/common.h" | #include "src/fallback/conv_bias/common.h" | ||||
@@ -13,8 +13,8 @@ | |||||
#pragma once | #pragma once | ||||
#include "megdnn/basic_types.h" | #include "megdnn/basic_types.h" | ||||
#include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
#include "src/arm_common/elemwise_helper/kimpl/op_base.h" | #include "src/arm_common/elemwise_helper/kimpl/op_base.h" | ||||
#include "src/arm_common/elemwise_op.h" | |||||
#include "src/fallback/conv_bias/opr_impl.h" | #include "src/fallback/conv_bias/opr_impl.h" | ||||
#include "midout.h" | #include "midout.h" | ||||
@@ -44,29 +44,29 @@ namespace { | |||||
break; | break; | ||||
#define FOR_NONLINEAR_UNARY(_op) \ | #define FOR_NONLINEAR_UNARY(_op) \ | ||||
megdnn::arm_common::OpCallerUnary<_op<ctype>, megdnn::VEC>::run( \ | |||||
megdnn::elemwise::OpCallerUnary<_op<ctype>, megdnn::elemwise::VEC>::run( \ | |||||
static_cast<ctype*>(conv_dst_ptr), reinterpret_cast<ctype*>(dst_ptr), \ | static_cast<ctype*>(conv_dst_ptr), reinterpret_cast<ctype*>(dst_ptr), \ | ||||
bias_type, dst_type, N* OC* OH* OW* pack_oc_size); | bias_type, dst_type, N* OC* OH* OW* pack_oc_size); | ||||
#define FOR_NONLINEAR_BINARY_BROADCAST(_op) \ | |||||
megdnn::arm_common::OpCallerBinary<_op<ctype>, megdnn::VEC_BCAST101>::run( \ | |||||
static_cast<ctype*>(conv_dst_ptr), \ | |||||
reinterpret_cast<const ctype*>(bias_ptr), \ | |||||
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, N, OC, \ | |||||
#define FOR_NONLINEAR_BINARY_BROADCAST(_op) \ | |||||
megdnn::elemwise::OpCallerBinary<_op<ctype>, megdnn::elemwise::VEC_BCAST101>::run( \ | |||||
static_cast<ctype*>(conv_dst_ptr), \ | |||||
reinterpret_cast<const ctype*>(bias_ptr), \ | |||||
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, N, OC, \ | |||||
OH* OW); | OH* OW); | ||||
#define FOR_NONLINEAR_BINARY_BROADCAST_NCHWXX(_op) \ | #define FOR_NONLINEAR_BINARY_BROADCAST_NCHWXX(_op) \ | ||||
megdnn::arm_common::OpCallerBinary<_op<ctype>, megdnn::VEC_BCAST101xX>::run( \ | |||||
static_cast<ctype*>(conv_dst_ptr), \ | |||||
reinterpret_cast<const ctype*>(bias_ptr), \ | |||||
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, N, OC, \ | |||||
OH* OW, pack_oc_size); | |||||
#define FOR_NONLINEAR_BINARY(_op) \ | |||||
megdnn::arm_common::OpCallerBinary<_op<ctype>, megdnn::VEC_VEC>::run( \ | |||||
static_cast<ctype*>(conv_dst_ptr), \ | |||||
reinterpret_cast<const ctype*>(bias_ptr), \ | |||||
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, \ | |||||
megdnn::elemwise::OpCallerBinary<_op<ctype>, megdnn::elemwise::VEC_BCAST101xX>:: \ | |||||
run(static_cast<ctype*>(conv_dst_ptr), \ | |||||
reinterpret_cast<const ctype*>(bias_ptr), \ | |||||
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, N, \ | |||||
OC, OH* OW, pack_oc_size); | |||||
#define FOR_NONLINEAR_BINARY(_op) \ | |||||
megdnn::elemwise::OpCallerBinary<_op<ctype>, megdnn::elemwise::VEC_VEC>::run( \ | |||||
static_cast<ctype*>(conv_dst_ptr), \ | |||||
reinterpret_cast<const ctype*>(bias_ptr), \ | |||||
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, \ | |||||
N* OC* OH* OW* pack_oc_size); | N* OC* OH* OW* pack_oc_size); | ||||
#define FOR_BIAS(_mode) \ | #define FOR_BIAS(_mode) \ | ||||
@@ -167,33 +167,35 @@ struct PostProcess<ctype, dtype, megdnn::PostprocessMode::NO_PROCESS> { | |||||
#undef FOR_BIAS | #undef FOR_BIAS | ||||
#undef HANDLE_IDENTITY | #undef HANDLE_IDENTITY | ||||
#define FOR_NONLINEAR_UNARY(_op) \ | |||||
megdnn::arm_common::OpCallerUnary<_op<opctype, opdtype>, megdnn::VEC>::run( \ | |||||
static_cast<opctype*>(conv_dst_ptr), reinterpret_cast<opdtype*>(dst_ptr), \ | |||||
bias_type, dst_type, N* OC* OH* OW* pack_oc_size); | |||||
#define FOR_NONLINEAR_BINARY_BROADCAST(_op) \ | |||||
megdnn::arm_common::OpCallerBinary<_op<opctype, opdtype>, megdnn::VEC_BCAST101>:: \ | |||||
run(static_cast<opctype*>(conv_dst_ptr), \ | |||||
reinterpret_cast<const opctype*>(bias_ptr), \ | |||||
reinterpret_cast<opdtype*>(dst_ptr), bias_type, bias_type, dst_type, \ | |||||
#define FOR_NONLINEAR_UNARY(_op) \ | |||||
megdnn::elemwise::OpCallerUnary<_op<opctype, opdtype>, megdnn::elemwise::VEC>:: \ | |||||
run(static_cast<opctype*>(conv_dst_ptr), \ | |||||
reinterpret_cast<opdtype*>(dst_ptr), bias_type, dst_type, \ | |||||
N* OC* OH* OW* pack_oc_size); | |||||
#define FOR_NONLINEAR_BINARY_BROADCAST(_op) \ | |||||
megdnn::elemwise::OpCallerBinary< \ | |||||
_op<opctype, opdtype>, megdnn::elemwise::VEC_BCAST101>:: \ | |||||
run(static_cast<opctype*>(conv_dst_ptr), \ | |||||
reinterpret_cast<const opctype*>(bias_ptr), \ | |||||
reinterpret_cast<opdtype*>(dst_ptr), bias_type, bias_type, dst_type, \ | |||||
N, OC, OH* OW); | N, OC, OH* OW); | ||||
#define FOR_NONLINEAR_BINARY_BROADCAST_NCHWXX(_op) \ | |||||
megdnn::arm_common:: \ | |||||
OpCallerBinary<_op<opctype, opdtype>, megdnn::VEC_BCAST101xX>::run( \ | |||||
static_cast<opctype*>(conv_dst_ptr), \ | |||||
reinterpret_cast<const opctype*>(bias_ptr), \ | |||||
reinterpret_cast<opdtype*>(dst_ptr), bias_type, bias_type, \ | |||||
dst_type, N, OC, OH* OW, pack_oc_size); | |||||
#define FOR_NONLINEAR_BINARY_BROADCAST_NCHW88(_op) \ | |||||
megdnn::arm_common:: \ | |||||
OpCallerBinary<_op<opctype, opdtype>, megdnn::VEC_BCAST101xX>::run( \ | |||||
static_cast<opctype*>(conv_dst_ptr), \ | |||||
reinterpret_cast<const opctype*>(bias_ptr), \ | |||||
reinterpret_cast<opdtype*>(dst_ptr), bias_type, bias_type, \ | |||||
dst_type, N, OC, OH* OW, pack_oc_size); | |||||
#define FOR_NONLINEAR_BINARY_BROADCAST_NCHWXX(_op) \ | |||||
megdnn::elemwise::OpCallerBinary< \ | |||||
_op<opctype, opdtype>, megdnn::elemwise::VEC_BCAST101xX>:: \ | |||||
run(static_cast<opctype*>(conv_dst_ptr), \ | |||||
reinterpret_cast<const opctype*>(bias_ptr), \ | |||||
reinterpret_cast<opdtype*>(dst_ptr), bias_type, bias_type, dst_type, \ | |||||
N, OC, OH* OW, pack_oc_size); | |||||
#define FOR_NONLINEAR_BINARY_BROADCAST_NCHW88(_op) \ | |||||
megdnn::elemwise::OpCallerBinary< \ | |||||
_op<opctype, opdtype>, megdnn::elemwise::VEC_BCAST101xX>:: \ | |||||
run(static_cast<opctype*>(conv_dst_ptr), \ | |||||
reinterpret_cast<const opctype*>(bias_ptr), \ | |||||
reinterpret_cast<opdtype*>(dst_ptr), bias_type, bias_type, dst_type, \ | |||||
N, OC, OH* OW, pack_oc_size); | |||||
#define HANDLE_IDENTITY(_caller, _op) \ | #define HANDLE_IDENTITY(_caller, _op) \ | ||||
case megdnn::NonlineMode::IDENTITY: \ | case megdnn::NonlineMode::IDENTITY: \ | ||||
@@ -267,25 +269,25 @@ struct PostProcess<opctype, opdtype, megdnn::PostprocessMode::QUANTIZED> { | |||||
#undef FOR_NONLINEAR | #undef FOR_NONLINEAR | ||||
#undef FOR_BIAS | #undef FOR_BIAS | ||||
#define FOR_BINARY_BROADCAST(_op) \ | |||||
megdnn::arm_common::OpCallerBinary<_op<ctype>, megdnn::VEC_BCAST101>::run( \ | |||||
static_cast<ctype*>(conv_dst_ptr), \ | |||||
reinterpret_cast<const ctype*>(bias_ptr), \ | |||||
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, N, OC, \ | |||||
#define FOR_BINARY_BROADCAST(_op) \ | |||||
megdnn::elemwise::OpCallerBinary<_op<ctype>, megdnn::elemwise::VEC_BCAST101>::run( \ | |||||
static_cast<ctype*>(conv_dst_ptr), \ | |||||
reinterpret_cast<const ctype*>(bias_ptr), \ | |||||
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, N, OC, \ | |||||
OH* OW); | OH* OW); | ||||
#define FOR_BINARY_BROADCAST_NCHWXX(_op) \ | #define FOR_BINARY_BROADCAST_NCHWXX(_op) \ | ||||
megdnn::arm_common::OpCallerBinary<_op<ctype>, megdnn::VEC_BCAST101xX>::run( \ | |||||
static_cast<ctype*>(conv_dst_ptr), \ | |||||
reinterpret_cast<const ctype*>(bias_ptr), \ | |||||
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, N, OC, \ | |||||
OH* OW, pack_oc_size); | |||||
#define FOR_BINARY(_op) \ | |||||
megdnn::arm_common::OpCallerBinary<_op<ctype>, megdnn::VEC_VEC>::run( \ | |||||
static_cast<ctype*>(conv_dst_ptr), \ | |||||
reinterpret_cast<const ctype*>(bias_ptr), \ | |||||
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, \ | |||||
megdnn::elemwise::OpCallerBinary<_op<ctype>, megdnn::elemwise::VEC_BCAST101xX>:: \ | |||||
run(static_cast<ctype*>(conv_dst_ptr), \ | |||||
reinterpret_cast<const ctype*>(bias_ptr), \ | |||||
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, N, \ | |||||
OC, OH* OW, pack_oc_size); | |||||
#define FOR_BINARY(_op) \ | |||||
megdnn::elemwise::OpCallerBinary<_op<ctype>, megdnn::elemwise::VEC_VEC>::run( \ | |||||
static_cast<ctype*>(conv_dst_ptr), \ | |||||
reinterpret_cast<const ctype*>(bias_ptr), \ | |||||
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, \ | |||||
N* OC* OH* OW* pack_oc_size); | N* OC* OH* OW* pack_oc_size); | ||||
#define FOR_BIAS(_bias_mode, OH, OW) \ | #define FOR_BIAS(_bias_mode, OH, OW) \ | ||||
@@ -15,7 +15,7 @@ | |||||
#include "src/arm_common/conv_bias/quint8/stride1_dotprod.h" | #include "src/arm_common/conv_bias/quint8/stride1_dotprod.h" | ||||
#include "src/arm_common/conv_bias/quint8/stride2.h" | #include "src/arm_common/conv_bias/quint8/stride2.h" | ||||
#include "src/arm_common/conv_bias/quint8/stride2_dotprod.h" | #include "src/arm_common/conv_bias/quint8/stride2_dotprod.h" | ||||
#include "src/arm_common/elemwise_op.h" | |||||
#include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
#include "src/fallback/conv_bias/common.h" | #include "src/fallback/conv_bias/common.h" | ||||
#include "midout.h" | #include "midout.h" | ||||
@@ -10,7 +10,7 @@ | |||||
*/ | */ | ||||
#include "src/arm_common/conv_bias/quint8/direct.h" | #include "src/arm_common/conv_bias/quint8/direct.h" | ||||
#include "src/arm_common/elemwise_op.h" | |||||
#include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | #include "src/arm_common/simd_macro/marm_neon.h" | ||||
#include "src/common/utils.h" | #include "src/common/utils.h" | ||||
#include "src/fallback/conv_bias/common.h" | #include "src/fallback/conv_bias/common.h" | ||||
@@ -11,7 +11,7 @@ | |||||
#include "src/arm_common/conv_bias/quint8/direct_dotprod.h" | #include "src/arm_common/conv_bias/quint8/direct_dotprod.h" | ||||
#if MGB_ENABLE_DOT | #if MGB_ENABLE_DOT | ||||
#include "src/arm_common/elemwise_op.h" | |||||
#include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | #include "src/arm_common/simd_macro/marm_neon.h" | ||||
#include "src/common/utils.h" | #include "src/common/utils.h" | ||||
#include "src/fallback/conv_bias/common.h" | #include "src/fallback/conv_bias/common.h" | ||||
@@ -12,7 +12,7 @@ | |||||
#include "src/arm_common/conv_bias/quint8/stride1.h" | #include "src/arm_common/conv_bias/quint8/stride1.h" | ||||
#include "megdnn/oprs.h" | #include "megdnn/oprs.h" | ||||
#include "src/arm_common/conv_bias/quint8/direct.h" | #include "src/arm_common/conv_bias/quint8/direct.h" | ||||
#include "src/arm_common/elemwise_op.h" | |||||
#include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
#include "src/common/opr_delegate.h" | #include "src/common/opr_delegate.h" | ||||
using namespace megdnn; | using namespace megdnn; | ||||
@@ -12,7 +12,7 @@ | |||||
#if MGB_ENABLE_DOT | #if MGB_ENABLE_DOT | ||||
#include "megdnn/oprs.h" | #include "megdnn/oprs.h" | ||||
#include "src/arm_common/conv_bias/quint8/direct_dotprod.h" | #include "src/arm_common/conv_bias/quint8/direct_dotprod.h" | ||||
#include "src/arm_common/elemwise_op.h" | |||||
#include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
#include "src/common/opr_delegate.h" | #include "src/common/opr_delegate.h" | ||||
using namespace megdnn; | using namespace megdnn; | ||||
@@ -12,7 +12,7 @@ | |||||
#include "src/arm_common/conv_bias/quint8/stride2.h" | #include "src/arm_common/conv_bias/quint8/stride2.h" | ||||
#include "megdnn/oprs.h" | #include "megdnn/oprs.h" | ||||
#include "src/arm_common/conv_bias/quint8/direct.h" | #include "src/arm_common/conv_bias/quint8/direct.h" | ||||
#include "src/arm_common/elemwise_op.h" | |||||
#include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
#include "src/common/opr_delegate.h" | #include "src/common/opr_delegate.h" | ||||
using namespace megdnn; | using namespace megdnn; | ||||
@@ -12,7 +12,7 @@ | |||||
#if MGB_ENABLE_DOT | #if MGB_ENABLE_DOT | ||||
#include "megdnn/oprs.h" | #include "megdnn/oprs.h" | ||||
#include "src/arm_common/conv_bias/quint8/direct_dotprod.h" | #include "src/arm_common/conv_bias/quint8/direct_dotprod.h" | ||||
#include "src/arm_common/elemwise_op.h" | |||||
#include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
#include "src/common/opr_delegate.h" | #include "src/common/opr_delegate.h" | ||||
using namespace megdnn; | using namespace megdnn; | ||||
@@ -10,7 +10,7 @@ | |||||
* implied. | * implied. | ||||
*/ | */ | ||||
#include "src/arm_common/elemwise/binary/algo.h" | #include "src/arm_common/elemwise/binary/algo.h" | ||||
#include "src/arm_common/elemwise_op.h" | |||||
#include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
#include "src/common/utils.h" | #include "src/common/utils.h" | ||||
#include "src/naive/handle.h" | #include "src/naive/handle.h" | ||||
@@ -20,6 +20,7 @@ | |||||
MIDOUT_DECL(megdnn_arm_common_elemwise_binary) | MIDOUT_DECL(megdnn_arm_common_elemwise_binary) | ||||
using namespace megdnn; | using namespace megdnn; | ||||
using namespace elemwise; | |||||
using namespace arm_common; | using namespace arm_common; | ||||
namespace { | namespace { | ||||
@@ -160,7 +161,7 @@ bool ElemwiseImpl::AlgoBinaryVecBcast101xX::is_available( | |||||
DISPATCH_BINARY(ADD, _case, _type, _type_midout_id, AddOp); \ | DISPATCH_BINARY(ADD, _case, _type, _type_midout_id, AddOp); \ | ||||
DISPATCH_BINARY(SUB, _case, _type, _type_midout_id, SubOp); \ | DISPATCH_BINARY(SUB, _case, _type, _type_midout_id, SubOp); \ | ||||
DISPATCH_BINARY(MUL, _case, _type, _type_midout_id, MulOp); \ | DISPATCH_BINARY(MUL, _case, _type, _type_midout_id, MulOp); \ | ||||
DISPATCH_BINARY(POW, _case, _type, _type_midout_id, PowOp); \ | |||||
DISPATCH_BINARY(POW, _case, _type, _type_midout_id, fallback::PowOp); \ | |||||
DISPATCH_BINARY(TRUE_DIV, _case, _type, _type_midout_id, TrueDivOp); \ | DISPATCH_BINARY(TRUE_DIV, _case, _type, _type_midout_id, TrueDivOp); \ | ||||
DISPATCH_BINARY(FUSE_ADD_RELU, _case, _type, _type_midout_id, FuseAddReluOp); \ | DISPATCH_BINARY(FUSE_ADD_RELU, _case, _type, _type_midout_id, FuseAddReluOp); \ | ||||
DISPATCH_BINARY( \ | DISPATCH_BINARY( \ | ||||
@@ -178,7 +179,7 @@ bool ElemwiseImpl::AlgoBinaryVecBcast101xX::is_available( | |||||
DISPATCH_BINARY(ADD, _case, _type, _type_midout_id, AddOp); \ | DISPATCH_BINARY(ADD, _case, _type, _type_midout_id, AddOp); \ | ||||
DISPATCH_BINARY(SUB, _case, _type, _type_midout_id, SubOp); \ | DISPATCH_BINARY(SUB, _case, _type, _type_midout_id, SubOp); \ | ||||
DISPATCH_BINARY(MUL, _case, _type, _type_midout_id, MulOp); \ | DISPATCH_BINARY(MUL, _case, _type, _type_midout_id, MulOp); \ | ||||
DISPATCH_BINARY(POW, _case, _type, _type_midout_id, PowOp); \ | |||||
DISPATCH_BINARY(POW, _case, _type, _type_midout_id, fallback::PowOp); \ | |||||
DISPATCH_BINARY(FUSE_ADD_RELU, _case, _type, _type_midout_id, FuseAddReluOp); \ | DISPATCH_BINARY(FUSE_ADD_RELU, _case, _type, _type_midout_id, FuseAddReluOp); \ | ||||
DISPATCH_BINARY( \ | DISPATCH_BINARY( \ | ||||
FUSE_ADD_H_SWISH, _case, _type, _type_midout_id, FuseAddHSwishOp); \ | FUSE_ADD_H_SWISH, _case, _type, _type_midout_id, FuseAddHSwishOp); \ | ||||
@@ -13,7 +13,7 @@ | |||||
#include "src/arm_common/elemwise/binary/algo.h" | #include "src/arm_common/elemwise/binary/algo.h" | ||||
#include "src/arm_common/elemwise/ternary/algo.h" | #include "src/arm_common/elemwise/ternary/algo.h" | ||||
#include "src/arm_common/elemwise/unary/algo.h" | #include "src/arm_common/elemwise/unary/algo.h" | ||||
#include "src/arm_common/elemwise_op.h" | |||||
#include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
#include "src/common/metahelper.h" | #include "src/common/metahelper.h" | ||||
#include "src/common/utils.h" | #include "src/common/utils.h" | ||||
@@ -12,7 +12,7 @@ | |||||
#pragma once | #pragma once | ||||
#include "src/fallback/elemwise/opr_impl.h" | #include "src/fallback/elemwise/opr_impl.h" | ||||
#include "src/arm_common/elemwise_op.h" | |||||
#include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
namespace megdnn { | namespace megdnn { | ||||
namespace arm_common { | namespace arm_common { | ||||
@@ -10,7 +10,7 @@ | |||||
* implied. | * implied. | ||||
*/ | */ | ||||
#include "src/arm_common/elemwise/ternary/algo.h" | #include "src/arm_common/elemwise/ternary/algo.h" | ||||
#include "src/arm_common/elemwise_op.h" | |||||
#include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
#include "src/common/utils.h" | #include "src/common/utils.h" | ||||
#include "src/naive/handle.h" | #include "src/naive/handle.h" | ||||
@@ -20,6 +20,7 @@ | |||||
MIDOUT_DECL(megdnn_arm_common_elemwise_ternary) | MIDOUT_DECL(megdnn_arm_common_elemwise_ternary) | ||||
using namespace megdnn; | using namespace megdnn; | ||||
using namespace elemwise; | |||||
using namespace arm_common; | using namespace arm_common; | ||||
#define DISPATCH_MODE_FLOAT(_case, _type, _type_midout_id) \ | #define DISPATCH_MODE_FLOAT(_case, _type, _type_midout_id) \ | ||||
@@ -10,7 +10,7 @@ | |||||
* implied. | * implied. | ||||
*/ | */ | ||||
#include "src/arm_common/elemwise/unary/algo.h" | #include "src/arm_common/elemwise/unary/algo.h" | ||||
#include "src/arm_common/elemwise_op.h" | |||||
#include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
#include "src/common/utils.h" | #include "src/common/utils.h" | ||||
#include "src/naive/handle.h" | #include "src/naive/handle.h" | ||||
@@ -20,6 +20,7 @@ | |||||
MIDOUT_DECL(megdnn_arm_common_elemwise_unary) | MIDOUT_DECL(megdnn_arm_common_elemwise_unary) | ||||
using namespace megdnn; | using namespace megdnn; | ||||
using namespace elemwise; | |||||
using namespace arm_common; | using namespace arm_common; | ||||
bool ElemwiseImpl::AlgoUnary::is_available(const KernParam& kern_param) const { | bool ElemwiseImpl::AlgoUnary::is_available(const KernParam& kern_param) const { | ||||
@@ -0,0 +1,151 @@ | |||||
/** | |||||
* \file dnn/src/arm_common/elemwise_helper/elemwise_op.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#pragma once | |||||
#include "src/arm_common/elemwise_helper/op_binary.h" | |||||
#include "src/arm_common/elemwise_helper/op_ternary.h" | |||||
#include "src/arm_common/elemwise_helper/op_unary.h" | |||||
#include "src/fallback/elemwise_helper/elemwise_op.h" | |||||
namespace megdnn { | |||||
namespace elemwise { | |||||
using BcastType = megdnn::elemwise::BcastType; | |||||
///////////////////////////////// ParamElemVistor /////////////////////////// | |||||
#define cb(_ctype, _inner_ctype, _neon_type, _fun_suffix) \ | |||||
template <> \ | |||||
struct ParamElemVisitor<_ctype> { \ | |||||
_neon_type operator()(const _ctype* src) const { \ | |||||
return vld1q_##_fun_suffix(reinterpret_cast<const _inner_ctype*>(src)); \ | |||||
} \ | |||||
}; \ | |||||
template <> \ | |||||
struct ParamElemVisitorDup<_ctype> { \ | |||||
_neon_type operator()(const _ctype* src) const { \ | |||||
return vdupq_n_##_fun_suffix(*reinterpret_cast<const _inner_ctype*>(src)); \ | |||||
} \ | |||||
} | |||||
cb(dt_quint8, uint8_t, uint8x16_t, u8); | |||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||||
cb(__fp16, __fp16, float16x8_t, f16); | |||||
#endif | |||||
cb(dt_int16, int16_t, int16x8_t, s16); | |||||
#undef cb | |||||
template <typename ctype> | |||||
struct ParamElemVisitorBcast101x4; | |||||
#define cb(_ctype, _inner_ctype, _neon_type, _fun_suffix, rel_suffix) \ | |||||
template <> \ | |||||
struct ParamElemVisitorBcast101x4<_ctype> { \ | |||||
_neon_type operator()(const _ctype* src) const { \ | |||||
return vreinterpretq_##_fun_suffix##_##rel_suffix(vld1q_dup_##rel_suffix( \ | |||||
reinterpret_cast<const _inner_ctype*>(src))); \ | |||||
} \ | |||||
} | |||||
cb(dt_quint8, uint32_t, uint8x16_t, u8, u32); | |||||
cb(dt_int16, int64_t, int16x8_t, s16, s64); | |||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||||
cb(__fp16, uint64_t, float16x8_t, f16, u64); | |||||
#endif | |||||
#undef cb | |||||
template <typename ctype> | |||||
struct ParamElemVisitorBcast101x8; | |||||
#define cb(_ctype, _inner_ctype, _neon_type, _fun_suffix) \ | |||||
template <> \ | |||||
struct ParamElemVisitorBcast101x8<_ctype> { \ | |||||
_neon_type operator()(const _ctype* src) const { \ | |||||
return vld1q_##_fun_suffix(reinterpret_cast<const _inner_ctype*>(src)); \ | |||||
} \ | |||||
} | |||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||||
cb(__fp16, __fp16, float16x8_t, f16); | |||||
#endif | |||||
#undef cb | |||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||||
template <> | |||||
struct OpCallerBinaryBcast101xXVec<__fp16, 8> { | |||||
using src_ctype = __fp16; | |||||
template <typename Op> | |||||
static void run( | |||||
const src_ctype* src0, const src_ctype* src1, typename Op::dst_ctype* dst, | |||||
const Op& op, size_t batch, size_t nr_channel_blocks, | |||||
size_t channel_stride) { | |||||
ParamElemVisitorBcast101x8<src_ctype> vis0; | |||||
ParamElemVisitor<src_ctype> vis1; | |||||
OpCallerBinaryBcast101xDVec<src_ctype, 8>::run( | |||||
src0, src1, dst, op, vis0, vis1, batch, nr_channel_blocks, | |||||
channel_stride); | |||||
} | |||||
}; | |||||
template <> | |||||
struct OpCallerBinaryVecBcast101xX<__fp16, 8> { | |||||
using src_ctype = __fp16; | |||||
template <typename Op> | |||||
static void run( | |||||
const src_ctype* src0, const src_ctype* src1, typename Op::dst_ctype* dst, | |||||
const Op& op, size_t batch, size_t nr_channel_blocks, | |||||
size_t channel_stride) { | |||||
ParamElemVisitor<src_ctype> vis0; | |||||
ParamElemVisitorBcast101x8<src_ctype> vis1; | |||||
OpCallerBinaryVecBcast101xD<src_ctype, 8>::run( | |||||
src0, src1, dst, op, vis0, vis1, batch, nr_channel_blocks, | |||||
channel_stride); | |||||
} | |||||
}; | |||||
template <> | |||||
struct OpCallerTernaryBcast101xXVecBcast101xX<__fp16, 8> { | |||||
using src_ctype = __fp16; | |||||
template <typename Op> | |||||
static void run( | |||||
const src_ctype* src0, const src_ctype* src1, const src_ctype* src2, | |||||
typename Op::dst_ctype* dst, const Op& op, size_t batch, | |||||
size_t nr_channel_blocks, size_t channel_stride) { | |||||
ParamElemVisitorBcast101x8<src_ctype> vis0; | |||||
ParamElemVisitor<src_ctype> vis1; | |||||
ParamElemVisitorBcast101x8<src_ctype> vis2; | |||||
OpCallerTernaryBcast101xDVecBcast101xD<src_ctype, 8>::run( | |||||
src0, src1, src2, dst, op, vis0, vis1, vis2, batch, nr_channel_blocks, | |||||
channel_stride); | |||||
} | |||||
}; | |||||
template <> | |||||
struct OpCallerTernaryVecBcast101xXVec<__fp16, 8> { | |||||
using src_ctype = __fp16; | |||||
template <typename Op> | |||||
static void run( | |||||
const src_ctype* src0, const src_ctype* src1, const src_ctype* src2, | |||||
typename Op::dst_ctype* dst, const Op& op, size_t batch, | |||||
size_t nr_channel_blocks, size_t channel_stride) { | |||||
ParamElemVisitor<src_ctype> vis0; | |||||
ParamElemVisitorBcast101x8<src_ctype> vis1; | |||||
ParamElemVisitor<src_ctype> vis2; | |||||
OpCallerTernaryVecBcast101xDVec<src_ctype, 8>::run( | |||||
src0, src1, src2, dst, op, vis0, vis1, vis2, batch, nr_channel_blocks, | |||||
channel_stride); | |||||
} | |||||
}; | |||||
#endif | |||||
} // namespace elemwise | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -1,36 +0,0 @@ | |||||
/** | |||||
* \file dnn/src/arm_common/elemwise_helper/kimpl/pow.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include "src/arm_common/elemwise_helper/kimpl/op_base.h" | |||||
namespace megdnn { | |||||
namespace arm_common { | |||||
// when __fp16 is avaliable POW is very slow, so add there | |||||
/////////////////////// POW float only //////////////////////////// | |||||
template <typename src_ctype, typename dst_ctype = src_ctype> | |||||
struct PowOp : BinaryOpBase<src_ctype, dst_ctype> { | |||||
using BinaryOpBase<src_ctype, dst_ctype>::BinaryOpBase; | |||||
constexpr static size_t SIMD_WIDTH = 1; | |||||
void operator()( | |||||
const src_ctype& src0, const src_ctype& src1, dst_ctype* dst) const { | |||||
*dst = operator()(src0, src1); | |||||
} | |||||
dst_ctype operator()(const src_ctype& src0, const src_ctype& src1) const { | |||||
return powf(src0, src1); | |||||
} | |||||
}; | |||||
} // namespace arm_common | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -18,7 +18,6 @@ | |||||
#include "src/arm_common/elemwise_helper/kimpl/max.h" | #include "src/arm_common/elemwise_helper/kimpl/max.h" | ||||
#include "src/arm_common/elemwise_helper/kimpl/min.h" | #include "src/arm_common/elemwise_helper/kimpl/min.h" | ||||
#include "src/arm_common/elemwise_helper/kimpl/mul.h" | #include "src/arm_common/elemwise_helper/kimpl/mul.h" | ||||
#include "src/arm_common/elemwise_helper/kimpl/pow.h" | |||||
#include "src/arm_common/elemwise_helper/kimpl/rmulh.h" | #include "src/arm_common/elemwise_helper/kimpl/rmulh.h" | ||||
#include "src/arm_common/elemwise_helper/kimpl/sub.h" | #include "src/arm_common/elemwise_helper/kimpl/sub.h" | ||||
#include "src/arm_common/elemwise_helper/kimpl/true_div.h" | #include "src/arm_common/elemwise_helper/kimpl/true_div.h" | ||||
@@ -15,7 +15,7 @@ | |||||
#include "src/common/elemwise_multi_type/kern_defs.cuh" | #include "src/common/elemwise_multi_type/kern_defs.cuh" | ||||
#include "src/naive/handle.h" | #include "src/naive/handle.h" | ||||
#include "src/arm_common/elemwise_op.h" | |||||
#include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | #include "src/arm_common/simd_macro/marm_neon.h" | ||||
namespace { | namespace { | ||||
@@ -46,6 +46,8 @@ void neon_round_shr_saturate_int16_static_k( | |||||
} // namespace | } // namespace | ||||
using namespace elemwise; | |||||
namespace megdnn { | namespace megdnn { | ||||
namespace arm_common { | namespace arm_common { | ||||
@@ -2,7 +2,7 @@ | |||||
* \file dnn/src/fallback/elemwise/gi_impl/binary/algo.cpp | * \file dnn/src/fallback/elemwise/gi_impl/binary/algo.cpp | ||||
*/ | */ | ||||
#include "src/fallback/elemwise/gi_impl/binary/algo.h" | #include "src/fallback/elemwise/gi_impl/binary/algo.h" | ||||
#include "src/fallback/elemwise_op.h" | |||||
#include "src/fallback/elemwise_helper/elemwise_op.h" | |||||
#include "src/common/utils.h" | #include "src/common/utils.h" | ||||
#include "src/naive/handle.h" | #include "src/naive/handle.h" | ||||
@@ -12,6 +12,7 @@ | |||||
MIDOUT_DECL(megdnn_fallback_elemwise_binary) | MIDOUT_DECL(megdnn_fallback_elemwise_binary) | ||||
using namespace megdnn; | using namespace megdnn; | ||||
using namespace elemwise; | |||||
using namespace fallback; | using namespace fallback; | ||||
namespace { | namespace { | ||||
@@ -3,7 +3,7 @@ | |||||
*/ | */ | ||||
#include "src/fallback/elemwise/gi_impl/ternary/algo.h" | #include "src/fallback/elemwise/gi_impl/ternary/algo.h" | ||||
#include "src/fallback/elemwise_op.h" | |||||
#include "src/fallback/elemwise_helper/elemwise_op.h" | |||||
#include "src/common/utils.h" | #include "src/common/utils.h" | ||||
#include "src/naive/handle.h" | #include "src/naive/handle.h" | ||||
@@ -13,6 +13,7 @@ | |||||
MIDOUT_DECL(megdnn_fallback_elemwise_ternary) | MIDOUT_DECL(megdnn_fallback_elemwise_ternary) | ||||
using namespace megdnn; | using namespace megdnn; | ||||
using namespace elemwise; | |||||
using namespace fallback; | using namespace fallback; | ||||
#define DISPATCH_MODE_FLOAT(_case, _type, _type_midout_id) \ | #define DISPATCH_MODE_FLOAT(_case, _type, _type_midout_id) \ | ||||
@@ -2,7 +2,7 @@ | |||||
* \file dnn/src/fallback/elemwise/gi_impl/unary/algo.cpp | * \file dnn/src/fallback/elemwise/gi_impl/unary/algo.cpp | ||||
*/ | */ | ||||
#include "src/fallback/elemwise/gi_impl/unary/algo.h" | #include "src/fallback/elemwise/gi_impl/unary/algo.h" | ||||
#include "src/fallback/elemwise_op.h" | |||||
#include "src/fallback/elemwise_helper/elemwise_op.h" | |||||
#include "src/common/utils.h" | #include "src/common/utils.h" | ||||
#include "src/naive/handle.h" | #include "src/naive/handle.h" | ||||
@@ -12,6 +12,7 @@ | |||||
MIDOUT_DECL(megdnn_fallback_elemwise_unary) | MIDOUT_DECL(megdnn_fallback_elemwise_unary) | ||||
using namespace megdnn; | using namespace megdnn; | ||||
using namespace elemwise; | |||||
using namespace fallback; | using namespace fallback; | ||||
bool ElemwiseImpl::AlgoUnary::is_available(const KernParam& kern_param) const { | bool ElemwiseImpl::AlgoUnary::is_available(const KernParam& kern_param) const { | ||||
@@ -25,6 +25,7 @@ MIDOUT_DECL(megdnn_fallback_elemwise_exec_BINARY_INT) | |||||
MIDOUT_DECL(megdnn_fallback_elemwise_exec_BINARY_FLOAT) | MIDOUT_DECL(megdnn_fallback_elemwise_exec_BINARY_FLOAT) | ||||
using namespace megdnn; | using namespace megdnn; | ||||
using namespace elemwise; | |||||
using namespace fallback; | using namespace fallback; | ||||
void ElemwiseImpl::exec(const TensorNDArray& srcs, _megdnn_tensor_out dst) { | void ElemwiseImpl::exec(const TensorNDArray& srcs, _megdnn_tensor_out dst) { | ||||
@@ -9,7 +9,7 @@ | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
*/ | */ | ||||
#pragma once | #pragma once | ||||
#include "src/fallback/elemwise_op.h" | |||||
#include "src/fallback/elemwise_helper/elemwise_op.h" | |||||
#include "src/naive/elemwise/opr_impl.h" | #include "src/naive/elemwise/opr_impl.h" | ||||
namespace megdnn { | namespace megdnn { | ||||
@@ -60,7 +60,7 @@ private: | |||||
public: | public: | ||||
class AlgoBase; | class AlgoBase; | ||||
struct KernParam { | struct KernParam { | ||||
BcastType broad_cast_type; | |||||
elemwise::BcastType broad_cast_type; | |||||
Mode mode; | Mode mode; | ||||
const TensorND* m_dst; | const TensorND* m_dst; | ||||
Handle* handle; | Handle* handle; | ||||
@@ -0,0 +1,72 @@ | |||||
/** | |||||
* \file dnn/src/fallback/elemwise_helper/elemwise_op.h | |||||
*/ | |||||
#pragma once | |||||
#include "src/fallback/elemwise_helper/op_binary.h" | |||||
#include "src/fallback/elemwise_helper/op_common.h" | |||||
#include "src/fallback/elemwise_helper/op_ternary.h" | |||||
#include "src/fallback/elemwise_helper/op_unary.h" | |||||
#include "src/fallback/general_intrinsic/gi_float.h" | |||||
#include "src/fallback/general_intrinsic/gi_int.h" | |||||
namespace megdnn { | |||||
namespace elemwise { | |||||
///////////////////////////////// ParamElemVistor /////////////////////////// | |||||
#define cb(_ctype, _inner_ctype, _simd_type, _fun_suffix) \ | |||||
template <> \ | |||||
struct ParamElemVisitor<_ctype> { \ | |||||
_simd_type operator()(const _ctype* src) const { \ | |||||
return GiLoad##_fun_suffix(src); \ | |||||
} \ | |||||
}; \ | |||||
template <> \ | |||||
struct ParamElemVisitorDup<_ctype> { \ | |||||
_simd_type operator()(const _ctype* src) const { \ | |||||
return GiBroadcast##_fun_suffix( \ | |||||
*reinterpret_cast<const _inner_ctype*>(src)); \ | |||||
} \ | |||||
} | |||||
cb(dt_qint32, int32_t, GI_INT32_t, Int32); | |||||
cb(dt_qint8, int8_t, GI_INT8_t, Int8); | |||||
cb(dt_float32, float, GI_FLOAT32_t, Float32); | |||||
cb(dt_int32, int32_t, GI_INT32_t, Int32); | |||||
cb(dt_int8, int8_t, GI_INT8_t, Int8); | |||||
#undef cb | |||||
template <typename ctype> | |||||
struct ParamElemVisitorBcast101x4; | |||||
#define cb(_ctype, _inner_ctype, _simd_type, _fun_suffix, rel_suffix) \ | |||||
template <> \ | |||||
struct ParamElemVisitorBcast101x4<_ctype> { \ | |||||
_simd_type operator()(const _ctype* src) const { \ | |||||
return GiReinter##rel_suffix##To##_fun_suffix(GiBroadcast##rel_suffix( \ | |||||
*reinterpret_cast<const _inner_ctype*>(src))); \ | |||||
} \ | |||||
} | |||||
cb(dt_qint8, int32_t, GI_INT8_t, Int8, Int32); | |||||
cb(dt_int8, int32_t, GI_INT8_t, Int8, Int32); | |||||
#undef cb | |||||
#define cb(_ctype, _inner_ctype, _simd_type, _fun_suffix) \ | |||||
template <> \ | |||||
struct ParamElemVisitorBcast101x4<_ctype> { \ | |||||
_simd_type operator()(const _ctype* src) const { \ | |||||
return GiLoad##_fun_suffix(src); \ | |||||
} \ | |||||
} | |||||
cb(dt_qint32, int32_t, GI_INT32_t, Int32); | |||||
cb(dt_float32, float, GI_FLOAT32_t, Float32); | |||||
cb(dt_int32, int32_t, GI_INT32_t, Int32); | |||||
#undef cb | |||||
} // namespace elemwise | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -58,7 +58,7 @@ struct AbsOpBase<dt_qint8, dt_qint8> : UnaryOpBase<dt_qint8, dt_qint8> { | |||||
template <> | template <> | ||||
struct AbsOp<dt_qint8, dt_qint8> : AbsOpBase<dt_qint8, dt_qint8> { | struct AbsOp<dt_qint8, dt_qint8> : AbsOpBase<dt_qint8, dt_qint8> { | ||||
using AbsOpBase::AbsOpBase; | using AbsOpBase::AbsOpBase; | ||||
constexpr static size_t SIMD_WIDTH = 16; | |||||
constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int8_t); | |||||
using AbsOpBase::operator(); | using AbsOpBase::operator(); | ||||
void operator()(const GI_INT8_V2_t& vsrc, dt_qint8* dst) const { | void operator()(const GI_INT8_V2_t& vsrc, dt_qint8* dst) const { | ||||
OPERATOR_UNARY_QINT8_FALLBACK; | OPERATOR_UNARY_QINT8_FALLBACK; | ||||
@@ -87,7 +87,7 @@ template <> | |||||
struct FuseAddHSwishOp<dt_qint32, dt_qint8> : FuseAddHSwishOpBase<dt_qint32, dt_qint8> { | struct FuseAddHSwishOp<dt_qint32, dt_qint8> : FuseAddHSwishOpBase<dt_qint32, dt_qint8> { | ||||
using FuseAddHSwishOpBase::FuseAddHSwishOpBase; | using FuseAddHSwishOpBase::FuseAddHSwishOpBase; | ||||
using FuseAddHSwishOpBase::operator(); | using FuseAddHSwishOpBase::operator(); | ||||
constexpr static size_t SIMD_WIDTH = 4; | |||||
constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int8_t); | |||||
void operator()( | void operator()( | ||||
const GI_INT32_V2_t& vsrc0, const GI_INT32_V2_t& vsrc1, | const GI_INT32_V2_t& vsrc0, const GI_INT32_V2_t& vsrc1, | ||||
dt_qint8* dst) const { | dt_qint8* dst) const { | ||||
@@ -83,7 +83,7 @@ template <> | |||||
struct HSwishOp<dt_qint32, dt_qint8> : HSwishOpBase<dt_qint32, dt_qint8> { | struct HSwishOp<dt_qint32, dt_qint8> : HSwishOpBase<dt_qint32, dt_qint8> { | ||||
using HSwishOpBase::HSwishOpBase; | using HSwishOpBase::HSwishOpBase; | ||||
using HSwishOpBase::operator(); | using HSwishOpBase::operator(); | ||||
constexpr static size_t SIMD_WIDTH = 4; | |||||
constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int32_t); | |||||
void operator()(const GI_INT32_V2_t& vsrc, dt_qint8* dst) const { | void operator()(const GI_INT32_V2_t& vsrc, dt_qint8* dst) const { | ||||
GiStoreLowInt8(reinterpret_cast<int8_t*>(dst), operator()(vsrc)); | GiStoreLowInt8(reinterpret_cast<int8_t*>(dst), operator()(vsrc)); | ||||
@@ -77,7 +77,7 @@ struct MaxOpBase<dt_qint8, dt_qint8> : BinaryOpBase<dt_qint8, dt_qint8> { | |||||
template <> | template <> | ||||
struct MaxOp<dt_qint8, dt_qint8> : MaxOpBase<dt_qint8, dt_qint8> { | struct MaxOp<dt_qint8, dt_qint8> : MaxOpBase<dt_qint8, dt_qint8> { | ||||
using MaxOpBase::MaxOpBase; | using MaxOpBase::MaxOpBase; | ||||
constexpr static size_t SIMD_WIDTH = 16; | |||||
constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int8_t); | |||||
using MaxOpBase::operator(); | using MaxOpBase::operator(); | ||||
void operator()( | void operator()( | ||||
@@ -74,7 +74,7 @@ struct MinOpBase<dt_qint8, dt_qint8> : BinaryOpBase<dt_qint8, dt_qint8> { | |||||
template <> | template <> | ||||
struct MinOp<dt_qint8, dt_qint8> : MinOpBase<dt_qint8, dt_qint8> { | struct MinOp<dt_qint8, dt_qint8> : MinOpBase<dt_qint8, dt_qint8> { | ||||
using MinOpBase::MinOpBase; | using MinOpBase::MinOpBase; | ||||
constexpr static size_t SIMD_WIDTH = 16; | |||||
constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int8_t); | |||||
using MinOpBase::operator(); | using MinOpBase::operator(); | ||||
void operator()( | void operator()( | ||||
@@ -73,7 +73,7 @@ struct MulOpBase<dt_qint8, dt_qint8> : BinaryOpBase<dt_qint8, dt_qint8> { | |||||
template <> | template <> | ||||
struct MulOp<dt_qint8, dt_qint8> : MulOpBase<dt_qint8, dt_qint8> { | struct MulOp<dt_qint8, dt_qint8> : MulOpBase<dt_qint8, dt_qint8> { | ||||
using MulOpBase::MulOpBase; | using MulOpBase::MulOpBase; | ||||
constexpr static size_t SIMD_WIDTH = 16; | |||||
constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int8_t); | |||||
using MulOpBase::operator(); | using MulOpBase::operator(); | ||||
void operator()( | void operator()( | ||||
@@ -54,8 +54,6 @@ struct NoneOpBase<dt_qint32, dt_qint8> : UnaryOpBase<dt_qint32, dt_qint8> { | |||||
} | } | ||||
}; | }; | ||||
#pragma GCC diagnostic ignored "-Waddress-of-packed-member" | |||||
template <> | template <> | ||||
struct NoneOp<dt_qint32, dt_qint8> : NoneOpBase<dt_qint32, dt_qint8> { | struct NoneOp<dt_qint32, dt_qint8> : NoneOpBase<dt_qint32, dt_qint8> { | ||||
using NoneOpBase::NoneOpBase; | using NoneOpBase::NoneOpBase; | ||||
@@ -63,11 +61,11 @@ struct NoneOp<dt_qint32, dt_qint8> : NoneOpBase<dt_qint32, dt_qint8> { | |||||
constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int32_t); | constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int32_t); | ||||
void operator()(const GI_INT32_V2_t& vsrc, dt_qint8* dst) const { | void operator()(const GI_INT32_V2_t& vsrc, dt_qint8* dst) const { | ||||
GiStoreInt32(reinterpret_cast<int32_t*>(dst), vsrc.val[0]); | |||||
GiStoreInt32(reinterpret_cast<int32_t*>(dst + 16), vsrc.val[1]); | |||||
GiStoreInt32(dst, vsrc.val[0]); | |||||
GiStoreInt32(dst + 16, vsrc.val[1]); | |||||
} | } | ||||
void operator()(const GI_INT32_t& src, dt_qint8* dst) const { | void operator()(const GI_INT32_t& src, dt_qint8* dst) const { | ||||
GiStoreInt32(reinterpret_cast<int32_t*>(dst), src); | |||||
GiStoreInt32(dst, src); | |||||
} | } | ||||
}; | }; | ||||
@@ -112,36 +112,38 @@ struct ReluOp<dt_qint32, dt_qint8> : ReluOpBase<dt_qint32, dt_qint8>, FixupBase | |||||
: ReluOpBase(src_scale, dst_scale), FixupBase(scale) {} | : ReluOpBase(src_scale, dst_scale), FixupBase(scale) {} | ||||
void operator()(const int32x4x2_t& vsrc, dt_qint8* dst) const { | void operator()(const int32x4x2_t& vsrc, dt_qint8* dst) const { | ||||
vst1_s8(reinterpret_cast<int8_t*>(dst), operator()(vsrc)); | |||||
vst1_s8(reinterpret_cast<int8_t*>(dst), vget_low_s8(operator()(vsrc))); | |||||
} | } | ||||
int8x8_t operator()(const int32x4x2_t& vsrc) const { | |||||
int8x16_t operator()(const int32x4x2_t& vsrc) const { | |||||
int32x4_t vitem0 = vqrdmulhq_s32(vsrc.val[0], vmultiplier); | int32x4_t vitem0 = vqrdmulhq_s32(vsrc.val[0], vmultiplier); | ||||
int32x4_t vitem1 = vqrdmulhq_s32(vsrc.val[1], vmultiplier); | int32x4_t vitem1 = vqrdmulhq_s32(vsrc.val[1], vmultiplier); | ||||
vitem0 = vmaxq_s32(vitem0, QConverterBase::vzero()); | vitem0 = vmaxq_s32(vitem0, QConverterBase::vzero()); | ||||
vitem1 = vmaxq_s32(vitem1, QConverterBase::vzero()); | vitem1 = vmaxq_s32(vitem1, QConverterBase::vzero()); | ||||
return vqmovn_s16(vcombine_s16( | |||||
auto tmp = vqmovn_s16(vcombine_s16( | |||||
vqmovn_s32(vrshlq_s32(vitem0, vshift)), | vqmovn_s32(vrshlq_s32(vitem0, vshift)), | ||||
vqmovn_s32(vrshlq_s32(vitem1, vshift)))); | vqmovn_s32(vrshlq_s32(vitem1, vshift)))); | ||||
return vcombine_s8(tmp, tmp); | |||||
} | } | ||||
int8x8_t operator()(const float32x4_t& vsrc) const { | |||||
int8x16_t operator()(const float32x4_t& vsrc) const { | |||||
int32x4_t vitem0 = vqrdmulhq_s32(vcvtq_s32_f32(vsrc), vmultiplier); | int32x4_t vitem0 = vqrdmulhq_s32(vcvtq_s32_f32(vsrc), vmultiplier); | ||||
vitem0 = vmaxq_s32(vitem0, QConverterBase::vzero()); | vitem0 = vmaxq_s32(vitem0, QConverterBase::vzero()); | ||||
vitem0 = vrshlq_s32(vitem0, vshift); | vitem0 = vrshlq_s32(vitem0, vshift); | ||||
int16x4_t vitem = vqmovn_s32(vitem0); | int16x4_t vitem = vqmovn_s32(vitem0); | ||||
return vqmovn_s16(vcombine_s16(vitem, vitem)); | |||||
auto tmp = vqmovn_s16(vcombine_s16(vitem, vitem)); | |||||
return vcombine_s8(tmp, tmp); | |||||
} | } | ||||
void operator()(const int32x4_t& src, dt_qint8* dst) const { | void operator()(const int32x4_t& src, dt_qint8* dst) const { | ||||
auto vitem0 = vmulq_f32(vcvtq_f32_s32(src), this->vscale); | auto vitem0 = vmulq_f32(vcvtq_f32_s32(src), this->vscale); | ||||
vitem0 = vmaxq_f32(vitem0, QConverterBase::vfzero()); | vitem0 = vmaxq_f32(vitem0, QConverterBase::vfzero()); | ||||
auto result = QConverter::convert<int8x8_t, float32x4_t>(vitem0); | |||||
vst1_lane_s32(reinterpret_cast<int32_t*>(dst), (int32x2_t)result, 0); | |||||
auto result = QConverter::convert<int8x16_t, float32x4_t>(vitem0); | |||||
vst1q_lane_s32(reinterpret_cast<int32_t*>(dst), (int32x4_t)result, 0); | |||||
} | } | ||||
void operator()(const float32x4_t& src, dt_qint8* dst) const { | void operator()(const float32x4_t& src, dt_qint8* dst) const { | ||||
auto vitem0 = vmulq_f32(src, this->vscale); | auto vitem0 = vmulq_f32(src, this->vscale); | ||||
vitem0 = vmaxq_f32(vitem0, QConverterBase::vfzero()); | vitem0 = vmaxq_f32(vitem0, QConverterBase::vfzero()); | ||||
auto result = QConverter::convert<int8x8_t, float32x4_t>(vitem0); | |||||
vst1_lane_s32(reinterpret_cast<int32_t*>(dst), (int32x2_t)result, 0); | |||||
auto result = QConverter::convert<int8x16_t, float32x4_t>(vitem0); | |||||
vst1q_lane_s32(reinterpret_cast<int32_t*>(dst), (int32x4_t)result, 0); | |||||
} | } | ||||
}; | }; | ||||
@@ -73,7 +73,7 @@ struct SubOpBase<dt_qint8, dt_qint8> : BinaryOpBase<dt_qint8, dt_qint8> { | |||||
template <> | template <> | ||||
struct SubOp<dt_qint8, dt_qint8> : SubOpBase<dt_qint8, dt_qint8> { | struct SubOp<dt_qint8, dt_qint8> : SubOpBase<dt_qint8, dt_qint8> { | ||||
using SubOpBase::SubOpBase; | using SubOpBase::SubOpBase; | ||||
constexpr static size_t SIMD_WIDTH = 16; | |||||
constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int8_t); | |||||
using SubOpBase::operator(); | using SubOpBase::operator(); | ||||
void operator()( | void operator()( | ||||
@@ -13,6 +13,7 @@ | |||||
#include "math.h" | #include "math.h" | ||||
#include "stdint.h" | #include "stdint.h" | ||||
#include "string.h" | |||||
#if defined(_WIN32) | #if defined(_WIN32) | ||||
#include <intrin.h> | #include <intrin.h> | ||||
@@ -132,6 +133,18 @@ typedef uint32_t GI_UINT32_t __attribute__((vector_size(16))); | |||||
#define Max(a, b) (a) > (b) ? (a) : (b) | #define Max(a, b) (a) > (b) ? (a) : (b) | ||||
#define Min(a, b) (a) < (b) ? (a) : (b) | #define Min(a, b) (a) < (b) ? (a) : (b) | ||||
#if defined(GI_NEON_INTRINSICS) | |||||
#if defined(__ARM_FEATURE_FMA) && defined(GI_NEON64_INTRINSICS) | |||||
#define v_fma_ps_f32(c, b, a) vfmaq_f32((c), (b), (a)) | |||||
#define v_fma_n_f32(c, b, a) vfmaq_n_f32((c), (b), (a)) | |||||
#define v_fma_lane_f32(c, b, a, lane) vfmaq_lane_f32((c), (b), (a), (lane)) | |||||
#else | |||||
#define v_fma_ps_f32(c, b, a) vmlaq_f32((c), (b), (a)) | |||||
#define v_fma_n_f32(c, b, a) vmlaq_n_f32((c), (b), (a)) | |||||
#define v_fma_lane_f32(c, b, a, lane) vmlaq_lane_f32((c), (b), (a), (lane)) | |||||
#endif | |||||
#endif | |||||
typedef struct { | typedef struct { | ||||
GI_INT32_t val[2]; | GI_INT32_t val[2]; | ||||
} GI_INT32_V2_t; | } GI_INT32_V2_t; | ||||
@@ -20,7 +20,9 @@ GI_INT32_t GiReinterpretAsInt32(GI_FLOAT32_t In) { | |||||
#elif defined(GI_SSE2_INTRINSICS) | #elif defined(GI_SSE2_INTRINSICS) | ||||
return _mm_castps_si128(In); | return _mm_castps_si128(In); | ||||
#else | #else | ||||
return *(GI_INT32_t*)(&In); | |||||
GI_INT32_t ret; | |||||
memcpy(&ret, &In, GI_SIMD_LEN_BYTE); | |||||
return ret; | |||||
#endif | #endif | ||||
} | } | ||||
@@ -31,7 +33,9 @@ GI_UINT32_t GiReinterpretAsUint32(GI_FLOAT32_t In) { | |||||
#elif defined(GI_SSE2_INTRINSICS) | #elif defined(GI_SSE2_INTRINSICS) | ||||
return _mm_castps_si128(In); | return _mm_castps_si128(In); | ||||
#else | #else | ||||
return *(GI_UINT32_t*)(&In); | |||||
GI_UINT32_t ret; | |||||
memcpy(&ret, &In, GI_SIMD_LEN_BYTE); | |||||
return ret; | |||||
#endif | #endif | ||||
} | } | ||||
@@ -42,7 +46,9 @@ GI_FLOAT32_t GiReintInt32ToFloat32(GI_INT32_t Vector) { | |||||
#elif defined(GI_SSE2_INTRINSICS) | #elif defined(GI_SSE2_INTRINSICS) | ||||
return _mm_castsi128_ps(Vector); | return _mm_castsi128_ps(Vector); | ||||
#else | #else | ||||
return *(GI_FLOAT32_t*)(&Vector); | |||||
GI_FLOAT32_t ret; | |||||
memcpy(&ret, &Vector, GI_SIMD_LEN_BYTE); | |||||
return ret; | |||||
#endif | #endif | ||||
} | } | ||||
@@ -53,7 +59,9 @@ GI_FLOAT32_t GiReintUint32ToFloat32(GI_UINT32_t Vector) { | |||||
#elif defined(GI_SSE2_INTRINSICS) | #elif defined(GI_SSE2_INTRINSICS) | ||||
return _mm_castsi128_ps(Vector); | return _mm_castsi128_ps(Vector); | ||||
#else | #else | ||||
return *(GI_FLOAT32_t*)(&Vector); | |||||
GI_FLOAT32_t ret; | |||||
memcpy(&ret, &Vector, GI_SIMD_LEN_BYTE); | |||||
return ret; | |||||
#endif | #endif | ||||
} | } | ||||
@@ -69,7 +77,7 @@ GI_INT32_t GiRoundAsInt32(GI_FLOAT32_t Vector) { | |||||
float32x4_t vinc0 = vbslq_f32(vcgeq_f32(Vector, vzero), vfhalf, vfneg_half); | float32x4_t vinc0 = vbslq_f32(vcgeq_f32(Vector, vzero), vfhalf, vfneg_half); | ||||
return vcvtq_s32_f32(vaddq_f32(Vector, vinc0)); | return vcvtq_s32_f32(vaddq_f32(Vector, vinc0)); | ||||
#endif | #endif | ||||
#elif defined(GI_SSE2_INTRINSICS) | |||||
#elif defined(GI_SSE42_INTRINSICS) | |||||
__m128 vfzero = _mm_set1_ps(0.f); | __m128 vfzero = _mm_set1_ps(0.f); | ||||
__m128 vfhalf = _mm_set1_ps(0.5f); | __m128 vfhalf = _mm_set1_ps(0.5f); | ||||
__m128 vfneg_half = _mm_set1_ps(-0.5f); | __m128 vfneg_half = _mm_set1_ps(-0.5f); | ||||
@@ -322,11 +330,7 @@ GI_FORCEINLINE | |||||
GI_FLOAT32_t GiMultiplyAddFloat32( | GI_FLOAT32_t GiMultiplyAddFloat32( | ||||
GI_FLOAT32_t VectorSum, GI_FLOAT32_t Vector1, GI_FLOAT32_t Vector2) { | GI_FLOAT32_t VectorSum, GI_FLOAT32_t Vector1, GI_FLOAT32_t Vector2) { | ||||
#if defined(GI_NEON_INTRINSICS) | #if defined(GI_NEON_INTRINSICS) | ||||
#if defined(__ARM_FEATURE_FMA) | |||||
return vfmaq_f32(VectorSum, Vector1, Vector2); | |||||
#else | |||||
return vmlaq_f32(VectorSum, Vector1, Vector2); | |||||
#endif | |||||
return v_fma_ps_f32(VectorSum, Vector1, Vector2); | |||||
#elif defined(GI_FMA3_INTRINSICS) | #elif defined(GI_FMA3_INTRINSICS) | ||||
return _mm_fmadd_ps(Vector1, Vector2, VectorSum); | return _mm_fmadd_ps(Vector1, Vector2, VectorSum); | ||||
#elif defined(GI_SSE2_INTRINSICS) | #elif defined(GI_SSE2_INTRINSICS) | ||||
@@ -352,11 +356,7 @@ GI_FORCEINLINE | |||||
GI_FLOAT32_t GiMultiplyAddScalarFloat32( | GI_FLOAT32_t GiMultiplyAddScalarFloat32( | ||||
GI_FLOAT32_t VectorSum, GI_FLOAT32_t Vector, float Scalar) { | GI_FLOAT32_t VectorSum, GI_FLOAT32_t Vector, float Scalar) { | ||||
#if defined(GI_NEON_INTRINSICS) | #if defined(GI_NEON_INTRINSICS) | ||||
#if defined(__ARM_FEATURE_FMA) | |||||
return vfmaq_n_f32(VectorSum, Vector, Scalar); | |||||
#else | |||||
return vfmla_n_f32(VectorSum, Vector, Scalar); | |||||
#endif | |||||
return v_fma_n_f32(VectorSum, Vector, Scalar); | |||||
#elif defined(GI_SSE2_INTRINSICS) | #elif defined(GI_SSE2_INTRINSICS) | ||||
return GiMultiplyAddFloat32(VectorSum, GiBroadcastFloat32(Scalar), Vector); | return GiMultiplyAddFloat32(VectorSum, GiBroadcastFloat32(Scalar), Vector); | ||||
#else | #else | ||||
@@ -365,27 +365,10 @@ GI_FLOAT32_t GiMultiplyAddScalarFloat32( | |||||
} | } | ||||
#if defined(GI_NEON_INTRINSICS) | #if defined(GI_NEON_INTRINSICS) | ||||
#if defined(__ARM_FEATURE_FMA) | |||||
#define GIMULTIPLYADDLANFLOAT32(i) \ | |||||
GI_FORCEINLINE GI_FLOAT32_t GiMultiplyAddLan##i##Float32( \ | |||||
GI_FLOAT32_t VectorSum, GI_FLOAT32_t Vector1, GI_FLOAT32_t Vector2) { \ | |||||
return vfmaq_lane_f32(VectorSum, Vector1, vget_low_f32(Vector2), i); \ | |||||
} | |||||
GIMULTIPLYADDLANFLOAT32(0) | |||||
GIMULTIPLYADDLANFLOAT32(1) | |||||
#undef GIMULTIPLYADDLANFLOAT32 | |||||
#define GIMULTIPLYADDLANFLOAT32(i) \ | |||||
GI_FORCEINLINE GI_FLOAT32_t GiMultiplyAddLan##i##Float32( \ | |||||
GI_FLOAT32_t VectorSum, GI_FLOAT32_t Vector1, GI_FLOAT32_t Vector2) { \ | |||||
return vfmaq_lane_f32(VectorSum, Vector1, vget_high_f32(Vector2), i - 2); \ | |||||
} | |||||
GIMULTIPLYADDLANFLOAT32(2) | |||||
GIMULTIPLYADDLANFLOAT32(3) | |||||
#else | |||||
#define GIMULTIPLYADDLANFLOAT32(i) \ | #define GIMULTIPLYADDLANFLOAT32(i) \ | ||||
GI_FORCEINLINE GI_FLOAT32_t GiMultiplyAddLan##i##Float32( \ | GI_FORCEINLINE GI_FLOAT32_t GiMultiplyAddLan##i##Float32( \ | ||||
GI_FLOAT32_t VectorSum, GI_FLOAT32_t Vector1, GI_FLOAT32_t Vector2) { \ | GI_FLOAT32_t VectorSum, GI_FLOAT32_t Vector1, GI_FLOAT32_t Vector2) { \ | ||||
return vmlaq_lane_f32(VectorSum, Vector1, vget_low_f32(Vector2), i); \ | |||||
return v_fma_lane_f32(VectorSum, Vector1, vget_low_f32(Vector2), i); \ | |||||
} | } | ||||
GIMULTIPLYADDLANFLOAT32(0) | GIMULTIPLYADDLANFLOAT32(0) | ||||
GIMULTIPLYADDLANFLOAT32(1) | GIMULTIPLYADDLANFLOAT32(1) | ||||
@@ -393,11 +376,10 @@ GIMULTIPLYADDLANFLOAT32(1) | |||||
#define GIMULTIPLYADDLANFLOAT32(i) \ | #define GIMULTIPLYADDLANFLOAT32(i) \ | ||||
GI_FORCEINLINE GI_FLOAT32_t GiMultiplyAddLan##i##Float32( \ | GI_FORCEINLINE GI_FLOAT32_t GiMultiplyAddLan##i##Float32( \ | ||||
GI_FLOAT32_t VectorSum, GI_FLOAT32_t Vector1, GI_FLOAT32_t Vector2) { \ | GI_FLOAT32_t VectorSum, GI_FLOAT32_t Vector1, GI_FLOAT32_t Vector2) { \ | ||||
return vmlaq_lane_f32(VectorSum, Vector1, vget_high_f32(Vector2), i - 2); \ | |||||
return v_fma_lane_f32(VectorSum, Vector1, vget_high_f32(Vector2), i - 2); \ | |||||
} | } | ||||
GIMULTIPLYADDLANFLOAT32(2) | GIMULTIPLYADDLANFLOAT32(2) | ||||
GIMULTIPLYADDLANFLOAT32(3) | GIMULTIPLYADDLANFLOAT32(3) | ||||
#endif | |||||
#undef GIMULTIPLYADDLANFLOAT32 | #undef GIMULTIPLYADDLANFLOAT32 | ||||
#elif defined(GI_SSE2_INTRINSICS) | #elif defined(GI_SSE2_INTRINSICS) | ||||
@@ -59,66 +59,69 @@ GI_INT8_t GiBroadcastInt8(int8_t Value) { | |||||
} | } | ||||
GI_FORCEINLINE | GI_FORCEINLINE | ||||
GI_INT32_t GiLoadInt32(const int32_t* Buffer) { | |||||
GI_INT32_t GiLoadInt32(const void* Buffer) { | |||||
#if defined(GI_NEON_INTRINSICS) | #if defined(GI_NEON_INTRINSICS) | ||||
return vld1q_s32(Buffer); | |||||
return vld1q_s32((int32_t*)Buffer); | |||||
#elif defined(GI_SSE2_INTRINSICS) | #elif defined(GI_SSE2_INTRINSICS) | ||||
return _mm_loadu_si128((const __m128i*)Buffer); | return _mm_loadu_si128((const __m128i*)Buffer); | ||||
#else | #else | ||||
GI_INT32_t ret; | GI_INT32_t ret; | ||||
const int32_t* ptr = (int32_t*)Buffer; | |||||
for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(int32_t); i++) { | for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(int32_t); i++) { | ||||
ret[i] = Buffer[i]; | |||||
ret[i] = ptr[i]; | |||||
} | } | ||||
return ret; | return ret; | ||||
#endif | #endif | ||||
} | } | ||||
GI_FORCEINLINE | GI_FORCEINLINE | ||||
GI_INT8_t GiLoadInt8(const int8_t* Buffer) { | |||||
GI_INT8_t GiLoadInt8(const void* Buffer) { | |||||
#if defined(GI_NEON_INTRINSICS) | #if defined(GI_NEON_INTRINSICS) | ||||
return vld1q_s8(Buffer); | |||||
return vld1q_s8((int8_t*)Buffer); | |||||
#elif defined(GI_SSE2_INTRINSICS) | #elif defined(GI_SSE2_INTRINSICS) | ||||
return _mm_loadu_si128((const __m128i*)Buffer); | return _mm_loadu_si128((const __m128i*)Buffer); | ||||
#else | #else | ||||
GI_INT8_t ret; | GI_INT8_t ret; | ||||
const int8_t* ptr = (int8_t*)Buffer; | |||||
for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(int8_t); i++) { | for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(int8_t); i++) { | ||||
ret[i] = Buffer[i]; | |||||
ret[i] = ptr[i]; | |||||
} | } | ||||
return ret; | return ret; | ||||
#endif | #endif | ||||
} | } | ||||
GI_FORCEINLINE | GI_FORCEINLINE | ||||
void GiStoreInt32(int32_t* Buffer, GI_INT32_t Vector) { | |||||
void GiStoreInt32(void* Buffer, GI_INT32_t Vector) { | |||||
#if defined(GI_NEON_INTRINSICS) | #if defined(GI_NEON_INTRINSICS) | ||||
vst1q_s32(Buffer, Vector); | |||||
vst1q_s32((int32_t*)Buffer, Vector); | |||||
#elif defined(GI_SSE2_INTRINSICS) | #elif defined(GI_SSE2_INTRINSICS) | ||||
_mm_storeu_si128((__m128i*)Buffer, Vector); | _mm_storeu_si128((__m128i*)Buffer, Vector); | ||||
#else | #else | ||||
int32_t* ptr = (int32_t*)Buffer; | |||||
for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(int32_t); i++) { | for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(int32_t); i++) { | ||||
Buffer[i] = Vector[i]; | |||||
ptr[i] = Vector[i]; | |||||
} | } | ||||
#endif | #endif | ||||
} | } | ||||
#if defined(GI_NEON_INTRINSICS) | #if defined(GI_NEON_INTRINSICS) | ||||
#define GISTORELANEINT32(i) \ | |||||
GI_FORCEINLINE void GiStoreLane##i##Int32(int32_t* Buffer, GI_INT32_t Vector) { \ | |||||
vst1q_lane_s32(Buffer, Vector, i); \ | |||||
#define GISTORELANEINT32(i) \ | |||||
GI_FORCEINLINE void GiStoreLane##i##Int32(void* Buffer, GI_INT32_t Vector) { \ | |||||
vst1q_lane_s32((int32_t*)Buffer, Vector, i); \ | |||||
} | } | ||||
#elif defined(GI_SSE2_INTRINSICS) | #elif defined(GI_SSE2_INTRINSICS) | ||||
#define GISTORELANEINT32(i) \ | #define GISTORELANEINT32(i) \ | ||||
GI_FORCEINLINE void GiStoreLane##i##Int32(int32_t* Buffer, GI_INT32_t Vector) { \ | |||||
GI_FORCEINLINE void GiStoreLane##i##Int32(void* Buffer, GI_INT32_t Vector) { \ | |||||
GI_FLOAT32_t tmp = _mm_castsi128_ps(Vector); \ | GI_FLOAT32_t tmp = _mm_castsi128_ps(Vector); \ | ||||
_mm_store_ss( \ | _mm_store_ss( \ | ||||
(float*)Buffer, _mm_shuffle_ps(tmp, tmp, _MM_SHUFFLE(i, i, i, i))); \ | (float*)Buffer, _mm_shuffle_ps(tmp, tmp, _MM_SHUFFLE(i, i, i, i))); \ | ||||
} | } | ||||
#else | #else | ||||
#define GISTORELANEINT32(i) \ | |||||
GI_FORCEINLINE void GiStoreLane##i##Int32(int32_t* Buffer, GI_INT32_t Vector) { \ | |||||
*Buffer = Vector[i]; \ | |||||
#define GISTORELANEINT32(i) \ | |||||
GI_FORCEINLINE void GiStoreLane##i##Int32(void* Buffer, GI_INT32_t Vector) { \ | |||||
*((int32_t*)Buffer) = Vector[i]; \ | |||||
} | } | ||||
#endif | #endif | ||||
@@ -141,53 +144,57 @@ GI_INT8_t GiReinterInt32ToInt8(GI_INT32_t Vector) { | |||||
} | } | ||||
GI_FORCEINLINE | GI_FORCEINLINE | ||||
void GiStoreInt16(int16_t* Buffer, GI_INT16_t Vector) { | |||||
void GiStoreInt16(void* Buffer, GI_INT16_t Vector) { | |||||
#if defined(GI_NEON_INTRINSICS) | #if defined(GI_NEON_INTRINSICS) | ||||
vst1q_s16(Buffer, Vector); | |||||
vst1q_s16((int16_t*)Buffer, Vector); | |||||
#elif defined(GI_SSE2_INTRINSICS) | #elif defined(GI_SSE2_INTRINSICS) | ||||
_mm_storeu_si128((__m128i*)Buffer, Vector); | _mm_storeu_si128((__m128i*)Buffer, Vector); | ||||
#else | #else | ||||
int16_t* ptr = (int16_t*)Buffer; | |||||
for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(int16_t); i++) { | for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(int16_t); i++) { | ||||
Buffer[i] = Vector[i]; | |||||
ptr[i] = Vector[i]; | |||||
} | } | ||||
#endif | #endif | ||||
} | } | ||||
GI_FORCEINLINE | GI_FORCEINLINE | ||||
void GiStoreInt8(int8_t* Buffer, GI_INT8_t Vector) { | |||||
void GiStoreInt8(void* Buffer, GI_INT8_t Vector) { | |||||
#if defined(GI_NEON_INTRINSICS) | #if defined(GI_NEON_INTRINSICS) | ||||
vst1q_s8(Buffer, Vector); | |||||
vst1q_s8((int8_t*)Buffer, Vector); | |||||
#elif defined(GI_SSE2_INTRINSICS) | #elif defined(GI_SSE2_INTRINSICS) | ||||
_mm_storeu_si128((__m128i*)Buffer, Vector); | _mm_storeu_si128((__m128i*)Buffer, Vector); | ||||
#else | #else | ||||
int8_t* ptr = (int8_t*)Buffer; | |||||
for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(int8_t); i++) { | for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(int8_t); i++) { | ||||
Buffer[i] = Vector[i]; | |||||
ptr[i] = Vector[i]; | |||||
} | } | ||||
#endif | #endif | ||||
} | } | ||||
GI_FORCEINLINE | GI_FORCEINLINE | ||||
void GiStoreLowInt8(int8_t* Buffer, GI_INT8_t Vector) { | |||||
void GiStoreLowInt8(void* Buffer, GI_INT8_t Vector) { | |||||
#if defined(GI_NEON_INTRINSICS) | #if defined(GI_NEON_INTRINSICS) | ||||
vst1_s8(Buffer, vget_low_s8(Vector)); | |||||
vst1_s8((int8_t*)Buffer, vget_low_s8(Vector)); | |||||
#elif defined(GI_SSE2_INTRINSICS) | #elif defined(GI_SSE2_INTRINSICS) | ||||
_mm_storel_epi64((__m128i*)Buffer, Vector); | _mm_storel_epi64((__m128i*)Buffer, Vector); | ||||
#else | #else | ||||
int8_t* ptr = (int8_t*)Buffer; | |||||
for (size_t i = 0; i < GI_SIMD_LEN_BYTE / 2 / sizeof(int8_t); i++) { | for (size_t i = 0; i < GI_SIMD_LEN_BYTE / 2 / sizeof(int8_t); i++) { | ||||
Buffer[i] = Vector[i]; | |||||
ptr[i] = Vector[i]; | |||||
} | } | ||||
#endif | #endif | ||||
} | } | ||||
GI_FORCEINLINE | GI_FORCEINLINE | ||||
void GiStoreHihgInt8(int8_t* Buffer, GI_INT8_t Vector) { | |||||
void GiStoreHihgInt8(void* Buffer, GI_INT8_t Vector) { | |||||
#if defined(GI_NEON_INTRINSICS) | #if defined(GI_NEON_INTRINSICS) | ||||
vst1_s8(Buffer, vget_high_s8(Vector)); | |||||
vst1_s8((int8_t*)Buffer, vget_high_s8(Vector)); | |||||
#elif defined(GI_SSE2_INTRINSICS) | #elif defined(GI_SSE2_INTRINSICS) | ||||
_mm_storel_epi64((__m128i*)Buffer, _mm_unpackhi_epi64(Vector, Vector)); | _mm_storel_epi64((__m128i*)Buffer, _mm_unpackhi_epi64(Vector, Vector)); | ||||
#else | #else | ||||
int8_t* ptr = (int8_t*)Buffer; | |||||
for (size_t i = 0; i < GI_SIMD_LEN_BYTE / 2 / sizeof(int8_t); i++) { | for (size_t i = 0; i < GI_SIMD_LEN_BYTE / 2 / sizeof(int8_t); i++) { | ||||
Buffer[i] = Vector[GI_SIMD_LEN_BYTE / 2 + i]; | |||||
ptr[i] = Vector[GI_SIMD_LEN_BYTE / 2 + i]; | |||||
} | } | ||||
#endif | #endif | ||||
} | } | ||||
@@ -39,7 +39,6 @@ TEST_F(FALLBACK, ELEMWISE_RECORD) { | |||||
checker.execs({{10, 10, 32}, {10, 10, 32}, {}}); | checker.execs({{10, 10, 32}, {10, 10, 32}, {}}); | ||||
} | } | ||||
TEST_F(FALLBACK, ELEMWISE_FORWARD_TERNARY) { | TEST_F(FALLBACK, ELEMWISE_FORWARD_TERNARY) { | ||||
using Mode = ElemwiseForward::Param::Mode; | using Mode = ElemwiseForward::Param::Mode; | ||||
Checker<ElemwiseForward> checker(handle()); | Checker<ElemwiseForward> checker(handle()); | ||||