Browse Source

fix(build): split some cpp, which consume two many mem when build

make build possible at 8G ddr env, when -j8

GitOrigin-RevId: d0c442b41d
tags/v1.8.0
Megvii Engine Team 3 years ago
parent
commit
32717b0ca4
83 changed files with 2058 additions and 1004 deletions
  1. +3
    -2
      dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s1_bias.cpp
  2. +15
    -0
      dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s1_broadcast_channel_bias.cpp
  3. +15
    -0
      dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s1_no_bias.cpp
  4. +3
    -2
      dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s2_bias.cpp
  5. +15
    -0
      dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s2_broadcast_channel_bias.cpp
  6. +15
    -0
      dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s2_no_bias.cpp
  7. +3
    -2
      dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s1_bias.cpp
  8. +15
    -0
      dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s1_broadcast_channel_bias.cpp
  9. +15
    -0
      dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s1_no_bias.cpp
  10. +3
    -2
      dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s2_bias.cpp
  11. +15
    -0
      dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s2_broadcast_channel_bias.cpp
  12. +15
    -0
      dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s2_no_bias.cpp
  13. +3
    -2
      dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s1_bias.cpp
  14. +15
    -0
      dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s1_broadcast_channel_bias.cpp
  15. +15
    -0
      dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s1_no_bias.cpp
  16. +3
    -2
      dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s2_bias.cpp
  17. +15
    -0
      dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s2_broadcast_channel_bias.cpp
  18. +15
    -0
      dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s2_no_bias.cpp
  19. +3
    -2
      dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s1_bias.cpp
  20. +15
    -0
      dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s1_broadcast_channel_bias.cpp
  21. +15
    -0
      dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s1_no_bias.cpp
  22. +3
    -2
      dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s2_bias.cpp
  23. +15
    -0
      dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s2_broadcast_channel_bias.cpp
  24. +15
    -0
      dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s2_no_bias.cpp
  25. +8
    -5
      dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h
  26. +8
    -5
      dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h
  27. +3
    -2
      dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s1_bias.cpp
  28. +15
    -0
      dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s1_broadcast_channel_bias.cpp
  29. +15
    -0
      dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s1_no_bias.cpp
  30. +3
    -2
      dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s2_bias.cpp
  31. +15
    -0
      dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s2_broadcast_channel_bias.cpp
  32. +15
    -0
      dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s2_no_bias.cpp
  33. +3
    -2
      dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s1_bias.cpp
  34. +15
    -0
      dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s1_broadcast_channel_bias.cpp
  35. +15
    -0
      dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s1_no_bias.cpp
  36. +3
    -2
      dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s2_bias.cpp
  37. +15
    -0
      dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s2_broadcast_channel_bias.cpp
  38. +15
    -0
      dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s2_no_bias.cpp
  39. +15
    -0
      dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s1_bias.cpp
  40. +15
    -0
      dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s1_broadcast_channel_bias.cpp
  41. +15
    -0
      dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s1_no_bias.cpp
  42. +15
    -0
      dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s2_bias.cpp
  43. +15
    -0
      dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s2_broadcast_channel_bias.cpp
  44. +15
    -0
      dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s2_no_bias.cpp
  45. +15
    -0
      dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s1_bias.cpp
  46. +15
    -0
      dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s1_broadcast_channel_bias.cpp
  47. +15
    -0
      dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s1_no_bias.cpp
  48. +15
    -0
      dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s2_bias.cpp
  49. +15
    -0
      dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s2_broadcast_channel_bias.cpp
  50. +15
    -0
      dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s2_no_bias.cpp
  51. +6
    -4
      dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h
  52. +3
    -18
      dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s1.h
  53. +21
    -0
      dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s1_2x2.cpp
  54. +21
    -0
      dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s1_3x3.cpp
  55. +21
    -0
      dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s1_5x5.cpp
  56. +21
    -0
      dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s1_7x7.cpp
  57. +3
    -18
      dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s2.h
  58. +21
    -0
      dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s2_2x2.cpp
  59. +21
    -0
      dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s2_3x3.cpp
  60. +21
    -0
      dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s2_5x5.cpp
  61. +21
    -0
      dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s2_7x7.cpp
  62. +2
    -2
      dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_common.h
  63. +5
    -445
      dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1.cpp
  64. +481
    -0
      dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1.h
  65. +19
    -0
      dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1_1x1.cpp
  66. +19
    -0
      dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1_2x2.cpp
  67. +19
    -0
      dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1_3x3.cpp
  68. +19
    -0
      dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1_5x5.cpp
  69. +19
    -0
      dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1_7x7.cpp
  70. +2
    -5
      dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl_2x2s1.cpp
  71. +2
    -5
      dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl_2x2s2.cpp
  72. +16
    -0
      dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl_3x3s1.cpp
  73. +16
    -0
      dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl_3x3s2.cpp
  74. +4
    -2
      dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl_5x5s1.cpp
  75. +4
    -2
      dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl_5x5s2.cpp
  76. +4
    -2
      dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl_7x7s1.cpp
  77. +4
    -2
      dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl_7x7s2.cpp
  78. +275
    -0
      dnn/src/fallback/elemwise/opr_binary_impl.cpp
  79. +0
    -252
      dnn/src/fallback/elemwise/opr_impl.cpp
  80. +122
    -0
      dnn/src/fallback/elemwise/opr_unary_impl.cpp
  81. +138
    -0
      dnn/src/naive/elemwise_multi_type/opr_impl_1.cpp
  82. +115
    -0
      dnn/src/naive/elemwise_multi_type/opr_impl_2.cpp
  83. +1
    -213
      dnn/src/naive/elemwise_multi_type/opr_impl_3.cpp

dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s1.cpp → dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s1_bias.cpp View File

@@ -1,6 +1,6 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s1.cpp
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s1_bias.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -11,4 +11,5 @@
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h"
INSTANTIATION_CONV_S1(2);
INSTANTIATION_CONV_S1_BIAS(2);
// vim: syntax=cpp.doxygen

+ 15
- 0
dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s1_broadcast_channel_bias.cpp View File

@@ -0,0 +1,15 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s1_broadcast_channel_bias.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h"
INSTANTIATION_CONV_S1_BROADCAST_CHANNEL_BIAS(2);
// vim: syntax=cpp.doxygen

+ 15
- 0
dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s1_no_bias.cpp View File

@@ -0,0 +1,15 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s1_no_bias.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h"
INSTANTIATION_CONV_S1_NO_BIAS(2);
// vim: syntax=cpp.doxygen

dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s2.cpp → dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s2_bias.cpp View File

@@ -1,6 +1,6 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s2.cpp
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s2_bias.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -11,4 +11,5 @@
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h"
INSTANTIATION_CONV_S2(5);
INSTANTIATION_CONV_S2_BIAS(2);
// vim: syntax=cpp.doxygen

+ 15
- 0
dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s2_broadcast_channel_bias.cpp View File

@@ -0,0 +1,15 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s2_broadcast_channel_bias.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h"
INSTANTIATION_CONV_S2_BROADCAST_CHANNEL_BIAS(2);
// vim: syntax=cpp.doxygen

+ 15
- 0
dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s2_no_bias.cpp View File

@@ -0,0 +1,15 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s2_no_bias.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h"
INSTANTIATION_CONV_S2_NO_BIAS(2);
// vim: syntax=cpp.doxygen

dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s1.cpp → dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s1_bias.cpp View File

@@ -1,6 +1,6 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s1.cpp
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s1_bias.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -11,4 +11,5 @@
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h"
INSTANTIATION_CONV_S1(5);
INSTANTIATION_CONV_S1_BIAS(3);
// vim: syntax=cpp.doxygen

+ 15
- 0
dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s1_broadcast_channel_bias.cpp View File

@@ -0,0 +1,15 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s1_broadcast_channel_bias.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h"
INSTANTIATION_CONV_S1_BROADCAST_CHANNEL_BIAS(3);
// vim: syntax=cpp.doxygen

+ 15
- 0
dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s1_no_bias.cpp View File

@@ -0,0 +1,15 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s1_no_bias.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h"
INSTANTIATION_CONV_S1_NO_BIAS(3);
// vim: syntax=cpp.doxygen

dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s2.cpp → dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s2_bias.cpp View File

@@ -1,6 +1,6 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s2.cpp
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s2_bias.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -11,4 +11,5 @@
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h"
INSTANTIATION_CONV_S2(2);
INSTANTIATION_CONV_S2_BIAS(3);
// vim: syntax=cpp.doxygen

+ 15
- 0
dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s2_broadcast_channel_bias.cpp View File

@@ -0,0 +1,15 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s2_broadcast_channel_bias.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h"
INSTANTIATION_CONV_S2_BROADCAST_CHANNEL_BIAS(3);
// vim: syntax=cpp.doxygen

+ 15
- 0
dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s2_no_bias.cpp View File

@@ -0,0 +1,15 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s2_no_bias.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h"
INSTANTIATION_CONV_S2_NO_BIAS(3);
// vim: syntax=cpp.doxygen

dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s1.cpp → dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s1_bias.cpp View File

@@ -1,6 +1,6 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s1.cpp
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s1_bias.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -11,4 +11,5 @@
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h"
INSTANTIATION_CONV_S1(3);
INSTANTIATION_CONV_S1_BIAS(5);
// vim: syntax=cpp.doxygen

+ 15
- 0
dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s1_broadcast_channel_bias.cpp View File

@@ -0,0 +1,15 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s1_broadcast_channel_bias.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h"
INSTANTIATION_CONV_S1_BROADCAST_CHANNEL_BIAS(5);
// vim: syntax=cpp.doxygen

+ 15
- 0
dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s1_no_bias.cpp View File

@@ -0,0 +1,15 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s1_no_bias.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h"
INSTANTIATION_CONV_S1_NO_BIAS(5);
// vim: syntax=cpp.doxygen

dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s2.cpp → dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s2_bias.cpp View File

@@ -1,6 +1,6 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s2.cpp
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s2_bias.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -11,4 +11,5 @@
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h"
INSTANTIATION_CONV_S2(7);
INSTANTIATION_CONV_S2_BIAS(5);
// vim: syntax=cpp.doxygen

+ 15
- 0
dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s2_broadcast_channel_bias.cpp View File

@@ -0,0 +1,15 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s2_broadcast_channel_bias.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h"
INSTANTIATION_CONV_S2_BROADCAST_CHANNEL_BIAS(5);
// vim: syntax=cpp.doxygen

+ 15
- 0
dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s2_no_bias.cpp View File

@@ -0,0 +1,15 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s2_no_bias.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h"
INSTANTIATION_CONV_S2_NO_BIAS(5);
// vim: syntax=cpp.doxygen

dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s1.cpp → dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s1_bias.cpp View File

@@ -1,6 +1,6 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s1.cpp
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s1_bias.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -11,4 +11,5 @@
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h"
INSTANTIATION_CONV_S1(7);
INSTANTIATION_CONV_S1_BIAS(7);
// vim: syntax=cpp.doxygen

+ 15
- 0
dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s1_broadcast_channel_bias.cpp View File

@@ -0,0 +1,15 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s1_broadcast_channel_bias.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h"
INSTANTIATION_CONV_S1_BROADCAST_CHANNEL_BIAS(7);
// vim: syntax=cpp.doxygen

+ 15
- 0
dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s1_no_bias.cpp View File

@@ -0,0 +1,15 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s1_no_bias.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h"
INSTANTIATION_CONV_S1_NO_BIAS(7);
// vim: syntax=cpp.doxygen

dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s2.cpp → dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s2_bias.cpp View File

@@ -1,6 +1,6 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s2.cpp
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s2_bias.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -11,4 +11,5 @@
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h"
INSTANTIATION_CONV_S2(3);
INSTANTIATION_CONV_S2_BIAS(7);
// vim: syntax=cpp.doxygen

+ 15
- 0
dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s2_broadcast_channel_bias.cpp View File

@@ -0,0 +1,15 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s2_broadcast_channel_bias.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h"
INSTANTIATION_CONV_S2_BROADCAST_CHANNEL_BIAS(7);
// vim: syntax=cpp.doxygen

+ 15
- 0
dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s2_no_bias.cpp View File

@@ -0,0 +1,15 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s2_no_bias.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h"
INSTANTIATION_CONV_S2_NO_BIAS(7);
// vim: syntax=cpp.doxygen

+ 8
- 5
dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h View File

