Browse Source

fix(fallback): delete the repeat opcaller in fallback and arm_common

GitOrigin-RevId: 87046b8197
release-1.10
Megvii Engine Team 3 years ago
parent
commit
ff6a3bb819
80 changed files with 1810 additions and 3211 deletions
  1. +1
    -1
      dnn/src/aarch64/conv_bias/int8/algos.cpp
  2. +1
    -1
      dnn/src/aarch64/conv_bias/quint8/algos.cpp
  3. +1
    -1
      dnn/src/arm_common/conv_bias/f16/channel_wise_3x3_s1p1_nchw88_kern.cpp
  4. +1
    -1
      dnn/src/arm_common/conv_bias/f16/channel_wise_nchw88_algo.cpp
  5. +1
    -1
      dnn/src/arm_common/conv_bias/f16/channel_wise_nchw88_kern.cpp
  6. +1
    -1
      dnn/src/arm_common/conv_bias/f16/direct_nchw88_algo.cpp
  7. +1
    -1
      dnn/src/arm_common/conv_bias/f16/direct_nchw88_kern.cpp
  8. +1
    -1
      dnn/src/arm_common/conv_bias/fp32/channel_wise_3x3_s1p1_nchw44_kern.cpp
  9. +1
    -1
      dnn/src/arm_common/conv_bias/fp32/channel_wise_5x5_s1p2_nchw44_kern.cpp
  10. +1
    -1
      dnn/src/arm_common/conv_bias/fp32/channel_wise_nchw44_algo.cpp
  11. +1
    -1
      dnn/src/arm_common/conv_bias/fp32/channel_wise_nchw44_kern.cpp
  12. +1
    -1
      dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h
  13. +1
    -1
      dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h
  14. +1
    -1
      dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h
  15. +1
    -1
      dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw44_algo.cpp
  16. +1
    -1
      dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_algo.cpp
  17. +1
    -1
      dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_kern.h
  18. +1
    -1
      dnn/src/arm_common/conv_bias/int8/algos.cpp
  19. +1
    -1
      dnn/src/arm_common/conv_bias/int8/channel_wise_kernel.cpp
  20. +1
    -1
      dnn/src/arm_common/conv_bias/int8/channel_wise_nchw44.cpp
  21. +1
    -1
      dnn/src/arm_common/conv_bias/int8/direct.cpp
  22. +1
    -1
      dnn/src/arm_common/conv_bias/int8/direct_dotprod.cpp
  23. +1
    -1
      dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44_algo.cpp
  24. +1
    -1
      dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_common.h
  25. +1
    -1
      dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw44_s1.cpp
  26. +1
    -1
      dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw44_s2.cpp
  27. +1
    -1
      dnn/src/arm_common/conv_bias/int8/direct_nchw44_algo.cpp
  28. +1
    -1
      dnn/src/arm_common/conv_bias/int8/direct_nchw44_kern.h
  29. +1
    -1
      dnn/src/arm_common/conv_bias/int8/direct_nchw_nchw44_algo.cpp
  30. +1
    -1
      dnn/src/arm_common/conv_bias/int8/direct_nchw_nchw44_kern.h
  31. +1
    -1
      dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_algo.cpp
  32. +1
    -1
      dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_kern.h
  33. +1
    -1
      dnn/src/arm_common/conv_bias/int8/stride1.cpp
  34. +1
    -1
      dnn/src/arm_common/conv_bias/int8/stride1_dotprod.cpp
  35. +1
    -1
      dnn/src/arm_common/conv_bias/int8/stride2.cpp
  36. +1
    -1
      dnn/src/arm_common/conv_bias/int8/stride2_dotprod.cpp
  37. +1
    -1
      dnn/src/arm_common/conv_bias/int8x8x16/channel_wise_kernel_int8x8x16_nchw44.cpp
  38. +1
    -1
      dnn/src/arm_common/conv_bias/int8x8x16/direct_nchw_nchw44_algo.cpp
  39. +1
    -1
      dnn/src/arm_common/conv_bias/int8x8x16/direct_nchw_nchw44_kern.h
  40. +1
    -1
      dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl.h
  41. +1
    -1
      dnn/src/arm_common/conv_bias/matmul_postprocess.h
  42. +61
    -59
      dnn/src/arm_common/conv_bias/postprocess_helper.h
  43. +1
    -1
      dnn/src/arm_common/conv_bias/quint8/algos.cpp
  44. +1
    -1
      dnn/src/arm_common/conv_bias/quint8/direct.cpp
  45. +1
    -1
      dnn/src/arm_common/conv_bias/quint8/direct_dotprod.cpp
  46. +1
    -1
      dnn/src/arm_common/conv_bias/quint8/stride1.cpp
  47. +1
    -1
      dnn/src/arm_common/conv_bias/quint8/stride1_dotprod.cpp
  48. +1
    -1
      dnn/src/arm_common/conv_bias/quint8/stride2.cpp
  49. +1
    -1
      dnn/src/arm_common/conv_bias/quint8/stride2_dotprod.cpp
  50. +4
    -3
      dnn/src/arm_common/elemwise/binary/algo.cpp
  51. +1
    -1
      dnn/src/arm_common/elemwise/opr_impl.cpp
  52. +1
    -1
      dnn/src/arm_common/elemwise/opr_impl.h
  53. +2
    -1
      dnn/src/arm_common/elemwise/ternary/algo.cpp
  54. +2
    -1
      dnn/src/arm_common/elemwise/unary/algo.cpp
  55. +151
    -0
      dnn/src/arm_common/elemwise_helper/elemwise_op.h
  56. +0
    -36
      dnn/src/arm_common/elemwise_helper/kimpl/pow.h
  57. +0
    -1
      dnn/src/arm_common/elemwise_helper/op_binary.h
  58. +3
    -1
      dnn/src/arm_common/elemwise_multi_type/opr_impl.cpp
  59. +0
    -1537
      dnn/src/arm_common/elemwise_op.h
  60. +2
    -1
      dnn/src/fallback/elemwise/gi_impl/binary/algo.cpp
  61. +2
    -1
      dnn/src/fallback/elemwise/gi_impl/ternary/algo.cpp
  62. +2
    -1
      dnn/src/fallback/elemwise/gi_impl/unary/algo.cpp
  63. +1
    -0
      dnn/src/fallback/elemwise/opr_impl.cpp
  64. +2
    -2
      dnn/src/fallback/elemwise/opr_impl.h
  65. +72
    -0
      dnn/src/fallback/elemwise_helper/elemwise_op.h
  66. +1
    -1
      dnn/src/fallback/elemwise_helper/kimpl/abs.h
  67. +1
    -1
      dnn/src/fallback/elemwise_helper/kimpl/fuse_add_h_swish.h
  68. +1
    -1
      dnn/src/fallback/elemwise_helper/kimpl/hswish.h
  69. +1
    -1
      dnn/src/fallback/elemwise_helper/kimpl/max.h
  70. +1
    -1
      dnn/src/fallback/elemwise_helper/kimpl/min.h
  71. +1
    -1
      dnn/src/fallback/elemwise_helper/kimpl/mul.h
  72. +3
    -5
      dnn/src/fallback/elemwise_helper/kimpl/none.h
  73. +11
    -9
      dnn/src/fallback/elemwise_helper/kimpl/relu.h
  74. +1
    -1
      dnn/src/fallback/elemwise_helper/kimpl/sub.h
  75. +1370
    -0
      dnn/src/fallback/elemwise_helper/op_common.h
  76. +0
    -1432
      dnn/src/fallback/elemwise_op.h
  77. +13
    -0
      dnn/src/fallback/general_intrinsic/gi_common.h
  78. +17
    -35
      dnn/src/fallback/general_intrinsic/gi_float.h
  79. +35
    -28
      dnn/src/fallback/general_intrinsic/gi_int.h
  80. +0
    -1
      dnn/test/fallback/elemwise.cpp

+ 1
- 1
dnn/src/aarch64/conv_bias/int8/algos.cpp View File

@@ -12,7 +12,7 @@
#include "src/aarch64/conv_bias/int8/algos.h" #include "src/aarch64/conv_bias/int8/algos.h"
#include "src/aarch64/conv_bias/int8/strategy.h" #include "src/aarch64/conv_bias/int8/strategy.h"
#include "src/arm_common/convolution/img2col_helper.h" #include "src/arm_common/convolution/img2col_helper.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/common/opr_delegate.h" #include "src/common/opr_delegate.h"
#include "src/fallback/conv_bias/common.h" #include "src/fallback/conv_bias/common.h"
#include "src/fallback/matrix_mul/gemm_impl.h" #include "src/fallback/matrix_mul/gemm_impl.h"


+ 1
- 1
dnn/src/aarch64/conv_bias/quint8/algos.cpp View File

@@ -14,7 +14,7 @@
#include "src/aarch64/matrix_mul/quint8_dot/gemv.h" #include "src/aarch64/matrix_mul/quint8_dot/gemv.h"
#include "src/aarch64/matrix_mul/quint8_dot/strategy.h" #include "src/aarch64/matrix_mul/quint8_dot/strategy.h"
#include "src/arm_common/convolution/img2col_helper.h" #include "src/arm_common/convolution/img2col_helper.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/common/opr_delegate.h" #include "src/common/opr_delegate.h"
#include "src/fallback/conv_bias/common.h" #include "src/fallback/conv_bias/common.h"
#include "src/fallback/matrix_mul/gemm_impl.h" #include "src/fallback/matrix_mul/gemm_impl.h"


+ 1
- 1
dnn/src/arm_common/conv_bias/f16/channel_wise_3x3_s1p1_nchw88_kern.cpp View File

@@ -11,7 +11,7 @@
*/ */


#include "src/arm_common/conv_bias/f16/channel_wise_3x3_s1p1_nchw88_kern.h" #include "src/arm_common/conv_bias/f16/channel_wise_3x3_s1p1_nchw88_kern.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/arm_common/simd_macro/marm_neon.h" #include "src/arm_common/simd_macro/marm_neon.h"
#include "src/arm_common/utils.h" #include "src/arm_common/utils.h"
#include "src/common/utils.h" #include "src/common/utils.h"


+ 1
- 1
dnn/src/arm_common/conv_bias/f16/channel_wise_nchw88_algo.cpp View File

@@ -12,7 +12,7 @@


#include "src/arm_common/conv_bias/f16/algos.h" #include "src/arm_common/conv_bias/f16/algos.h"
#include "src/arm_common/conv_bias/f16/channel_wise_nchw88_kern.h" #include "src/arm_common/conv_bias/f16/channel_wise_nchw88_kern.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"


#include "midout.h" #include "midout.h"




+ 1
- 1
dnn/src/arm_common/conv_bias/f16/channel_wise_nchw88_kern.cpp View File

@@ -12,7 +12,7 @@


#include "src/arm_common/conv_bias/f16/channel_wise_nchw88_kern.h" #include "src/arm_common/conv_bias/f16/channel_wise_nchw88_kern.h"
#include "src/arm_common/conv_bias/f16/channel_wise_3x3_s1p1_nchw88_kern.h" #include "src/arm_common/conv_bias/f16/channel_wise_3x3_s1p1_nchw88_kern.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/arm_common/simd_macro/marm_neon.h" #include "src/arm_common/simd_macro/marm_neon.h"
#include "src/arm_common/utils.h" #include "src/arm_common/utils.h"
#include "src/common/unroll_macro.h" #include "src/common/unroll_macro.h"


+ 1
- 1
dnn/src/arm_common/conv_bias/f16/direct_nchw88_algo.cpp View File

@@ -15,7 +15,7 @@
#include "src/arm_common/conv_bias/f16/algos.h" #include "src/arm_common/conv_bias/f16/algos.h"
#include "src/arm_common/conv_bias/f16/direct_nchw88_kern.h" #include "src/arm_common/conv_bias/f16/direct_nchw88_kern.h"


#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"


#include "midout.h" #include "midout.h"




+ 1
- 1
dnn/src/arm_common/conv_bias/f16/direct_nchw88_kern.cpp View File

@@ -13,7 +13,7 @@


#include "src/arm_common/conv_bias/f16/direct_nchw88_kern.h" #include "src/arm_common/conv_bias/f16/direct_nchw88_kern.h"
#include "src/arm_common/conv_bias/opr_impl.h" #include "src/arm_common/conv_bias/opr_impl.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/arm_common/simd_macro/marm_neon.h" #include "src/arm_common/simd_macro/marm_neon.h"
#include "src/fallback/conv_bias/common.h" #include "src/fallback/conv_bias/common.h"




+ 1
- 1
dnn/src/arm_common/conv_bias/fp32/channel_wise_3x3_s1p1_nchw44_kern.cpp View File

@@ -11,7 +11,7 @@
*/ */


