Browse Source

fix(fallback): delete the repeat opcaller in fallback and arm_common

GitOrigin-RevId: 87046b8197
release-1.10
Megvii Engine Team 3 years ago
parent
commit
ff6a3bb819
80 changed files with 1810 additions and 3211 deletions
  1. +1
    -1
      dnn/src/aarch64/conv_bias/int8/algos.cpp
  2. +1
    -1
      dnn/src/aarch64/conv_bias/quint8/algos.cpp
  3. +1
    -1
      dnn/src/arm_common/conv_bias/f16/channel_wise_3x3_s1p1_nchw88_kern.cpp
  4. +1
    -1
      dnn/src/arm_common/conv_bias/f16/channel_wise_nchw88_algo.cpp
  5. +1
    -1
      dnn/src/arm_common/conv_bias/f16/channel_wise_nchw88_kern.cpp
  6. +1
    -1
      dnn/src/arm_common/conv_bias/f16/direct_nchw88_algo.cpp
  7. +1
    -1
      dnn/src/arm_common/conv_bias/f16/direct_nchw88_kern.cpp
  8. +1
    -1
      dnn/src/arm_common/conv_bias/fp32/channel_wise_3x3_s1p1_nchw44_kern.cpp
  9. +1
    -1
      dnn/src/arm_common/conv_bias/fp32/channel_wise_5x5_s1p2_nchw44_kern.cpp
  10. +1
    -1
      dnn/src/arm_common/conv_bias/fp32/channel_wise_nchw44_algo.cpp
  11. +1
    -1
      dnn/src/arm_common/conv_bias/fp32/channel_wise_nchw44_kern.cpp
  12. +1
    -1
      dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h
  13. +1
    -1
      dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h
  14. +1
    -1
      dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h
  15. +1
    -1
      dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw44_algo.cpp
  16. +1
    -1
      dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_algo.cpp
  17. +1
    -1
      dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_kern.h
  18. +1
    -1
      dnn/src/arm_common/conv_bias/int8/algos.cpp
  19. +1
    -1
      dnn/src/arm_common/conv_bias/int8/channel_wise_kernel.cpp
  20. +1
    -1
      dnn/src/arm_common/conv_bias/int8/channel_wise_nchw44.cpp
  21. +1
    -1
      dnn/src/arm_common/conv_bias/int8/direct.cpp
  22. +1
    -1
      dnn/src/arm_common/conv_bias/int8/direct_dotprod.cpp
  23. +1
    -1
      dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44_algo.cpp
  24. +1
    -1
      dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_common.h
  25. +1
    -1
      dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw44_s1.cpp
  26. +1
    -1
      dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw44_s2.cpp
  27. +1
    -1
      dnn/src/arm_common/conv_bias/int8/direct_nchw44_algo.cpp
  28. +1
    -1
      dnn/src/arm_common/conv_bias/int8/direct_nchw44_kern.h
  29. +1
    -1
      dnn/src/arm_common/conv_bias/int8/direct_nchw_nchw44_algo.cpp
  30. +1
    -1
      dnn/src/arm_common/conv_bias/int8/direct_nchw_nchw44_kern.h
  31. +1
    -1
      dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_algo.cpp
  32. +1
    -1
      dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_kern.h
  33. +1
    -1
      dnn/src/arm_common/conv_bias/int8/stride1.cpp
  34. +1
    -1
      dnn/src/arm_common/conv_bias/int8/stride1_dotprod.cpp
  35. +1
    -1
      dnn/src/arm_common/conv_bias/int8/stride2.cpp
  36. +1
    -1
      dnn/src/arm_common/conv_bias/int8/stride2_dotprod.cpp
  37. +1
    -1
      dnn/src/arm_common/conv_bias/int8x8x16/channel_wise_kernel_int8x8x16_nchw44.cpp
  38. +1
    -1
      dnn/src/arm_common/conv_bias/int8x8x16/direct_nchw_nchw44_algo.cpp
  39. +1
    -1
      dnn/src/arm_common/conv_bias/int8x8x16/direct_nchw_nchw44_kern.h
  40. +1
    -1
      dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl.h
  41. +1
    -1
      dnn/src/arm_common/conv_bias/matmul_postprocess.h
  42. +61
    -59
      dnn/src/arm_common/conv_bias/postprocess_helper.h
  43. +1
    -1
      dnn/src/arm_common/conv_bias/quint8/algos.cpp
  44. +1
    -1
      dnn/src/arm_common/conv_bias/quint8/direct.cpp
  45. +1
    -1
      dnn/src/arm_common/conv_bias/quint8/direct_dotprod.cpp
  46. +1
    -1
      dnn/src/arm_common/conv_bias/quint8/stride1.cpp
  47. +1
    -1
      dnn/src/arm_common/conv_bias/quint8/stride1_dotprod.cpp
  48. +1
    -1
      dnn/src/arm_common/conv_bias/quint8/stride2.cpp
  49. +1
    -1
      dnn/src/arm_common/conv_bias/quint8/stride2_dotprod.cpp
  50. +4
    -3
      dnn/src/arm_common/elemwise/binary/algo.cpp
  51. +1
    -1
      dnn/src/arm_common/elemwise/opr_impl.cpp
  52. +1
    -1
      dnn/src/arm_common/elemwise/opr_impl.h
  53. +2
    -1
      dnn/src/arm_common/elemwise/ternary/algo.cpp
  54. +2
    -1
      dnn/src/arm_common/elemwise/unary/algo.cpp
  55. +151
    -0
      dnn/src/arm_common/elemwise_helper/elemwise_op.h
  56. +0
    -36
      dnn/src/arm_common/elemwise_helper/kimpl/pow.h
  57. +0
    -1
      dnn/src/arm_common/elemwise_helper/op_binary.h
  58. +3
    -1
      dnn/src/arm_common/elemwise_multi_type/opr_impl.cpp
  59. +0
    -1537
      dnn/src/arm_common/elemwise_op.h
  60. +2
    -1
      dnn/src/fallback/elemwise/gi_impl/binary/algo.cpp
  61. +2
    -1
      dnn/src/fallback/elemwise/gi_impl/ternary/algo.cpp
  62. +2
    -1
      dnn/src/fallback/elemwise/gi_impl/unary/algo.cpp
  63. +1
    -0
      dnn/src/fallback/elemwise/opr_impl.cpp
  64. +2
    -2
      dnn/src/fallback/elemwise/opr_impl.h
  65. +72
    -0
      dnn/src/fallback/elemwise_helper/elemwise_op.h
  66. +1
    -1
      dnn/src/fallback/elemwise_helper/kimpl/abs.h
  67. +1
    -1
      dnn/src/fallback/elemwise_helper/kimpl/fuse_add_h_swish.h
  68. +1
    -1
      dnn/src/fallback/elemwise_helper/kimpl/hswish.h
  69. +1
    -1
      dnn/src/fallback/elemwise_helper/kimpl/max.h
  70. +1
    -1
      dnn/src/fallback/elemwise_helper/kimpl/min.h
  71. +1
    -1
      dnn/src/fallback/elemwise_helper/kimpl/mul.h
  72. +3
    -5
      dnn/src/fallback/elemwise_helper/kimpl/none.h
  73. +11
    -9
      dnn/src/fallback/elemwise_helper/kimpl/relu.h
  74. +1
    -1
      dnn/src/fallback/elemwise_helper/kimpl/sub.h
  75. +1370
    -0
      dnn/src/fallback/elemwise_helper/op_common.h
  76. +0
    -1432
      dnn/src/fallback/elemwise_op.h
  77. +13
    -0
      dnn/src/fallback/general_intrinsic/gi_common.h
  78. +17
    -35
      dnn/src/fallback/general_intrinsic/gi_float.h
  79. +35
    -28
      dnn/src/fallback/general_intrinsic/gi_int.h
  80. +0
    -1
      dnn/test/fallback/elemwise.cpp

+ 1
- 1
dnn/src/aarch64/conv_bias/int8/algos.cpp View File

@@ -12,7 +12,7 @@
#include "src/aarch64/conv_bias/int8/algos.h"
#include "src/aarch64/conv_bias/int8/strategy.h"
#include "src/arm_common/convolution/img2col_helper.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/common/opr_delegate.h"
#include "src/fallback/conv_bias/common.h"
#include "src/fallback/matrix_mul/gemm_impl.h"


+ 1
- 1
dnn/src/aarch64/conv_bias/quint8/algos.cpp View File

@@ -14,7 +14,7 @@
#include "src/aarch64/matrix_mul/quint8_dot/gemv.h"
#include "src/aarch64/matrix_mul/quint8_dot/strategy.h"
#include "src/arm_common/convolution/img2col_helper.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/common/opr_delegate.h"
#include "src/fallback/conv_bias/common.h"
#include "src/fallback/matrix_mul/gemm_impl.h"


+ 1
- 1
dnn/src/arm_common/conv_bias/f16/channel_wise_3x3_s1p1_nchw88_kern.cpp View File

@@ -11,7 +11,7 @@
*/

#include "src/arm_common/conv_bias/f16/channel_wise_3x3_s1p1_nchw88_kern.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/arm_common/utils.h"
#include "src/common/utils.h"


+ 1
- 1
dnn/src/arm_common/conv_bias/f16/channel_wise_nchw88_algo.cpp View File

@@ -12,7 +12,7 @@

#include "src/arm_common/conv_bias/f16/algos.h"
#include "src/arm_common/conv_bias/f16/channel_wise_nchw88_kern.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"

#include "midout.h"