@@ -469,9 +469,12 @@ void conv_bias::conv_direct_fp32_nchw44(
INSTANTIATION(filter_size, bias, HSwishOp<dt_float32>) \
INSTANTIATION(filter_size, bias, SigmoidOp<dt_float32>)

#define INSTANTIATION_CONV_S1(filter_size) \
FOR_OP(filter_size, BiasMode::NO_BIAS) \
FOR_OP(filter_size, BiasMode::BROADCAST_CHANNEL_BIAS) \
FOR_OP(filter_size, BiasMode::BIAS)
#define INSTANTIATION_CONV_S1_NO_BIAS(filter_size) \
FOR_OP(filter_size, BiasMode::NO_BIAS)

// vim: syntax=cpp.doxygen
#define INSTANTIATION_CONV_S1_BROADCAST_CHANNEL_BIAS(filter_size) \
FOR_OP(filter_size, BiasMode::BROADCAST_CHANNEL_BIAS)

#define INSTANTIATION_CONV_S1_BIAS(filter_size) FOR_OP(filter_size, BiasMode::BIAS)

// vim: syntax=cpp.doxygen

+ 8
- 5
dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h View File

@@ -550,9 +550,12 @@ void conv_bias::conv_direct_fp32_nchw44(
INSTANTIATION(filter_size, bias, HSwishOp<dt_float32>) \
INSTANTIATION(filter_size, bias, SigmoidOp<dt_float32>)

#define INSTANTIATION_CONV_S2(filter_size) \
FOR_OP(filter_size, BiasMode::NO_BIAS) \
FOR_OP(filter_size, BiasMode::BROADCAST_CHANNEL_BIAS) \
FOR_OP(filter_size, BiasMode::BIAS)
#define INSTANTIATION_CONV_S2_NO_BIAS(filter_size) \
FOR_OP(filter_size, BiasMode::NO_BIAS)

// vim: syntax=cpp.doxygen
#define INSTANTIATION_CONV_S2_BROADCAST_CHANNEL_BIAS(filter_size) \
FOR_OP(filter_size, BiasMode::BROADCAST_CHANNEL_BIAS)

#define INSTANTIATION_CONV_S2_BIAS(filter_size) FOR_OP(filter_size, BiasMode::BIAS)

// vim: syntax=cpp.doxygen

dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s1.cpp → dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s1_bias.cpp View File

@@ -1,6 +1,6 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s1.cpp
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s1_bias.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -11,4 +11,5 @@
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h"
INSTANCE_CONV(2, 1);
INSTANCE_CONV_BIAS(2, 1);
// vim: syntax=cpp.doxygen

+ 15
- 0
dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s1_broadcast_channel_bias.cpp View File

@@ -0,0 +1,15 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s1_broadcast_channel_bias.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h"
INSTANCE_CONV_BROADCAST_CHANNEL_BIAS(2, 1);
// vim: syntax=cpp.doxygen

+ 15
- 0
dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s1_no_bias.cpp View File

@@ -0,0 +1,15 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s1_no_bias.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h"
INSTANCE_CONV_NO_BIAS(2, 1);
// vim: syntax=cpp.doxygen

dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s2.cpp → dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s2_bias.cpp View File

@@ -1,6 +1,6 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s2.cpp
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s2_bias.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -11,4 +11,5 @@
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h"
INSTANCE_CONV(2, 2);
INSTANCE_CONV_BIAS(2, 2);
// vim: syntax=cpp.doxygen

+ 15
- 0
dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s2_broadcast_channel_bias.cpp View File

@@ -0,0 +1,15 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s2_broadcast_channel_bias.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h"
INSTANCE_CONV_BROADCAST_CHANNEL_BIAS(2, 2);
// vim: syntax=cpp.doxygen

+ 15
- 0
dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s2_no_bias.cpp View File

@@ -0,0 +1,15 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s2_no_bias.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h"
INSTANCE_CONV_NO_BIAS(2, 2);
// vim: syntax=cpp.doxygen

dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s1.cpp → dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s1_bias.cpp View File

@@ -1,6 +1,6 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s1.cpp
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s1_bias.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -11,4 +11,5 @@
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h"
INSTANCE_CONV(3, 1);
INSTANCE_CONV_BIAS(3, 1);
// vim: syntax=cpp.doxygen

+ 15
- 0
dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s1_broadcast_channel_bias.cpp View File

@@ -0,0 +1,15 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s1_broadcast_channel_bias.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h"
INSTANCE_CONV_BROADCAST_CHANNEL_BIAS(3, 1);
// vim: syntax=cpp.doxygen

+ 15
- 0
dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s1_no_bias.cpp View File

@@ -0,0 +1,15 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s1_no_bias
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h"
INSTANCE_CONV_NO_BIAS(3, 1);
// vim: syntax=cpp.doxygen

dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s2.cpp → dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s2_bias.cpp View File

@@ -1,6 +1,6 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s2.cpp
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s2_bias.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -11,4 +11,5 @@
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h"
INSTANCE_CONV(3, 2);
INSTANCE_CONV_BIAS(3, 2);
// vim: syntax=cpp.doxygen

+ 15
- 0
dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s2_broadcast_channel_bias.cpp View File

@@ -0,0 +1,15 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s2_broadcast_channel_bias.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h"
INSTANCE_CONV_BROADCAST_CHANNEL_BIAS(3, 2);
// vim: syntax=cpp.doxygen

+ 15
- 0
dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s2_no_bias.cpp View File

@@ -0,0 +1,15 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s2_no_bias.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h"
INSTANCE_CONV_NO_BIAS(3, 2);
// vim: syntax=cpp.doxygen

+ 15
- 0
dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s1_bias.cpp View File

@@ -0,0 +1,15 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s1_bias.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h"
INSTANCE_CONV_BIAS(5, 1);
// vim: syntax=cpp.doxygen

+ 15
- 0
dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s1_broadcast_channel_bias.cpp View File

@@ -0,0 +1,15 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s1_broadcast_channel_bias.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h"
INSTANCE_CONV_BROADCAST_CHANNEL_BIAS(5, 1);
// vim: syntax=cpp.doxygen

+ 15
- 0
dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s1_no_bias.cpp View File

@@ -0,0 +1,15 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s1_no_bias.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h"
INSTANCE_CONV_NO_BIAS(5, 1);
// vim: syntax=cpp.doxygen

+ 15
- 0
dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s2_bias.cpp View File

@@ -0,0 +1,15 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s2_bias.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h"
INSTANCE_CONV_BIAS(5, 2);
// vim: syntax=cpp.doxygen

+ 15
- 0
dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s2_broadcast_channel_bias.cpp View File

@@ -0,0 +1,15 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s2_broadcast_channel_bias.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h"
INSTANCE_CONV_BROADCAST_CHANNEL_BIAS(5, 2);
// vim: syntax=cpp.doxygen

+ 15
- 0
dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s2_no_bias.cpp View File

@@ -0,0 +1,15 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s2_no_bias.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h"
INSTANCE_CONV_NO_BIAS(5, 2);
// vim: syntax=cpp.doxygen

+ 15
- 0
dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s1_bias.cpp View File

@@ -0,0 +1,15 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s1_bias.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h"
INSTANCE_CONV_BIAS(7, 1);
// vim: syntax=cpp.doxygen

+ 15
- 0
dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s1_broadcast_channel_bias.cpp View File

@@ -0,0 +1,15 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s1_broadcast_channel_bias.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h"
INSTANCE_CONV_BROADCAST_CHANNEL_BIAS(7, 1);
// vim: syntax=cpp.doxygen

+ 15
- 0
dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s1_no_bias.cpp View File

@@ -0,0 +1,15 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s1_no_bias.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h"
INSTANCE_CONV_NO_BIAS(7, 1);
// vim: syntax=cpp.doxygen

+ 15
- 0
dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s2_bias.cpp View File

@@ -0,0 +1,15 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s2_bias.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h"
INSTANCE_CONV_BIAS(7, 2);
// vim: syntax=cpp.doxygen

+ 15
- 0
dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s2_broadcast_channel_bias.cpp View File

@@ -0,0 +1,15 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s2_broadcast_channel_bias.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h"
INSTANCE_CONV_BROADCAST_CHANNEL_BIAS(7, 2);
// vim: syntax=cpp.doxygen

+ 15
- 0
dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s2_no_bias.cpp View File

@@ -0,0 +1,15 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s2_no_bias.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h"
INSTANCE_CONV_NO_BIAS(7, 2);
// vim: syntax=cpp.doxygen

+ 6
- 4
dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h View File

@@ -928,9 +928,11 @@ void fp32_direct_nchw_nchw44::conv_direct_fp32_nchw_nchw44(
INSTANTIATION(stride, filter, bias, ReluOp<dt_float32>) \
INSTANTIATION(stride, filter, bias, HSwishOp<dt_float32>)

#define INSTANCE_CONV(filter, stride) \
FOR_OP(stride, filter, BiasMode::NO_BIAS) \
FOR_OP(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS) \
FOR_OP(stride, filter, BiasMode::BIAS)
#define INSTANCE_CONV_NO_BIAS(filter, stride) FOR_OP(stride, filter, BiasMode::NO_BIAS)

#define INSTANCE_CONV_BROADCAST_CHANNEL_BIAS(filter, stride) \
FOR_OP(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS)

#define INSTANCE_CONV_BIAS(filter, stride) FOR_OP(stride, filter, BiasMode::BIAS)

// vim: syntax=cpp.doxygen

dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s1.cpp → dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s1.h View File

@@ -1,6 +1,6 @@
/**
* \file
* dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s1.cpp
* dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s1.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -265,7 +265,8 @@ void conv_direct_sdot_int8_nchw44(

#define INSTANTIATION(dst_type, stride, filter_size, bias_mode, Op) \
template void \
conv_direct_sdot_int8_nchw44<dst_type, stride, bias_mode, Op, filter_size>( \
megdnn::arm_common::direct_dotprod_nchw44::conv_direct_sdot_int8_nchw44< \
dst_type, stride, bias_mode, Op, filter_size>( \
dst_type * dst, const int oh, const int ow, const int8_t* src, \
const int ih, const int iw, const int8_t* weight, const int32_t* bias, \
const int oh_size, const int oc, const int ic, const Op& op);
@@ -284,22 +285,6 @@ void conv_direct_sdot_int8_nchw44(
FOR_OP(stride, i, BiasMode::NO_BIAS) \
FOR_OP(stride, i, BiasMode::BROADCAST_CHANNEL_BIAS)

#define FOR_FILTER(stride) \
FOR_BIAS(stride, 2) \
FOR_BIAS(stride, 3) \
FOR_BIAS(stride, 5) \
FOR_BIAS(stride, 7)

FOR_FILTER(1)

#undef FOR_STRIDE
#undef FOR_FILTER
#undef FOR_IC
#undef FOR_BIAS
#undef FOR_NONLINEAR
#undef FOR_REMAIN
#undef INSTANTIATION

} // namespace direct_dotprod_nchw44
} // namespace arm_common
} // namespace megdnn

+ 21
- 0
dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s1_2x2.cpp View File

@@ -0,0 +1,21 @@
/**
* \file
* dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s1_2x2.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s1.h"
#if MGB_ENABLE_DOT
using namespace megdnn;
using namespace arm_common;

FOR_BIAS(1, 2);

#endif
// vim: syntax=cpp.doxygen

+ 21
- 0
dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s1_3x3.cpp View File

@@ -0,0 +1,21 @@
/**
* \file
* dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s1_3x3.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s1.h"
#if MGB_ENABLE_DOT
using namespace megdnn;
using namespace arm_common;

FOR_BIAS(1, 3);

#endif
// vim: syntax=cpp.doxygen

+ 21
- 0
dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s1_5x5.cpp View File

@@ -0,0 +1,21 @@
/**
* \file
* dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s1_5x5.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s1.h"
#if MGB_ENABLE_DOT
using namespace megdnn;
using namespace arm_common;

FOR_BIAS(1, 5);

#endif
// vim: syntax=cpp.doxygen

+ 21
- 0
dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s1_7x7.cpp View File

@@ -0,0 +1,21 @@
/**
* \file
* dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s1_7x7.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s1.h"
#if MGB_ENABLE_DOT
using namespace megdnn;
using namespace arm_common;

FOR_BIAS(1, 7);

#endif
// vim: syntax=cpp.doxygen

dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s2.cpp → dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s2.h View File

@@ -1,6 +1,6 @@
/**
* \file
* dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s2.cpp
* dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s2.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -266,7 +266,8 @@ void conv_direct_sdot_int8_nchw44(

#define INSTANTIATION(dst_type, stride, filter_size, bias_mode, Op) \
template void \
conv_direct_sdot_int8_nchw44<dst_type, stride, bias_mode, Op, filter_size>( \
megdnn::arm_common::direct_dotprod_nchw44::conv_direct_sdot_int8_nchw44< \
dst_type, stride, bias_mode, Op, filter_size>( \
dst_type * dst, const int oh, const int ow, const int8_t* src, \
const int ih, const int iw, const int8_t* weight, const int32_t* bias, \
const int oh_size, const int oc, const int ic, const Op& op);
@@ -285,22 +286,6 @@ void conv_direct_sdot_int8_nchw44(
FOR_OP(stride, i, BiasMode::NO_BIAS) \
FOR_OP(stride, i, BiasMode::BROADCAST_CHANNEL_BIAS)

#define FOR_FILTER(stride) \
FOR_BIAS(stride, 2) \
FOR_BIAS(stride, 3) \
FOR_BIAS(stride, 5) \
FOR_BIAS(stride, 7)

FOR_FILTER(2)

#undef FOR_STRIDE
#undef FOR_FILTER
#undef FOR_IC
#undef FOR_BIAS
#undef FOR_NONLINEAR
#undef FOR_REMAIN
#undef INSTANTIATION

} // namespace direct_dotprod_nchw44
} // namespace arm_common
} // namespace megdnn

+ 21
- 0
dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s2_2x2.cpp View File

@@ -0,0 +1,21 @@
/**
* \file
* dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s2_2x2.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s2.h"
#if MGB_ENABLE_DOT
using namespace megdnn;
using namespace arm_common;

FOR_BIAS(2, 2);

#endif
// vim: syntax=cpp.doxygen

+ 21
- 0
dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s2_3x3.cpp View File

@@ -0,0 +1,21 @@
/**
* \file
* dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s2_3x3.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s2.h"
#if MGB_ENABLE_DOT
using namespace megdnn;
using namespace arm_common;

FOR_BIAS(2, 3);

#endif
// vim: syntax=cpp.doxygen

+ 21
- 0
dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s2_5x5.cpp View File

@@ -0,0 +1,21 @@
/**
* \file
* dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s2_5x5.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s2.h"
#if MGB_ENABLE_DOT
using namespace megdnn;
using namespace arm_common;

FOR_BIAS(2, 5);

#endif
// vim: syntax=cpp.doxygen

+ 21
- 0
dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s2_7x7.cpp View File

@@ -0,0 +1,21 @@
/**
* \file
* dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s2_7x7.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s2.h"
#if MGB_ENABLE_DOT
using namespace megdnn;
using namespace arm_common;

FOR_BIAS(2, 7);

#endif
// vim: syntax=cpp.doxygen

+ 2
- 2
dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_common.h View File

@@ -1,6 +1,6 @@
/**
* \file
* dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1.cpp
* dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_common.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -45,4 +45,4 @@ public:
} // namespace arm_common
} // namespace megdnn

// vim: syntax=cpp.doxygen
// vim: syntax=cpp.doxygen

+ 5
- 445
dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1.cpp View File

@@ -13,336 +13,9 @@

#include "src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_common.h"
#include "src/arm_common/conv_bias/int8/direct_nchw_nchw44_kern.h"

namespace megdnn {
namespace arm_common {
namespace {
/**
* @brief core code for calculation patten
*
* @tparam src_idx is offset of src reg
* @tparam weight_idx is offset of weight reg
* @tparam c_dim is output channel
* @tparam Func mla operation funcion
* @tparam stride
* @tparam T outpur regs type
* @tparam T2 src regs type
* @tparam T3 weight regs type
* @tparam T4 temp regs type
*/

template <
int src_idx, int weight_idx, int c_dim, int stride, typename T, typename T2,
typename T3, typename T4>
struct ShiftCalHelper {
static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight, T4& temp);
static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight);
};
template <
int src_idx, int weight_idx, int c_dim, int stride, typename T, typename T2,
typename T3, typename T4>
MEGDNN_ALWAYS_INLINE void cal_helper(T& c, T2& src, T3& weight, T4& temp) {
ShiftCalHelper<src_idx, weight_idx, c_dim, stride, T, T2, T3, T4>::impl(
c, src, weight, temp);
}
template <
int src_idx, int weight_idx, int c_dim, int stride, typename T, typename T2,
typename T3>
MEGDNN_ALWAYS_INLINE void cal_helper(T& c, T2& src, T3& weight) {
ShiftCalHelper<src_idx, weight_idx, c_dim, stride, T, T2, T3, int>::impl(
c, src, weight);
};
template <
int src_idx, int weight_idx, typename T, typename T2, typename T3, typename T4>
struct ShiftCalHelper<src_idx, weight_idx, 2, 1, T, T2, T3, T4> {
static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight, T4& temp) {
c[0][0] = vdotq_s32_h(
src[(0 + src_idx) % 8], weight[0][weight_idx], c[0][0], temp[0]);
c[1][0] = vdotq_s32_h(
src[(0 + src_idx) % 8], weight[1][weight_idx], c[1][0], temp[1]);
c[0][1] = vdotq_s32_h(
src[(1 + src_idx) % 8], weight[0][weight_idx], c[0][1], temp[2]);
c[1][1] = vdotq_s32_h(
src[(1 + src_idx) % 8], weight[1][weight_idx], c[1][1], temp[3]);
c[0][2] = vdotq_s32_h(
src[(2 + src_idx) % 8], weight[0][weight_idx], c[0][2], temp[0]);
c[1][2] = vdotq_s32_h(
src[(2 + src_idx) % 8], weight[1][weight_idx], c[1][2], temp[1]);
c[0][3] = vdotq_s32_h(
src[(3 + src_idx) % 8], weight[0][weight_idx], c[0][3], temp[2]);
c[1][3] = vdotq_s32_h(
src[(3 + src_idx) % 8], weight[1][weight_idx], c[1][3], temp[3]);

c[0][4] = vdotq_s32_h(
src[(4 + src_idx) % 8], weight[0][weight_idx], c[0][4], temp[0]);
c[1][4] = vdotq_s32_h(
src[(4 + src_idx) % 8], weight[1][weight_idx], c[1][4], temp[1]);
c[0][5] = vdotq_s32_h(
src[(5 + src_idx) % 8], weight[0][weight_idx], c[0][5], temp[2]);
c[1][5] = vdotq_s32_h(
src[(5 + src_idx) % 8], weight[1][weight_idx], c[1][5], temp[3]);
c[0][6] = vdotq_s32_h(
src[(6 + src_idx) % 8], weight[0][weight_idx], c[0][6], temp[0]);
c[1][6] = vdotq_s32_h(
src[(6 + src_idx) % 8], weight[1][weight_idx], c[1][6], temp[1]);
c[0][7] = vdotq_s32_h(
src[(7 + src_idx) % 8], weight[0][weight_idx], c[0][7], temp[2]);
c[1][7] = vdotq_s32_h(
src[(7 + src_idx) % 8], weight[1][weight_idx], c[1][7], temp[3]);
}
static MEGDNN_ALWAYS_INLINE void impl(T&, T2&, T3&);
};
template <
int src_idx, int weight_idx, typename T, typename T2, typename T3, typename T4>
struct ShiftCalHelper<src_idx, weight_idx, 1, 1, T, T2, T3, T4> {
static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight, T4& temp) {
c[0][0] = vdotq_s32_h(
src[(0 + src_idx) % 8], weight[0][weight_idx], c[0][0], temp[0]);
c[0][1] = vdotq_s32_h(
src[(1 + src_idx) % 8], weight[0][weight_idx], c[0][1], temp[1]);
c[0][2] = vdotq_s32_h(
src[(2 + src_idx) % 8], weight[0][weight_idx], c[0][2], temp[2]);
c[0][3] = vdotq_s32_h(
src[(3 + src_idx) % 8], weight[0][weight_idx], c[0][3], temp[3]);
c[0][4] = vdotq_s32_h(
src[(4 + src_idx) % 8], weight[0][weight_idx], c[0][4], temp[0]);
c[0][5] = vdotq_s32_h(
src[(5 + src_idx) % 8], weight[0][weight_idx], c[0][5], temp[1]);
c[0][6] = vdotq_s32_h(
src[(6 + src_idx) % 8], weight[0][weight_idx], c[0][6], temp[2]);
c[0][7] = vdotq_s32_h(
src[(7 + src_idx) % 8], weight[0][weight_idx], c[0][7], temp[3]);
}
static MEGDNN_ALWAYS_INLINE void impl(T&, T2&, T3&);
};