#include "src/arm_common/conv_bias/fp32/channel_wise_3x3_s1p1_nchw44_kern.h" #include "src/arm_common/conv_bias/fp32/channel_wise_3x3_s1p1_nchw44_kern.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/arm_common/simd_macro/marm_neon.h" #include "src/arm_common/simd_macro/marm_neon.h"
#include "src/arm_common/utils.h" #include "src/arm_common/utils.h"
#include "src/common/utils.h" #include "src/common/utils.h"


+ 1
- 1
dnn/src/arm_common/conv_bias/fp32/channel_wise_5x5_s1p2_nchw44_kern.cpp View File

@@ -11,7 +11,7 @@
*/ */


#include "src/arm_common/conv_bias/fp32/channel_wise_5x5_s1p2_nchw44_kern.h" #include "src/arm_common/conv_bias/fp32/channel_wise_5x5_s1p2_nchw44_kern.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/arm_common/simd_macro/marm_neon.h" #include "src/arm_common/simd_macro/marm_neon.h"
#include "src/arm_common/utils.h" #include "src/arm_common/utils.h"
#include "src/common/utils.h" #include "src/common/utils.h"


+ 1
- 1
dnn/src/arm_common/conv_bias/fp32/channel_wise_nchw44_algo.cpp View File

@@ -12,7 +12,7 @@


#include "src/arm_common/conv_bias/fp32/algos.h" #include "src/arm_common/conv_bias/fp32/algos.h"
#include "src/arm_common/conv_bias/fp32/channel_wise_nchw44_kern.h" #include "src/arm_common/conv_bias/fp32/channel_wise_nchw44_kern.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"


#include "midout.h" #include "midout.h"




+ 1
- 1
dnn/src/arm_common/conv_bias/fp32/channel_wise_nchw44_kern.cpp View File

@@ -13,7 +13,7 @@
#include "src/arm_common/conv_bias/fp32/channel_wise_nchw44_kern.h" #include "src/arm_common/conv_bias/fp32/channel_wise_nchw44_kern.h"
#include "src/arm_common/conv_bias/fp32/channel_wise_3x3_s1p1_nchw44_kern.h" #include "src/arm_common/conv_bias/fp32/channel_wise_3x3_s1p1_nchw44_kern.h"
#include "src/arm_common/conv_bias/fp32/channel_wise_5x5_s1p2_nchw44_kern.h" #include "src/arm_common/conv_bias/fp32/channel_wise_5x5_s1p2_nchw44_kern.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/arm_common/simd_macro/marm_neon.h" #include "src/arm_common/simd_macro/marm_neon.h"
#include "src/arm_common/utils.h" #include "src/arm_common/utils.h"
#include "src/common/unroll_macro.h" #include "src/common/unroll_macro.h"


+ 1
- 1
dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h View File

@@ -14,7 +14,7 @@
#include "megdnn/arch.h" #include "megdnn/arch.h"
#include "src/arm_common/conv_bias/fp32/f32_direct_nchw44_kern.h" #include "src/arm_common/conv_bias/fp32/f32_direct_nchw44_kern.h"
#include "src/arm_common/conv_bias/intrinsic_helper.h" #include "src/arm_common/conv_bias/intrinsic_helper.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/arm_common/simd_macro/marm_neon.h" #include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/unroll_macro.h" #include "src/common/unroll_macro.h"
#include "src/common/utils.h" #include "src/common/utils.h"


+ 1
- 1
dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h View File

@@ -14,7 +14,7 @@
#include "megdnn/arch.h" #include "megdnn/arch.h"
#include "src/arm_common/conv_bias/fp32/f32_direct_nchw44_kern.h" #include "src/arm_common/conv_bias/fp32/f32_direct_nchw44_kern.h"
#include "src/arm_common/conv_bias/intrinsic_helper.h" #include "src/arm_common/conv_bias/intrinsic_helper.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/arm_common/simd_macro/marm_neon.h" #include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/unroll_macro.h" #include "src/common/unroll_macro.h"
#include "src/common/utils.h" #include "src/common/utils.h"


+ 1
- 1
dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h View File

@@ -14,7 +14,7 @@
#include "src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_kern.h" #include "src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_kern.h"
#include "src/arm_common/conv_bias/intrinsic_helper.h" #include "src/arm_common/conv_bias/intrinsic_helper.h"
#include "src/arm_common/conv_bias/opr_impl.h" #include "src/arm_common/conv_bias/opr_impl.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/arm_common/simd_macro/marm_neon.h" #include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/unroll_macro.h" #include "src/common/unroll_macro.h"
#include "src/common/utils.h" #include "src/common/utils.h"


+ 1
- 1
dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw44_algo.cpp View File

@@ -15,7 +15,7 @@
#include "src/arm_common/conv_bias/fp32/algos.h" #include "src/arm_common/conv_bias/fp32/algos.h"
#include "src/arm_common/conv_bias/fp32/f32_direct_nchw44_kern.h" #include "src/arm_common/conv_bias/fp32/f32_direct_nchw44_kern.h"


#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"


#include "midout.h" #include "midout.h"




+ 1
- 1
dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_algo.cpp View File

@@ -15,7 +15,7 @@
#include "src/arm_common/conv_bias/fp32/algos.h" #include "src/arm_common/conv_bias/fp32/algos.h"
#include "src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_kern.h" #include "src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_kern.h"
#include "src/arm_common/conv_bias/fp32/strategy.h" #include "src/arm_common/conv_bias/fp32/strategy.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/common/nchw_nchwxx_valid.h" #include "src/common/nchw_nchwxx_valid.h"
#include "src/common/opr_delegate.h" #include "src/common/opr_delegate.h"




+ 1
- 1
dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_kern.h View File

@@ -13,7 +13,7 @@
#include "megdnn/arch.h" #include "megdnn/arch.h"
#include "src/arm_common/conv_bias/intrinsic_helper.h" #include "src/arm_common/conv_bias/intrinsic_helper.h"
#include "src/arm_common/conv_bias/opr_impl.h" #include "src/arm_common/conv_bias/opr_impl.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/arm_common/simd_macro/marm_neon.h" #include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/unroll_macro.h" #include "src/common/unroll_macro.h"
#include "src/common/utils.h" #include "src/common/utils.h"


+ 1
- 1
dnn/src/arm_common/conv_bias/int8/algos.cpp View File

@@ -17,7 +17,7 @@
#include "src/arm_common/conv_bias/int8/stride1_dotprod.h" #include "src/arm_common/conv_bias/int8/stride1_dotprod.h"
#include "src/arm_common/conv_bias/int8/stride2.h" #include "src/arm_common/conv_bias/int8/stride2.h"
#include "src/arm_common/conv_bias/int8/stride2_dotprod.h" #include "src/arm_common/conv_bias/int8/stride2_dotprod.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/fallback/conv_bias/common.h" #include "src/fallback/conv_bias/common.h"


#include "midout.h" #include "midout.h"


+ 1
- 1
dnn/src/arm_common/conv_bias/int8/channel_wise_kernel.cpp View File

@@ -11,7 +11,7 @@
*/ */


#include "src/arm_common/conv_bias/int8/channel_wise_kernel.h" #include "src/arm_common/conv_bias/int8/channel_wise_kernel.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/arm_common/simd_macro/marm_neon.h" #include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/unroll_macro.h" #include "src/common/unroll_macro.h"
#include "src/common/utils.h" #include "src/common/utils.h"


+ 1
- 1
dnn/src/arm_common/conv_bias/int8/channel_wise_nchw44.cpp View File

@@ -12,7 +12,7 @@
#include "src/arm_common/conv_bias/int8/channel_wise_nchw44.h" #include "src/arm_common/conv_bias/int8/channel_wise_nchw44.h"
#include "megdnn/oprs.h" #include "megdnn/oprs.h"
#include "src/arm_common/conv_bias/int8/channel_wise_kernel.h" #include "src/arm_common/conv_bias/int8/channel_wise_kernel.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/common/opr_delegate.h" #include "src/common/opr_delegate.h"


#include "midout.h" #include "midout.h"


+ 1
- 1
dnn/src/arm_common/conv_bias/int8/direct.cpp View File

@@ -10,7 +10,7 @@
*/ */


#include "src/arm_common/conv_bias/int8/direct.h" #include "src/arm_common/conv_bias/int8/direct.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/arm_common/simd_macro/marm_neon.h" #include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/utils.h" #include "src/common/utils.h"
#include "src/fallback/conv_bias/common.h" #include "src/fallback/conv_bias/common.h"


+ 1
- 1
dnn/src/arm_common/conv_bias/int8/direct_dotprod.cpp View File

@@ -11,7 +11,7 @@


#include "src/arm_common/conv_bias/int8/direct_dotprod.h" #include "src/arm_common/conv_bias/int8/direct_dotprod.h"
#if MGB_ENABLE_DOT #if MGB_ENABLE_DOT
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/arm_common/simd_macro/marm_neon.h" #include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/utils.h" #include "src/common/utils.h"
#include "src/fallback/conv_bias/common.h" #include "src/fallback/conv_bias/common.h"


+ 1
- 1
dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44_algo.cpp View File

@@ -14,7 +14,7 @@
#if MGB_ENABLE_DOT #if MGB_ENABLE_DOT
#include "src/arm_common/conv_bias/int8/algos.h" #include "src/arm_common/conv_bias/int8/algos.h"
#include "src/arm_common/conv_bias/int8/direct_dotprod_nchw44.h" #include "src/arm_common/conv_bias/int8/direct_dotprod_nchw44.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"


#include "midout.h" #include "midout.h"




+ 1
- 1
dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_common.h View File

@@ -14,7 +14,7 @@
#include "megdnn/arch.h" #include "megdnn/arch.h"
#include "src/arm_common/conv_bias/intrinsic_helper.h" #include "src/arm_common/conv_bias/intrinsic_helper.h"
#if MGB_ENABLE_DOT #if MGB_ENABLE_DOT
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/arm_common/intrinsic_helper.h" #include "src/arm_common/intrinsic_helper.h"
#include "src/arm_common/neon_struct.h" #include "src/arm_common/neon_struct.h"
#include "src/common/unroll_macro.h" #include "src/common/unroll_macro.h"


+ 1
- 1
dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw44_s1.cpp View File

@@ -13,7 +13,7 @@
#include "src/arm_common/conv_bias/int8/direct.h" #include "src/arm_common/conv_bias/int8/direct.h"
#include "src/arm_common/conv_bias/int8/direct_nchw44_kern.h" #include "src/arm_common/conv_bias/int8/direct_nchw44_kern.h"
#include "src/arm_common/conv_bias/intrinsic_helper.h" #include "src/arm_common/conv_bias/intrinsic_helper.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/arm_common/simd_macro/marm_neon.h" #include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/utils.h" #include "src/common/utils.h"
#include "src/fallback/conv_bias/common.h" #include "src/fallback/conv_bias/common.h"


+ 1
- 1
dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw44_s2.cpp View File

@@ -14,7 +14,7 @@
#include "src/arm_common/conv_bias/int8/direct.h" #include "src/arm_common/conv_bias/int8/direct.h"
#include "src/arm_common/conv_bias/int8/direct_nchw44_kern.h" #include "src/arm_common/conv_bias/int8/direct_nchw44_kern.h"
#include "src/arm_common/conv_bias/intrinsic_helper.h" #include "src/arm_common/conv_bias/intrinsic_helper.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/arm_common/simd_macro/marm_neon.h" #include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/utils.h" #include "src/common/utils.h"
#include "src/fallback/conv_bias/common.h" #include "src/fallback/conv_bias/common.h"


+ 1
- 1
dnn/src/arm_common/conv_bias/int8/direct_nchw44_algo.cpp View File

@@ -14,7 +14,7 @@
#include "src/arm_common/conv_bias/int8/algos.h" #include "src/arm_common/conv_bias/int8/algos.h"
#include "src/arm_common/conv_bias/int8/direct.h" #include "src/arm_common/conv_bias/int8/direct.h"
#include "src/arm_common/conv_bias/int8/direct_nchw44_kern.h" #include "src/arm_common/conv_bias/int8/direct_nchw44_kern.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/common/opr_delegate.h" #include "src/common/opr_delegate.h"


#include "midout.h" #include "midout.h"


+ 1
- 1
dnn/src/arm_common/conv_bias/int8/direct_nchw44_kern.h View File

@@ -12,7 +12,7 @@
#pragma once #pragma once
#include "src/arm_common/conv_bias/int8/direct.h" #include "src/arm_common/conv_bias/int8/direct.h"
#include "src/arm_common/conv_bias/intrinsic_helper.h" #include "src/arm_common/conv_bias/intrinsic_helper.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/arm_common/simd_macro/marm_neon.h" #include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/utils.h" #include "src/common/utils.h"
#include "src/fallback/conv_bias/common.h" #include "src/fallback/conv_bias/common.h"


+ 1
- 1
dnn/src/arm_common/conv_bias/int8/direct_nchw_nchw44_algo.cpp View File