+ 1
- 1
dnn/src/arm_common/conv_bias/f16/channel_wise_nchw88_kern.cpp View File

@@ -12,7 +12,7 @@

#include "src/arm_common/conv_bias/f16/channel_wise_nchw88_kern.h"
#include "src/arm_common/conv_bias/f16/channel_wise_3x3_s1p1_nchw88_kern.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/arm_common/utils.h"
#include "src/common/unroll_macro.h"


+ 1
- 1
dnn/src/arm_common/conv_bias/f16/direct_nchw88_algo.cpp View File

@@ -15,7 +15,7 @@
#include "src/arm_common/conv_bias/f16/algos.h"
#include "src/arm_common/conv_bias/f16/direct_nchw88_kern.h"

#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"

#include "midout.h"



+ 1
- 1
dnn/src/arm_common/conv_bias/f16/direct_nchw88_kern.cpp View File

@@ -13,7 +13,7 @@

#include "src/arm_common/conv_bias/f16/direct_nchw88_kern.h"
#include "src/arm_common/conv_bias/opr_impl.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/fallback/conv_bias/common.h"



+ 1
- 1
dnn/src/arm_common/conv_bias/fp32/channel_wise_3x3_s1p1_nchw44_kern.cpp View File

@@ -11,7 +11,7 @@
*/

#include "src/arm_common/conv_bias/fp32/channel_wise_3x3_s1p1_nchw44_kern.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/arm_common/utils.h"
#include "src/common/utils.h"


+ 1
- 1
dnn/src/arm_common/conv_bias/fp32/channel_wise_5x5_s1p2_nchw44_kern.cpp View File

@@ -11,7 +11,7 @@
*/

#include "src/arm_common/conv_bias/fp32/channel_wise_5x5_s1p2_nchw44_kern.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/arm_common/utils.h"
#include "src/common/utils.h"


+ 1
- 1
dnn/src/arm_common/conv_bias/fp32/channel_wise_nchw44_algo.cpp View File

@@ -12,7 +12,7 @@

#include "src/arm_common/conv_bias/fp32/algos.h"
#include "src/arm_common/conv_bias/fp32/channel_wise_nchw44_kern.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"

#include "midout.h"



+ 1
- 1
dnn/src/arm_common/conv_bias/fp32/channel_wise_nchw44_kern.cpp View File

@@ -13,7 +13,7 @@
#include "src/arm_common/conv_bias/fp32/channel_wise_nchw44_kern.h"
#include "src/arm_common/conv_bias/fp32/channel_wise_3x3_s1p1_nchw44_kern.h"
#include "src/arm_common/conv_bias/fp32/channel_wise_5x5_s1p2_nchw44_kern.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/arm_common/utils.h"
#include "src/common/unroll_macro.h"


+ 1
- 1
dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h View File

@@ -14,7 +14,7 @@
#include "megdnn/arch.h"
#include "src/arm_common/conv_bias/fp32/f32_direct_nchw44_kern.h"
#include "src/arm_common/conv_bias/intrinsic_helper.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/unroll_macro.h"
#include "src/common/utils.h"


+ 1
- 1
dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h View File

@@ -14,7 +14,7 @@
#include "megdnn/arch.h"
#include "src/arm_common/conv_bias/fp32/f32_direct_nchw44_kern.h"
#include "src/arm_common/conv_bias/intrinsic_helper.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/unroll_macro.h"
#include "src/common/utils.h"


+ 1
- 1
dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h View File

@@ -14,7 +14,7 @@
#include "src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_kern.h"
#include "src/arm_common/conv_bias/intrinsic_helper.h"
#include "src/arm_common/conv_bias/opr_impl.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/unroll_macro.h"
#include "src/common/utils.h"


+ 1
- 1
dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw44_algo.cpp View File

@@ -15,7 +15,7 @@
#include "src/arm_common/conv_bias/fp32/algos.h"
#include "src/arm_common/conv_bias/fp32/f32_direct_nchw44_kern.h"

#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"

#include "midout.h"



+ 1
- 1
dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_algo.cpp View File

@@ -15,7 +15,7 @@
#include "src/arm_common/conv_bias/fp32/algos.h"
#include "src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_kern.h"
#include "src/arm_common/conv_bias/fp32/strategy.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/common/nchw_nchwxx_valid.h"
#include "src/common/opr_delegate.h"



+ 1
- 1
dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_kern.h View File

@@ -13,7 +13,7 @@
#include "megdnn/arch.h"
#include "src/arm_common/conv_bias/intrinsic_helper.h"
#include "src/arm_common/conv_bias/opr_impl.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/unroll_macro.h"
#include "src/common/utils.h"


+ 1
- 1
dnn/src/arm_common/conv_bias/int8/algos.cpp View File

@@ -17,7 +17,7 @@
#include "src/arm_common/conv_bias/int8/stride1_dotprod.h"
#include "src/arm_common/conv_bias/int8/stride2.h"
#include "src/arm_common/conv_bias/int8/stride2_dotprod.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/fallback/conv_bias/common.h"

#include "midout.h"


+ 1
- 1
dnn/src/arm_common/conv_bias/int8/channel_wise_kernel.cpp View File

@@ -11,7 +11,7 @@
*/

#include "src/arm_common/conv_bias/int8/channel_wise_kernel.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/unroll_macro.h"
#include "src/common/utils.h"


+ 1
- 1
dnn/src/arm_common/conv_bias/int8/channel_wise_nchw44.cpp View File

@@ -12,7 +12,7 @@
#include "src/arm_common/conv_bias/int8/channel_wise_nchw44.h"
#include "megdnn/oprs.h"
#include "src/arm_common/conv_bias/int8/channel_wise_kernel.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/common/opr_delegate.h"

#include "midout.h"


+ 1
- 1
dnn/src/arm_common/conv_bias/int8/direct.cpp View File

@@ -10,7 +10,7 @@
*/

#include "src/arm_common/conv_bias/int8/direct.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/utils.h"
#include "src/fallback/conv_bias/common.h"


+ 1
- 1
dnn/src/arm_common/conv_bias/int8/direct_dotprod.cpp View File

@@ -11,7 +11,7 @@

#include "src/arm_common/conv_bias/int8/direct_dotprod.h"
#if MGB_ENABLE_DOT
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/utils.h"
#include "src/fallback/conv_bias/common.h"


+ 1
- 1
dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44_algo.cpp View File

@@ -14,7 +14,7 @@
#if MGB_ENABLE_DOT
#include "src/arm_common/conv_bias/int8/algos.h"
#include "src/arm_common/conv_bias/int8/direct_dotprod_nchw44.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"

#include "midout.h"



+ 1
- 1
dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_common.h View File

@@ -14,7 +14,7 @@
#include "megdnn/arch.h"
#include "src/arm_common/conv_bias/intrinsic_helper.h"
#if MGB_ENABLE_DOT
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/arm_common/intrinsic_helper.h"
#include "src/arm_common/neon_struct.h"
#include "src/common/unroll_macro.h"


+ 1
- 1
dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw44_s1.cpp View File

@@ -13,7 +13,7 @@
#include "src/arm_common/conv_bias/int8/direct.h"
#include "src/arm_common/conv_bias/int8/direct_nchw44_kern.h"
#include "src/arm_common/conv_bias/intrinsic_helper.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/utils.h"
#include "src/fallback/conv_bias/common.h"


+ 1
- 1
dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw44_s2.cpp View File

@@ -14,7 +14,7 @@
#include "src/arm_common/conv_bias/int8/direct.h"
#include "src/arm_common/conv_bias/int8/direct_nchw44_kern.h"
#include "src/arm_common/conv_bias/intrinsic_helper.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/utils.h"
#include "src/fallback/conv_bias/common.h"


+ 1
- 1
dnn/src/arm_common/conv_bias/int8/direct_nchw44_algo.cpp View File

@@ -14,7 +14,7 @@
#include "src/arm_common/conv_bias/int8/algos.h"
#include "src/arm_common/conv_bias/int8/direct.h"
#include "src/arm_common/conv_bias/int8/direct_nchw44_kern.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/common/opr_delegate.h"

#include "midout.h"


+ 1
- 1
dnn/src/arm_common/conv_bias/int8/direct_nchw44_kern.h View File

@@ -12,7 +12,7 @@
#pragma once
#include "src/arm_common/conv_bias/int8/direct.h"
#include "src/arm_common/conv_bias/intrinsic_helper.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/utils.h"
#include "src/fallback/conv_bias/common.h"


+ 1
- 1
dnn/src/arm_common/conv_bias/int8/direct_nchw_nchw44_algo.cpp View File

@@ -14,7 +14,7 @@
#include "src/arm_common/conv_bias/int8/algos.h"
#include "src/arm_common/conv_bias/int8/direct_nchw_nchw44_kern.h"
#include "src/arm_common/conv_bias/int8/strategy.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/common/nchw_nchwxx_valid.h"
#include "src/common/opr_delegate.h"



+ 1
- 1
dnn/src/arm_common/conv_bias/int8/direct_nchw_nchw44_kern.h View File

@@ -13,7 +13,7 @@
#include "megdnn/arch.h"
#include "src/arm_common/conv_bias/intrinsic_helper.h"
#include "src/arm_common/conv_bias/opr_impl.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/unroll_macro.h"
#include "src/common/utils.h"


+ 1
- 1
dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_algo.cpp View File

@@ -15,7 +15,7 @@
#include "src/arm_common/conv_bias/block_helper.h"
#include "src/arm_common/conv_bias/int8/algos.h"
#include "src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_kern.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/common/nchw_nchwxx_valid.h"

#include "midout.h"