template <BiasMode bias_mode, typename Op, int remain_w, int oc_block>
struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 1, oc_block, 1> {
static void impl(
const int8_t* src_ptr, const int8_t* weight_ptr, const int32_t* bias_ptr,
int8_t* dst_ptr, int ic, int ih, int iw, int ld_dst_oc, const Op& op) {
constexpr int stride = 1;
constexpr int filter_height = 1;
constexpr int filter_width = 4;
constexpr int oc_step = 4;
constexpr int loop_ic_step = 1;
constexpr int simd_len = 16;
constexpr int pack_iw_len = 16;
constexpr int src_reg = 8;
constexpr int weight_reg = 1;

const int ic_stride = ih * iw * pack_iw_len;
const int ld_weight_oc = oc_step * filter_height * filter_width * ic;
constexpr int c_dim = OCHelper<oc_block>::val;
int32x4_t c[c_dim][8];
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, oc_step);

for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride;
int8x16_t src[src_reg];
int8x16_t dot4_weight[c_dim][weight_reg];
int16x8_t temp_c[4];
load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>(
dot4_weight, weight_ptr, ld_weight_oc);
load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>(
src, nchw_src_ptr + 0 * iw * pack_iw_len, 0);
cal_helper<0, 0, c_dim, stride>(c, src, dot4_weight, temp_c);

weight_ptr += oc_step * filter_height * filter_width;
}

store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>(
c, op, dst_ptr, ld_dst_oc);
}
};

template <BiasMode bias_mode, typename Op, int remain_w, int oc_block>
struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 2, oc_block, 1> {
static void impl(
const int8_t* src_ptr, const int8_t* weight_ptr, const int32_t* bias_ptr,
int8_t* dst_ptr, int ic, int ih, int iw, int ld_dst_oc, const Op& op) {
constexpr int stride = 1;
constexpr int filter_height = 2;
constexpr int filter_width = 4;
constexpr int oc_step = 4;
constexpr int loop_ic_step = 1;
constexpr int simd_len = 16;
constexpr int pack_iw_len = 16;
constexpr int src_reg = 8;
constexpr int weight_reg = 1;

const int ic_stride = ih * iw * pack_iw_len;
const int ld_weight_oc = oc_step * filter_height * filter_width * ic;
constexpr int c_dim = OCHelper<oc_block>::val;
int32x4_t c[c_dim][8];
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, oc_step);

for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride;
int8x16_t src[src_reg];
int8x16_t dot4_weight[c_dim][weight_reg];
int16x8_t temp_c[4];
load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>(
dot4_weight, weight_ptr, ld_weight_oc);
load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>(
src, nchw_src_ptr + 0 * iw * pack_iw_len, 0);
cal_helper<0, 0, c_dim, stride>(c, src, dot4_weight, temp_c);

load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>(
dot4_weight, weight_ptr + 1 * filter_width * oc_step, ld_weight_oc);
load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>(
src, nchw_src_ptr + 1 * iw * pack_iw_len, 0);
cal_helper<0, 0, c_dim, stride>(c, src, dot4_weight, temp_c);

weight_ptr += oc_step * filter_height * filter_width;
}

store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>(
c, op, dst_ptr, ld_dst_oc);
}
};

template <BiasMode bias_mode, typename Op, int remain_w, int oc_block>
struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 3, oc_block, 1> {
static void impl(
const int8_t* src_ptr, const int8_t* weight_ptr, const int32_t* bias_ptr,
int8_t* dst_ptr, int ic, int ih, int iw, int ld_dst_oc, const Op& op) {
constexpr int stride = 1;
constexpr int filter_height = 3;
constexpr int filter_width = 4;
constexpr int oc_step = 4;
constexpr int loop_ic_step = 1;
constexpr int simd_len = 16;
constexpr int pack_iw_len = 16;
constexpr int src_reg = 8;
constexpr int weight_reg = 1;

const int ic_stride = ih * iw * pack_iw_len;
const int ld_weight_oc = oc_step * filter_height * filter_width * ic;
constexpr int c_dim = OCHelper<oc_block>::val;
int32x4_t c[c_dim][8];
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, oc_step);

for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride;
int8x16_t src[src_reg];
int8x16_t dot4_weight[c_dim][weight_reg];
int16x8_t temp_c[4];
load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>(
dot4_weight, weight_ptr, ld_weight_oc);

load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>(
src, nchw_src_ptr + 0 * iw * pack_iw_len, 0);
cal_helper<0, 0, c_dim, stride>(c, src, dot4_weight, temp_c);
load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>(
dot4_weight, weight_ptr + 1 * filter_width * oc_step, ld_weight_oc);

load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>(
src, nchw_src_ptr + 1 * iw * pack_iw_len, 0);
cal_helper<0, 0, c_dim, stride>(c, src, dot4_weight, temp_c);

load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>(
dot4_weight, weight_ptr + 2 * filter_width * oc_step, ld_weight_oc);
load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>(
src, nchw_src_ptr + 2 * iw * pack_iw_len, 0);
cal_helper<0, 0, c_dim, stride>(c, src, dot4_weight, temp_c);

weight_ptr += oc_step * filter_height * filter_width;
}
store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>(
c, op, dst_ptr, ld_dst_oc);
}
};

template <BiasMode bias_mode, typename Op, int remain_w, int oc_block>
struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 5, oc_block, 1> {
static void impl(
const int8_t* src_ptr, const int8_t* weight_ptr, const int32_t* bias_ptr,
int8_t* dst_ptr, int ic, int ih, int iw, int ld_dst_oc, const Op& op) {
constexpr int stride = 1;
constexpr int filter_height = 5;
constexpr int filter_width = 8;
constexpr int oc_step = 4;
constexpr int loop_ic_step = 1;
constexpr int simd_len = 16;
constexpr int pack_iw_len = 16;
constexpr int src_reg = 8;
constexpr int weight_reg = 2;

const int ic_stride = ih * iw * pack_iw_len;
const int ld_weight_oc = oc_step * filter_height * filter_width * ic;
constexpr int c_dim = OCHelper<oc_block>::val;
int32x4_t c[c_dim][8];
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, oc_step);

for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride;
int8x16_t src[src_reg];
int8x16_t dot4_weight[c_dim][weight_reg];
int16x8_t temp_c[4];
#define cb(step) \
load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>( \
dot4_weight, weight_ptr + step * filter_width * oc_step, ld_weight_oc); \
load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>( \
src, nchw_src_ptr + step * iw * pack_iw_len, 0); \
cal_helper<0, 0, c_dim, stride>(c, src, dot4_weight, temp_c); \
load_helper<4, 0, simd_len, 0, Vld1q_s8>( \
src, nchw_src_ptr + step * iw * pack_iw_len + src_reg * pack_iw_len, 0); \
cal_helper<4, 1, c_dim, stride>(c, src, dot4_weight, temp_c);
UNROLL_CALL_RAW(5, cb);
#undef cb
weight_ptr += oc_step * filter_height * filter_width;
}
store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>(
c, op, dst_ptr, ld_dst_oc);
}
};

template <BiasMode bias_mode, typename Op, int remain_w, int oc_block>
struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 7, oc_block, 1> {
static void impl(
const int8_t* src_ptr, const int8_t* weight_ptr, const int32_t* bias_ptr,
int8_t* dst_ptr, int ic, int ih, int iw, int ld_dst_oc, const Op& op) {
constexpr int stride = 1;
constexpr int filter_height = 7;
constexpr int filter_width = 8;
constexpr int oc_step = 4;
constexpr int loop_ic_step = 1;
constexpr int simd_len = 16;
constexpr int pack_iw_len = 16;
constexpr int src_reg = 8;
constexpr int weight_reg = 2;

const int ic_stride = ih * iw * pack_iw_len;
const int ld_weight_oc = oc_step * filter_height * filter_width * ic;
constexpr int c_dim = OCHelper<oc_block>::val;
int32x4_t c[c_dim][8];
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, oc_step);

for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride;
int8x16_t src[src_reg];
int8x16_t dot4_weight[c_dim][weight_reg];
int16x8_t temp_c[4];
#define cb(step) \
load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>( \
dot4_weight, weight_ptr + step * filter_width * oc_step, ld_weight_oc); \
load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>( \
src, nchw_src_ptr + step * iw * pack_iw_len, 0); \
cal_helper<0, 0, c_dim, stride>(c, src, dot4_weight, temp_c); \
load_helper<4, 0, simd_len, 0, Vld1q_s8>( \
src, nchw_src_ptr + step * iw * pack_iw_len + src_reg * pack_iw_len, 0); \
cal_helper<4, 1, c_dim, stride>(c, src, dot4_weight, temp_c);

UNROLL_CALL_RAW(7, cb);
#undef cb
weight_ptr += oc_step * filter_height * filter_width;
}
store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>(
c, op, dst_ptr, ld_dst_oc);
}
};
} // namespace