@@ -14,7 +14,7 @@
#include "src/arm_common/conv_bias/int8/algos.h" #include "src/arm_common/conv_bias/int8/algos.h"
#include "src/arm_common/conv_bias/int8/direct_nchw_nchw44_kern.h" #include "src/arm_common/conv_bias/int8/direct_nchw_nchw44_kern.h"
#include "src/arm_common/conv_bias/int8/strategy.h" #include "src/arm_common/conv_bias/int8/strategy.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/common/nchw_nchwxx_valid.h" #include "src/common/nchw_nchwxx_valid.h"
#include "src/common/opr_delegate.h" #include "src/common/opr_delegate.h"




+ 1
- 1
dnn/src/arm_common/conv_bias/int8/direct_nchw_nchw44_kern.h View File

@@ -13,7 +13,7 @@
#include "megdnn/arch.h" #include "megdnn/arch.h"
#include "src/arm_common/conv_bias/intrinsic_helper.h" #include "src/arm_common/conv_bias/intrinsic_helper.h"
#include "src/arm_common/conv_bias/opr_impl.h" #include "src/arm_common/conv_bias/opr_impl.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/arm_common/simd_macro/marm_neon.h" #include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/unroll_macro.h" #include "src/common/unroll_macro.h"
#include "src/common/utils.h" #include "src/common/utils.h"


+ 1
- 1
dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_algo.cpp View File

@@ -15,7 +15,7 @@
#include "src/arm_common/conv_bias/block_helper.h" #include "src/arm_common/conv_bias/block_helper.h"
#include "src/arm_common/conv_bias/int8/algos.h" #include "src/arm_common/conv_bias/int8/algos.h"
#include "src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_kern.h" #include "src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_kern.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/common/nchw_nchwxx_valid.h" #include "src/common/nchw_nchwxx_valid.h"


#include "midout.h" #include "midout.h"


+ 1
- 1
dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_kern.h View File

@@ -14,7 +14,7 @@


#include "src/arm_common/conv_bias/intrinsic_helper.h" #include "src/arm_common/conv_bias/intrinsic_helper.h"
#if MGB_ENABLE_DOT #if MGB_ENABLE_DOT
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/arm_common/simd_macro/marm_neon.h" #include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/unroll_macro.h" #include "src/common/unroll_macro.h"
#include "src/common/utils.h" #include "src/common/utils.h"


+ 1
- 1
dnn/src/arm_common/conv_bias/int8/stride1.cpp View File

@@ -14,7 +14,7 @@
#include "megdnn/oprs.h" #include "megdnn/oprs.h"
#include "src/arm_common/conv_bias/int8/direct.h" #include "src/arm_common/conv_bias/int8/direct.h"
#include "src/arm_common/conv_bias/int8/strategy.h" #include "src/arm_common/conv_bias/int8/strategy.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/common/opr_delegate.h" #include "src/common/opr_delegate.h"


using namespace megdnn; using namespace megdnn;


+ 1
- 1
dnn/src/arm_common/conv_bias/int8/stride1_dotprod.cpp View File

@@ -14,7 +14,7 @@
#include "megdnn/oprs.h" #include "megdnn/oprs.h"
#include "src/arm_common/conv_bias/int8/direct_dotprod.h" #include "src/arm_common/conv_bias/int8/direct_dotprod.h"
#include "src/arm_common/conv_bias/int8/strategy.h" #include "src/arm_common/conv_bias/int8/strategy.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/common/opr_delegate.h" #include "src/common/opr_delegate.h"


using namespace megdnn; using namespace megdnn;


+ 1
- 1
dnn/src/arm_common/conv_bias/int8/stride2.cpp View File

@@ -14,7 +14,7 @@
#include "megdnn/oprs.h" #include "megdnn/oprs.h"
#include "src/arm_common/conv_bias/int8/direct.h" #include "src/arm_common/conv_bias/int8/direct.h"
#include "src/arm_common/conv_bias/int8/strategy.h" #include "src/arm_common/conv_bias/int8/strategy.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/common/opr_delegate.h" #include "src/common/opr_delegate.h"


using namespace megdnn; using namespace megdnn;


+ 1
- 1
dnn/src/arm_common/conv_bias/int8/stride2_dotprod.cpp View File

@@ -14,7 +14,7 @@
#include "megdnn/oprs.h" #include "megdnn/oprs.h"
#include "src/arm_common/conv_bias/int8/direct_dotprod.h" #include "src/arm_common/conv_bias/int8/direct_dotprod.h"
#include "src/arm_common/conv_bias/int8/strategy.h" #include "src/arm_common/conv_bias/int8/strategy.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/common/opr_delegate.h" #include "src/common/opr_delegate.h"


using namespace megdnn; using namespace megdnn;


+ 1
- 1
dnn/src/arm_common/conv_bias/int8x8x16/channel_wise_kernel_int8x8x16_nchw44.cpp View File

@@ -11,7 +11,7 @@
*/ */


#include "src/arm_common/conv_bias/int8x8x16/channel_wise_kernel.h" #include "src/arm_common/conv_bias/int8x8x16/channel_wise_kernel.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/arm_common/simd_macro/marm_neon.h" #include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/unroll_macro.h" #include "src/common/unroll_macro.h"
#include "src/common/utils.h" #include "src/common/utils.h"


+ 1
- 1
dnn/src/arm_common/conv_bias/int8x8x16/direct_nchw_nchw44_algo.cpp View File

@@ -15,7 +15,7 @@
#include "src/arm_common/conv_bias/block_helper.h" #include "src/arm_common/conv_bias/block_helper.h"
#include "src/arm_common/conv_bias/int8x8x16/algos.h" #include "src/arm_common/conv_bias/int8x8x16/algos.h"
#include "src/arm_common/conv_bias/int8x8x16/direct_nchw_nchw44_kern.h" #include "src/arm_common/conv_bias/int8x8x16/direct_nchw_nchw44_kern.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/common/nchw_nchwxx_valid.h" #include "src/common/nchw_nchwxx_valid.h"
#include "src/common/opr_delegate.h" #include "src/common/opr_delegate.h"




+ 1
- 1
dnn/src/arm_common/conv_bias/int8x8x16/direct_nchw_nchw44_kern.h View File

@@ -13,7 +13,7 @@
#include "megdnn/arch.h" #include "megdnn/arch.h"
#include "src/arm_common/conv_bias/intrinsic_helper.h" #include "src/arm_common/conv_bias/intrinsic_helper.h"
#include "src/arm_common/conv_bias/opr_impl.h" #include "src/arm_common/conv_bias/opr_impl.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/arm_common/simd_macro/marm_neon.h" #include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/unroll_macro.h" #include "src/common/unroll_macro.h"
#include "src/common/utils.h" #include "src/common/utils.h"


+ 1
- 1
dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl.h View File

@@ -16,7 +16,7 @@
#include "src/arm_common/conv_bias/int8x8x16/direct_nchw_nchw44_kern.h" #include "src/arm_common/conv_bias/int8x8x16/direct_nchw_nchw44_kern.h"
#include "src/arm_common/conv_bias/intrinsic_helper.h" #include "src/arm_common/conv_bias/intrinsic_helper.h"
#include "src/arm_common/conv_bias/opr_impl.h" #include "src/arm_common/conv_bias/opr_impl.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/arm_common/simd_macro/marm_neon.h" #include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/unroll_macro.h" #include "src/common/unroll_macro.h"
#include "src/common/utils.h" #include "src/common/utils.h"


+ 1
- 1
dnn/src/arm_common/conv_bias/matmul_postprocess.h View File

@@ -12,7 +12,7 @@


#include "megdnn/dtype.h" #include "megdnn/dtype.h"
#include "megdnn/oprs.h" #include "megdnn/oprs.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/arm_common/simd_macro/marm_neon.h" #include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/utils.h" #include "src/common/utils.h"
#include "src/fallback/conv_bias/common.h" #include "src/fallback/conv_bias/common.h"


+ 61
- 59
dnn/src/arm_common/conv_bias/postprocess_helper.h View File

@@ -13,8 +13,8 @@
#pragma once #pragma once


#include "megdnn/basic_types.h" #include "megdnn/basic_types.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/arm_common/elemwise_helper/kimpl/op_base.h" #include "src/arm_common/elemwise_helper/kimpl/op_base.h"
#include "src/arm_common/elemwise_op.h"
#include "src/fallback/conv_bias/opr_impl.h" #include "src/fallback/conv_bias/opr_impl.h"