+ 1
- 1
dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_kern.h View File

@@ -14,7 +14,7 @@

#include "src/arm_common/conv_bias/intrinsic_helper.h"
#if MGB_ENABLE_DOT
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/unroll_macro.h"
#include "src/common/utils.h"


+ 1
- 1
dnn/src/arm_common/conv_bias/int8/stride1.cpp View File

@@ -14,7 +14,7 @@
#include "megdnn/oprs.h"
#include "src/arm_common/conv_bias/int8/direct.h"
#include "src/arm_common/conv_bias/int8/strategy.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/common/opr_delegate.h"

using namespace megdnn;


+ 1
- 1
dnn/src/arm_common/conv_bias/int8/stride1_dotprod.cpp View File

@@ -14,7 +14,7 @@
#include "megdnn/oprs.h"
#include "src/arm_common/conv_bias/int8/direct_dotprod.h"
#include "src/arm_common/conv_bias/int8/strategy.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/common/opr_delegate.h"

using namespace megdnn;


+ 1
- 1
dnn/src/arm_common/conv_bias/int8/stride2.cpp View File

@@ -14,7 +14,7 @@
#include "megdnn/oprs.h"
#include "src/arm_common/conv_bias/int8/direct.h"
#include "src/arm_common/conv_bias/int8/strategy.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/common/opr_delegate.h"

using namespace megdnn;


+ 1
- 1
dnn/src/arm_common/conv_bias/int8/stride2_dotprod.cpp View File

@@ -14,7 +14,7 @@
#include "megdnn/oprs.h"
#include "src/arm_common/conv_bias/int8/direct_dotprod.h"
#include "src/arm_common/conv_bias/int8/strategy.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/common/opr_delegate.h"

using namespace megdnn;


+ 1
- 1
dnn/src/arm_common/conv_bias/int8x8x16/channel_wise_kernel_int8x8x16_nchw44.cpp View File

@@ -11,7 +11,7 @@
*/

#include "src/arm_common/conv_bias/int8x8x16/channel_wise_kernel.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/unroll_macro.h"
#include "src/common/utils.h"


+ 1
- 1
dnn/src/arm_common/conv_bias/int8x8x16/direct_nchw_nchw44_algo.cpp View File

@@ -15,7 +15,7 @@
#include "src/arm_common/conv_bias/block_helper.h"
#include "src/arm_common/conv_bias/int8x8x16/algos.h"
#include "src/arm_common/conv_bias/int8x8x16/direct_nchw_nchw44_kern.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/common/nchw_nchwxx_valid.h"
#include "src/common/opr_delegate.h"



+ 1
- 1
dnn/src/arm_common/conv_bias/int8x8x16/direct_nchw_nchw44_kern.h View File

@@ -13,7 +13,7 @@
#include "megdnn/arch.h"
#include "src/arm_common/conv_bias/intrinsic_helper.h"
#include "src/arm_common/conv_bias/opr_impl.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/unroll_macro.h"
#include "src/common/utils.h"


+ 1
- 1
dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl.h View File

@@ -16,7 +16,7 @@
#include "src/arm_common/conv_bias/int8x8x16/direct_nchw_nchw44_kern.h"
#include "src/arm_common/conv_bias/intrinsic_helper.h"
#include "src/arm_common/conv_bias/opr_impl.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/unroll_macro.h"
#include "src/common/utils.h"


+ 1
- 1
dnn/src/arm_common/conv_bias/matmul_postprocess.h View File

@@ -12,7 +12,7 @@

#include "megdnn/dtype.h"
#include "megdnn/oprs.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/utils.h"
#include "src/fallback/conv_bias/common.h"


+ 61
- 59
dnn/src/arm_common/conv_bias/postprocess_helper.h View File

@@ -13,8 +13,8 @@
#pragma once

#include "megdnn/basic_types.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/arm_common/elemwise_helper/kimpl/op_base.h"
#include "src/arm_common/elemwise_op.h"
#include "src/fallback/conv_bias/opr_impl.h"