namespace int8_direct_nchw_nchw44 {
/**
* pack {oc / 4, fh, fw, ic, 4(oc)} to {oc / 4, ic, fh ,fw/4, 4(oc)*4(fw)}
@@ -444,115 +117,9 @@ void pack_nchw_src_for_nchw44_conv<1>(
}
}

template <BiasMode bias_mode, typename Op, size_t filter_size>
struct ConvDiectStrideInt8NchwNchw44<bias_mode, Op, filter_size, 1> {
static void impl(
const int8_t* src, const int8_t* filter, const int32_t* bias, int32_t* temp,
int8_t* dst, const size_t oc, const size_t ic, const size_t ih,
const size_t iw, const size_t oh, const size_t ow, const Op& op) {
MEGDNN_MARK_USED_VAR(temp);
constexpr int stride = 1;
constexpr size_t fh = filter_size;
constexpr size_t fw = (filter_size + 3) / 4 * 4;
constexpr size_t ic_step = 1;
constexpr size_t big_oc_step = 8;
constexpr size_t oc_step = 4;
constexpr size_t ih_step = 1;
constexpr size_t oh_step = 1;
constexpr size_t ow_step = 8;
constexpr size_t stride_h = stride;
constexpr size_t stride_w = stride;
constexpr int pack_iw_len = 16;

const size_t img_stride = oh * ow;
const size_t ow_end = ow / ow_step * ow_step;
const size_t ow_remain = ow - ow_end;
const size_t oc_end = oc / big_oc_step * big_oc_step;
const size_t oc_remain = oc - oc_end;
const int ld_dst_oc = oc_step * img_stride;

using remain_fun = std::function<void(
const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, int iw,
int ld_dst_oc, const Op& op)>;
remain_fun kern_big_oc_remain = nullptr;
remain_fun kern_small_oc_remain = nullptr;
switch (ow_remain) {
#define cb(step) \
case step: \
kern_big_oc_remain = KerNeonXXs2NchwNchw44< \
bias_mode, Op, step, filter_size, big_oc_step, stride>::impl; \
kern_small_oc_remain = KerNeonXXs2NchwNchw44< \
bias_mode, Op, step, filter_size, oc_step, stride>::impl; \
break;

UNROLL_CALL_RAW(8, cb);
default:
megdnn_assert(0, "no remain %zu for kern", ow_remain);
}

for (size_t oc_idx = 0; oc_idx < oc_end; oc_idx += big_oc_step) {
const size_t weight_offset = oc_idx * ic * fh * fw;
for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) {
for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) {
const size_t src_offset =
(oh_idx * stride_h * iw + ow_idx * stride_w * ih_step) *
ic_step * pack_iw_len;
const size_t dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step;

KerNeonXXs2NchwNchw44<
bias_mode, Op, ow_step, filter_size, big_oc_step, stride>::
impl(src + src_offset, filter + weight_offset,
bias + oc_idx, dst + dst_offset, ic, ih, iw, ld_dst_oc,
op);
}
if (ow_remain > 0) {
const size_t src_offset =
(oh_idx * stride_h * iw + ow_end * stride_w * ih_step) *
ic_step * pack_iw_len;
const size_t dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step;
kern_big_oc_remain(
src + src_offset, filter + weight_offset, bias + oc_idx,
dst + dst_offset, ic, ih, iw, ld_dst_oc, op);
}
}
}

if (oc_remain > 0) {
size_t oc_idx = oc_end;
const size_t weight_offset = oc_idx * ic * fh * fw;
for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) {
for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) {
const size_t src_offset =
(oh_idx * stride_h * iw + ow_idx * stride_w * ih_step) *
ic_step * pack_iw_len;
const size_t dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step;
KerNeonXXs2NchwNchw44<
bias_mode, Op, ow_step, filter_size, oc_step, stride>::
impl(src + src_offset, filter + weight_offset,
bias + oc_idx, dst + dst_offset, ic, ih, iw, ld_dst_oc,
op);
}
if (ow_remain > 0) {
const size_t src_offset =
(oh_idx * stride_h * iw + ow_end * stride_w * ih_step) *
ic_step * pack_iw_len;
const size_t dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step;
kern_small_oc_remain(
src + src_offset, filter + weight_offset, bias + oc_idx,
dst + dst_offset, ic, ih, iw, ld_dst_oc, op);
}
}
}
}
};

#define INSTANCE_CONV_KERN_FUN(stride, filter_size, bias_mode, Op) \
template struct ConvDiectStrideInt8NchwNchw44<bias_mode, Op, filter_size, stride>;
template struct megdnn::arm_common::int8_direct_nchw_nchw44:: \
ConvDiectStrideInt8NchwNchw44<bias_mode, Op, filter_size, stride>;

#define INSTANCE_OP_PARAM(stride, filter, bias_mode) \
INSTANCE_CONV_KERN_FUN( \
@@ -566,17 +133,10 @@ struct ConvDiectStrideInt8NchwNchw44<bias_mode, Op, filter_size, 1> {
INSTANCE_OP_PARAM(stride, filter, BiasMode::NO_BIAS) \
INSTANCE_OP_PARAM(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS)

#define INSTANCE_CONV_KERN(stride) \
INSTANCE_BIAS_MODE_PARAM(stride, 1) \
INSTANCE_BIAS_MODE_PARAM(stride, 2) \
INSTANCE_BIAS_MODE_PARAM(stride, 3) \
INSTANCE_BIAS_MODE_PARAM(stride, 5) \
INSTANCE_BIAS_MODE_PARAM(stride, 7)

INSTANCE_CONV_KERN(1);
#define INSTANCE_CONV_KERN(stride, filter) INSTANCE_BIAS_MODE_PARAM(stride, filter)

} // namespace int8_direct_nchw_nchw44
} // namespace arm_common
} // namespace megdnn

// vim: syntax=cpp.doxygen
// vim: syntax=cpp.doxygen

+ 481
- 0
dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1.h View File

@@ -0,0 +1,481 @@
/**
* \file
* dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/

#include "src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_common.h"
#include "src/arm_common/conv_bias/int8/direct_nchw_nchw44_kern.h"

namespace megdnn {
namespace arm_common {
namespace {
/**
* @brief core code for calculation patten
*
* @tparam src_idx is offset of src reg
* @tparam weight_idx is offset of weight reg
* @tparam c_dim is output channel
* @tparam Func mla operation funcion
* @tparam stride
* @tparam T outpur regs type
* @tparam T2 src regs type
* @tparam T3 weight regs type
* @tparam T4 temp regs type
*/

template <
int src_idx, int weight_idx, int c_dim, int stride, typename T, typename T2,
typename T3, typename T4>
struct ShiftCalHelper {
static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight, T4& temp);
static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight);
};
template <
int src_idx, int weight_idx, int c_dim, int stride, typename T, typename T2,
typename T3, typename T4>
MEGDNN_ALWAYS_INLINE void cal_helper(T& c, T2& src, T3& weight, T4& temp) {
ShiftCalHelper<src_idx, weight_idx, c_dim, stride, T, T2, T3, T4>::impl(
c, src, weight, temp);
}
template <
int src_idx, int weight_idx, int c_dim, int stride, typename T, typename T2,
typename T3>
MEGDNN_ALWAYS_INLINE void cal_helper(T& c, T2& src, T3& weight) {
ShiftCalHelper<src_idx, weight_idx, c_dim, stride, T, T2, T3, int>::impl(
c, src, weight);
};
template <
int src_idx, int weight_idx, typename T, typename T2, typename T3, typename T4>
struct ShiftCalHelper<src_idx, weight_idx, 2, 1, T, T2, T3, T4> {
static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight, T4& temp) {
c[0][0] = vdotq_s32_h(
src[(0 + src_idx) % 8], weight[0][weight_idx], c[0][0], temp[0]);
c[1][0] = vdotq_s32_h(
src[(0 + src_idx) % 8], weight[1][weight_idx], c[1][0], temp[1]);
c[0][1] = vdotq_s32_h(
src[(1 + src_idx) % 8], weight[0][weight_idx], c[0][1], temp[2]);
c[1][1] = vdotq_s32_h(
src[(1 + src_idx) % 8], weight[1][weight_idx], c[1][1], temp[3]);
c[0][2] = vdotq_s32_h(
src[(2 + src_idx) % 8], weight[0][weight_idx], c[0][2], temp[0]);
c[1][2] = vdotq_s32_h(
src[(2 + src_idx) % 8], weight[1][weight_idx], c[1][2], temp[1]);
c[0][3] = vdotq_s32_h(
src[(3 + src_idx) % 8], weight[0][weight_idx], c[0][3], temp[2]);
c[1][3] = vdotq_s32_h(
src[(3 + src_idx) % 8], weight[1][weight_idx], c[1][3], temp[3]);

c[0][4] = vdotq_s32_h(
src[(4 + src_idx) % 8], weight[0][weight_idx], c[0][4], temp[0]);
c[1][4] = vdotq_s32_h(
src[(4 + src_idx) % 8], weight[1][weight_idx], c[1][4], temp[1]);
c[0][5] = vdotq_s32_h(
src[(5 + src_idx) % 8], weight[0][weight_idx], c[0][5], temp[2]);
c[1][5] = vdotq_s32_h(
src[(5 + src_idx) % 8], weight[1][weight_idx], c[1][5], temp[3]);
c[0][6] = vdotq_s32_h(
src[(6 + src_idx) % 8], weight[0][weight_idx], c[0][6], temp[0]);
c[1][6] = vdotq_s32_h(
src[(6 + src_idx) % 8], weight[1][weight_idx], c[1][6], temp[1]);
c[0][7] = vdotq_s32_h(
src[(7 + src_idx) % 8], weight[0][weight_idx], c[0][7], temp[2]);
c[1][7] = vdotq_s32_h(
src[(7 + src_idx) % 8], weight[1][weight_idx], c[1][7], temp[3]);
}
static MEGDNN_ALWAYS_INLINE void impl(T&, T2&, T3&);
};
template <
int src_idx, int weight_idx, typename T, typename T2, typename T3, typename T4>
struct ShiftCalHelper<src_idx, weight_idx, 1, 1, T, T2, T3, T4> {
static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight, T4& temp) {
c[0][0] = vdotq_s32_h(
src[(0 + src_idx) % 8], weight[0][weight_idx], c[0][0], temp[0]);
c[0][1] = vdotq_s32_h(
src[(1 + src_idx) % 8], weight[0][weight_idx], c[0][1], temp[1]);
c[0][2] = vdotq_s32_h(
src[(2 + src_idx) % 8], weight[0][weight_idx], c[0][2], temp[2]);
c[0][3] = vdotq_s32_h(
src[(3 + src_idx) % 8], weight[0][weight_idx], c[0][3], temp[3]);
c[0][4] = vdotq_s32_h(
src[(4 + src_idx) % 8], weight[0][weight_idx], c[0][4], temp[0]);
c[0][5] = vdotq_s32_h(
src[(5 + src_idx) % 8], weight[0][weight_idx], c[0][5], temp[1]);
c[0][6] = vdotq_s32_h(
src[(6 + src_idx) % 8], weight[0][weight_idx], c[0][6], temp[2]);
c[0][7] = vdotq_s32_h(
src[(7 + src_idx) % 8], weight[0][weight_idx], c[0][7], temp[3]);
}
static MEGDNN_ALWAYS_INLINE void impl(T&, T2&, T3&);
};

template <BiasMode bias_mode, typename Op, int remain_w, int oc_block>
struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 1, oc_block, 1> {
static void impl(
const int8_t* src_ptr, const int8_t* weight_ptr, const int32_t* bias_ptr,
int8_t* dst_ptr, int ic, int ih, int iw, int ld_dst_oc, const Op& op) {
constexpr int stride = 1;
constexpr int filter_height = 1;
constexpr int filter_width = 4;
constexpr int oc_step = 4;
constexpr int loop_ic_step = 1;
constexpr int simd_len = 16;
constexpr int pack_iw_len = 16;
constexpr int src_reg = 8;
constexpr int weight_reg = 1;

const int ic_stride = ih * iw * pack_iw_len;
const int ld_weight_oc = oc_step * filter_height * filter_width * ic;
constexpr int c_dim = OCHelper<oc_block>::val;
int32x4_t c[c_dim][8];
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, oc_step);

for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride;
int8x16_t src[src_reg];
int8x16_t dot4_weight[c_dim][weight_reg];
int16x8_t temp_c[4];
load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>(
dot4_weight, weight_ptr, ld_weight_oc);
load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>(
src, nchw_src_ptr + 0 * iw * pack_iw_len, 0);
cal_helper<0, 0, c_dim, stride>(c, src, dot4_weight, temp_c);

weight_ptr += oc_step * filter_height * filter_width;
}

store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>(
c, op, dst_ptr, ld_dst_oc);
}
};

template <BiasMode bias_mode, typename Op, int remain_w, int oc_block>
struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 2, oc_block, 1> {
static void impl(
const int8_t* src_ptr, const int8_t* weight_ptr, const int32_t* bias_ptr,
int8_t* dst_ptr, int ic, int ih, int iw, int ld_dst_oc, const Op& op) {
constexpr int stride = 1;
constexpr int filter_height = 2;
constexpr int filter_width = 4;
constexpr int oc_step = 4;
constexpr int loop_ic_step = 1;
constexpr int simd_len = 16;
constexpr int pack_iw_len = 16;
constexpr int src_reg = 8;
constexpr int weight_reg = 1;

const int ic_stride = ih * iw * pack_iw_len;
const int ld_weight_oc = oc_step * filter_height * filter_width * ic;
constexpr int c_dim = OCHelper<oc_block>::val;
int32x4_t c[c_dim][8];
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, oc_step);

for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride;
int8x16_t src[src_reg];
int8x16_t dot4_weight[c_dim][weight_reg];
int16x8_t temp_c[4];
load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>(
dot4_weight, weight_ptr, ld_weight_oc);
load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>(
src, nchw_src_ptr + 0 * iw * pack_iw_len, 0);
cal_helper<0, 0, c_dim, stride>(c, src, dot4_weight, temp_c);

load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>(
dot4_weight, weight_ptr + 1 * filter_width * oc_step, ld_weight_oc);
load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>(
src, nchw_src_ptr + 1 * iw * pack_iw_len, 0);
cal_helper<0, 0, c_dim, stride>(c, src, dot4_weight, temp_c);

weight_ptr += oc_step * filter_height * filter_width;
}

store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>(
c, op, dst_ptr, ld_dst_oc);
}
};