#include "midout.h" #include "midout.h"
@@ -44,29 +44,29 @@ namespace {
break; break;


#define FOR_NONLINEAR_UNARY(_op) \ #define FOR_NONLINEAR_UNARY(_op) \
megdnn::arm_common::OpCallerUnary<_op<ctype>, megdnn::VEC>::run( \
megdnn::elemwise::OpCallerUnary<_op<ctype>, megdnn::elemwise::VEC>::run( \
static_cast<ctype*>(conv_dst_ptr), reinterpret_cast<ctype*>(dst_ptr), \ static_cast<ctype*>(conv_dst_ptr), reinterpret_cast<ctype*>(dst_ptr), \
bias_type, dst_type, N* OC* OH* OW* pack_oc_size); bias_type, dst_type, N* OC* OH* OW* pack_oc_size);


#define FOR_NONLINEAR_BINARY_BROADCAST(_op) \
megdnn::arm_common::OpCallerBinary<_op<ctype>, megdnn::VEC_BCAST101>::run( \
static_cast<ctype*>(conv_dst_ptr), \
reinterpret_cast<const ctype*>(bias_ptr), \
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, N, OC, \
#define FOR_NONLINEAR_BINARY_BROADCAST(_op) \
megdnn::elemwise::OpCallerBinary<_op<ctype>, megdnn::elemwise::VEC_BCAST101>::run( \
static_cast<ctype*>(conv_dst_ptr), \
reinterpret_cast<const ctype*>(bias_ptr), \
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, N, OC, \
OH* OW); OH* OW);


#define FOR_NONLINEAR_BINARY_BROADCAST_NCHWXX(_op) \ #define FOR_NONLINEAR_BINARY_BROADCAST_NCHWXX(_op) \
megdnn::arm_common::OpCallerBinary<_op<ctype>, megdnn::VEC_BCAST101xX>::run( \
static_cast<ctype*>(conv_dst_ptr), \
reinterpret_cast<const ctype*>(bias_ptr), \
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, N, OC, \
OH* OW, pack_oc_size);
#define FOR_NONLINEAR_BINARY(_op) \
megdnn::arm_common::OpCallerBinary<_op<ctype>, megdnn::VEC_VEC>::run( \
static_cast<ctype*>(conv_dst_ptr), \
reinterpret_cast<const ctype*>(bias_ptr), \
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, \
megdnn::elemwise::OpCallerBinary<_op<ctype>, megdnn::elemwise::VEC_BCAST101xX>:: \
run(static_cast<ctype*>(conv_dst_ptr), \
reinterpret_cast<const ctype*>(bias_ptr), \
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, N, \
OC, OH* OW, pack_oc_size);
#define FOR_NONLINEAR_BINARY(_op) \
megdnn::elemwise::OpCallerBinary<_op<ctype>, megdnn::elemwise::VEC_VEC>::run( \
static_cast<ctype*>(conv_dst_ptr), \
reinterpret_cast<const ctype*>(bias_ptr), \
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, \
N* OC* OH* OW* pack_oc_size); N* OC* OH* OW* pack_oc_size);


#define FOR_BIAS(_mode) \ #define FOR_BIAS(_mode) \
@@ -167,33 +167,35 @@ struct PostProcess<ctype, dtype, megdnn::PostprocessMode::NO_PROCESS> {
#undef FOR_BIAS #undef FOR_BIAS
#undef HANDLE_IDENTITY #undef HANDLE_IDENTITY


#define FOR_NONLINEAR_UNARY(_op) \
megdnn::arm_common::OpCallerUnary<_op<opctype, opdtype>, megdnn::VEC>::run( \
static_cast<opctype*>(conv_dst_ptr), reinterpret_cast<opdtype*>(dst_ptr), \
bias_type, dst_type, N* OC* OH* OW* pack_oc_size);

#define FOR_NONLINEAR_BINARY_BROADCAST(_op) \
megdnn::arm_common::OpCallerBinary<_op<opctype, opdtype>, megdnn::VEC_BCAST101>:: \
run(static_cast<opctype*>(conv_dst_ptr), \
reinterpret_cast<const opctype*>(bias_ptr), \
reinterpret_cast<opdtype*>(dst_ptr), bias_type, bias_type, dst_type, \
#define FOR_NONLINEAR_UNARY(_op) \
megdnn::elemwise::OpCallerUnary<_op<opctype, opdtype>, megdnn::elemwise::VEC>:: \
run(static_cast<opctype*>(conv_dst_ptr), \
reinterpret_cast<opdtype*>(dst_ptr), bias_type, dst_type, \
N* OC* OH* OW* pack_oc_size);

#define FOR_NONLINEAR_BINARY_BROADCAST(_op) \
megdnn::elemwise::OpCallerBinary< \
_op<opctype, opdtype>, megdnn::elemwise::VEC_BCAST101>:: \
run(static_cast<opctype*>(conv_dst_ptr), \
reinterpret_cast<const opctype*>(bias_ptr), \
reinterpret_cast<opdtype*>(dst_ptr), bias_type, bias_type, dst_type, \
N, OC, OH* OW); N, OC, OH* OW);


#define FOR_NONLINEAR_BINARY_BROADCAST_NCHWXX(_op) \
megdnn::arm_common:: \
OpCallerBinary<_op<opctype, opdtype>, megdnn::VEC_BCAST101xX>::run( \
static_cast<opctype*>(conv_dst_ptr), \
reinterpret_cast<const opctype*>(bias_ptr), \
reinterpret_cast<opdtype*>(dst_ptr), bias_type, bias_type, \
dst_type, N, OC, OH* OW, pack_oc_size);
#define FOR_NONLINEAR_BINARY_BROADCAST_NCHW88(_op) \
megdnn::arm_common:: \
OpCallerBinary<_op<opctype, opdtype>, megdnn::VEC_BCAST101xX>::run( \
static_cast<opctype*>(conv_dst_ptr), \
reinterpret_cast<const opctype*>(bias_ptr), \
reinterpret_cast<opdtype*>(dst_ptr), bias_type, bias_type, \
dst_type, N, OC, OH* OW, pack_oc_size);
#define FOR_NONLINEAR_BINARY_BROADCAST_NCHWXX(_op) \
megdnn::elemwise::OpCallerBinary< \
_op<opctype, opdtype>, megdnn::elemwise::VEC_BCAST101xX>:: \
run(static_cast<opctype*>(conv_dst_ptr), \
reinterpret_cast<const opctype*>(bias_ptr), \
reinterpret_cast<opdtype*>(dst_ptr), bias_type, bias_type, dst_type, \
N, OC, OH* OW, pack_oc_size);
#define FOR_NONLINEAR_BINARY_BROADCAST_NCHW88(_op) \
megdnn::elemwise::OpCallerBinary< \
_op<opctype, opdtype>, megdnn::elemwise::VEC_BCAST101xX>:: \
run(static_cast<opctype*>(conv_dst_ptr), \
reinterpret_cast<const opctype*>(bias_ptr), \
reinterpret_cast<opdtype*>(dst_ptr), bias_type, bias_type, dst_type, \
N, OC, OH* OW, pack_oc_size);


#define HANDLE_IDENTITY(_caller, _op) \ #define HANDLE_IDENTITY(_caller, _op) \
case megdnn::NonlineMode::IDENTITY: \ case megdnn::NonlineMode::IDENTITY: \
@@ -267,25 +269,25 @@ struct PostProcess<opctype, opdtype, megdnn::PostprocessMode::QUANTIZED> {
#undef FOR_NONLINEAR #undef FOR_NONLINEAR
#undef FOR_BIAS #undef FOR_BIAS


#define FOR_BINARY_BROADCAST(_op) \
megdnn::arm_common::OpCallerBinary<_op<ctype>, megdnn::VEC_BCAST101>::run( \
static_cast<ctype*>(conv_dst_ptr), \
reinterpret_cast<const ctype*>(bias_ptr), \
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, N, OC, \
#define FOR_BINARY_BROADCAST(_op) \
megdnn::elemwise::OpCallerBinary<_op<ctype>, megdnn::elemwise::VEC_BCAST101>::run( \
static_cast<ctype*>(conv_dst_ptr), \
reinterpret_cast<const ctype*>(bias_ptr), \
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, N, OC, \
OH* OW); OH* OW);


#define FOR_BINARY_BROADCAST_NCHWXX(_op) \ #define FOR_BINARY_BROADCAST_NCHWXX(_op) \
megdnn::arm_common::OpCallerBinary<_op<ctype>, megdnn::VEC_BCAST101xX>::run( \
static_cast<ctype*>(conv_dst_ptr), \
reinterpret_cast<const ctype*>(bias_ptr), \
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, N, OC, \
OH* OW, pack_oc_size);
#define FOR_BINARY(_op) \
megdnn::arm_common::OpCallerBinary<_op<ctype>, megdnn::VEC_VEC>::run( \
static_cast<ctype*>(conv_dst_ptr), \
reinterpret_cast<const ctype*>(bias_ptr), \
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, \
megdnn::elemwise::OpCallerBinary<_op<ctype>, megdnn::elemwise::VEC_BCAST101xX>:: \
run(static_cast<ctype*>(conv_dst_ptr), \
reinterpret_cast<const ctype*>(bias_ptr), \
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, N, \
OC, OH* OW, pack_oc_size);
#define FOR_BINARY(_op) \
megdnn::elemwise::OpCallerBinary<_op<ctype>, megdnn::elemwise::VEC_VEC>::run( \
static_cast<ctype*>(conv_dst_ptr), \
reinterpret_cast<const ctype*>(bias_ptr), \
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, \
N* OC* OH* OW* pack_oc_size); N* OC* OH* OW* pack_oc_size);


#define FOR_BIAS(_bias_mode, OH, OW) \ #define FOR_BIAS(_bias_mode, OH, OW) \


+ 1
- 1
dnn/src/arm_common/conv_bias/quint8/algos.cpp View File

@@ -15,7 +15,7 @@
#include "src/arm_common/conv_bias/quint8/stride1_dotprod.h" #include "src/arm_common/conv_bias/quint8/stride1_dotprod.h"
#include "src/arm_common/conv_bias/quint8/stride2.h" #include "src/arm_common/conv_bias/quint8/stride2.h"
#include "src/arm_common/conv_bias/quint8/stride2_dotprod.h" #include "src/arm_common/conv_bias/quint8/stride2_dotprod.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/fallback/conv_bias/common.h" #include "src/fallback/conv_bias/common.h"


#include "midout.h" #include "midout.h"


+ 1
- 1
dnn/src/arm_common/conv_bias/quint8/direct.cpp View File

@@ -10,7 +10,7 @@
*/ */


#include "src/arm_common/conv_bias/quint8/direct.h" #include "src/arm_common/conv_bias/quint8/direct.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/arm_common/simd_macro/marm_neon.h" #include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/utils.h" #include "src/common/utils.h"
#include "src/fallback/conv_bias/common.h" #include "src/fallback/conv_bias/common.h"


+ 1
- 1
dnn/src/arm_common/conv_bias/quint8/direct_dotprod.cpp View File

@@ -11,7 +11,7 @@


#include "src/arm_common/conv_bias/quint8/direct_dotprod.h" #include "src/arm_common/conv_bias/quint8/direct_dotprod.h"
#if MGB_ENABLE_DOT #if MGB_ENABLE_DOT
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/arm_common/simd_macro/marm_neon.h" #include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/utils.h" #include "src/common/utils.h"
#include "src/fallback/conv_bias/common.h" #include "src/fallback/conv_bias/common.h"


+ 1
- 1
dnn/src/arm_common/conv_bias/quint8/stride1.cpp View File

@@ -12,7 +12,7 @@
#include "src/arm_common/conv_bias/quint8/stride1.h" #include "src/arm_common/conv_bias/quint8/stride1.h"
#include "megdnn/oprs.h" #include "megdnn/oprs.h"
#include "src/arm_common/conv_bias/quint8/direct.h" #include "src/arm_common/conv_bias/quint8/direct.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/common/opr_delegate.h" #include "src/common/opr_delegate.h"


using namespace megdnn; using namespace megdnn;


+ 1
- 1
dnn/src/arm_common/conv_bias/quint8/stride1_dotprod.cpp View File

@@ -12,7 +12,7 @@
#if MGB_ENABLE_DOT #if MGB_ENABLE_DOT
#include "megdnn/oprs.h" #include "megdnn/oprs.h"
#include "src/arm_common/conv_bias/quint8/direct_dotprod.h" #include "src/arm_common/conv_bias/quint8/direct_dotprod.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/common/opr_delegate.h" #include "src/common/opr_delegate.h"


using namespace megdnn; using namespace megdnn;


+ 1
- 1
dnn/src/arm_common/conv_bias/quint8/stride2.cpp View File

@@ -12,7 +12,7 @@
#include "src/arm_common/conv_bias/quint8/stride2.h" #include "src/arm_common/conv_bias/quint8/stride2.h"
#include "megdnn/oprs.h" #include "megdnn/oprs.h"
#include "src/arm_common/conv_bias/quint8/direct.h" #include "src/arm_common/conv_bias/quint8/direct.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/common/opr_delegate.h" #include "src/common/opr_delegate.h"


using namespace megdnn; using namespace megdnn;


+ 1
- 1
dnn/src/arm_common/conv_bias/quint8/stride2_dotprod.cpp View File

@@ -12,7 +12,7 @@
#if MGB_ENABLE_DOT #if MGB_ENABLE_DOT
#include "megdnn/oprs.h" #include "megdnn/oprs.h"
#include "src/arm_common/conv_bias/quint8/direct_dotprod.h" #include "src/arm_common/conv_bias/quint8/direct_dotprod.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/common/opr_delegate.h" #include "src/common/opr_delegate.h"


using namespace megdnn; using namespace megdnn;


+ 4
- 3
dnn/src/arm_common/elemwise/binary/algo.cpp View File

@@ -10,7 +10,7 @@
* implied. * implied.
*/ */
#include "src/arm_common/elemwise/binary/algo.h" #include "src/arm_common/elemwise/binary/algo.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"


#include "src/common/utils.h" #include "src/common/utils.h"
#include "src/naive/handle.h" #include "src/naive/handle.h"
@@ -20,6 +20,7 @@
MIDOUT_DECL(megdnn_arm_common_elemwise_binary) MIDOUT_DECL(megdnn_arm_common_elemwise_binary)


using namespace megdnn; using namespace megdnn;
using namespace elemwise;
using namespace arm_common; using namespace arm_common;


namespace { namespace {
@@ -160,7 +161,7 @@ bool ElemwiseImpl::AlgoBinaryVecBcast101xX::is_available(
DISPATCH_BINARY(ADD, _case, _type, _type_midout_id, AddOp); \ DISPATCH_BINARY(ADD, _case, _type, _type_midout_id, AddOp); \
DISPATCH_BINARY(SUB, _case, _type, _type_midout_id, SubOp); \ DISPATCH_BINARY(SUB, _case, _type, _type_midout_id, SubOp); \
DISPATCH_BINARY(MUL, _case, _type, _type_midout_id, MulOp); \ DISPATCH_BINARY(MUL, _case, _type, _type_midout_id, MulOp); \
DISPATCH_BINARY(POW, _case, _type, _type_midout_id, PowOp); \
DISPATCH_BINARY(POW, _case, _type, _type_midout_id, fallback::PowOp); \
DISPATCH_BINARY(TRUE_DIV, _case, _type, _type_midout_id, TrueDivOp); \ DISPATCH_BINARY(TRUE_DIV, _case, _type, _type_midout_id, TrueDivOp); \
DISPATCH_BINARY(FUSE_ADD_RELU, _case, _type, _type_midout_id, FuseAddReluOp); \ DISPATCH_BINARY(FUSE_ADD_RELU, _case, _type, _type_midout_id, FuseAddReluOp); \
DISPATCH_BINARY( \ DISPATCH_BINARY( \
@@ -178,7 +179,7 @@ bool ElemwiseImpl::AlgoBinaryVecBcast101xX::is_available(
DISPATCH_BINARY(ADD, _case, _type, _type_midout_id, AddOp); \ DISPATCH_BINARY(ADD, _case, _type, _type_midout_id, AddOp); \
DISPATCH_BINARY(SUB, _case, _type, _type_midout_id, SubOp); \ DISPATCH_BINARY(SUB, _case, _type, _type_midout_id, SubOp); \
DISPATCH_BINARY(MUL, _case, _type, _type_midout_id, MulOp); \ DISPATCH_BINARY(MUL, _case, _type, _type_midout_id, MulOp); \
DISPATCH_BINARY(POW, _case, _type, _type_midout_id, PowOp); \
DISPATCH_BINARY(POW, _case, _type, _type_midout_id, fallback::PowOp); \
DISPATCH_BINARY(FUSE_ADD_RELU, _case, _type, _type_midout_id, FuseAddReluOp); \ DISPATCH_BINARY(FUSE_ADD_RELU, _case, _type, _type_midout_id, FuseAddReluOp); \
DISPATCH_BINARY( \ DISPATCH_BINARY( \
FUSE_ADD_H_SWISH, _case, _type, _type_midout_id, FuseAddHSwishOp); \ FUSE_ADD_H_SWISH, _case, _type, _type_midout_id, FuseAddHSwishOp); \


+ 1
- 1
dnn/src/arm_common/elemwise/opr_impl.cpp View File

@@ -13,7 +13,7 @@
#include "src/arm_common/elemwise/binary/algo.h" #include "src/arm_common/elemwise/binary/algo.h"
#include "src/arm_common/elemwise/ternary/algo.h" #include "src/arm_common/elemwise/ternary/algo.h"
#include "src/arm_common/elemwise/unary/algo.h" #include "src/arm_common/elemwise/unary/algo.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/common/metahelper.h" #include "src/common/metahelper.h"


#include "src/common/utils.h" #include "src/common/utils.h"


+ 1
- 1
dnn/src/arm_common/elemwise/opr_impl.h View File

@@ -12,7 +12,7 @@
#pragma once #pragma once
#include "src/fallback/elemwise/opr_impl.h" #include "src/fallback/elemwise/opr_impl.h"


#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"


namespace megdnn { namespace megdnn {
namespace arm_common { namespace arm_common {


+ 2
- 1
dnn/src/arm_common/elemwise/ternary/algo.cpp View File

@@ -10,7 +10,7 @@
* implied. * implied.
*/ */
#include "src/arm_common/elemwise/ternary/algo.h" #include "src/arm_common/elemwise/ternary/algo.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"


#include "src/common/utils.h" #include "src/common/utils.h"
#include "src/naive/handle.h" #include "src/naive/handle.h"
@@ -20,6 +20,7 @@
MIDOUT_DECL(megdnn_arm_common_elemwise_ternary) MIDOUT_DECL(megdnn_arm_common_elemwise_ternary)


using namespace megdnn; using namespace megdnn;
using namespace elemwise;
using namespace arm_common; using namespace arm_common;


#define DISPATCH_MODE_FLOAT(_case, _type, _type_midout_id) \ #define DISPATCH_MODE_FLOAT(_case, _type, _type_midout_id) \


+ 2
- 1
dnn/src/arm_common/elemwise/unary/algo.cpp View File

@@ -10,7 +10,7 @@
* implied. * implied.
*/ */
#include "src/arm_common/elemwise/unary/algo.h" #include "src/arm_common/elemwise/unary/algo.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"


#include "src/common/utils.h" #include "src/common/utils.h"
#include "src/naive/handle.h" #include "src/naive/handle.h"
@@ -20,6 +20,7 @@
MIDOUT_DECL(megdnn_arm_common_elemwise_unary) MIDOUT_DECL(megdnn_arm_common_elemwise_unary)


using namespace megdnn; using namespace megdnn;
using namespace elemwise;
using namespace arm_common; using namespace arm_common;


bool ElemwiseImpl::AlgoUnary::is_available(const KernParam& kern_param) const { bool ElemwiseImpl::AlgoUnary::is_available(const KernParam& kern_param) const {


+ 151
- 0
dnn/src/arm_common/elemwise_helper/elemwise_op.h View File

@@ -0,0 +1,151 @@
/**
* \file dnn/src/arm_common/elemwise_helper/elemwise_op.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/

#pragma once

#include "src/arm_common/elemwise_helper/op_binary.h"
#include "src/arm_common/elemwise_helper/op_ternary.h"
#include "src/arm_common/elemwise_helper/op_unary.h"
#include "src/fallback/elemwise_helper/elemwise_op.h"

namespace megdnn {
namespace elemwise {

using BcastType = megdnn::elemwise::BcastType;

///////////////////////////////// ParamElemVistor ///////////////////////////

#define cb(_ctype, _inner_ctype, _neon_type, _fun_suffix) \
template <> \
struct ParamElemVisitor<_ctype> { \
_neon_type operator()(const _ctype* src) const { \
return vld1q_##_fun_suffix(reinterpret_cast<const _inner_ctype*>(src)); \
} \
}; \
template <> \
struct ParamElemVisitorDup<_ctype> { \
_neon_type operator()(const _ctype* src) const { \
return vdupq_n_##_fun_suffix(*reinterpret_cast<const _inner_ctype*>(src)); \
} \
}
cb(dt_quint8, uint8_t, uint8x16_t, u8);
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
cb(__fp16, __fp16, float16x8_t, f16);
#endif
cb(dt_int16, int16_t, int16x8_t, s16);
#undef cb

template <typename ctype>
struct ParamElemVisitorBcast101x4;
#define cb(_ctype, _inner_ctype, _neon_type, _fun_suffix, rel_suffix) \
template <> \
struct ParamElemVisitorBcast101x4<_ctype> { \
_neon_type operator()(const _ctype* src) const { \
return vreinterpretq_##_fun_suffix##_##rel_suffix(vld1q_dup_##rel_suffix( \
reinterpret_cast<const _inner_ctype*>(src))); \
} \
}

cb(dt_quint8, uint32_t, uint8x16_t, u8, u32);
cb(dt_int16, int64_t, int16x8_t, s16, s64);
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
cb(__fp16, uint64_t, float16x8_t, f16, u64);
#endif
#undef cb

template <typename ctype>
struct ParamElemVisitorBcast101x8;
#define cb(_ctype, _inner_ctype, _neon_type, _fun_suffix) \
template <> \
struct ParamElemVisitorBcast101x8<_ctype> { \
_neon_type operator()(const _ctype* src) const { \
return vld1q_##_fun_suffix(reinterpret_cast<const _inner_ctype*>(src)); \
} \
}
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
cb(__fp16, __fp16, float16x8_t, f16);
#endif
#undef cb

#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
template <>
struct OpCallerBinaryBcast101xXVec<__fp16, 8> {
using src_ctype = __fp16;

template <typename Op>
static void run(
const src_ctype* src0, const src_ctype* src1, typename Op::dst_ctype* dst,
const Op& op, size_t batch, size_t nr_channel_blocks,
size_t channel_stride) {
ParamElemVisitorBcast101x8<src_ctype> vis0;
ParamElemVisitor<src_ctype> vis1;
OpCallerBinaryBcast101xDVec<src_ctype, 8>::run(
src0, src1, dst, op, vis0, vis1, batch, nr_channel_blocks,
channel_stride);
}
};

template <>
struct OpCallerBinaryVecBcast101xX<__fp16, 8> {
using src_ctype = __fp16;
template <typename Op>
static void run(
const src_ctype* src0, const src_ctype* src1, typename Op::dst_ctype* dst,
const Op& op, size_t batch, size_t nr_channel_blocks,
size_t channel_stride) {
ParamElemVisitor<src_ctype> vis0;
ParamElemVisitorBcast101x8<src_ctype> vis1;
OpCallerBinaryVecBcast101xD<src_ctype, 8>::run(
src0, src1, dst, op, vis0, vis1, batch, nr_channel_blocks,
channel_stride);
}
};

template <>
struct OpCallerTernaryBcast101xXVecBcast101xX<__fp16, 8> {
using src_ctype = __fp16;
template <typename Op>
static void run(
const src_ctype* src0, const src_ctype* src1, const src_ctype* src2,
typename Op::dst_ctype* dst, const Op& op, size_t batch,
size_t nr_channel_blocks, size_t channel_stride) {
ParamElemVisitorBcast101x8<src_ctype> vis0;
ParamElemVisitor<src_ctype> vis1;
ParamElemVisitorBcast101x8<src_ctype> vis2;
OpCallerTernaryBcast101xDVecBcast101xD<src_ctype, 8>::run(
src0, src1, src2, dst, op, vis0, vis1, vis2, batch, nr_channel_blocks,
channel_stride);
}
};

template <>
struct OpCallerTernaryVecBcast101xXVec<__fp16, 8> {
using src_ctype = __fp16;
template <typename Op>
static void run(
const src_ctype* src0, const src_ctype* src1, const src_ctype* src2,
typename Op::dst_ctype* dst, const Op& op, size_t batch,
size_t nr_channel_blocks, size_t channel_stride) {
ParamElemVisitor<src_ctype> vis0;
ParamElemVisitorBcast101x8<src_ctype> vis1;
ParamElemVisitor<src_ctype> vis2;
OpCallerTernaryVecBcast101xDVec<src_ctype, 8>::run(
src0, src1, src2, dst, op, vis0, vis1, vis2, batch, nr_channel_blocks,
channel_stride);
}
};
#endif

} // namespace elemwise
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 0
- 36
dnn/src/arm_common/elemwise_helper/kimpl/pow.h View File

@@ -1,36 +0,0 @@
/**
* \file dnn/src/arm_common/elemwise_helper/kimpl/pow.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once

#include "src/arm_common/elemwise_helper/kimpl/op_base.h"

namespace megdnn {
namespace arm_common {

// when __fp16 is avaliable POW is very slow, so add there
/////////////////////// POW float only ////////////////////////////
template <typename src_ctype, typename dst_ctype = src_ctype>
struct PowOp : BinaryOpBase<src_ctype, dst_ctype> {
using BinaryOpBase<src_ctype, dst_ctype>::BinaryOpBase;
constexpr static size_t SIMD_WIDTH = 1;
void operator()(
const src_ctype& src0, const src_ctype& src1, dst_ctype* dst) const {
*dst = operator()(src0, src1);
}
dst_ctype operator()(const src_ctype& src0, const src_ctype& src1) const {
return powf(src0, src1);
}
};

} // namespace arm_common
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 0
- 1
dnn/src/arm_common/elemwise_helper/op_binary.h View File

@@ -18,7 +18,6 @@
#include "src/arm_common/elemwise_helper/kimpl/max.h" #include "src/arm_common/elemwise_helper/kimpl/max.h"
#include "src/arm_common/elemwise_helper/kimpl/min.h" #include "src/arm_common/elemwise_helper/kimpl/min.h"
#include "src/arm_common/elemwise_helper/kimpl/mul.h" #include "src/arm_common/elemwise_helper/kimpl/mul.h"
#include "src/arm_common/elemwise_helper/kimpl/pow.h"
#include "src/arm_common/elemwise_helper/kimpl/rmulh.h" #include "src/arm_common/elemwise_helper/kimpl/rmulh.h"
#include "src/arm_common/elemwise_helper/kimpl/sub.h" #include "src/arm_common/elemwise_helper/kimpl/sub.h"
#include "src/arm_common/elemwise_helper/kimpl/true_div.h" #include "src/arm_common/elemwise_helper/kimpl/true_div.h"


+ 3
- 1
dnn/src/arm_common/elemwise_multi_type/opr_impl.cpp View File

@@ -15,7 +15,7 @@
#include "src/common/elemwise_multi_type/kern_defs.cuh" #include "src/common/elemwise_multi_type/kern_defs.cuh"
#include "src/naive/handle.h" #include "src/naive/handle.h"


#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/arm_common/simd_macro/marm_neon.h" #include "src/arm_common/simd_macro/marm_neon.h"


namespace { namespace {
@@ -46,6 +46,8 @@ void neon_round_shr_saturate_int16_static_k(


} // namespace } // namespace


using namespace elemwise;

namespace megdnn { namespace megdnn {
namespace arm_common { namespace arm_common {




+ 0
- 1537
dnn/src/arm_common/elemwise_op.h
File diff suppressed because it is too large
View File


+ 2
- 1
dnn/src/fallback/elemwise/gi_impl/binary/algo.cpp View File

@@ -2,7 +2,7 @@
* \file dnn/src/fallback/elemwise/gi_impl/binary/algo.cpp * \file dnn/src/fallback/elemwise/gi_impl/binary/algo.cpp
*/ */
#include "src/fallback/elemwise/gi_impl/binary/algo.h" #include "src/fallback/elemwise/gi_impl/binary/algo.h"
#include "src/fallback/elemwise_op.h"
#include "src/fallback/elemwise_helper/elemwise_op.h"


#include "src/common/utils.h" #include "src/common/utils.h"
#include "src/naive/handle.h" #include "src/naive/handle.h"
@@ -12,6 +12,7 @@
MIDOUT_DECL(megdnn_fallback_elemwise_binary) MIDOUT_DECL(megdnn_fallback_elemwise_binary)


using namespace megdnn; using namespace megdnn;
using namespace elemwise;
using namespace fallback; using namespace fallback;


namespace { namespace {


+ 2
- 1
dnn/src/fallback/elemwise/gi_impl/ternary/algo.cpp View File

@@ -3,7 +3,7 @@
*/ */


#include "src/fallback/elemwise/gi_impl/ternary/algo.h" #include "src/fallback/elemwise/gi_impl/ternary/algo.h"
#include "src/fallback/elemwise_op.h"
#include "src/fallback/elemwise_helper/elemwise_op.h"


#include "src/common/utils.h" #include "src/common/utils.h"
#include "src/naive/handle.h" #include "src/naive/handle.h"
@@ -13,6 +13,7 @@
MIDOUT_DECL(megdnn_fallback_elemwise_ternary) MIDOUT_DECL(megdnn_fallback_elemwise_ternary)


using namespace megdnn; using namespace megdnn;
using namespace elemwise;
using namespace fallback; using namespace fallback;


#define DISPATCH_MODE_FLOAT(_case, _type, _type_midout_id) \ #define DISPATCH_MODE_FLOAT(_case, _type, _type_midout_id) \


+ 2
- 1
dnn/src/fallback/elemwise/gi_impl/unary/algo.cpp View File

@@ -2,7 +2,7 @@
* \file dnn/src/fallback/elemwise/gi_impl/unary/algo.cpp * \file dnn/src/fallback/elemwise/gi_impl/unary/algo.cpp
*/ */
#include "src/fallback/elemwise/gi_impl/unary/algo.h" #include "src/fallback/elemwise/gi_impl/unary/algo.h"
#include "src/fallback/elemwise_op.h"
#include "src/fallback/elemwise_helper/elemwise_op.h"


#include "src/common/utils.h" #include "src/common/utils.h"
#include "src/naive/handle.h" #include "src/naive/handle.h"
@@ -12,6 +12,7 @@
MIDOUT_DECL(megdnn_fallback_elemwise_unary) MIDOUT_DECL(megdnn_fallback_elemwise_unary)


using namespace megdnn; using namespace megdnn;
using namespace elemwise;
using namespace fallback; using namespace fallback;


bool ElemwiseImpl::AlgoUnary::is_available(const KernParam& kern_param) const { bool ElemwiseImpl::AlgoUnary::is_available(const KernParam& kern_param) const {


+ 1
- 0
dnn/src/fallback/elemwise/opr_impl.cpp View File

@@ -25,6 +25,7 @@ MIDOUT_DECL(megdnn_fallback_elemwise_exec_BINARY_INT)
MIDOUT_DECL(megdnn_fallback_elemwise_exec_BINARY_FLOAT) MIDOUT_DECL(megdnn_fallback_elemwise_exec_BINARY_FLOAT)


using namespace megdnn; using namespace megdnn;
using namespace elemwise;
using namespace fallback; using namespace fallback;


void ElemwiseImpl::exec(const TensorNDArray& srcs, _megdnn_tensor_out dst) { void ElemwiseImpl::exec(const TensorNDArray& srcs, _megdnn_tensor_out dst) {


+ 2
- 2
dnn/src/fallback/elemwise/opr_impl.h View File

@@ -9,7 +9,7 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/ */
#pragma once #pragma once
#include "src/fallback/elemwise_op.h"
#include "src/fallback/elemwise_helper/elemwise_op.h"
#include "src/naive/elemwise/opr_impl.h" #include "src/naive/elemwise/opr_impl.h"


namespace megdnn { namespace megdnn {
@@ -60,7 +60,7 @@ private:
public: public:
class AlgoBase; class AlgoBase;
struct KernParam { struct KernParam {
BcastType broad_cast_type;
elemwise::BcastType broad_cast_type;
Mode mode; Mode mode;
const TensorND* m_dst; const TensorND* m_dst;
Handle* handle; Handle* handle;


+ 72
- 0
dnn/src/fallback/elemwise_helper/elemwise_op.h View File

@@ -0,0 +1,72 @@
/**
* \file dnn/src/fallback/elemwise_helper/elemwise_op.h
*/

#pragma once

#include "src/fallback/elemwise_helper/op_binary.h"
#include "src/fallback/elemwise_helper/op_common.h"
#include "src/fallback/elemwise_helper/op_ternary.h"
#include "src/fallback/elemwise_helper/op_unary.h"

#include "src/fallback/general_intrinsic/gi_float.h"
#include "src/fallback/general_intrinsic/gi_int.h"

namespace megdnn {
namespace elemwise {

///////////////////////////////// ParamElemVistor ///////////////////////////

#define cb(_ctype, _inner_ctype, _simd_type, _fun_suffix) \
template <> \
struct ParamElemVisitor<_ctype> { \
_simd_type operator()(const _ctype* src) const { \
return GiLoad##_fun_suffix(src); \
} \
}; \
template <> \
struct ParamElemVisitorDup<_ctype> { \
_simd_type operator()(const _ctype* src) const { \
return GiBroadcast##_fun_suffix( \
*reinterpret_cast<const _inner_ctype*>(src)); \
} \
}
cb(dt_qint32, int32_t, GI_INT32_t, Int32);
cb(dt_qint8, int8_t, GI_INT8_t, Int8);

cb(dt_float32, float, GI_FLOAT32_t, Float32);
cb(dt_int32, int32_t, GI_INT32_t, Int32);
cb(dt_int8, int8_t, GI_INT8_t, Int8);
#undef cb

template <typename ctype>
struct ParamElemVisitorBcast101x4;
#define cb(_ctype, _inner_ctype, _simd_type, _fun_suffix, rel_suffix) \
template <> \
struct ParamElemVisitorBcast101x4<_ctype> { \
_simd_type operator()(const _ctype* src) const { \
return GiReinter##rel_suffix##To##_fun_suffix(GiBroadcast##rel_suffix( \
*reinterpret_cast<const _inner_ctype*>(src))); \
} \
}

cb(dt_qint8, int32_t, GI_INT8_t, Int8, Int32);
cb(dt_int8, int32_t, GI_INT8_t, Int8, Int32);
#undef cb
#define cb(_ctype, _inner_ctype, _simd_type, _fun_suffix) \
template <> \
struct ParamElemVisitorBcast101x4<_ctype> { \
_simd_type operator()(const _ctype* src) const { \
return GiLoad##_fun_suffix(src); \
} \
}

cb(dt_qint32, int32_t, GI_INT32_t, Int32);
cb(dt_float32, float, GI_FLOAT32_t, Float32);
cb(dt_int32, int32_t, GI_INT32_t, Int32);
#undef cb

} // namespace elemwise
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 1
- 1
dnn/src/fallback/elemwise_helper/kimpl/abs.h View File

@@ -58,7 +58,7 @@ struct AbsOpBase<dt_qint8, dt_qint8> : UnaryOpBase<dt_qint8, dt_qint8> {
template <> template <>
struct AbsOp<dt_qint8, dt_qint8> : AbsOpBase<dt_qint8, dt_qint8> { struct AbsOp<dt_qint8, dt_qint8> : AbsOpBase<dt_qint8, dt_qint8> {
using AbsOpBase::AbsOpBase; using AbsOpBase::AbsOpBase;
constexpr static size_t SIMD_WIDTH = 16;
constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int8_t);
using AbsOpBase::operator(); using AbsOpBase::operator();
void operator()(const GI_INT8_V2_t& vsrc, dt_qint8* dst) const { void operator()(const GI_INT8_V2_t& vsrc, dt_qint8* dst) const {
OPERATOR_UNARY_QINT8_FALLBACK; OPERATOR_UNARY_QINT8_FALLBACK;


+ 1
- 1
dnn/src/fallback/elemwise_helper/kimpl/fuse_add_h_swish.h View File

@@ -87,7 +87,7 @@ template <>
struct FuseAddHSwishOp<dt_qint32, dt_qint8> : FuseAddHSwishOpBase<dt_qint32, dt_qint8> { struct FuseAddHSwishOp<dt_qint32, dt_qint8> : FuseAddHSwishOpBase<dt_qint32, dt_qint8> {
using FuseAddHSwishOpBase::FuseAddHSwishOpBase; using FuseAddHSwishOpBase::FuseAddHSwishOpBase;
using FuseAddHSwishOpBase::operator(); using FuseAddHSwishOpBase::operator();
constexpr static size_t SIMD_WIDTH = 4;
constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int8_t);
void operator()( void operator()(
const GI_INT32_V2_t& vsrc0, const GI_INT32_V2_t& vsrc1, const GI_INT32_V2_t& vsrc0, const GI_INT32_V2_t& vsrc1,
dt_qint8* dst) const { dt_qint8* dst) const {


+ 1
- 1
dnn/src/fallback/elemwise_helper/kimpl/hswish.h View File

@@ -83,7 +83,7 @@ template <>
struct HSwishOp<dt_qint32, dt_qint8> : HSwishOpBase<dt_qint32, dt_qint8> { struct HSwishOp<dt_qint32, dt_qint8> : HSwishOpBase<dt_qint32, dt_qint8> {
using HSwishOpBase::HSwishOpBase; using HSwishOpBase::HSwishOpBase;
using HSwishOpBase::operator(); using HSwishOpBase::operator();
constexpr static size_t SIMD_WIDTH = 4;
constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int32_t);


void operator()(const GI_INT32_V2_t& vsrc, dt_qint8* dst) const { void operator()(const GI_INT32_V2_t& vsrc, dt_qint8* dst) const {
GiStoreLowInt8(reinterpret_cast<int8_t*>(dst), operator()(vsrc)); GiStoreLowInt8(reinterpret_cast<int8_t*>(dst), operator()(vsrc));


+ 1
- 1
dnn/src/fallback/elemwise_helper/kimpl/max.h View File

@@ -77,7 +77,7 @@ struct MaxOpBase<dt_qint8, dt_qint8> : BinaryOpBase<dt_qint8, dt_qint8> {
template <> template <>
struct MaxOp<dt_qint8, dt_qint8> : MaxOpBase<dt_qint8, dt_qint8> { struct MaxOp<dt_qint8, dt_qint8> : MaxOpBase<dt_qint8, dt_qint8> {
using MaxOpBase::MaxOpBase; using MaxOpBase::MaxOpBase;
constexpr static size_t SIMD_WIDTH = 16;
constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int8_t);
using MaxOpBase::operator(); using MaxOpBase::operator();


void operator()( void operator()(


+ 1
- 1
dnn/src/fallback/elemwise_helper/kimpl/min.h View File

@@ -74,7 +74,7 @@ struct MinOpBase<dt_qint8, dt_qint8> : BinaryOpBase<dt_qint8, dt_qint8> {
template <> template <>
struct MinOp<dt_qint8, dt_qint8> : MinOpBase<dt_qint8, dt_qint8> { struct MinOp<dt_qint8, dt_qint8> : MinOpBase<dt_qint8, dt_qint8> {
using MinOpBase::MinOpBase; using MinOpBase::MinOpBase;
constexpr static size_t SIMD_WIDTH = 16;
constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int8_t);
using MinOpBase::operator(); using MinOpBase::operator();


void operator()( void operator()(


+ 1
- 1
dnn/src/fallback/elemwise_helper/kimpl/mul.h View File

@@ -73,7 +73,7 @@ struct MulOpBase<dt_qint8, dt_qint8> : BinaryOpBase<dt_qint8, dt_qint8> {
template <> template <>
struct MulOp<dt_qint8, dt_qint8> : MulOpBase<dt_qint8, dt_qint8> { struct MulOp<dt_qint8, dt_qint8> : MulOpBase<dt_qint8, dt_qint8> {
using MulOpBase::MulOpBase; using MulOpBase::MulOpBase;
constexpr static size_t SIMD_WIDTH = 16;
constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int8_t);
using MulOpBase::operator(); using MulOpBase::operator();


void operator()( void operator()(


+ 3
- 5
dnn/src/fallback/elemwise_helper/kimpl/none.h View File

@@ -54,8 +54,6 @@ struct NoneOpBase<dt_qint32, dt_qint8> : UnaryOpBase<dt_qint32, dt_qint8> {
} }
}; };


#pragma GCC diagnostic ignored "-Waddress-of-packed-member"

template <> template <>
struct NoneOp<dt_qint32, dt_qint8> : NoneOpBase<dt_qint32, dt_qint8> { struct NoneOp<dt_qint32, dt_qint8> : NoneOpBase<dt_qint32, dt_qint8> {
using NoneOpBase::NoneOpBase; using NoneOpBase::NoneOpBase;
@@ -63,11 +61,11 @@ struct NoneOp<dt_qint32, dt_qint8> : NoneOpBase<dt_qint32, dt_qint8> {
constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int32_t); constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int32_t);


void operator()(const GI_INT32_V2_t& vsrc, dt_qint8* dst) const { void operator()(const GI_INT32_V2_t& vsrc, dt_qint8* dst) const {
GiStoreInt32(reinterpret_cast<int32_t*>(dst), vsrc.val[0]);
GiStoreInt32(reinterpret_cast<int32_t*>(dst + 16), vsrc.val[1]);
GiStoreInt32(dst, vsrc.val[0]);
GiStoreInt32(dst + 16, vsrc.val[1]);
} }
void operator()(const GI_INT32_t& src, dt_qint8* dst) const { void operator()(const GI_INT32_t& src, dt_qint8* dst) const {
GiStoreInt32(reinterpret_cast<int32_t*>(dst), src);
GiStoreInt32(dst, src);
} }
}; };




+ 11
- 9
dnn/src/fallback/elemwise_helper/kimpl/relu.h View File

@@ -112,36 +112,38 @@ struct ReluOp<dt_qint32, dt_qint8> : ReluOpBase<dt_qint32, dt_qint8>, FixupBase
: ReluOpBase(src_scale, dst_scale), FixupBase(scale) {} : ReluOpBase(src_scale, dst_scale), FixupBase(scale) {}


void operator()(const int32x4x2_t& vsrc, dt_qint8* dst) const { void operator()(const int32x4x2_t& vsrc, dt_qint8* dst) const {
vst1_s8(reinterpret_cast<int8_t*>(dst), operator()(vsrc));
vst1_s8(reinterpret_cast<int8_t*>(dst), vget_low_s8(operator()(vsrc)));
} }


int8x8_t operator()(const int32x4x2_t& vsrc) const {
int8x16_t operator()(const int32x4x2_t& vsrc) const {
int32x4_t vitem0 = vqrdmulhq_s32(vsrc.val[0], vmultiplier); int32x4_t vitem0 = vqrdmulhq_s32(vsrc.val[0], vmultiplier);
int32x4_t vitem1 = vqrdmulhq_s32(vsrc.val[1], vmultiplier); int32x4_t vitem1 = vqrdmulhq_s32(vsrc.val[1], vmultiplier);
vitem0 = vmaxq_s32(vitem0, QConverterBase::vzero()); vitem0 = vmaxq_s32(vitem0, QConverterBase::vzero());
vitem1 = vmaxq_s32(vitem1, QConverterBase::vzero()); vitem1 = vmaxq_s32(vitem1, QConverterBase::vzero());
return vqmovn_s16(vcombine_s16(
auto tmp = vqmovn_s16(vcombine_s16(
vqmovn_s32(vrshlq_s32(vitem0, vshift)), vqmovn_s32(vrshlq_s32(vitem0, vshift)),
vqmovn_s32(vrshlq_s32(vitem1, vshift)))); vqmovn_s32(vrshlq_s32(vitem1, vshift))));
return vcombine_s8(tmp, tmp);
} }
int8x8_t operator()(const float32x4_t& vsrc) const {
int8x16_t operator()(const float32x4_t& vsrc) const {
int32x4_t vitem0 = vqrdmulhq_s32(vcvtq_s32_f32(vsrc), vmultiplier); int32x4_t vitem0 = vqrdmulhq_s32(vcvtq_s32_f32(vsrc), vmultiplier);
vitem0 = vmaxq_s32(vitem0, QConverterBase::vzero()); vitem0 = vmaxq_s32(vitem0, QConverterBase::vzero());
vitem0 = vrshlq_s32(vitem0, vshift); vitem0 = vrshlq_s32(vitem0, vshift);
int16x4_t vitem = vqmovn_s32(vitem0); int16x4_t vitem = vqmovn_s32(vitem0);
return vqmovn_s16(vcombine_s16(vitem, vitem));
auto tmp = vqmovn_s16(vcombine_s16(vitem, vitem));
return vcombine_s8(tmp, tmp);
} }
void operator()(const int32x4_t& src, dt_qint8* dst) const { void operator()(const int32x4_t& src, dt_qint8* dst) const {
auto vitem0 = vmulq_f32(vcvtq_f32_s32(src), this->vscale); auto vitem0 = vmulq_f32(vcvtq_f32_s32(src), this->vscale);
vitem0 = vmaxq_f32(vitem0, QConverterBase::vfzero()); vitem0 = vmaxq_f32(vitem0, QConverterBase::vfzero());
auto result = QConverter::convert<int8x8_t, float32x4_t>(vitem0);
vst1_lane_s32(reinterpret_cast<int32_t*>(dst), (int32x2_t)result, 0);
auto result = QConverter::convert<int8x16_t, float32x4_t>(vitem0);
vst1q_lane_s32(reinterpret_cast<int32_t*>(dst), (int32x4_t)result, 0);
} }
void operator()(const float32x4_t& src, dt_qint8* dst) const { void operator()(const float32x4_t& src, dt_qint8* dst) const {
auto vitem0 = vmulq_f32(src, this->vscale); auto vitem0 = vmulq_f32(src, this->vscale);
vitem0 = vmaxq_f32(vitem0, QConverterBase::vfzero()); vitem0 = vmaxq_f32(vitem0, QConverterBase::vfzero());
auto result = QConverter::convert<int8x8_t, float32x4_t>(vitem0);
vst1_lane_s32(reinterpret_cast<int32_t*>(dst), (int32x2_t)result, 0);
auto result = QConverter::convert<int8x16_t, float32x4_t>(vitem0);
vst1q_lane_s32(reinterpret_cast<int32_t*>(dst), (int32x4_t)result, 0);
} }
}; };




+ 1
- 1
dnn/src/fallback/elemwise_helper/kimpl/sub.h View File

@@ -73,7 +73,7 @@ struct SubOpBase<dt_qint8, dt_qint8> : BinaryOpBase<dt_qint8, dt_qint8> {
template <> template <>
struct SubOp<dt_qint8, dt_qint8> : SubOpBase<dt_qint8, dt_qint8> { struct SubOp<dt_qint8, dt_qint8> : SubOpBase<dt_qint8, dt_qint8> {
using SubOpBase::SubOpBase; using SubOpBase::SubOpBase;
constexpr static size_t SIMD_WIDTH = 16;
constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int8_t);
using SubOpBase::operator(); using SubOpBase::operator();


void operator()( void operator()(


+ 1370
- 0
dnn/src/fallback/elemwise_helper/op_common.h
File diff suppressed because it is too large
View File


+ 0
- 1432
dnn/src/fallback/elemwise_op.h
File diff suppressed because it is too large
View File


+ 13
- 0
dnn/src/fallback/general_intrinsic/gi_common.h View File

@@ -13,6 +13,7 @@


#include "math.h" #include "math.h"
#include "stdint.h" #include "stdint.h"
#include "string.h"


#if defined(_WIN32) #if defined(_WIN32)
#include <intrin.h> #include <intrin.h>
@@ -132,6 +133,18 @@ typedef uint32_t GI_UINT32_t __attribute__((vector_size(16)));
#define Max(a, b) (a) > (b) ? (a) : (b) #define Max(a, b) (a) > (b) ? (a) : (b)
#define Min(a, b) (a) < (b) ? (a) : (b) #define Min(a, b) (a) < (b) ? (a) : (b)


#if defined(GI_NEON_INTRINSICS)
#if defined(__ARM_FEATURE_FMA) && defined(GI_NEON64_INTRINSICS)
#define v_fma_ps_f32(c, b, a) vfmaq_f32((c), (b), (a))
#define v_fma_n_f32(c, b, a) vfmaq_n_f32((c), (b), (a))
#define v_fma_lane_f32(c, b, a, lane) vfmaq_lane_f32((c), (b), (a), (lane))
#else
#define v_fma_ps_f32(c, b, a) vmlaq_f32((c), (b), (a))
#define v_fma_n_f32(c, b, a) vmlaq_n_f32((c), (b), (a))
#define v_fma_lane_f32(c, b, a, lane) vmlaq_lane_f32((c), (b), (a), (lane))
#endif
#endif

typedef struct { typedef struct {
GI_INT32_t val[2]; GI_INT32_t val[2];
} GI_INT32_V2_t; } GI_INT32_V2_t;


+ 17
- 35
dnn/src/fallback/general_intrinsic/gi_float.h View File

@@ -20,7 +20,9 @@ GI_INT32_t GiReinterpretAsInt32(GI_FLOAT32_t In) {
#elif defined(GI_SSE2_INTRINSICS) #elif defined(GI_SSE2_INTRINSICS)
return _mm_castps_si128(In); return _mm_castps_si128(In);
#else #else
return *(GI_INT32_t*)(&In);
GI_INT32_t ret;
memcpy(&ret, &In, GI_SIMD_LEN_BYTE);
return ret;
#endif #endif
} }


@@ -31,7 +33,9 @@ GI_UINT32_t GiReinterpretAsUint32(GI_FLOAT32_t In) {
#elif defined(GI_SSE2_INTRINSICS) #elif defined(GI_SSE2_INTRINSICS)
return _mm_castps_si128(In); return _mm_castps_si128(In);
#else #else
return *(GI_UINT32_t*)(&In);
GI_UINT32_t ret;
memcpy(&ret, &In, GI_SIMD_LEN_BYTE);
return ret;
#endif #endif
} }


@@ -42,7 +46,9 @@ GI_FLOAT32_t GiReintInt32ToFloat32(GI_INT32_t Vector) {
#elif defined(GI_SSE2_INTRINSICS) #elif defined(GI_SSE2_INTRINSICS)
return _mm_castsi128_ps(Vector); return _mm_castsi128_ps(Vector);
#else #else
return *(GI_FLOAT32_t*)(&Vector);
GI_FLOAT32_t ret;
memcpy(&ret, &Vector, GI_SIMD_LEN_BYTE);
return ret;
#endif #endif
} }


@@ -53,7 +59,9 @@ GI_FLOAT32_t GiReintUint32ToFloat32(GI_UINT32_t Vector) {
#elif defined(GI_SSE2_INTRINSICS) #elif defined(GI_SSE2_INTRINSICS)
return _mm_castsi128_ps(Vector); return _mm_castsi128_ps(Vector);
#else #else
return *(GI_FLOAT32_t*)(&Vector);
GI_FLOAT32_t ret;
memcpy(&ret, &Vector, GI_SIMD_LEN_BYTE);
return ret;
#endif #endif
} }


@@ -69,7 +77,7 @@ GI_INT32_t GiRoundAsInt32(GI_FLOAT32_t Vector) {
float32x4_t vinc0 = vbslq_f32(vcgeq_f32(Vector, vzero), vfhalf, vfneg_half); float32x4_t vinc0 = vbslq_f32(vcgeq_f32(Vector, vzero), vfhalf, vfneg_half);
return vcvtq_s32_f32(vaddq_f32(Vector, vinc0)); return vcvtq_s32_f32(vaddq_f32(Vector, vinc0));
#endif #endif
#elif defined(GI_SSE2_INTRINSICS)
#elif defined(GI_SSE42_INTRINSICS)
__m128 vfzero = _mm_set1_ps(0.f); __m128 vfzero = _mm_set1_ps(0.f);
__m128 vfhalf = _mm_set1_ps(0.5f); __m128 vfhalf = _mm_set1_ps(0.5f);
__m128 vfneg_half = _mm_set1_ps(-0.5f); __m128 vfneg_half = _mm_set1_ps(-0.5f);
@@ -322,11 +330,7 @@ GI_FORCEINLINE
GI_FLOAT32_t GiMultiplyAddFloat32( GI_FLOAT32_t GiMultiplyAddFloat32(
GI_FLOAT32_t VectorSum, GI_FLOAT32_t Vector1, GI_FLOAT32_t Vector2) { GI_FLOAT32_t VectorSum, GI_FLOAT32_t Vector1, GI_FLOAT32_t Vector2) {
#if defined(GI_NEON_INTRINSICS) #if defined(GI_NEON_INTRINSICS)
#if defined(__ARM_FEATURE_FMA)
return vfmaq_f32(VectorSum, Vector1, Vector2);
#else
return vmlaq_f32(VectorSum, Vector1, Vector2);
#endif
return v_fma_ps_f32(VectorSum, Vector1, Vector2);
#elif defined(GI_FMA3_INTRINSICS) #elif defined(GI_FMA3_INTRINSICS)
return _mm_fmadd_ps(Vector1, Vector2, VectorSum); return _mm_fmadd_ps(Vector1, Vector2, VectorSum);
#elif defined(GI_SSE2_INTRINSICS) #elif defined(GI_SSE2_INTRINSICS)
@@ -352,11 +356,7 @@ GI_FORCEINLINE
GI_FLOAT32_t GiMultiplyAddScalarFloat32( GI_FLOAT32_t GiMultiplyAddScalarFloat32(
GI_FLOAT32_t VectorSum, GI_FLOAT32_t Vector, float Scalar) { GI_FLOAT32_t VectorSum, GI_FLOAT32_t Vector, float Scalar) {
#if defined(GI_NEON_INTRINSICS) #if defined(GI_NEON_INTRINSICS)
#if defined(__ARM_FEATURE_FMA)
return vfmaq_n_f32(VectorSum, Vector, Scalar);
#else
return vfmla_n_f32(VectorSum, Vector, Scalar);
#endif
return v_fma_n_f32(VectorSum, Vector, Scalar);
#elif defined(GI_SSE2_INTRINSICS) #elif defined(GI_SSE2_INTRINSICS)
return GiMultiplyAddFloat32(VectorSum, GiBroadcastFloat32(Scalar), Vector); return GiMultiplyAddFloat32(VectorSum, GiBroadcastFloat32(Scalar), Vector);
#else #else
@@ -365,27 +365,10 @@ GI_FLOAT32_t GiMultiplyAddScalarFloat32(
} }


#if defined(GI_NEON_INTRINSICS) #if defined(GI_NEON_INTRINSICS)
#if defined(__ARM_FEATURE_FMA)
#define GIMULTIPLYADDLANFLOAT32(i) \
GI_FORCEINLINE GI_FLOAT32_t GiMultiplyAddLan##i##Float32( \
GI_FLOAT32_t VectorSum, GI_FLOAT32_t Vector1, GI_FLOAT32_t Vector2) { \
return vfmaq_lane_f32(VectorSum, Vector1, vget_low_f32(Vector2), i); \
}
GIMULTIPLYADDLANFLOAT32(0)
GIMULTIPLYADDLANFLOAT32(1)
#undef GIMULTIPLYADDLANFLOAT32
#define GIMULTIPLYADDLANFLOAT32(i) \
GI_FORCEINLINE GI_FLOAT32_t GiMultiplyAddLan##i##Float32( \
GI_FLOAT32_t VectorSum, GI_FLOAT32_t Vector1, GI_FLOAT32_t Vector2) { \
return vfmaq_lane_f32(VectorSum, Vector1, vget_high_f32(Vector2), i - 2); \
}
GIMULTIPLYADDLANFLOAT32(2)
GIMULTIPLYADDLANFLOAT32(3)
#else
#define GIMULTIPLYADDLANFLOAT32(i) \ #define GIMULTIPLYADDLANFLOAT32(i) \
GI_FORCEINLINE GI_FLOAT32_t GiMultiplyAddLan##i##Float32( \ GI_FORCEINLINE GI_FLOAT32_t GiMultiplyAddLan##i##Float32( \
GI_FLOAT32_t VectorSum, GI_FLOAT32_t Vector1, GI_FLOAT32_t Vector2) { \ GI_FLOAT32_t VectorSum, GI_FLOAT32_t Vector1, GI_FLOAT32_t Vector2) { \
return vmlaq_lane_f32(VectorSum, Vector1, vget_low_f32(Vector2), i); \
return v_fma_lane_f32(VectorSum, Vector1, vget_low_f32(Vector2), i); \
} }
GIMULTIPLYADDLANFLOAT32(0) GIMULTIPLYADDLANFLOAT32(0)
GIMULTIPLYADDLANFLOAT32(1) GIMULTIPLYADDLANFLOAT32(1)
@@ -393,11 +376,10 @@ GIMULTIPLYADDLANFLOAT32(1)
#define GIMULTIPLYADDLANFLOAT32(i) \ #define GIMULTIPLYADDLANFLOAT32(i) \
GI_FORCEINLINE GI_FLOAT32_t GiMultiplyAddLan##i##Float32( \ GI_FORCEINLINE GI_FLOAT32_t GiMultiplyAddLan##i##Float32( \
GI_FLOAT32_t VectorSum, GI_FLOAT32_t Vector1, GI_FLOAT32_t Vector2) { \ GI_FLOAT32_t VectorSum, GI_FLOAT32_t Vector1, GI_FLOAT32_t Vector2) { \
return vmlaq_lane_f32(VectorSum, Vector1, vget_high_f32(Vector2), i - 2); \
return v_fma_lane_f32(VectorSum, Vector1, vget_high_f32(Vector2), i - 2); \
} }
GIMULTIPLYADDLANFLOAT32(2) GIMULTIPLYADDLANFLOAT32(2)
GIMULTIPLYADDLANFLOAT32(3) GIMULTIPLYADDLANFLOAT32(3)
#endif
#undef GIMULTIPLYADDLANFLOAT32 #undef GIMULTIPLYADDLANFLOAT32
#elif defined(GI_SSE2_INTRINSICS) #elif defined(GI_SSE2_INTRINSICS)




+ 35
- 28
dnn/src/fallback/general_intrinsic/gi_int.h View File

@@ -59,66 +59,69 @@ GI_INT8_t GiBroadcastInt8(int8_t Value) {
} }


GI_FORCEINLINE GI_FORCEINLINE
GI_INT32_t GiLoadInt32(const int32_t* Buffer) {
GI_INT32_t GiLoadInt32(const void* Buffer) {
#if defined(GI_NEON_INTRINSICS) #if defined(GI_NEON_INTRINSICS)
return vld1q_s32(Buffer);
return vld1q_s32((int32_t*)Buffer);
#elif defined(GI_SSE2_INTRINSICS) #elif defined(GI_SSE2_INTRINSICS)
return _mm_loadu_si128((const __m128i*)Buffer); return _mm_loadu_si128((const __m128i*)Buffer);
#else #else
GI_INT32_t ret; GI_INT32_t ret;
const int32_t* ptr = (int32_t*)Buffer;
for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(int32_t); i++) { for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(int32_t); i++) {
ret[i] = Buffer[i];
ret[i] = ptr[i];
} }
return ret; return ret;
#endif #endif
} }


GI_FORCEINLINE GI_FORCEINLINE
GI_INT8_t GiLoadInt8(const int8_t* Buffer) {
GI_INT8_t GiLoadInt8(const void* Buffer) {
#if defined(GI_NEON_INTRINSICS) #if defined(GI_NEON_INTRINSICS)
return vld1q_s8(Buffer);
return vld1q_s8((int8_t*)Buffer);
#elif defined(GI_SSE2_INTRINSICS) #elif defined(GI_SSE2_INTRINSICS)
return _mm_loadu_si128((const __m128i*)Buffer); return _mm_loadu_si128((const __m128i*)Buffer);
#else #else
GI_INT8_t ret; GI_INT8_t ret;
const int8_t* ptr = (int8_t*)Buffer;
for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(int8_t); i++) { for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(int8_t); i++) {
ret[i] = Buffer[i];
ret[i] = ptr[i];
} }
return ret; return ret;
#endif #endif
} }


GI_FORCEINLINE GI_FORCEINLINE
void GiStoreInt32(int32_t* Buffer, GI_INT32_t Vector) {
void GiStoreInt32(void* Buffer, GI_INT32_t Vector) {
#if defined(GI_NEON_INTRINSICS) #if defined(GI_NEON_INTRINSICS)
vst1q_s32(Buffer, Vector);
vst1q_s32((int32_t*)Buffer, Vector);
#elif defined(GI_SSE2_INTRINSICS) #elif defined(GI_SSE2_INTRINSICS)
_mm_storeu_si128((__m128i*)Buffer, Vector); _mm_storeu_si128((__m128i*)Buffer, Vector);
#else #else
int32_t* ptr = (int32_t*)Buffer;
for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(int32_t); i++) { for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(int32_t); i++) {
Buffer[i] = Vector[i];
ptr[i] = Vector[i];
} }
#endif #endif
} }


#if defined(GI_NEON_INTRINSICS) #if defined(GI_NEON_INTRINSICS)
#define GISTORELANEINT32(i) \
GI_FORCEINLINE void GiStoreLane##i##Int32(int32_t* Buffer, GI_INT32_t Vector) { \
vst1q_lane_s32(Buffer, Vector, i); \
#define GISTORELANEINT32(i) \
GI_FORCEINLINE void GiStoreLane##i##Int32(void* Buffer, GI_INT32_t Vector) { \
vst1q_lane_s32((int32_t*)Buffer, Vector, i); \
} }


#elif defined(GI_SSE2_INTRINSICS) #elif defined(GI_SSE2_INTRINSICS)


#define GISTORELANEINT32(i) \ #define GISTORELANEINT32(i) \
GI_FORCEINLINE void GiStoreLane##i##Int32(int32_t* Buffer, GI_INT32_t Vector) { \
GI_FORCEINLINE void GiStoreLane##i##Int32(void* Buffer, GI_INT32_t Vector) { \
GI_FLOAT32_t tmp = _mm_castsi128_ps(Vector); \ GI_FLOAT32_t tmp = _mm_castsi128_ps(Vector); \
_mm_store_ss( \ _mm_store_ss( \
(float*)Buffer, _mm_shuffle_ps(tmp, tmp, _MM_SHUFFLE(i, i, i, i))); \ (float*)Buffer, _mm_shuffle_ps(tmp, tmp, _MM_SHUFFLE(i, i, i, i))); \
} }
#else #else
#define GISTORELANEINT32(i) \
GI_FORCEINLINE void GiStoreLane##i##Int32(int32_t* Buffer, GI_INT32_t Vector) { \
*Buffer = Vector[i]; \
#define GISTORELANEINT32(i) \
GI_FORCEINLINE void GiStoreLane##i##Int32(void* Buffer, GI_INT32_t Vector) { \
*((int32_t*)Buffer) = Vector[i]; \
} }
#endif #endif


@@ -141,53 +144,57 @@ GI_INT8_t GiReinterInt32ToInt8(GI_INT32_t Vector) {
} }


GI_FORCEINLINE GI_FORCEINLINE
void GiStoreInt16(int16_t* Buffer, GI_INT16_t Vector) {
void GiStoreInt16(void* Buffer, GI_INT16_t Vector) {
#if defined(GI_NEON_INTRINSICS) #if defined(GI_NEON_INTRINSICS)
vst1q_s16(Buffer, Vector);
vst1q_s16((int16_t*)Buffer, Vector);
#elif defined(GI_SSE2_INTRINSICS) #elif defined(GI_SSE2_INTRINSICS)
_mm_storeu_si128((__m128i*)Buffer, Vector); _mm_storeu_si128((__m128i*)Buffer, Vector);
#else #else
int16_t* ptr = (int16_t*)Buffer;
for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(int16_t); i++) { for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(int16_t); i++) {
Buffer[i] = Vector[i];
ptr[i] = Vector[i];
} }
#endif #endif
} }


GI_FORCEINLINE GI_FORCEINLINE
void GiStoreInt8(int8_t* Buffer, GI_INT8_t Vector) {
void GiStoreInt8(void* Buffer, GI_INT8_t Vector) {
#if defined(GI_NEON_INTRINSICS) #if defined(GI_NEON_INTRINSICS)
vst1q_s8(Buffer, Vector);
vst1q_s8((int8_t*)Buffer, Vector);
#elif defined(GI_SSE2_INTRINSICS) #elif defined(GI_SSE2_INTRINSICS)
_mm_storeu_si128((__m128i*)Buffer, Vector); _mm_storeu_si128((__m128i*)Buffer, Vector);
#else #else
int8_t* ptr = (int8_t*)Buffer;
for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(int8_t); i++) { for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(int8_t); i++) {
Buffer[i] = Vector[i];
ptr[i] = Vector[i];
} }
#endif #endif
} }


GI_FORCEINLINE GI_FORCEINLINE
void GiStoreLowInt8(int8_t* Buffer, GI_INT8_t Vector) {
void GiStoreLowInt8(void* Buffer, GI_INT8_t Vector) {
#if defined(GI_NEON_INTRINSICS) #if defined(GI_NEON_INTRINSICS)
vst1_s8(Buffer, vget_low_s8(Vector));
vst1_s8((int8_t*)Buffer, vget_low_s8(Vector));
#elif defined(GI_SSE2_INTRINSICS) #elif defined(GI_SSE2_INTRINSICS)
_mm_storel_epi64((__m128i*)Buffer, Vector); _mm_storel_epi64((__m128i*)Buffer, Vector);
#else #else
int8_t* ptr = (int8_t*)Buffer;
for (size_t i = 0; i < GI_SIMD_LEN_BYTE / 2 / sizeof(int8_t); i++) { for (size_t i = 0; i < GI_SIMD_LEN_BYTE / 2 / sizeof(int8_t); i++) {
Buffer[i] = Vector[i];
ptr[i] = Vector[i];
} }
#endif #endif
} }


GI_FORCEINLINE GI_FORCEINLINE
void GiStoreHihgInt8(int8_t* Buffer, GI_INT8_t Vector) {
void GiStoreHihgInt8(void* Buffer, GI_INT8_t Vector) {
#if defined(GI_NEON_INTRINSICS) #if defined(GI_NEON_INTRINSICS)
vst1_s8(Buffer, vget_high_s8(Vector));
vst1_s8((int8_t*)Buffer, vget_high_s8(Vector));
#elif defined(GI_SSE2_INTRINSICS) #elif defined(GI_SSE2_INTRINSICS)
_mm_storel_epi64((__m128i*)Buffer, _mm_unpackhi_epi64(Vector, Vector)); _mm_storel_epi64((__m128i*)Buffer, _mm_unpackhi_epi64(Vector, Vector));
#else #else
int8_t* ptr = (int8_t*)Buffer;
for (size_t i = 0; i < GI_SIMD_LEN_BYTE / 2 / sizeof(int8_t); i++) { for (size_t i = 0; i < GI_SIMD_LEN_BYTE / 2 / sizeof(int8_t); i++) {
Buffer[i] = Vector[GI_SIMD_LEN_BYTE / 2 + i];
ptr[i] = Vector[GI_SIMD_LEN_BYTE / 2 + i];
} }
#endif #endif
} }


+ 0
- 1
dnn/test/fallback/elemwise.cpp View File

@@ -39,7 +39,6 @@ TEST_F(FALLBACK, ELEMWISE_RECORD) {
checker.execs({{10, 10, 32}, {10, 10, 32}, {}}); checker.execs({{10, 10, 32}, {10, 10, 32}, {}});
} }



TEST_F(FALLBACK, ELEMWISE_FORWARD_TERNARY) { TEST_F(FALLBACK, ELEMWISE_FORWARD_TERNARY) {
using Mode = ElemwiseForward::Param::Mode; using Mode = ElemwiseForward::Param::Mode;
Checker<ElemwiseForward> checker(handle()); Checker<ElemwiseForward> checker(handle());


Loading…
Cancel
Save