#include "midout.h"
@@ -44,29 +44,29 @@ namespace {
break;

#define FOR_NONLINEAR_UNARY(_op) \
megdnn::arm_common::OpCallerUnary<_op<ctype>, megdnn::VEC>::run( \
megdnn::elemwise::OpCallerUnary<_op<ctype>, megdnn::elemwise::VEC>::run( \
static_cast<ctype*>(conv_dst_ptr), reinterpret_cast<ctype*>(dst_ptr), \
bias_type, dst_type, N* OC* OH* OW* pack_oc_size);

#define FOR_NONLINEAR_BINARY_BROADCAST(_op) \
megdnn::arm_common::OpCallerBinary<_op<ctype>, megdnn::VEC_BCAST101>::run( \
static_cast<ctype*>(conv_dst_ptr), \
reinterpret_cast<const ctype*>(bias_ptr), \
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, N, OC, \
#define FOR_NONLINEAR_BINARY_BROADCAST(_op) \
megdnn::elemwise::OpCallerBinary<_op<ctype>, megdnn::elemwise::VEC_BCAST101>::run( \
static_cast<ctype*>(conv_dst_ptr), \
reinterpret_cast<const ctype*>(bias_ptr), \
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, N, OC, \
OH* OW);

#define FOR_NONLINEAR_BINARY_BROADCAST_NCHWXX(_op) \
megdnn::arm_common::OpCallerBinary<_op<ctype>, megdnn::VEC_BCAST101xX>::run( \
static_cast<ctype*>(conv_dst_ptr), \
reinterpret_cast<const ctype*>(bias_ptr), \
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, N, OC, \
OH* OW, pack_oc_size);
#define FOR_NONLINEAR_BINARY(_op) \
megdnn::arm_common::OpCallerBinary<_op<ctype>, megdnn::VEC_VEC>::run( \
static_cast<ctype*>(conv_dst_ptr), \
reinterpret_cast<const ctype*>(bias_ptr), \
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, \
megdnn::elemwise::OpCallerBinary<_op<ctype>, megdnn::elemwise::VEC_BCAST101xX>:: \
run(static_cast<ctype*>(conv_dst_ptr), \
reinterpret_cast<const ctype*>(bias_ptr), \
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, N, \
OC, OH* OW, pack_oc_size);
#define FOR_NONLINEAR_BINARY(_op) \
megdnn::elemwise::OpCallerBinary<_op<ctype>, megdnn::elemwise::VEC_VEC>::run( \
static_cast<ctype*>(conv_dst_ptr), \
reinterpret_cast<const ctype*>(bias_ptr), \
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, \
N* OC* OH* OW* pack_oc_size);

#define FOR_BIAS(_mode) \
@@ -167,33 +167,35 @@ struct PostProcess<ctype, dtype, megdnn::PostprocessMode::NO_PROCESS> {
#undef FOR_BIAS
#undef HANDLE_IDENTITY

#define FOR_NONLINEAR_UNARY(_op) \
megdnn::arm_common::OpCallerUnary<_op<opctype, opdtype>, megdnn::VEC>::run( \
static_cast<opctype*>(conv_dst_ptr), reinterpret_cast<opdtype*>(dst_ptr), \
bias_type, dst_type, N* OC* OH* OW* pack_oc_size);

#define FOR_NONLINEAR_BINARY_BROADCAST(_op) \
megdnn::arm_common::OpCallerBinary<_op<opctype, opdtype>, megdnn::VEC_BCAST101>:: \
run(static_cast<opctype*>(conv_dst_ptr), \
reinterpret_cast<const opctype*>(bias_ptr), \
reinterpret_cast<opdtype*>(dst_ptr), bias_type, bias_type, dst_type, \
#define FOR_NONLINEAR_UNARY(_op) \
megdnn::elemwise::OpCallerUnary<_op<opctype, opdtype>, megdnn::elemwise::VEC>:: \
run(static_cast<opctype*>(conv_dst_ptr), \
reinterpret_cast<opdtype*>(dst_ptr), bias_type, dst_type, \
N* OC* OH* OW* pack_oc_size);

#define FOR_NONLINEAR_BINARY_BROADCAST(_op) \
megdnn::elemwise::OpCallerBinary< \
_op<opctype, opdtype>, megdnn::elemwise::VEC_BCAST101>:: \
run(static_cast<opctype*>(conv_dst_ptr), \
reinterpret_cast<const opctype*>(bias_ptr), \
reinterpret_cast<opdtype*>(dst_ptr), bias_type, bias_type, dst_type, \
N, OC, OH* OW);

#define FOR_NONLINEAR_BINARY_BROADCAST_NCHWXX(_op) \
megdnn::arm_common:: \
OpCallerBinary<_op<opctype, opdtype>, megdnn::VEC_BCAST101xX>::run( \
static_cast<opctype*>(conv_dst_ptr), \
reinterpret_cast<const opctype*>(bias_ptr), \
reinterpret_cast<opdtype*>(dst_ptr), bias_type, bias_type, \
dst_type, N, OC, OH* OW, pack_oc_size);
#define FOR_NONLINEAR_BINARY_BROADCAST_NCHW88(_op) \
megdnn::arm_common:: \
OpCallerBinary<_op<opctype, opdtype>, megdnn::VEC_BCAST101xX>::run( \
static_cast<opctype*>(conv_dst_ptr), \
reinterpret_cast<const opctype*>(bias_ptr), \
reinterpret_cast<opdtype*>(dst_ptr), bias_type, bias_type, \
dst_type, N, OC, OH* OW, pack_oc_size);
#define FOR_NONLINEAR_BINARY_BROADCAST_NCHWXX(_op) \
megdnn::elemwise::OpCallerBinary< \
_op<opctype, opdtype>, megdnn::elemwise::VEC_BCAST101xX>:: \
run(static_cast<opctype*>(conv_dst_ptr), \
reinterpret_cast<const opctype*>(bias_ptr), \
reinterpret_cast<opdtype*>(dst_ptr), bias_type, bias_type, dst_type, \
N, OC, OH* OW, pack_oc_size);
#define FOR_NONLINEAR_BINARY_BROADCAST_NCHW88(_op) \
megdnn::elemwise::OpCallerBinary< \
_op<opctype, opdtype>, megdnn::elemwise::VEC_BCAST101xX>:: \
run(static_cast<opctype*>(conv_dst_ptr), \
reinterpret_cast<const opctype*>(bias_ptr), \
reinterpret_cast<opdtype*>(dst_ptr), bias_type, bias_type, dst_type, \
N, OC, OH* OW, pack_oc_size);

#define HANDLE_IDENTITY(_caller, _op) \
case megdnn::NonlineMode::IDENTITY: \
@@ -267,25 +269,25 @@ struct PostProcess<opctype, opdtype, megdnn::PostprocessMode::QUANTIZED> {
#undef FOR_NONLINEAR
#undef FOR_BIAS

#define FOR_BINARY_BROADCAST(_op) \
megdnn::arm_common::OpCallerBinary<_op<ctype>, megdnn::VEC_BCAST101>::run( \
static_cast<ctype*>(conv_dst_ptr), \
reinterpret_cast<const ctype*>(bias_ptr), \
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, N, OC, \
#define FOR_BINARY_BROADCAST(_op) \
megdnn::elemwise::OpCallerBinary<_op<ctype>, megdnn::elemwise::VEC_BCAST101>::run( \
static_cast<ctype*>(conv_dst_ptr), \
reinterpret_cast<const ctype*>(bias_ptr), \
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, N, OC, \
OH* OW);

#define FOR_BINARY_BROADCAST_NCHWXX(_op) \
megdnn::arm_common::OpCallerBinary<_op<ctype>, megdnn::VEC_BCAST101xX>::run( \
static_cast<ctype*>(conv_dst_ptr), \
reinterpret_cast<const ctype*>(bias_ptr), \
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, N, OC, \
OH* OW, pack_oc_size);
#define FOR_BINARY(_op) \
megdnn::arm_common::OpCallerBinary<_op<ctype>, megdnn::VEC_VEC>::run( \
static_cast<ctype*>(conv_dst_ptr), \
reinterpret_cast<const ctype*>(bias_ptr), \
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, \
megdnn::elemwise::OpCallerBinary<_op<ctype>, megdnn::elemwise::VEC_BCAST101xX>:: \
run(static_cast<ctype*>(conv_dst_ptr), \
reinterpret_cast<const ctype*>(bias_ptr), \
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, N, \
OC, OH* OW, pack_oc_size);
#define FOR_BINARY(_op) \
megdnn::elemwise::OpCallerBinary<_op<ctype>, megdnn::elemwise::VEC_VEC>::run( \
static_cast<ctype*>(conv_dst_ptr), \
reinterpret_cast<const ctype*>(bias_ptr), \
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, \
N* OC* OH* OW* pack_oc_size);

#define FOR_BIAS(_bias_mode, OH, OW) \


+ 1
- 1
dnn/src/arm_common/conv_bias/quint8/algos.cpp View File

@@ -15,7 +15,7 @@
#include "src/arm_common/conv_bias/quint8/stride1_dotprod.h"
#include "src/arm_common/conv_bias/quint8/stride2.h"
#include "src/arm_common/conv_bias/quint8/stride2_dotprod.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/fallback/conv_bias/common.h"

#include "midout.h"


+ 1
- 1
dnn/src/arm_common/conv_bias/quint8/direct.cpp View File

@@ -10,7 +10,7 @@
*/

#include "src/arm_common/conv_bias/quint8/direct.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/utils.h"
#include "src/fallback/conv_bias/common.h"


+ 1
- 1
dnn/src/arm_common/conv_bias/quint8/direct_dotprod.cpp View File

@@ -11,7 +11,7 @@

#include "src/arm_common/conv_bias/quint8/direct_dotprod.h"
#if MGB_ENABLE_DOT
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/utils.h"
#include "src/fallback/conv_bias/common.h"


+ 1
- 1
dnn/src/arm_common/conv_bias/quint8/stride1.cpp View File

@@ -12,7 +12,7 @@
#include "src/arm_common/conv_bias/quint8/stride1.h"
#include "megdnn/oprs.h"
#include "src/arm_common/conv_bias/quint8/direct.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/common/opr_delegate.h"

using namespace megdnn;


+ 1
- 1
dnn/src/arm_common/conv_bias/quint8/stride1_dotprod.cpp View File

@@ -12,7 +12,7 @@
#if MGB_ENABLE_DOT
#include "megdnn/oprs.h"
#include "src/arm_common/conv_bias/quint8/direct_dotprod.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/common/opr_delegate.h"

using namespace megdnn;


+ 1
- 1
dnn/src/arm_common/conv_bias/quint8/stride2.cpp View File

@@ -12,7 +12,7 @@
#include "src/arm_common/conv_bias/quint8/stride2.h"
#include "megdnn/oprs.h"
#include "src/arm_common/conv_bias/quint8/direct.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/common/opr_delegate.h"

using namespace megdnn;


+ 1
- 1
dnn/src/arm_common/conv_bias/quint8/stride2_dotprod.cpp View File

@@ -12,7 +12,7 @@
#if MGB_ENABLE_DOT
#include "megdnn/oprs.h"
#include "src/arm_common/conv_bias/quint8/direct_dotprod.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/common/opr_delegate.h"

using namespace megdnn;


+ 4
- 3
dnn/src/arm_common/elemwise/binary/algo.cpp View File

@@ -10,7 +10,7 @@
* implied.
*/
#include "src/arm_common/elemwise/binary/algo.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"

#include "src/common/utils.h"
#include "src/naive/handle.h"
@@ -20,6 +20,7 @@
MIDOUT_DECL(megdnn_arm_common_elemwise_binary)

using namespace megdnn;
using namespace elemwise;
using namespace arm_common;

namespace {
@@ -160,7 +161,7 @@ bool ElemwiseImpl::AlgoBinaryVecBcast101xX::is_available(
DISPATCH_BINARY(ADD, _case, _type, _type_midout_id, AddOp); \
DISPATCH_BINARY(SUB, _case, _type, _type_midout_id, SubOp); \
DISPATCH_BINARY(MUL, _case, _type, _type_midout_id, MulOp); \
DISPATCH_BINARY(POW, _case, _type, _type_midout_id, PowOp); \
DISPATCH_BINARY(POW, _case, _type, _type_midout_id, fallback::PowOp); \
DISPATCH_BINARY(TRUE_DIV, _case, _type, _type_midout_id, TrueDivOp); \
DISPATCH_BINARY(FUSE_ADD_RELU, _case, _type, _type_midout_id, FuseAddReluOp); \
DISPATCH_BINARY( \
@@ -178,7 +179,7 @@ bool ElemwiseImpl::AlgoBinaryVecBcast101xX::is_available(
DISPATCH_BINARY(ADD, _case, _type, _type_midout_id, AddOp); \
DISPATCH_BINARY(SUB, _case, _type, _type_midout_id, SubOp); \
DISPATCH_BINARY(MUL, _case, _type, _type_midout_id, MulOp); \
DISPATCH_BINARY(POW, _case, _type, _type_midout_id, PowOp); \
DISPATCH_BINARY(POW, _case, _type, _type_midout_id, fallback::PowOp); \
DISPATCH_BINARY(FUSE_ADD_RELU, _case, _type, _type_midout_id, FuseAddReluOp); \
DISPATCH_BINARY( \
FUSE_ADD_H_SWISH, _case, _type, _type_midout_id, FuseAddHSwishOp); \


+ 1
- 1
dnn/src/arm_common/elemwise/opr_impl.cpp View File

@@ -13,7 +13,7 @@
#include "src/arm_common/elemwise/binary/algo.h"
#include "src/arm_common/elemwise/ternary/algo.h"
#include "src/arm_common/elemwise/unary/algo.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/common/metahelper.h"

#include "src/common/utils.h"


+ 1
- 1
dnn/src/arm_common/elemwise/opr_impl.h View File

@@ -12,7 +12,7 @@
#pragma once
#include "src/fallback/elemwise/opr_impl.h"

#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"

namespace megdnn {
namespace arm_common {


+ 2
- 1
dnn/src/arm_common/elemwise/ternary/algo.cpp View File

@@ -10,7 +10,7 @@
* implied.
*/
#include "src/arm_common/elemwise/ternary/algo.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"

#include "src/common/utils.h"
#include "src/naive/handle.h"
@@ -20,6 +20,7 @@
MIDOUT_DECL(megdnn_arm_common_elemwise_ternary)

using namespace megdnn;
using namespace elemwise;
using namespace arm_common;

#define DISPATCH_MODE_FLOAT(_case, _type, _type_midout_id) \


+ 2
- 1
dnn/src/arm_common/elemwise/unary/algo.cpp View File

@@ -10,7 +10,7 @@
* implied.
*/
#include "src/arm_common/elemwise/unary/algo.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"

#include "src/common/utils.h"
#include "src/naive/handle.h"
@@ -20,6 +20,7 @@
MIDOUT_DECL(megdnn_arm_common_elemwise_unary)

using namespace megdnn;
using namespace elemwise;
using namespace arm_common;

bool ElemwiseImpl::AlgoUnary::is_available(const KernParam& kern_param) const {


+ 151
- 0
dnn/src/arm_common/elemwise_helper/elemwise_op.h View File

@@ -0,0 +1,151 @@
/**
* \file dnn/src/arm_common/elemwise_helper/elemwise_op.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/

#pragma once

#include "src/arm_common/elemwise_helper/op_binary.h"
#include "src/arm_common/elemwise_helper/op_ternary.h"
#include "src/arm_common/elemwise_helper/op_unary.h"
#include "src/fallback/elemwise_helper/elemwise_op.h"

namespace megdnn {
namespace elemwise {

using BcastType = megdnn::elemwise::BcastType;

///////////////////////////////// ParamElemVistor ///////////////////////////

#define cb(_ctype, _inner_ctype, _neon_type, _fun_suffix) \
template <> \
struct ParamElemVisitor<_ctype> { \
_neon_type operator()(const _ctype* src) const { \
return vld1q_##_fun_suffix(reinterpret_cast<const _inner_ctype*>(src)); \
} \
}; \
template <> \
struct ParamElemVisitorDup<_ctype> { \
_neon_type operator()(const _ctype* src) const { \
return vdupq_n_##_fun_suffix(*reinterpret_cast<const _inner_ctype*>(src)); \
} \
}
cb(dt_quint8, uint8_t, uint8x16_t, u8);
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
cb(__fp16, __fp16, float16x8_t, f16);
#endif
cb(dt_int16, int16_t, int16x8_t, s16);
#undef cb

template <typename ctype>
struct ParamElemVisitorBcast101x4;
#define cb(_ctype, _inner_ctype, _neon_type, _fun_suffix, rel_suffix) \
template <> \
struct ParamElemVisitorBcast101x4<_ctype> { \
_neon_type operator()(const _ctype* src) const { \
return vreinterpretq_##_fun_suffix##_##rel_suffix(vld1q_dup_##rel_suffix( \
reinterpret_cast<const _inner_ctype*>(src))); \
} \
}

cb(dt_quint8, uint32_t, uint8x16_t, u8, u32);
cb(dt_int16, int64_t, int16x8_t, s16, s64);
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
cb(__fp16, uint64_t, float16x8_t, f16, u64);
#endif
#undef cb

template <typename ctype>
struct ParamElemVisitorBcast101x8;
#define cb(_ctype, _inner_ctype, _neon_type, _fun_suffix) \
template <> \
struct ParamElemVisitorBcast101x8<_ctype> { \
_neon_type operator()(const _ctype* src) const { \
return vld1q_##_fun_suffix(reinterpret_cast<const _inner_ctype*>(src)); \
} \
}
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
cb(__fp16, __fp16, float16x8_t, f16);
#endif
#undef cb

#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
template <>
struct OpCallerBinaryBcast101xXVec<__fp16, 8> {
using src_ctype = __fp16;

template <typename Op>
static void run(
const src_ctype* src0, const src_ctype* src1, typename Op::dst_ctype* dst,
const Op& op, size_t batch, size_t nr_channel_blocks,
size_t channel_stride) {
ParamElemVisitorBcast101x8<src_ctype> vis0;
ParamElemVisitor<src_ctype> vis1;
OpCallerBinaryBcast101xDVec<src_ctype, 8>::run(
src0, src1, dst, op, vis0, vis1, batch, nr_channel_blocks,
channel_stride);
}
};

template <>
struct OpCallerBinaryVecBcast101xX<__fp16, 8> {
using src_ctype = __fp16;
template <typename Op>
static void run(
const src_ctype* src0, const src_ctype* src1, typename Op::dst_ctype* dst,
const Op& op, size_t batch, size_t nr_channel_blocks,
size_t channel_stride) {
ParamElemVisitor<src_ctype> vis0;
ParamElemVisitorBcast101x8<src_ctype> vis1;
OpCallerBinaryVecBcast101xD<src_ctype, 8>::run(
src0, src1, dst, op, vis0, vis1, batch, nr_channel_blocks,
channel_stride);
}
};

template <>
struct OpCallerTernaryBcast101xXVecBcast101xX<__fp16, 8> {
using src_ctype = __fp16;
template <typename Op>
static void run(
const src_ctype* src0, const src_ctype* src1, const src_ctype* src2,
typename Op::dst_ctype* dst, const Op& op, size_t batch,
size_t nr_channel_blocks, size_t channel_stride) {
ParamElemVisitorBcast101x8<src_ctype> vis0;
ParamElemVisitor<src_ctype> vis1;
ParamElemVisitorBcast101x8<src_ctype> vis2;
OpCallerTernaryBcast101xDVecBcast101xD<src_ctype, 8>::run(
src0, src1, src2, dst, op, vis0, vis1, vis2, batch, nr_channel_blocks,
channel_stride);
}
};

template <>
struct OpCallerTernaryVecBcast101xXVec<__fp16, 8> {
using src_ctype = __fp16;
template <typename Op>
static void run(
const src_ctype* src0, const src_ctype* src1, const src_ctype* src2,
typename Op::dst_ctype* dst, const Op& op, size_t batch,
size_t nr_channel_blocks, size_t channel_stride) {
ParamElemVisitor<src_ctype> vis0;
ParamElemVisitorBcast101x8<src_ctype> vis1;
ParamElemVisitor<src_ctype> vis2;
OpCallerTernaryVecBcast101xDVec<src_ctype, 8>::run(
src0, src1, src2, dst, op, vis0, vis1, vis2, batch, nr_channel_blocks,
channel_stride);
}
};
#endif

} // namespace elemwise
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 0
- 36
dnn/src/arm_common/elemwise_helper/kimpl/pow.h View File

@@ -1,36 +0,0 @@
/**
* \file dnn/src/arm_common/elemwise_helper/kimpl/pow.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once

#include "src/arm_common/elemwise_helper/kimpl/op_base.h"

namespace megdnn {
namespace arm_common {

// when __fp16 is avaliable POW is very slow, so add there
/////////////////////// POW float only ////////////////////////////
template <typename src_ctype, typename dst_ctype = src_ctype>
struct PowOp : BinaryOpBase<src_ctype, dst_ctype> {
using BinaryOpBase<src_ctype, dst_ctype>::BinaryOpBase;
constexpr static size_t SIMD_WIDTH = 1;
void operator()(
const src_ctype& src0, const src_ctype& src1, dst_ctype* dst) const {
*dst = operator()(src0, src1);
}
dst_ctype operator()(const src_ctype& src0, const src_ctype& src1) const {
return powf(src0, src1);
}
};

} // namespace arm_common
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 0
- 1
dnn/src/arm_common/elemwise_helper/op_binary.h View File

@@ -18,7 +18,6 @@
#include "src/arm_common/elemwise_helper/kimpl/max.h"
#include "src/arm_common/elemwise_helper/kimpl/min.h"
#include "src/arm_common/elemwise_helper/kimpl/mul.h"
#include "src/arm_common/elemwise_helper/kimpl/pow.h"
#include "src/arm_common/elemwise_helper/kimpl/rmulh.h"
#include "src/arm_common/elemwise_helper/kimpl/sub.h"
#include "src/arm_common/elemwise_helper/kimpl/true_div.h"


+ 3
- 1
dnn/src/arm_common/elemwise_multi_type/opr_impl.cpp View File

@@ -15,7 +15,7 @@
#include "src/common/elemwise_multi_type/kern_defs.cuh"
#include "src/naive/handle.h"

#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/arm_common/simd_macro/marm_neon.h"

namespace {
@@ -46,6 +46,8 @@ void neon_round_shr_saturate_int16_static_k(

} // namespace

using namespace elemwise;

namespace megdnn {
namespace arm_common {



+ 0
- 1537
dnn/src/arm_common/elemwise_op.h
File diff suppressed because it is too large
View File


+ 2
- 1
dnn/src/fallback/elemwise/gi_impl/binary/algo.cpp View File

@@ -2,7 +2,7 @@
* \file dnn/src/fallback/elemwise/gi_impl/binary/algo.cpp
*/
#include "src/fallback/elemwise/gi_impl/binary/algo.h"
#include "src/fallback/elemwise_op.h"
#include "src/fallback/elemwise_helper/elemwise_op.h"

#include "src/common/utils.h"
#include "src/naive/handle.h"
@@ -12,6 +12,7 @@
MIDOUT_DECL(megdnn_fallback_elemwise_binary)

using namespace megdnn;
using namespace elemwise;
using namespace fallback;

namespace {


+ 2
- 1
dnn/src/fallback/elemwise/gi_impl/ternary/algo.cpp View File

@@ -3,7 +3,7 @@
*/

#include "src/fallback/elemwise/gi_impl/ternary/algo.h"
#include "src/fallback/elemwise_op.h"
#include "src/fallback/elemwise_helper/elemwise_op.h"

#include "src/common/utils.h"
#include "src/naive/handle.h"
@@ -13,6 +13,7 @@
MIDOUT_DECL(megdnn_fallback_elemwise_ternary)

using namespace megdnn;
using namespace elemwise;
using namespace fallback;

#define DISPATCH_MODE_FLOAT(_case, _type, _type_midout_id) \


+ 2
- 1
dnn/src/fallback/elemwise/gi_impl/unary/algo.cpp View File

@@ -2,7 +2,7 @@
* \file dnn/src/fallback/elemwise/gi_impl/unary/algo.cpp
*/
#include "src/fallback/elemwise/gi_impl/unary/algo.h"
#include "src/fallback/elemwise_op.h"
#include "src/fallback/elemwise_helper/elemwise_op.h"

#include "src/common/utils.h"
#include "src/naive/handle.h"
@@ -12,6 +12,7 @@
MIDOUT_DECL(megdnn_fallback_elemwise_unary)

using namespace megdnn;
using namespace elemwise;
using namespace fallback;

bool ElemwiseImpl::AlgoUnary::is_available(const KernParam& kern_param) const {


+ 1
- 0
dnn/src/fallback/elemwise/opr_impl.cpp View File

@@ -25,6 +25,7 @@ MIDOUT_DECL(megdnn_fallback_elemwise_exec_BINARY_INT)
MIDOUT_DECL(megdnn_fallback_elemwise_exec_BINARY_FLOAT)

using namespace megdnn;
using namespace elemwise;
using namespace fallback;

void ElemwiseImpl::exec(const TensorNDArray& srcs, _megdnn_tensor_out dst) {


+ 2
- 2
dnn/src/fallback/elemwise/opr_impl.h View File

@@ -9,7 +9,7 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "src/fallback/elemwise_op.h"
#include "src/fallback/elemwise_helper/elemwise_op.h"
#include "src/naive/elemwise/opr_impl.h"

namespace megdnn {
@@ -60,7 +60,7 @@ private:
public:
class AlgoBase;
struct KernParam {
BcastType broad_cast_type;
elemwise::BcastType broad_cast_type;
Mode mode;
const TensorND* m_dst;
Handle* handle;


+ 72
- 0
dnn/src/fallback/elemwise_helper/elemwise_op.h View File

@@ -0,0 +1,72 @@
/**
* \file dnn/src/fallback/elemwise_helper/elemwise_op.h
*/

#pragma once

#include "src/fallback/elemwise_helper/op_binary.h"
#include "src/fallback/elemwise_helper/op_common.h"
#include "src/fallback/elemwise_helper/op_ternary.h"
#include "src/fallback/elemwise_helper/op_unary.h"

#include "src/fallback/general_intrinsic/gi_float.h"
#include "src/fallback/general_intrinsic/gi_int.h"

namespace megdnn {
namespace elemwise {

///////////////////////////////// ParamElemVistor ///////////////////////////

#define cb(_ctype, _inner_ctype, _simd_type, _fun_suffix) \
template <> \
struct ParamElemVisitor<_ctype> { \
_simd_type operator()(const _ctype* src) const { \
return GiLoad##_fun_suffix(src); \
} \
}; \
template <> \
struct ParamElemVisitorDup<_ctype> { \
_simd_type operator()(const _ctype* src) const { \
return GiBroadcast##_fun_suffix( \
*reinterpret_cast<const _inner_ctype*>(src)); \
} \
}
cb(dt_qint32, int32_t, GI_INT32_t, Int32);
cb(dt_qint8, int8_t, GI_INT8_t, Int8);

cb(dt_float32, float, GI_FLOAT32_t, Float32);
cb(dt_int32, int32_t, GI_INT32_t, Int32);
cb(dt_int8, int8_t, GI_INT8_t, Int8);
#undef cb

template <typename ctype>
struct ParamElemVisitorBcast101x4;
#define cb(_ctype, _inner_ctype, _simd_type, _fun_suffix, rel_suffix) \
template <> \
struct ParamElemVisitorBcast101x4<_ctype> { \
_simd_type operator()(const _ctype* src) const { \
return GiReinter##rel_suffix##To##_fun_suffix(GiBroadcast##rel_suffix( \
*reinterpret_cast<const _inner_ctype*>(src))); \
} \
}

cb(dt_qint8, int32_t, GI_INT8_t, Int8, Int32);
cb(dt_int8, int32_t, GI_INT8_t, Int8, Int32);
#undef cb
#define cb(_ctype, _inner_ctype, _simd_type, _fun_suffix) \
template <> \
struct ParamElemVisitorBcast101x4<_ctype> { \
_simd_type operator()(const _ctype* src) const { \
return GiLoad##_fun_suffix(src); \
} \
}

cb(dt_qint32, int32_t, GI_INT32_t, Int32);
cb(dt_float32, float, GI_FLOAT32_t, Float32);
cb(dt_int32, int32_t, GI_INT32_t, Int32);
#undef cb

} // namespace elemwise
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 1
- 1
dnn/src/fallback/elemwise_helper/kimpl/abs.h View File

@@ -58,7 +58,7 @@ struct AbsOpBase<dt_qint8, dt_qint8> : UnaryOpBase<dt_qint8, dt_qint8> {
template <>
struct AbsOp<dt_qint8, dt_qint8> : AbsOpBase<dt_qint8, dt_qint8> {
using AbsOpBase::AbsOpBase;
constexpr static size_t SIMD_WIDTH = 16;
constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int8_t);
using AbsOpBase::operator();
void operator()(const GI_INT8_V2_t& vsrc, dt_qint8* dst) const {
OPERATOR_UNARY_QINT8_FALLBACK;


+ 1
- 1
dnn/src/fallback/elemwise_helper/kimpl/fuse_add_h_swish.h View File

@@ -87,7 +87,7 @@ template <>
struct FuseAddHSwishOp<dt_qint32, dt_qint8> : FuseAddHSwishOpBase<dt_qint32, dt_qint8> {
using FuseAddHSwishOpBase::FuseAddHSwishOpBase;
using FuseAddHSwishOpBase::operator();
constexpr static size_t SIMD_WIDTH = 4;
constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int8_t);
void operator()(
const GI_INT32_V2_t& vsrc0, const GI_INT32_V2_t& vsrc1,
dt_qint8* dst) const {


+ 1
- 1
dnn/src/fallback/elemwise_helper/kimpl/hswish.h View File

@@ -83,7 +83,7 @@ template <>
struct HSwishOp<dt_qint32, dt_qint8> : HSwishOpBase<dt_qint32, dt_qint8> {
using HSwishOpBase::HSwishOpBase;
using HSwishOpBase::operator();
constexpr static size_t SIMD_WIDTH = 4;
constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int32_t);

void operator()(const GI_INT32_V2_t& vsrc, dt_qint8* dst) const {
GiStoreLowInt8(reinterpret_cast<int8_t*>(dst), operator()(vsrc));


+ 1
- 1
dnn/src/fallback/elemwise_helper/kimpl/max.h View File

@@ -77,7 +77,7 @@ struct MaxOpBase<dt_qint8, dt_qint8> : BinaryOpBase<dt_qint8, dt_qint8> {
template <>
struct MaxOp<dt_qint8, dt_qint8> : MaxOpBase<dt_qint8, dt_qint8> {
using MaxOpBase::MaxOpBase;
constexpr static size_t SIMD_WIDTH = 16;
constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int8_t);
using MaxOpBase::operator();

void operator()(


+ 1
- 1
dnn/src/fallback/elemwise_helper/kimpl/min.h View File

@@ -74,7 +74,7 @@ struct MinOpBase<dt_qint8, dt_qint8> : BinaryOpBase<dt_qint8, dt_qint8> {
template <>
struct MinOp<dt_qint8, dt_qint8> : MinOpBase<dt_qint8, dt_qint8> {
using MinOpBase::MinOpBase;
constexpr static size_t SIMD_WIDTH = 16;
constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int8_t);
using MinOpBase::operator();

void operator()(


+ 1
- 1
dnn/src/fallback/elemwise_helper/kimpl/mul.h View File

@@ -73,7 +73,7 @@ struct MulOpBase<dt_qint8, dt_qint8> : BinaryOpBase<dt_qint8, dt_qint8> {
template <>
struct MulOp<dt_qint8, dt_qint8> : MulOpBase<dt_qint8, dt_qint8> {
using MulOpBase::MulOpBase;
constexpr static size_t SIMD_WIDTH = 16;
constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int8_t);
using MulOpBase::operator();

void operator()(


+ 3
- 5
dnn/src/fallback/elemwise_helper/kimpl/none.h View File

@@ -54,8 +54,6 @@ struct NoneOpBase<dt_qint32, dt_qint8> : UnaryOpBase<dt_qint32, dt_qint8> {
}
};

#pragma GCC diagnostic ignored "-Waddress-of-packed-member"

template <>
struct NoneOp<dt_qint32, dt_qint8> : NoneOpBase<dt_qint32, dt_qint8> {
using NoneOpBase::NoneOpBase;
@@ -63,11 +61,11 @@ struct NoneOp<dt_qint32, dt_qint8> : NoneOpBase<dt_qint32, dt_qint8> {
constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int32_t);

void operator()(const GI_INT32_V2_t& vsrc, dt_qint8* dst) const {
GiStoreInt32(reinterpret_cast<int32_t*>(dst), vsrc.val[0]);
GiStoreInt32(reinterpret_cast<int32_t*>(dst + 16), vsrc.val[1]);
GiStoreInt32(dst, vsrc.val[0]);
GiStoreInt32(dst + 16, vsrc.val[1]);
}
void operator()(const GI_INT32_t& src, dt_qint8* dst) const {
GiStoreInt32(reinterpret_cast<int32_t*>(dst), src);
GiStoreInt32(dst, src);
}
};



+ 11
- 9
dnn/src/fallback/elemwise_helper/kimpl/relu.h View File

@@ -112,36 +112,38 @@ struct ReluOp<dt_qint32, dt_qint8> : ReluOpBase<dt_qint32, dt_qint8>, FixupBase
: ReluOpBase(src_scale, dst_scale), FixupBase(scale) {}

void operator()(const int32x4x2_t& vsrc, dt_qint8* dst) const {
vst1_s8(reinterpret_cast<int8_t*>(dst), operator()(vsrc));
vst1_s8(reinterpret_cast<int8_t*>(dst), vget_low_s8(operator()(vsrc)));
}

int8x8_t operator()(const int32x4x2_t& vsrc) const {
int8x16_t operator()(const int32x4x2_t& vsrc) const {
int32x4_t vitem0 = vqrdmulhq_s32(vsrc.val[0], vmultiplier);
int32x4_t vitem1 = vqrdmulhq_s32(vsrc.val[1], vmultiplier);
vitem0 = vmaxq_s32(vitem0, QConverterBase::vzero());
vitem1 = vmaxq_s32(vitem1, QConverterBase::vzero());
return vqmovn_s16(vcombine_s16(
auto tmp = vqmovn_s16(vcombine_s16(
vqmovn_s32(vrshlq_s32(vitem0, vshift)),
vqmovn_s32(vrshlq_s32(vitem1, vshift))));
return vcombine_s8(tmp, tmp);
}
int8x8_t operator()(const float32x4_t& vsrc) const {
int8x16_t operator()(const float32x4_t& vsrc) const {
int32x4_t vitem0 = vqrdmulhq_s32(vcvtq_s32_f32(vsrc), vmultiplier);
vitem0 = vmaxq_s32(vitem0, QConverterBase::vzero());
vitem0 = vrshlq_s32(vitem0, vshift);
int16x4_t vitem = vqmovn_s32(vitem0);
return vqmovn_s16(vcombine_s16(vitem, vitem));
auto tmp = vqmovn_s16(vcombine_s16(vitem, vitem));
return vcombine_s8(tmp, tmp);
}
void operator()(const int32x4_t& src, dt_qint8* dst) const {
auto vitem0 = vmulq_f32(vcvtq_f32_s32(src), this->vscale);
vitem0 = vmaxq_f32(vitem0, QConverterBase::vfzero());
auto result = QConverter::convert<int8x8_t, float32x4_t>(vitem0);
vst1_lane_s32(reinterpret_cast<int32_t*>(dst), (int32x2_t)result, 0);
auto result = QConverter::convert<int8x16_t, float32x4_t>(vitem0);
vst1q_lane_s32(reinterpret_cast<int32_t*>(dst), (int32x4_t)result, 0);
}
void operator()(const float32x4_t& src, dt_qint8* dst) const {
auto vitem0 = vmulq_f32(src, this->vscale);
vitem0 = vmaxq_f32(vitem0, QConverterBase::vfzero());
auto result = QConverter::convert<int8x8_t, float32x4_t>(vitem0);
vst1_lane_s32(reinterpret_cast<int32_t*>(dst), (int32x2_t)result, 0);
auto result = QConverter::convert<int8x16_t, float32x4_t>(vitem0);
vst1q_lane_s32(reinterpret_cast<int32_t*>(dst), (int32x4_t)result, 0);
}
};



+ 1
- 1
dnn/src/fallback/elemwise_helper/kimpl/sub.h View File

@@ -73,7 +73,7 @@ struct SubOpBase<dt_qint8, dt_qint8> : BinaryOpBase<dt_qint8, dt_qint8> {
template <>
struct SubOp<dt_qint8, dt_qint8> : SubOpBase<dt_qint8, dt_qint8> {
using SubOpBase::SubOpBase;
constexpr static size_t SIMD_WIDTH = 16;
constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int8_t);
using SubOpBase::operator();

void operator()(


+ 1370
- 0
dnn/src/fallback/elemwise_helper/op_common.h
File diff suppressed because it is too large
View File


+ 0
- 1432
dnn/src/fallback/elemwise_op.h
File diff suppressed because it is too large
View File


+ 13
- 0
dnn/src/fallback/general_intrinsic/gi_common.h View File

@@ -13,6 +13,7 @@

#include "math.h"
#include "stdint.h"
#include "string.h"

#if defined(_WIN32)
#include <intrin.h>
@@ -132,6 +133,18 @@ typedef uint32_t GI_UINT32_t __attribute__((vector_size(16)));
#define Max(a, b) (a) > (b) ? (a) : (b)
#define Min(a, b) (a) < (b) ? (a) : (b)

#if defined(GI_NEON_INTRINSICS)
#if defined(__ARM_FEATURE_FMA) && defined(GI_NEON64_INTRINSICS)
#define v_fma_ps_f32(c, b, a) vfmaq_f32((c), (b), (a))
#define v_fma_n_f32(c, b, a) vfmaq_n_f32((c), (b), (a))
#define v_fma_lane_f32(c, b, a, lane) vfmaq_lane_f32((c), (b), (a), (lane))
#else
#define v_fma_ps_f32(c, b, a) vmlaq_f32((c), (b), (a))
#define v_fma_n_f32(c, b, a) vmlaq_n_f32((c), (b), (a))
#define v_fma_lane_f32(c, b, a, lane) vmlaq_lane_f32((c), (b), (a), (lane))
#endif
#endif

typedef struct {
GI_INT32_t val[2];
} GI_INT32_V2_t;


+ 17
- 35
dnn/src/fallback/general_intrinsic/gi_float.h View File

@@ -20,7 +20,9 @@ GI_INT32_t GiReinterpretAsInt32(GI_FLOAT32_t In) {
#elif defined(GI_SSE2_INTRINSICS)
return _mm_castps_si128(In);
#else
return *(GI_INT32_t*)(&In);
GI_INT32_t ret;
memcpy(&ret, &In, GI_SIMD_LEN_BYTE);
return ret;
#endif
}

@@ -31,7 +33,9 @@ GI_UINT32_t GiReinterpretAsUint32(GI_FLOAT32_t In) {
#elif defined(GI_SSE2_INTRINSICS)
return _mm_castps_si128(In);
#else
return *(GI_UINT32_t*)(&In);
GI_UINT32_t ret;
memcpy(&ret, &In, GI_SIMD_LEN_BYTE);
return ret;
#endif
}

@@ -42,7 +46,9 @@ GI_FLOAT32_t GiReintInt32ToFloat32(GI_INT32_t Vector) {
#elif defined(GI_SSE2_INTRINSICS)
return _mm_castsi128_ps(Vector);
#else
return *(GI_FLOAT32_t*)(&Vector);
GI_FLOAT32_t ret;
memcpy(&ret, &Vector, GI_SIMD_LEN_BYTE);
return ret;
#endif
}

@@ -53,7 +59,9 @@ GI_FLOAT32_t GiReintUint32ToFloat32(GI_UINT32_t Vector) {
#elif defined(GI_SSE2_INTRINSICS)
return _mm_castsi128_ps(Vector);
#else
return *(GI_FLOAT32_t*)(&Vector);
GI_FLOAT32_t ret;
memcpy(&ret, &Vector, GI_SIMD_LEN_BYTE);
return ret;
#endif
}

@@ -69,7 +77,7 @@ GI_INT32_t GiRoundAsInt32(GI_FLOAT32_t Vector) {
float32x4_t vinc0 = vbslq_f32(vcgeq_f32(Vector, vzero), vfhalf, vfneg_half);
return vcvtq_s32_f32(vaddq_f32(Vector, vinc0));
#endif
#elif defined(GI_SSE2_INTRINSICS)
#elif defined(GI_SSE42_INTRINSICS)
__m128 vfzero = _mm_set1_ps(0.f);
__m128 vfhalf = _mm_set1_ps(0.5f);
__m128 vfneg_half = _mm_set1_ps(-0.5f);
@@ -322,11 +330,7 @@ GI_FORCEINLINE
GI_FLOAT32_t GiMultiplyAddFloat32(
GI_FLOAT32_t VectorSum, GI_FLOAT32_t Vector1, GI_FLOAT32_t Vector2) {
#if defined(GI_NEON_INTRINSICS)
#if defined(__ARM_FEATURE_FMA)
return vfmaq_f32(VectorSum, Vector1, Vector2);
#else
return vmlaq_f32(VectorSum, Vector1, Vector2);
#endif
return v_fma_ps_f32(VectorSum, Vector1, Vector2);
#elif defined(GI_FMA3_INTRINSICS)
return _mm_fmadd_ps(Vector1, Vector2, VectorSum);
#elif defined(GI_SSE2_INTRINSICS)
@@ -352,11 +356,7 @@ GI_FORCEINLINE
GI_FLOAT32_t GiMultiplyAddScalarFloat32(
GI_FLOAT32_t VectorSum, GI_FLOAT32_t Vector, float Scalar) {
#if defined(GI_NEON_INTRINSICS)
#if defined(__ARM_FEATURE_FMA)
return vfmaq_n_f32(VectorSum, Vector, Scalar);
#else
return vfmla_n_f32(VectorSum, Vector, Scalar);
#endif
return v_fma_n_f32(VectorSum, Vector, Scalar);
#elif defined(GI_SSE2_INTRINSICS)
return GiMultiplyAddFloat32(VectorSum, GiBroadcastFloat32(Scalar), Vector);
#else
@@ -365,27 +365,10 @@ GI_FLOAT32_t GiMultiplyAddScalarFloat32(
}

#if defined(GI_NEON_INTRINSICS)
#if defined(__ARM_FEATURE_FMA)
#define GIMULTIPLYADDLANFLOAT32(i) \
GI_FORCEINLINE GI_FLOAT32_t GiMultiplyAddLan##i##Float32( \
GI_FLOAT32_t VectorSum, GI_FLOAT32_t Vector1, GI_FLOAT32_t Vector2) { \
return vfmaq_lane_f32(VectorSum, Vector1, vget_low_f32(Vector2), i); \
}
GIMULTIPLYADDLANFLOAT32(0)
GIMULTIPLYADDLANFLOAT32(1)
#undef GIMULTIPLYADDLANFLOAT32
#define GIMULTIPLYADDLANFLOAT32(i) \
GI_FORCEINLINE GI_FLOAT32_t GiMultiplyAddLan##i##Float32( \
GI_FLOAT32_t VectorSum, GI_FLOAT32_t Vector1, GI_FLOAT32_t Vector2) { \
return vfmaq_lane_f32(VectorSum, Vector1, vget_high_f32(Vector2), i - 2); \
}
GIMULTIPLYADDLANFLOAT32(2)
GIMULTIPLYADDLANFLOAT32(3)
#else
#define GIMULTIPLYADDLANFLOAT32(i) \
GI_FORCEINLINE GI_FLOAT32_t GiMultiplyAddLan##i##Float32( \
GI_FLOAT32_t VectorSum, GI_FLOAT32_t Vector1, GI_FLOAT32_t Vector2) { \
return vmlaq_lane_f32(VectorSum, Vector1, vget_low_f32(Vector2), i); \
return v_fma_lane_f32(VectorSum, Vector1, vget_low_f32(Vector2), i); \
}
GIMULTIPLYADDLANFLOAT32(0)
GIMULTIPLYADDLANFLOAT32(1)
@@ -393,11 +376,10 @@ GIMULTIPLYADDLANFLOAT32(1)
#define GIMULTIPLYADDLANFLOAT32(i) \
GI_FORCEINLINE GI_FLOAT32_t GiMultiplyAddLan##i##Float32( \
GI_FLOAT32_t VectorSum, GI_FLOAT32_t Vector1, GI_FLOAT32_t Vector2) { \
return vmlaq_lane_f32(VectorSum, Vector1, vget_high_f32(Vector2), i - 2); \
return v_fma_lane_f32(VectorSum, Vector1, vget_high_f32(Vector2), i - 2); \
}
GIMULTIPLYADDLANFLOAT32(2)
GIMULTIPLYADDLANFLOAT32(3)
#endif
#undef GIMULTIPLYADDLANFLOAT32
#elif defined(GI_SSE2_INTRINSICS)



+ 35
- 28
dnn/src/fallback/general_intrinsic/gi_int.h View File

@@ -59,66 +59,69 @@ GI_INT8_t GiBroadcastInt8(int8_t Value) {
}

GI_FORCEINLINE
GI_INT32_t GiLoadInt32(const int32_t* Buffer) {
GI_INT32_t GiLoadInt32(const void* Buffer) {
#if defined(GI_NEON_INTRINSICS)
return vld1q_s32(Buffer);
return vld1q_s32((int32_t*)Buffer);
#elif defined(GI_SSE2_INTRINSICS)
return _mm_loadu_si128((const __m128i*)Buffer);
#else
GI_INT32_t ret;
const int32_t* ptr = (int32_t*)Buffer;
for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(int32_t); i++) {
ret[i] = Buffer[i];
ret[i] = ptr[i];
}
return ret;
#endif
}

GI_FORCEINLINE
GI_INT8_t GiLoadInt8(const int8_t* Buffer) {
GI_INT8_t GiLoadInt8(const void* Buffer) {
#if defined(GI_NEON_INTRINSICS)
return vld1q_s8(Buffer);
return vld1q_s8((int8_t*)Buffer);
#elif defined(GI_SSE2_INTRINSICS)
return _mm_loadu_si128((const __m128i*)Buffer);
#else
GI_INT8_t ret;
const int8_t* ptr = (int8_t*)Buffer;
for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(int8_t); i++) {
ret[i] = Buffer[i];
ret[i] = ptr[i];
}
return ret;
#endif
}

GI_FORCEINLINE
void GiStoreInt32(int32_t* Buffer, GI_INT32_t Vector) {
void GiStoreInt32(void* Buffer, GI_INT32_t Vector) {
#if defined(GI_NEON_INTRINSICS)
vst1q_s32(Buffer, Vector);
vst1q_s32((int32_t*)Buffer, Vector);
#elif defined(GI_SSE2_INTRINSICS)
_mm_storeu_si128((__m128i*)Buffer, Vector);
#else
int32_t* ptr = (int32_t*)Buffer;
for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(int32_t); i++) {
Buffer[i] = Vector[i];
ptr[i] = Vector[i];
}
#endif
}

#if defined(GI_NEON_INTRINSICS)
#define GISTORELANEINT32(i) \
GI_FORCEINLINE void GiStoreLane##i##Int32(int32_t* Buffer, GI_INT32_t Vector) { \
vst1q_lane_s32(Buffer, Vector, i); \
#define GISTORELANEINT32(i) \
GI_FORCEINLINE void GiStoreLane##i##Int32(void* Buffer, GI_INT32_t Vector) { \
vst1q_lane_s32((int32_t*)Buffer, Vector, i); \
}

#elif defined(GI_SSE2_INTRINSICS)

#define GISTORELANEINT32(i) \
GI_FORCEINLINE void GiStoreLane##i##Int32(int32_t* Buffer, GI_INT32_t Vector) { \
GI_FORCEINLINE void GiStoreLane##i##Int32(void* Buffer, GI_INT32_t Vector) { \
GI_FLOAT32_t tmp = _mm_castsi128_ps(Vector); \
_mm_store_ss( \
(float*)Buffer, _mm_shuffle_ps(tmp, tmp, _MM_SHUFFLE(i, i, i, i))); \
}
#else
#define GISTORELANEINT32(i) \
GI_FORCEINLINE void GiStoreLane##i##Int32(int32_t* Buffer, GI_INT32_t Vector) { \
*Buffer = Vector[i]; \
#define GISTORELANEINT32(i) \
GI_FORCEINLINE void GiStoreLane##i##Int32(void* Buffer, GI_INT32_t Vector) { \
*((int32_t*)Buffer) = Vector[i]; \
}
#endif

@@ -141,53 +144,57 @@ GI_INT8_t GiReinterInt32ToInt8(GI_INT32_t Vector) {
}

GI_FORCEINLINE
void GiStoreInt16(int16_t* Buffer, GI_INT16_t Vector) {
void GiStoreInt16(void* Buffer, GI_INT16_t Vector) {
#if defined(GI_NEON_INTRINSICS)
vst1q_s16(Buffer, Vector);
vst1q_s16((int16_t*)Buffer, Vector);
#elif defined(GI_SSE2_INTRINSICS)
_mm_storeu_si128((__m128i*)Buffer, Vector);
#else
int16_t* ptr = (int16_t*)Buffer;
for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(int16_t); i++) {
Buffer[i] = Vector[i];
ptr[i] = Vector[i];
}
#endif
}

GI_FORCEINLINE
void GiStoreInt8(int8_t* Buffer, GI_INT8_t Vector) {
void GiStoreInt8(void* Buffer, GI_INT8_t Vector) {
#if defined(GI_NEON_INTRINSICS)
vst1q_s8(Buffer, Vector);
vst1q_s8((int8_t*)Buffer, Vector);
#elif defined(GI_SSE2_INTRINSICS)
_mm_storeu_si128((__m128i*)Buffer, Vector);
#else
int8_t* ptr = (int8_t*)Buffer;
for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(int8_t); i++) {
Buffer[i] = Vector[i];
ptr[i] = Vector[i];
}
#endif
}

GI_FORCEINLINE
void GiStoreLowInt8(int8_t* Buffer, GI_INT8_t Vector) {
void GiStoreLowInt8(void* Buffer, GI_INT8_t Vector) {
#if defined(GI_NEON_INTRINSICS)
vst1_s8(Buffer, vget_low_s8(Vector));
vst1_s8((int8_t*)Buffer, vget_low_s8(Vector));
#elif defined(GI_SSE2_INTRINSICS)
_mm_storel_epi64((__m128i*)Buffer, Vector);
#else
int8_t* ptr = (int8_t*)Buffer;
for (size_t i = 0; i < GI_SIMD_LEN_BYTE / 2 / sizeof(int8_t); i++) {
Buffer[i] = Vector[i];
ptr[i] = Vector[i];
}
#endif
}

GI_FORCEINLINE
void GiStoreHihgInt8(int8_t* Buffer, GI_INT8_t Vector) {
void GiStoreHihgInt8(void* Buffer, GI_INT8_t Vector) {
#if defined(GI_NEON_INTRINSICS)
vst1_s8(Buffer, vget_high_s8(Vector));
vst1_s8((int8_t*)Buffer, vget_high_s8(Vector));
#elif defined(GI_SSE2_INTRINSICS)
_mm_storel_epi64((__m128i*)Buffer, _mm_unpackhi_epi64(Vector, Vector));
#else
int8_t* ptr = (int8_t*)Buffer;
for (size_t i = 0; i < GI_SIMD_LEN_BYTE / 2 / sizeof(int8_t); i++) {
Buffer[i] = Vector[GI_SIMD_LEN_BYTE / 2 + i];
ptr[i] = Vector[GI_SIMD_LEN_BYTE / 2 + i];
}
#endif
}


+ 0
- 1
dnn/test/fallback/elemwise.cpp View File

@@ -39,7 +39,6 @@ TEST_F(FALLBACK, ELEMWISE_RECORD) {
checker.execs({{10, 10, 32}, {10, 10, 32}, {}});
}


TEST_F(FALLBACK, ELEMWISE_FORWARD_TERNARY) {
using Mode = ElemwiseForward::Param::Mode;
Checker<ElemwiseForward> checker(handle());


Loading…
Cancel
Save