template <BiasMode bias_mode, typename Op, int remain_w, int oc_block>
struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 3, oc_block, 1> {
static void impl(
const int8_t* src_ptr, const int8_t* weight_ptr, const int32_t* bias_ptr,
int8_t* dst_ptr, int ic, int ih, int iw, int ld_dst_oc, const Op& op) {
constexpr int stride = 1;
constexpr int filter_height = 3;
constexpr int filter_width = 4;
constexpr int oc_step = 4;
constexpr int loop_ic_step = 1;
constexpr int simd_len = 16;
constexpr int pack_iw_len = 16;
constexpr int src_reg = 8;
constexpr int weight_reg = 1;

const int ic_stride = ih * iw * pack_iw_len;
const int ld_weight_oc = oc_step * filter_height * filter_width * ic;
constexpr int c_dim = OCHelper<oc_block>::val;
int32x4_t c[c_dim][8];
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, oc_step);

for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride;
int8x16_t src[src_reg];
int8x16_t dot4_weight[c_dim][weight_reg];
int16x8_t temp_c[4];
load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>(
dot4_weight, weight_ptr, ld_weight_oc);

load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>(
src, nchw_src_ptr + 0 * iw * pack_iw_len, 0);
cal_helper<0, 0, c_dim, stride>(c, src, dot4_weight, temp_c);
load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>(
dot4_weight, weight_ptr + 1 * filter_width * oc_step, ld_weight_oc);

load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>(
src, nchw_src_ptr + 1 * iw * pack_iw_len, 0);
cal_helper<0, 0, c_dim, stride>(c, src, dot4_weight, temp_c);

load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>(
dot4_weight, weight_ptr + 2 * filter_width * oc_step, ld_weight_oc);
load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>(
src, nchw_src_ptr + 2 * iw * pack_iw_len, 0);
cal_helper<0, 0, c_dim, stride>(c, src, dot4_weight, temp_c);

weight_ptr += oc_step * filter_height * filter_width;
}
store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>(
c, op, dst_ptr, ld_dst_oc);
}
};

template <BiasMode bias_mode, typename Op, int remain_w, int oc_block>
struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 5, oc_block, 1> {
static void impl(
const int8_t* src_ptr, const int8_t* weight_ptr, const int32_t* bias_ptr,
int8_t* dst_ptr, int ic, int ih, int iw, int ld_dst_oc, const Op& op) {
constexpr int stride = 1;
constexpr int filter_height = 5;
constexpr int filter_width = 8;
constexpr int oc_step = 4;
constexpr int loop_ic_step = 1;
constexpr int simd_len = 16;
constexpr int pack_iw_len = 16;
constexpr int src_reg = 8;
constexpr int weight_reg = 2;

const int ic_stride = ih * iw * pack_iw_len;
const int ld_weight_oc = oc_step * filter_height * filter_width * ic;
constexpr int c_dim = OCHelper<oc_block>::val;
int32x4_t c[c_dim][8];
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, oc_step);

for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride;
int8x16_t src[src_reg];
int8x16_t dot4_weight[c_dim][weight_reg];
int16x8_t temp_c[4];
#define cb(step) \
load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>( \
dot4_weight, weight_ptr + step * filter_width * oc_step, ld_weight_oc); \
load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>( \
src, nchw_src_ptr + step * iw * pack_iw_len, 0); \
cal_helper<0, 0, c_dim, stride>(c, src, dot4_weight, temp_c); \
load_helper<4, 0, simd_len, 0, Vld1q_s8>( \
src, nchw_src_ptr + step * iw * pack_iw_len + src_reg * pack_iw_len, 0); \
cal_helper<4, 1, c_dim, stride>(c, src, dot4_weight, temp_c);
UNROLL_CALL_RAW(5, cb);
#undef cb
weight_ptr += oc_step * filter_height * filter_width;
}
store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>(
c, op, dst_ptr, ld_dst_oc);
}
};

template <BiasMode bias_mode, typename Op, int remain_w, int oc_block>
struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 7, oc_block, 1> {
static void impl(
const int8_t* src_ptr, const int8_t* weight_ptr, const int32_t* bias_ptr,
int8_t* dst_ptr, int ic, int ih, int iw, int ld_dst_oc, const Op& op) {
constexpr int stride = 1;
constexpr int filter_height = 7;
constexpr int filter_width = 8;
constexpr int oc_step = 4;
constexpr int loop_ic_step = 1;
constexpr int simd_len = 16;
constexpr int pack_iw_len = 16;
constexpr int src_reg = 8;
constexpr int weight_reg = 2;

const int ic_stride = ih * iw * pack_iw_len;
const int ld_weight_oc = oc_step * filter_height * filter_width * ic;
constexpr int c_dim = OCHelper<oc_block>::val;
int32x4_t c[c_dim][8];
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, oc_step);

for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride;
int8x16_t src[src_reg];
int8x16_t dot4_weight[c_dim][weight_reg];
int16x8_t temp_c[4];
#define cb(step) \
load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>( \
dot4_weight, weight_ptr + step * filter_width * oc_step, ld_weight_oc); \
load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>( \
src, nchw_src_ptr + step * iw * pack_iw_len, 0); \
cal_helper<0, 0, c_dim, stride>(c, src, dot4_weight, temp_c); \
load_helper<4, 0, simd_len, 0, Vld1q_s8>( \
src, nchw_src_ptr + step * iw * pack_iw_len + src_reg * pack_iw_len, 0); \
cal_helper<4, 1, c_dim, stride>(c, src, dot4_weight, temp_c);

UNROLL_CALL_RAW(7, cb);
#undef cb
weight_ptr += oc_step * filter_height * filter_width;
}
store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>(
c, op, dst_ptr, ld_dst_oc);
}
};
} // namespace

namespace int8_direct_nchw_nchw44 {
/**
* pack {oc / 4, fh, fw, ic, 4(oc)} to {oc / 4, ic, fh ,fw/4, 4(oc)*4(fw)}
* pack interleave two adjacent row in filter to one row
* */
template <BiasMode bias_mode, typename Op, size_t filter_size>
struct ConvDiectStrideInt8NchwNchw44<bias_mode, Op, filter_size, 1> {
static void impl(
const int8_t* src, const int8_t* filter, const int32_t* bias, int32_t* temp,
int8_t* dst, const size_t oc, const size_t ic, const size_t ih,
const size_t iw, const size_t oh, const size_t ow, const Op& op) {
MEGDNN_MARK_USED_VAR(temp);
constexpr int stride = 1;
constexpr size_t fh = filter_size;
constexpr size_t fw = (filter_size + 3) / 4 * 4;
constexpr size_t ic_step = 1;
constexpr size_t big_oc_step = 8;
constexpr size_t oc_step = 4;
constexpr size_t ih_step = 1;
constexpr size_t oh_step = 1;
constexpr size_t ow_step = 8;
constexpr size_t stride_h = stride;
constexpr size_t stride_w = stride;
constexpr int pack_iw_len = 16;

const size_t img_stride = oh * ow;
const size_t ow_end = ow / ow_step * ow_step;
const size_t ow_remain = ow - ow_end;
const size_t oc_end = oc / big_oc_step * big_oc_step;
const size_t oc_remain = oc - oc_end;
const int ld_dst_oc = oc_step * img_stride;

using remain_fun = std::function<void(
const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, int iw,
int ld_dst_oc, const Op& op)>;
remain_fun kern_big_oc_remain = nullptr;
remain_fun kern_small_oc_remain = nullptr;
switch (ow_remain) {
#define cb(step) \
case step: \
kern_big_oc_remain = KerNeonXXs2NchwNchw44< \
bias_mode, Op, step, filter_size, big_oc_step, stride>::impl; \
kern_small_oc_remain = KerNeonXXs2NchwNchw44< \
bias_mode, Op, step, filter_size, oc_step, stride>::impl; \
break;

UNROLL_CALL_RAW(8, cb);
default:
megdnn_assert(0, "no remain %zu for kern", ow_remain);
}

for (size_t oc_idx = 0; oc_idx < oc_end; oc_idx += big_oc_step) {
const size_t weight_offset = oc_idx * ic * fh * fw;
for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) {
for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) {
const size_t src_offset =
(oh_idx * stride_h * iw + ow_idx * stride_w * ih_step) *
ic_step * pack_iw_len;
const size_t dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step;

KerNeonXXs2NchwNchw44<
bias_mode, Op, ow_step, filter_size, big_oc_step, stride>::
impl(src + src_offset, filter + weight_offset,
bias + oc_idx, dst + dst_offset, ic, ih, iw, ld_dst_oc,
op);
}
if (ow_remain > 0) {
const size_t src_offset =
(oh_idx * stride_h * iw + ow_end * stride_w * ih_step) *
ic_step * pack_iw_len;
const size_t dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step;
kern_big_oc_remain(
src + src_offset, filter + weight_offset, bias + oc_idx,
dst + dst_offset, ic, ih, iw, ld_dst_oc, op);
}
}
}

if (oc_remain > 0) {
size_t oc_idx = oc_end;
const size_t weight_offset = oc_idx * ic * fh * fw;
for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) {
for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) {
const size_t src_offset =
(oh_idx * stride_h * iw + ow_idx * stride_w * ih_step) *
ic_step * pack_iw_len;
const size_t dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step;
KerNeonXXs2NchwNchw44<
bias_mode, Op, ow_step, filter_size, oc_step, stride>::
impl(src + src_offset, filter + weight_offset,
bias + oc_idx, dst + dst_offset, ic, ih, iw, ld_dst_oc,
op);
}
if (ow_remain > 0) {
const size_t src_offset =
(oh_idx * stride_h * iw + ow_end * stride_w * ih_step) *
ic_step * pack_iw_len;
const size_t dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step;
kern_small_oc_remain(
src + src_offset, filter + weight_offset, bias + oc_idx,
dst + dst_offset, ic, ih, iw, ld_dst_oc, op);
}
}
}
}
};

#define INSTANCE_CONV_KERN_FUN(stride, filter_size, bias_mode, Op) \
template struct megdnn::arm_common::int8_direct_nchw_nchw44:: \
ConvDiectStrideInt8NchwNchw44<bias_mode, Op, filter_size, stride>;

#define INSTANCE_OP_PARAM(stride, filter, bias_mode) \
INSTANCE_CONV_KERN_FUN( \
stride, filter, bias_mode, TypeCvtOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
INSTANCE_CONV_KERN_FUN( \
stride, filter, bias_mode, ReluOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
INSTANCE_CONV_KERN_FUN( \
stride, filter, bias_mode, HSwishOp<dt_qint32 MEGDNN_COMMA dt_qint8>)

#define INSTANCE_BIAS_MODE_PARAM(stride, filter) \
INSTANCE_OP_PARAM(stride, filter, BiasMode::NO_BIAS) \
INSTANCE_OP_PARAM(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS)

#define INSTANCE_CONV_KERN(stride, filter) INSTANCE_BIAS_MODE_PARAM(stride, filter)

} // namespace int8_direct_nchw_nchw44
} // namespace arm_common
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 19
- 0
dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1_1x1.cpp View File

@@ -0,0 +1,19 @@
/**
* \file
* dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1_1x1.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1.h"
using namespace megdnn;
using namespace arm_common;

INSTANCE_CONV_KERN(1, 1);

// vim: syntax=cpp.doxygen

+ 19
- 0
dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1_2x2.cpp View File

@@ -0,0 +1,19 @@
/**
* \file
* dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1_2x2.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1.h"
using namespace megdnn;
using namespace arm_common;

INSTANCE_CONV_KERN(1, 2);

// vim: syntax=cpp.doxygen

+ 19
- 0
dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1_3x3.cpp View File

@@ -0,0 +1,19 @@
/**
* \file
* dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1_3x3.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1.h"
using namespace megdnn;
using namespace arm_common;

INSTANCE_CONV_KERN(1, 3);

// vim: syntax=cpp.doxygen

+ 19
- 0
dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1_5x5.cpp View File

@@ -0,0 +1,19 @@
/**
* \file
* dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1_5x5.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1.h"
using namespace megdnn;
using namespace arm_common;

INSTANCE_CONV_KERN(1, 5);

// vim: syntax=cpp.doxygen

+ 19
- 0
dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1_7x7.cpp View File

@@ -0,0 +1,19 @@
/**
* \file
* dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1_7x7.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1.h"
using namespace megdnn;
using namespace arm_common;

INSTANCE_CONV_KERN(1, 7);

// vim: syntax=cpp.doxygen

dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl_s1.cpp → dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl_2x2s1.cpp View File

@@ -1,6 +1,6 @@
/**
* \file
* dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl_s1.cpp
* dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl_2x2s1.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -12,8 +12,5 @@
*/
#include "src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl.h"
INSTANCE_CONV(2, 1);
INSTANCE_CONV(3, 1);
INSTANCE_CONV(5, 1);
INSTANCE_CONV(7, 1);

// vim: syntax=cpp.doxygen
// vim: syntax=cpp.doxygen

dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl_s2.cpp → dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl_2x2s2.cpp View File

@@ -1,6 +1,6 @@
/**
* \file
* dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl_s2.cpp
* dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl_2x2s2.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -12,8 +12,5 @@
*/
#include "src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl.h"
INSTANCE_CONV(2, 2);
INSTANCE_CONV(3, 2);
INSTANCE_CONV(5, 2);
INSTANCE_CONV(7, 2);

// vim: syntax=cpp.doxygen
// vim: syntax=cpp.doxygen

+ 16
- 0
dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl_3x3s1.cpp View File

@@ -0,0 +1,16 @@
/**
* \file
* dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl_3x3s1.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl.h"
INSTANCE_CONV(3, 1);

// vim: syntax=cpp.doxygen

+ 16
- 0
dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl_3x3s2.cpp View File

@@ -0,0 +1,16 @@
/**
* \file
* dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl_3x3s2.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl.h"
INSTANCE_CONV(3, 2);

// vim: syntax=cpp.doxygen

dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s1.cpp → dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl_5x5s1.cpp View File

@@ -1,6 +1,6 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s1.cpp
* dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl_5x5s1.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -10,5 +10,7 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h"
#include "src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl.h"
INSTANCE_CONV(5, 1);

// vim: syntax=cpp.doxygen

dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s2.cpp → dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl_5x5s2.cpp View File

@@ -1,6 +1,6 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s2.cpp
* dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl_5x5s2.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -10,5 +10,7 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h"
#include "src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl.h"
INSTANCE_CONV(5, 2);

// vim: syntax=cpp.doxygen

dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s1.cpp → dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl_7x7s1.cpp View File

@@ -1,6 +1,6 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s1.cpp
* dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl_7x7s1.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -10,5 +10,7 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h"
#include "src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl.h"
INSTANCE_CONV(7, 1);

// vim: syntax=cpp.doxygen

dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s2.cpp → dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl_7x7s2.cpp View File

@@ -1,6 +1,6 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s2.cpp
* dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl_7x7s2.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -10,5 +10,7 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h"
#include "src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl.h"
INSTANCE_CONV(7, 2);

// vim: syntax=cpp.doxygen

+ 275
- 0
dnn/src/fallback/elemwise/opr_binary_impl.cpp View File

@@ -0,0 +1,275 @@
/**
* \file dnn/src/fallback/elemwise/opr_binary_impl.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "./opr_impl.h"

#include "src/common/elemwise/kern_defs.cuh"
#include "src/common/utils.h"
#include "src/naive/handle.h"

#include "midout.h"

MIDOUT_DECL(megdnn_fallback_elemwise_binary)

namespace megdnn {
namespace fallback {

template <typename dtype, uint32_t mode>
void ElemwiseImpl::binary_kern(const ElemwiseOpParamN<2>& param) {
using ctype = typename DTypeTrait<dtype>::ctype;
using Kern = ElemwiseKern<megcorePlatformCPU, mode, ctype>;

MIDOUT_BEGIN(megdnn_fallback_elemwise_binary, ctype, midout_iv(mode)) {
if (param.max_ndim == 1) {
MIDOUT_BEGIN(
megdnn_fallback_elemwise_binary, ctype, midout_iv(mode),
midout_iv(1)) {
auto tot = param.size;
auto as = param[0].layout.stride[0], bs = param[1].layout.stride[0];
auto src0 = param[0];
auto src1 = param[1];
auto dst_tensor = *m_dst;

MEGDNN_DISPATCH_CPU_KERN_OPR({
ctype* __restrict a = static_cast<ctype*>(src0.raw_ptr());
ctype* __restrict b = static_cast<ctype*>(src1.raw_ptr());
ctype* __restrict dst = static_cast<ctype*>(dst_tensor.raw_ptr());
for (size_t i = 0; i < tot; ++i) {
dst[i] = Kern::apply(a[i * as], b[i * bs]);
}
});
return;
}
MIDOUT_END();
}

if (std::min(param[0].layout.ndim, param[1].layout.ndim) > 1) {
return naive::ElemwiseForwardImpl::exec(*m_src, *m_dst);
}

if (param.max_ndim == 2) {
if (param[0].layout.ndim == 1) {
MIDOUT_BEGIN(
megdnn_fallback_elemwise_binary, ctype, midout_iv(mode),
midout_iv(21)) {
auto as = param[0].layout.stride[0],
bs0 = param[1].layout.stride[0],
bs1 = param[1].layout.stride[1];
auto n0 = param[1].layout.shape[0], n1 = param[1].layout.shape[1];
auto src0 = param[0];
auto src1 = param[1];
auto dst_tensor = *m_dst;

MEGDNN_DISPATCH_CPU_KERN_OPR({
ctype* __restrict a = static_cast<ctype*>(src0.raw_ptr());
ctype* __restrict b = static_cast<ctype*>(src1.raw_ptr());
ctype* __restrict dst =
static_cast<ctype*>(dst_tensor.raw_ptr());
ptrdiff_t toff = 0;
for (size_t i = 0; i < n0; ++i) {
for (size_t j = 0; j < n1; ++j) {
dst[toff] =
Kern::apply(a[as * toff], b[bs0 * i + bs1 * j]);
++toff;
}
}
});
return;
}
MIDOUT_END();
}

MIDOUT_BEGIN(
megdnn_fallback_elemwise_binary, ctype, midout_iv(mode),
midout_iv(22)) {
megdnn_assert(param[1].layout.ndim == 1);
auto bs = param[1].layout.stride[0], as0 = param[0].layout.stride[0],
as1 = param[0].layout.stride[1];
auto n0 = param[0].layout.shape[0], n1 = param[0].layout.shape[1];
auto src0 = param[0];
auto src1 = param[1];
auto dst_tensor = *m_dst;

MEGDNN_DISPATCH_CPU_KERN_OPR({
ctype* __restrict a = static_cast<ctype*>(src0.raw_ptr());
ctype* __restrict b = static_cast<ctype*>(src1.raw_ptr());
ctype* __restrict dst = static_cast<ctype*>(dst_tensor.raw_ptr());
ptrdiff_t toff = 0;
for (size_t i = 0; i < n0; ++i) {
for (size_t j = 0; j < n1; ++j) {
dst[toff] = Kern::apply(a[as0 * i + as1 * j], b[toff * bs]);
++toff;
}
}
});
return;
}
MIDOUT_END();
}

if (param.max_ndim == 3) {
auto brd_101 = [](const TensorND& t) {
auto&& l = t.layout;
return l.ndim == 3 && l.stride[0] == 0 && l.stride[2] == 0;
};
if (param[0].layout.ndim == 1 && brd_101(param[1])) {
MIDOUT_BEGIN(
megdnn_fallback_elemwise_binary, ctype, midout_iv(mode),
midout_iv(31)) {
auto as = param[0].layout.stride[0], bs = param[1].layout.stride[1];
auto n0 = param[1].layout.shape[0], n1 = param[1].layout.shape[1],
n2 = param[1].layout.shape[2];
auto src0 = param[0];
auto src1 = param[1];
auto dst_tensor = *m_dst;

MEGDNN_DISPATCH_CPU_KERN_OPR({
ctype* __restrict a = static_cast<ctype*>(src0.raw_ptr());
ctype* __restrict b = static_cast<ctype*>(src1.raw_ptr());
ctype* __restrict dst =
static_cast<ctype*>(dst_tensor.raw_ptr());
size_t toff = 0;
for (size_t i = 0; i < n0; ++i) {
for (size_t j = 0; j < n1; ++j) {
for (size_t k = 0; k < n2; ++k) {
dst[toff] = Kern::apply(a[as * toff], b[bs * j]);
++toff;
}
}
}
});
return;
}
MIDOUT_END();
}
if (param[1].layout.ndim == 1 && brd_101(param[0])) {
MIDOUT_BEGIN(
megdnn_fallback_elemwise_binary, ctype, midout_iv(mode),
midout_iv(32)) {
auto as = param[0].layout.stride[1], bs = param[1].layout.stride[0];
auto n0 = param[0].layout.shape[0], n1 = param[0].layout.shape[1],
n2 = param[0].layout.shape[2];
auto src0 = param[0];
auto src1 = param[1];
auto dst_tensor = *m_dst;
MEGDNN_DISPATCH_CPU_KERN_OPR({
ctype* __restrict a = static_cast<ctype*>(src0.raw_ptr());
ctype* __restrict b = static_cast<ctype*>(src1.raw_ptr());
ctype* __restrict dst =
static_cast<ctype*>(dst_tensor.raw_ptr());
size_t toff = 0;
for (size_t i = 0; i < n0; ++i) {
for (size_t j = 0; j < n1; ++j) {
for (size_t k = 0; k < n2; ++k) {
dst[toff] = Kern::apply(a[as * j], b[bs * toff]);
++toff;
}
}
}
});
return;
}
MIDOUT_END();
}
}

naive::ElemwiseForwardImpl::exec(*m_src, *m_dst);
}
MIDOUT_END();
}

#define SWITCH_DTYPE(_cat, _cb) \
switch (m_dst->layout.dtype.enumv()) { \
MEGDNN_FOREACH_COMPUTING_DTYPE_##_cat(_cb) default \
: megdnn_throw("bad dtype"); \
}

template <uint32_t mode>
void ElemwiseImpl::exec_BINARY_INT() {
auto param = make_elemwise_op_param<2>();
#define cb(_dt) \
case DTypeTrait<_dt>::enumv: \
return binary_kern<_dt, mode>(param);

SWITCH_DTYPE(INT, cb)

#undef cb
}

template <uint32_t mode>
void ElemwiseImpl::exec_BINARY_FLOAT() {
auto param = make_elemwise_op_param<2>();
#define cb(_dt) \
case DTypeTrait<_dt>::enumv: \
return binary_kern<_dt, mode>(param);

SWITCH_DTYPE(FLOAT, cb)

#undef cb
}

#undef SWITCH_DTYPE

#undef SWITCH_DTYPE
using Mode = param_enumv::Elemwise::Mode;
#define INST(mode) template void megdnn::fallback::ElemwiseImpl::exec_BINARY_INT<mode>()
INST(Mode::ABS_GRAD);
INST(Mode::ADD);
INST(Mode::FLOOR_DIV);
INST(Mode::MAX);
INST(Mode::MIN);
INST(Mode::MOD);
INST(Mode::MUL);
INST(Mode::SIGMOID_GRAD);
INST(Mode::SUB);
INST(Mode::SWITCH_GT0);
INST(Mode::TANH_GRAD);
INST(Mode::LT);
INST(Mode::LEQ);
INST(Mode::EQ);
INST(Mode::SHL);
INST(Mode::SHR);
INST(Mode::FUSE_ADD_RELU);
INST(Mode::RMULH);
#undef INST

#define INST(mode) \
template void megdnn::fallback::ElemwiseImpl::exec_BINARY_FLOAT<mode>()
INST(Mode::ABS_GRAD);
INST(Mode::ADD);
INST(Mode::FLOOR_DIV);
INST(Mode::MAX);
INST(Mode::MIN);
INST(Mode::MOD);
INST(Mode::MUL);
INST(Mode::POW);
INST(Mode::SIGMOID_GRAD);
INST(Mode::SUB);
INST(Mode::SWITCH_GT0);
INST(Mode::TANH_GRAD);
INST(Mode::TRUE_DIV);
INST(Mode::LOG_SUM_EXP);
INST(Mode::LT);
INST(Mode::LEQ);
INST(Mode::EQ);
INST(Mode::FUSE_ADD_RELU);
INST(Mode::FUSE_ADD_SIGMOID);
INST(Mode::FUSE_ADD_TANH);
INST(Mode::FAST_TANH_GRAD);
INST(Mode::ATAN2);
INST(Mode::H_SWISH_GRAD);
INST(Mode::FUSE_ADD_H_SWISH);
INST(Mode::SILU_GRAD);
INST(Mode::GELU_GRAD);
#undef INST
} // namespace fallback
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 0
- 252
dnn/src/fallback/elemwise/opr_impl.cpp View File

@@ -16,8 +16,6 @@

#include "midout.h"

MIDOUT_DECL(megdnn_fallback_elemwise_unary)
MIDOUT_DECL(megdnn_fallback_elemwise_binary)
MIDOUT_DECL(megdnn_fallback_elemwise_exec_UNARY_INT)
MIDOUT_DECL(megdnn_fallback_elemwise_exec_UNARY_FLOAT)
MIDOUT_DECL(megdnn_fallback_elemwise_exec_BINARY_INT)
@@ -26,200 +24,6 @@ MIDOUT_DECL(megdnn_fallback_elemwise_exec_BINARY_FLOAT)
namespace megdnn {
namespace fallback {

template <typename dtype, uint32_t mode>
void ElemwiseImpl::unary_kern(const ElemwiseOpParamN<1>& param) {
using ctype = typename DTypeTrait<dtype>::ctype;
using Kern = ElemwiseKern<megcorePlatformCPU, mode, ctype>;
MIDOUT_BEGIN(megdnn_fallback_elemwise_unary, ctype, midout_iv(mode)) {
// only specialize for the most common 1-dim case
auto tot = param.size;
auto stride = param[0].layout.stride[0];
auto src0 = param[0];
auto dst_tensor = *m_dst;
if (param.max_ndim == 1) {
MIDOUT_BEGIN(
megdnn_fallback_elemwise_unary, ctype, midout_iv(mode),
midout_iv(1)) {
MEGDNN_DISPATCH_CPU_KERN_OPR({
ctype* __restrict src = static_cast<ctype*>(src0.raw_ptr());
ctype* __restrict dst = static_cast<ctype*>(dst_tensor.raw_ptr());
for (size_t i = 0; i < tot; ++i) {
dst[i] = Kern::apply(src[i * stride]);
}
});
return;
}
MIDOUT_END();
}
naive::ElemwiseForwardImpl::exec(*m_src, *m_dst);
}
MIDOUT_END();
}

template <typename dtype, uint32_t mode>
void ElemwiseImpl::binary_kern(const ElemwiseOpParamN<2>& param) {
using ctype = typename DTypeTrait<dtype>::ctype;
using Kern = ElemwiseKern<megcorePlatformCPU, mode, ctype>;

MIDOUT_BEGIN(megdnn_fallback_elemwise_binary, ctype, midout_iv(mode)) {
if (param.max_ndim == 1) {
MIDOUT_BEGIN(
megdnn_fallback_elemwise_binary, ctype, midout_iv(mode),
midout_iv(1)) {
auto tot = param.size;
auto as = param[0].layout.stride[0], bs = param[1].layout.stride[0];
auto src0 = param[0];
auto src1 = param[1];
auto dst_tensor = *m_dst;

MEGDNN_DISPATCH_CPU_KERN_OPR({
ctype* __restrict a = static_cast<ctype*>(src0.raw_ptr());
ctype* __restrict b = static_cast<ctype*>(src1.raw_ptr());
ctype* __restrict dst = static_cast<ctype*>(dst_tensor.raw_ptr());
for (size_t i = 0; i < tot; ++i) {
dst[i] = Kern::apply(a[i * as], b[i * bs]);
}
});
return;
}
MIDOUT_END();
}

if (std::min(param[0].layout.ndim, param[1].layout.ndim) > 1) {
return naive::ElemwiseForwardImpl::exec(*m_src, *m_dst);
}

if (param.max_ndim == 2) {
if (param[0].layout.ndim == 1) {
MIDOUT_BEGIN(
megdnn_fallback_elemwise_binary, ctype, midout_iv(mode),
midout_iv(21)) {
auto as = param[0].layout.stride[0],
bs0 = param[1].layout.stride[0],
bs1 = param[1].layout.stride[1];
auto n0 = param[1].layout.shape[0], n1 = param[1].layout.shape[1];
auto src0 = param[0];
auto src1 = param[1];
auto dst_tensor = *m_dst;

MEGDNN_DISPATCH_CPU_KERN_OPR({
ctype* __restrict a = static_cast<ctype*>(src0.raw_ptr());
ctype* __restrict b = static_cast<ctype*>(src1.raw_ptr());
ctype* __restrict dst =
static_cast<ctype*>(dst_tensor.raw_ptr());
ptrdiff_t toff = 0;
for (size_t i = 0; i < n0; ++i) {
for (size_t j = 0; j < n1; ++j) {
dst[toff] =
Kern::apply(a[as * toff], b[bs0 * i + bs1 * j]);
++toff;
}
}
});
return;
}
MIDOUT_END();
}

MIDOUT_BEGIN(
megdnn_fallback_elemwise_binary, ctype, midout_iv(mode),
midout_iv(22)) {
megdnn_assert(param[1].layout.ndim == 1);
auto bs = param[1].layout.stride[0], as0 = param[0].layout.stride[0],
as1 = param[0].layout.stride[1];
auto n0 = param[0].layout.shape[0], n1 = param[0].layout.shape[1];
auto src0 = param[0];
auto src1 = param[1];
auto dst_tensor = *m_dst;

MEGDNN_DISPATCH_CPU_KERN_OPR({
ctype* __restrict a = static_cast<ctype*>(src0.raw_ptr());
ctype* __restrict b = static_cast<ctype*>(src1.raw_ptr());
ctype* __restrict dst = static_cast<ctype*>(dst_tensor.raw_ptr());
ptrdiff_t toff = 0;
for (size_t i = 0; i < n0; ++i) {
for (size_t j = 0; j < n1; ++j) {
dst[toff] = Kern::apply(a[as0 * i + as1 * j], b[toff * bs]);
++toff;
}
}
});
return;
}
MIDOUT_END();
}

if (param.max_ndim == 3) {
auto brd_101 = [](const TensorND& t) {
auto&& l = t.layout;
return l.ndim == 3 && l.stride[0] == 0 && l.stride[2] == 0;
};
if (param[0].layout.ndim == 1 && brd_101(param[1])) {
MIDOUT_BEGIN(
megdnn_fallback_elemwise_binary, ctype, midout_iv(mode),
midout_iv(31)) {
auto as = param[0].layout.stride[0], bs = param[1].layout.stride[1];
auto n0 = param[1].layout.shape[0], n1 = param[1].layout.shape[1],
n2 = param[1].layout.shape[2];
auto src0 = param[0];
auto src1 = param[1];
auto dst_tensor = *m_dst;

MEGDNN_DISPATCH_CPU_KERN_OPR({
ctype* __restrict a = static_cast<ctype*>(src0.raw_ptr());
ctype* __restrict b = static_cast<ctype*>(src1.raw_ptr());
ctype* __restrict dst =
static_cast<ctype*>(dst_tensor.raw_ptr());
size_t toff = 0;
for (size_t i = 0; i < n0; ++i) {
for (size_t j = 0; j < n1; ++j) {
for (size_t k = 0; k < n2; ++k) {
dst[toff] = Kern::apply(a[as * toff], b[bs * j]);
++toff;
}
}
}
});
return;
}
MIDOUT_END();
}
if (param[1].layout.ndim == 1 && brd_101(param[0])) {
MIDOUT_BEGIN(
megdnn_fallback_elemwise_binary, ctype, midout_iv(mode),
midout_iv(32)) {
auto as = param[0].layout.stride[1], bs = param[1].layout.stride[0];
auto n0 = param[0].layout.shape[0], n1 = param[0].layout.shape[1],
n2 = param[0].layout.shape[2];
auto src0 = param[0];
auto src1 = param[1];
auto dst_tensor = *m_dst;
MEGDNN_DISPATCH_CPU_KERN_OPR({
ctype* __restrict a = static_cast<ctype*>(src0.raw_ptr());
ctype* __restrict b = static_cast<ctype*>(src1.raw_ptr());
ctype* __restrict dst =
static_cast<ctype*>(dst_tensor.raw_ptr());
size_t toff = 0;
for (size_t i = 0; i < n0; ++i) {
for (size_t j = 0; j < n1; ++j) {
for (size_t k = 0; k < n2; ++k) {
dst[toff] = Kern::apply(a[as * j], b[bs * toff]);
++toff;
}
}
}
});
return;
}
MIDOUT_END();
}
}

naive::ElemwiseForwardImpl::exec(*m_src, *m_dst);
}
MIDOUT_END();
}

void ElemwiseImpl::exec(const TensorNDArray& srcs, _megdnn_tensor_out dst) {
if (!dst.layout.is_contiguous()) {
return naive::ElemwiseForwardImpl::exec(srcs, dst);
@@ -278,62 +82,6 @@ void ElemwiseImpl::exec(const TensorNDArray& srcs, _megdnn_tensor_out dst) {
naive::ElemwiseForwardImpl::exec(srcs, dst);
}

#define SWITCH_DTYPE(_cat, _cb) \
switch (m_dst->layout.dtype.enumv()) { \
MEGDNN_FOREACH_COMPUTING_DTYPE_##_cat(_cb) default \
: megdnn_throw("bad dtype"); \
}

template <uint32_t mode>
void ElemwiseImpl::exec_UNARY_INT() {
auto param = make_elemwise_op_param<1>();
#define cb(_dt) \
case DTypeTrait<_dt>::enumv: \
return unary_kern<_dt, mode>(param);

SWITCH_DTYPE(INT, cb)

#undef cb
}

template <uint32_t mode>
void ElemwiseImpl::exec_UNARY_FLOAT() {
auto param = make_elemwise_op_param<1>();
#define cb(_dt) \
case DTypeTrait<_dt>::enumv: \
return unary_kern<_dt, mode>(param);

SWITCH_DTYPE(FLOAT, cb)

#undef cb
}

template <uint32_t mode>
void ElemwiseImpl::exec_BINARY_INT() {
auto param = make_elemwise_op_param<2>();
#define cb(_dt) \
case DTypeTrait<_dt>::enumv: \
return binary_kern<_dt, mode>(param);

SWITCH_DTYPE(INT, cb)

#undef cb
}

template <uint32_t mode>
void ElemwiseImpl::exec_BINARY_FLOAT() {
auto param = make_elemwise_op_param<2>();
#define cb(_dt) \
case DTypeTrait<_dt>::enumv: \
return binary_kern<_dt, mode>(param);

SWITCH_DTYPE(FLOAT, cb)

#undef cb
}

#undef SWITCH_DTYPE

} // namespace fallback
} // namespace megdnn



+ 122
- 0
dnn/src/fallback/elemwise/opr_unary_impl.cpp View File

@@ -0,0 +1,122 @@
/**
* \file dnn/src/fallback/elemwise/opr_unary_impl.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "./opr_impl.h"

#include "src/common/elemwise/kern_defs.cuh"
#include "src/common/utils.h"
#include "src/naive/handle.h"

#include "midout.h"

MIDOUT_DECL(megdnn_fallback_elemwise_unary)

namespace megdnn {
namespace fallback {

template <typename dtype, uint32_t mode>
void ElemwiseImpl::unary_kern(const ElemwiseOpParamN<1>& param) {
using ctype = typename DTypeTrait<dtype>::ctype;
using Kern = ElemwiseKern<megcorePlatformCPU, mode, ctype>;
MIDOUT_BEGIN(megdnn_fallback_elemwise_unary, ctype, midout_iv(mode)) {
// only specialize for the most common 1-dim case
auto tot = param.size;
auto stride = param[0].layout.stride[0];
auto src0 = param[0];
auto dst_tensor = *m_dst;
if (param.max_ndim == 1) {
MIDOUT_BEGIN(
megdnn_fallback_elemwise_unary, ctype, midout_iv(mode),
midout_iv(1)) {
MEGDNN_DISPATCH_CPU_KERN_OPR({
ctype* __restrict src = static_cast<ctype*>(src0.raw_ptr());
ctype* __restrict dst = static_cast<ctype*>(dst_tensor.raw_ptr());
for (size_t i = 0; i < tot; ++i) {
dst[i] = Kern::apply(src[i * stride]);
}
});
return;
}
MIDOUT_END();
}
naive::ElemwiseForwardImpl::exec(*m_src, *m_dst);
}
MIDOUT_END();
}

#define SWITCH_DTYPE(_cat, _cb) \
switch (m_dst->layout.dtype.enumv()) { \
MEGDNN_FOREACH_COMPUTING_DTYPE_##_cat(_cb) default \
: megdnn_throw("bad dtype"); \
}

template <uint32_t mode>
void ElemwiseImpl::exec_UNARY_INT() {
auto param = make_elemwise_op_param<1>();
#define cb(_dt) \
case DTypeTrait<_dt>::enumv: \
return unary_kern<_dt, mode>(param);

SWITCH_DTYPE(INT, cb)

#undef cb
}

template <uint32_t mode>
void ElemwiseImpl::exec_UNARY_FLOAT() {
auto param = make_elemwise_op_param<1>();
#define cb(_dt) \
case DTypeTrait<_dt>::enumv: \
return unary_kern<_dt, mode>(param);

SWITCH_DTYPE(FLOAT, cb)

#undef cb
}

#undef SWITCH_DTYPE
using Mode = param_enumv::Elemwise::Mode;
#define INST(mode) template void megdnn::fallback::ElemwiseImpl::exec_UNARY_INT<mode>();
INST(Mode::RELU);
INST(Mode::ABS);
INST(Mode::NEGATE);
#undef INST

#define INST(mode) \
template void megdnn::fallback::ElemwiseImpl::exec_UNARY_FLOAT<mode>();
INST(Mode::RELU);
INST(Mode::ABS);
INST(Mode::ACOS);
INST(Mode::ASIN);
INST(Mode::CEIL);
INST(Mode::COS);
INST(Mode::EXP);
INST(Mode::EXPM1);
INST(Mode::FLOOR);
INST(Mode::LOG);
INST(Mode::LOG1P);
INST(Mode::NEGATE);
INST(Mode::SIGMOID);
INST(Mode::SIN);
INST(Mode::TANH);
INST(Mode::FAST_TANH);
INST(Mode::ROUND);
INST(Mode::ERF);
INST(Mode::ERFINV);
INST(Mode::ERFC);
INST(Mode::ERFCINV);
INST(Mode::H_SWISH);
INST(Mode::SILU);
INST(Mode::GELU);
#undef INST
} // namespace fallback
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 138
- 0
dnn/src/naive/elemwise_multi_type/opr_impl_1.cpp View File

@@ -0,0 +1,138 @@
/**
* \file dnn/src/naive/elemwise_multi_type/opr_impl_1.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include "./opr_impl.h"
#include "megdnn/tensor_iter.h"
#include "src/common/elemwise/kern_defs.cuh"
#include "src/common/elemwise_multi_type/kern_defs.cuh"
#include "src/naive/handle.h"

using namespace megdnn;
using namespace naive;

void ElemwiseMultiTypeImpl::on_fuse_mul_add3_int16x32x32x32(
const ElemwiseOpParamN<3>& param, const TensorND& dst) {
auto size = param.size;
auto src0 = param[0];
auto src1 = param[1];
auto src2 = param[2];
auto work = [src0, src1, src2, size, dst]() {
auto i0 = tensor_iter_valonly<dt_int16>(src0).begin();
auto i1 = tensor_iter_valonly<dt_int32>(src1).begin();
auto i2 = tensor_iter_valonly<dt_int32>(src2).begin();
auto dst_ptr = dst.ptr<dt_int32>();
for (size_t i = 0; i < size; ++i) {
dst_ptr[i] = (*i0) * (*i1) + (*i2);
++i0;
++i1;
++i2;
}
};
MEGDNN_DISPATCH_CPU_KERN_OPR(work());
}

void ElemwiseMultiTypeImpl::on_fuse_mul_add3_int16xf32xf32xf32(
const ElemwiseOpParamN<3>& param, const TensorND& dst) {
auto size = param.size;
auto src0 = param[0];
auto src1 = param[1];
auto src2 = param[2];
auto work = [src0, src1, src2, size, dst]() {
auto i0 = tensor_iter_valonly<dt_int16>(src0).begin();
auto i1 = tensor_iter_valonly<dt_float32>(src1).begin();
auto i2 = tensor_iter_valonly<dt_float32>(src2).begin();
auto dst_ptr = dst.ptr<dt_float32>();
for (size_t i = 0; i < size; ++i) {
dst_ptr[i] = (*i0) * (*i1) + (*i2);
++i0;
++i1;
++i2;
}
};
MEGDNN_DISPATCH_CPU_KERN_OPR(work());
}

void ElemwiseMultiTypeImpl::on_fuse_mul_add3_uint8xf32xf32xf32(
const ElemwiseOpParamN<3>& param, const TensorND& dst) {
auto size = param.size;
auto src0 = param[0];
auto src1 = param[1];
auto src2 = param[2];
auto work = [src0, src1, src2, size, dst]() {
auto i0 = tensor_iter_valonly<dt_uint8>(src0).begin();
auto i1 = tensor_iter_valonly<dt_float32>(src1).begin();
auto i2 = tensor_iter_valonly<dt_float32>(src2).begin();
auto dst_ptr = dst.ptr<dt_float32>();
for (size_t i = 0; i < size; ++i) {
dst_ptr[i] = (*i0) * (*i1) + (*i2);
++i0;
++i1;
++i2;
}
};
MEGDNN_DISPATCH_CPU_KERN_OPR(work());
}

void ElemwiseMultiTypeImpl::on_mul_int16xf32xf32(
const ElemwiseOpParamN<2>& param, const TensorND& dst) {
auto size = param.size;
auto src0 = param[0];
auto src1 = param[1];
auto work = [src0, src1, size, dst]() {
auto i0 = tensor_iter_valonly<dt_int16>(src0).begin();
auto i1 = tensor_iter_valonly<dt_float32>(src1).begin();
auto dst_ptr = dst.ptr<dt_float32>();
for (size_t i = 0; i < size; ++i) {
dst_ptr[i] = (*i0) * (*i1);
++i0;
++i1;
}
};
MEGDNN_DISPATCH_CPU_KERN_OPR(work());
}

void ElemwiseMultiTypeImpl::on_fuse_mul_add3_iXxf32xf32xi8(
const ElemwiseOpParamN<3>& param, const TensorND& dst) {
switch (param[0].layout.dtype.enumv()) {
#define cb(t) \
case DTypeTrait<t>::enumv: \
return dispatch_fma3_iXxf32xf32xi8<DTypeTrait<t>::ctype>(param, dst);
MEGDNN_FOREACH_COMPUTING_DTYPE_INT(cb)
#undef cb
default:
megdnn_throw("unsupported src dtype");
}
}

template <typename ctype>
void ElemwiseMultiTypeImpl::dispatch_fma3_iXxf32xf32xi8(
const ElemwiseOpParamN<3>& param, const TensorND& dst) {
auto size = param.size;
auto src0 = param[0];
auto src1 = param[1];
auto src2 = param[2];
auto work = [src0, src1, src2, size, dst]() {
elemwise_multi_type::Fma3iXxf32xf32xiYOp<ctype, dt_int8> op;
auto i0 = tensor_iter_valonly<ctype>(src0).begin();
auto i1 = tensor_iter_valonly<dt_float32>(src1).begin();
auto i2 = tensor_iter_valonly<dt_float32>(src2).begin();
auto dst_ptr = dst.ptr<dt_int8>();
for (size_t i = 0; i < size; ++i) {
dst_ptr[i] = op(*i0, *i1, *i2);
++i0;
++i1;
++i2;
}
};
MEGDNN_DISPATCH_CPU_KERN_OPR(work());
}

// vim: syntax=cpp.doxygen

+ 115
- 0
dnn/src/naive/elemwise_multi_type/opr_impl_2.cpp View File

@@ -0,0 +1,115 @@
/**
* \file dnn/src/naive/elemwise_multi_type/opr_impl_2.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include "./opr_impl.h"
#include "megdnn/tensor_iter.h"
#include "src/common/elemwise/kern_defs.cuh"
#include "src/common/elemwise_multi_type/kern_defs.cuh"
#include "src/naive/handle.h"

using namespace megdnn;
using namespace naive;

void ElemwiseMultiTypeImpl::on_round_shr_saturate_iXxi8xi8(
const ElemwiseOpParamN<2>& param, const TensorND& dst) {
switch (param[0].layout.dtype.enumv()) {
#define cb(t) \
case DTypeTrait<t>::enumv: \
return dispatch_round_shr_saturate_iXxi8xiX<DTypeTrait<t>::ctype, dt_int8>( \
param, dst);
MEGDNN_FOREACH_COMPUTING_DTYPE_INT(cb)
#undef cb
default:
megdnn_throw("unsupported src dtype");
}
}

template <typename ctype, typename dst_ctype>
void ElemwiseMultiTypeImpl::dispatch_round_shr_saturate_iXxi8xiX(
const ElemwiseOpParamN<2>& param, const TensorND& dst) {
auto src0 = param[0];
auto src1 = param[1];
auto size = param.size;
auto work = [src0, src1, size, dst]() {
// This is needed as these iterators are captured as const value.
auto iA = tensor_iter_valonly<ctype>(src0).begin();
auto iB = tensor_iter_valonly<dt_int8>(src1).begin();
auto pD = dst.ptr<dst_ctype>();
for (size_t i = 0; i < size; i++) {
*pD = elemwise_multi_type::round_shr_saturate<ctype, dst_ctype>(*iA, *iB);
++iA;
++iB;
++pD;
}
};
MEGDNN_DISPATCH_CPU_KERN_OPR(work());
}
template <typename ctype>
void ElemwiseMultiTypeImpl::dispatch_fuse_add_rmulh_round_shr_saturate(
const ElemwiseOpParamN<6>& param, const TensorND& dst) {
auto size = param.size;
auto src0 = param[0];
auto src1 = param[1];
auto src2 = param[2];
auto src3 = param[3];
auto src4 = param[4];
auto src5 = param[5];
auto work = [size, src0, src1, src2, src3, src4, src5, dst]() {
auto i0 = tensor_iter_valonly<ctype>(src0).begin();
auto i1 = tensor_iter_valonly<ctype>(src1).begin();
auto i2 = tensor_iter_valonly<ctype>(src2).begin();
auto ioff = tensor_iter_valonly<dt_int8>(src3).begin();
auto imin = tensor_iter_valonly<dt_int8>(src4).begin();
auto imax = tensor_iter_valonly<dt_int8>(src5).begin();
auto dst_ptr = dst.ptr<dt_int8>();
for (size_t i = 0; i < size; ++i) {
auto res = elemwise_multi_type::round_shr_saturate<ctype, dt_int8>(
round_mulh_saturate<ctype>(*i0 + *i1, *i2), *ioff);
res = std::min(res, *imax);
res = std::max(res, *imin);
dst_ptr[i] = res;
++i0;
++i1;
++i2;
++ioff;
++imin;
++imax;
}
};
MEGDNN_DISPATCH_CPU_KERN_OPR(work());
}

void ElemwiseMultiTypeImpl::on_fuse_add_rmulh_round_shr_saturate_int16x16x16x8(
const ElemwiseOpParamN<6>& param, const TensorND& dst) {
dispatch_fuse_add_rmulh_round_shr_saturate<dt_int16>(param, dst);
}

void ElemwiseMultiTypeImpl::on_fuse_add_rmulh_round_shr_saturate_int32x32x32x8(
const ElemwiseOpParamN<6>& param, const TensorND& dst) {
dispatch_fuse_add_rmulh_round_shr_saturate<dt_int32>(param, dst);
}

void ElemwiseMultiTypeImpl::on_round_shr_saturate_iXxi8xi16(
const ElemwiseOpParamN<2>& param, const TensorND& dst) {
switch (param[0].layout.dtype.enumv()) {
#define cb(t) \
case DTypeTrait<t>::enumv: \
return dispatch_round_shr_saturate_iXxi8xiX<DTypeTrait<t>::ctype, dt_int16>( \
param, dst);
cb(::megdnn::dtype::Int32);
cb(::megdnn::dtype::Int16);
#undef cb
default:
megdnn_throw("unsupported src dtype");
}
}

// vim: syntax=cpp.doxygen

dnn/src/naive/elemwise_multi_type/opr_impl.cpp → dnn/src/naive/elemwise_multi_type/opr_impl_3.cpp View File

@@ -1,5 +1,5 @@
/**
* \file dnn/src/naive/elemwise_multi_type/opr_impl.cpp
* \file dnn/src/naive/elemwise_multi_type/opr_impl_3.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -18,218 +18,6 @@
using namespace megdnn;
using namespace naive;

void ElemwiseMultiTypeImpl::on_fuse_mul_add3_int16x32x32x32(
const ElemwiseOpParamN<3>& param, const TensorND& dst) {
auto size = param.size;
auto src0 = param[0];
auto src1 = param[1];
auto src2 = param[2];
auto work = [src0, src1, src2, size, dst]() {
auto i0 = tensor_iter_valonly<dt_int16>(src0).begin();
auto i1 = tensor_iter_valonly<dt_int32>(src1).begin();
auto i2 = tensor_iter_valonly<dt_int32>(src2).begin();
auto dst_ptr = dst.ptr<dt_int32>();
for (size_t i = 0; i < size; ++i) {
dst_ptr[i] = (*i0) * (*i1) + (*i2);
++i0;
++i1;
++i2;
}
};
MEGDNN_DISPATCH_CPU_KERN_OPR(work());
}

void ElemwiseMultiTypeImpl::on_fuse_mul_add3_int16xf32xf32xf32(
const ElemwiseOpParamN<3>& param, const TensorND& dst) {
auto size = param.size;
auto src0 = param[0];
auto src1 = param[1];
auto src2 = param[2];
auto work = [src0, src1, src2, size, dst]() {
auto i0 = tensor_iter_valonly<dt_int16>(src0).begin();
auto i1 = tensor_iter_valonly<dt_float32>(src1).begin();
auto i2 = tensor_iter_valonly<dt_float32>(src2).begin();
auto dst_ptr = dst.ptr<dt_float32>();
for (size_t i = 0; i < size; ++i) {
dst_ptr[i] = (*i0) * (*i1) + (*i2);
++i0;
++i1;
++i2;
}
};
MEGDNN_DISPATCH_CPU_KERN_OPR(work());
}

void ElemwiseMultiTypeImpl::on_fuse_mul_add3_uint8xf32xf32xf32(
const ElemwiseOpParamN<3>& param, const TensorND& dst) {
auto size = param.size;
auto src0 = param[0];
auto src1 = param[1];
auto src2 = param[2];
auto work = [src0, src1, src2, size, dst]() {
auto i0 = tensor_iter_valonly<dt_uint8>(src0).begin();
auto i1 = tensor_iter_valonly<dt_float32>(src1).begin();
auto i2 = tensor_iter_valonly<dt_float32>(src2).begin();
auto dst_ptr = dst.ptr<dt_float32>();
for (size_t i = 0; i < size; ++i) {
dst_ptr[i] = (*i0) * (*i1) + (*i2);
++i0;
++i1;
++i2;
}
};
MEGDNN_DISPATCH_CPU_KERN_OPR(work());
}

void ElemwiseMultiTypeImpl::on_mul_int16xf32xf32(
const ElemwiseOpParamN<2>& param, const TensorND& dst) {
auto size = param.size;
auto src0 = param[0];
auto src1 = param[1];
auto work = [src0, src1, size, dst]() {
auto i0 = tensor_iter_valonly<dt_int16>(src0).begin();
auto i1 = tensor_iter_valonly<dt_float32>(src1).begin();
auto dst_ptr = dst.ptr<dt_float32>();
for (size_t i = 0; i < size; ++i) {
dst_ptr[i] = (*i0) * (*i1);
++i0;
++i1;
}
};
MEGDNN_DISPATCH_CPU_KERN_OPR(work());
}

void ElemwiseMultiTypeImpl::on_fuse_mul_add3_iXxf32xf32xi8(
const ElemwiseOpParamN<3>& param, const TensorND& dst) {
switch (param[0].layout.dtype.enumv()) {
#define cb(t) \
case DTypeTrait<t>::enumv: \
return dispatch_fma3_iXxf32xf32xi8<DTypeTrait<t>::ctype>(param, dst);
MEGDNN_FOREACH_COMPUTING_DTYPE_INT(cb)
#undef cb
default:
megdnn_throw("unsupported src dtype");
}
}

template <typename ctype>
void ElemwiseMultiTypeImpl::dispatch_fma3_iXxf32xf32xi8(
const ElemwiseOpParamN<3>& param, const TensorND& dst) {
auto size = param.size;
auto src0 = param[0];
auto src1 = param[1];
auto src2 = param[2];
auto work = [src0, src1, src2, size, dst]() {
elemwise_multi_type::Fma3iXxf32xf32xiYOp<ctype, dt_int8> op;
auto i0 = tensor_iter_valonly<ctype>(src0).begin();
auto i1 = tensor_iter_valonly<dt_float32>(src1).begin();
auto i2 = tensor_iter_valonly<dt_float32>(src2).begin();
auto dst_ptr = dst.ptr<dt_int8>();
for (size_t i = 0; i < size; ++i) {
dst_ptr[i] = op(*i0, *i1, *i2);
++i0;
++i1;
++i2;
}
};
MEGDNN_DISPATCH_CPU_KERN_OPR(work());
}

void ElemwiseMultiTypeImpl::on_round_shr_saturate_iXxi8xi8(
const ElemwiseOpParamN<2>& param, const TensorND& dst) {
switch (param[0].layout.dtype.enumv()) {
#define cb(t) \
case DTypeTrait<t>::enumv: \
return dispatch_round_shr_saturate_iXxi8xiX<DTypeTrait<t>::ctype, dt_int8>( \
param, dst);
MEGDNN_FOREACH_COMPUTING_DTYPE_INT(cb)
#undef cb
default:
megdnn_throw("unsupported src dtype");
}
}

template <typename ctype, typename dst_ctype>
void ElemwiseMultiTypeImpl::dispatch_round_shr_saturate_iXxi8xiX(
const ElemwiseOpParamN<2>& param, const TensorND& dst) {
auto src0 = param[0];
auto src1 = param[1];
auto size = param.size;
auto work = [src0, src1, size, dst]() {
// This is needed as these iterators are captured as const value.
auto iA = tensor_iter_valonly<ctype>(src0).begin();
auto iB = tensor_iter_valonly<dt_int8>(src1).begin();
auto pD = dst.ptr<dst_ctype>();
for (size_t i = 0; i < size; i++) {
*pD = elemwise_multi_type::round_shr_saturate<ctype, dst_ctype>(*iA, *iB);
++iA;
++iB;
++pD;
}
};
MEGDNN_DISPATCH_CPU_KERN_OPR(work());
}

template <typename ctype>
void ElemwiseMultiTypeImpl::dispatch_fuse_add_rmulh_round_shr_saturate(
const ElemwiseOpParamN<6>& param, const TensorND& dst) {
auto size = param.size;
auto src0 = param[0];
auto src1 = param[1];
auto src2 = param[2];
auto src3 = param[3];
auto src4 = param[4];
auto src5 = param[5];
auto work = [size, src0, src1, src2, src3, src4, src5, dst]() {
auto i0 = tensor_iter_valonly<ctype>(src0).begin();
auto i1 = tensor_iter_valonly<ctype>(src1).begin();
auto i2 = tensor_iter_valonly<ctype>(src2).begin();
auto ioff = tensor_iter_valonly<dt_int8>(src3).begin();
auto imin = tensor_iter_valonly<dt_int8>(src4).begin();
auto imax = tensor_iter_valonly<dt_int8>(src5).begin();
auto dst_ptr = dst.ptr<dt_int8>();
for (size_t i = 0; i < size; ++i) {
auto res = elemwise_multi_type::round_shr_saturate<ctype, dt_int8>(
round_mulh_saturate<ctype>(*i0 + *i1, *i2), *ioff);
res = std::min(res, *imax);
res = std::max(res, *imin);
dst_ptr[i] = res;
++i0;
++i1;
++i2;
++ioff;
++imin;
++imax;
}
};
MEGDNN_DISPATCH_CPU_KERN_OPR(work());
}

void ElemwiseMultiTypeImpl::on_fuse_add_rmulh_round_shr_saturate_int16x16x16x8(
const ElemwiseOpParamN<6>& param, const TensorND& dst) {
dispatch_fuse_add_rmulh_round_shr_saturate<dt_int16>(param, dst);
}

void ElemwiseMultiTypeImpl::on_fuse_add_rmulh_round_shr_saturate_int32x32x32x8(
const ElemwiseOpParamN<6>& param, const TensorND& dst) {
dispatch_fuse_add_rmulh_round_shr_saturate<dt_int32>(param, dst);
}

void ElemwiseMultiTypeImpl::on_round_shr_saturate_iXxi8xi16(
const ElemwiseOpParamN<2>& param, const TensorND& dst) {
switch (param[0].layout.dtype.enumv()) {
#define cb(t) \
case DTypeTrait<t>::enumv: \
return dispatch_round_shr_saturate_iXxi8xiX<DTypeTrait<t>::ctype, dt_int16>( \
param, dst);
cb(::megdnn::dtype::Int32);
cb(::megdnn::dtype::Int16);
#undef cb
default:
megdnn_throw("unsupported src dtype");
}
}

template <typename KernImpl, typename src_ctype, typename dst_ctype>
void ElemwiseMultiTypeImpl::dispatch_add_qint_op(
const ElemwiseOpParamN<1>& param, const TensorND& dst_tensor) {

Loading…
Cancel
Save