You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

algos.cpp 9.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215
  1. /**
  2. * \file dnn/src/aarch64/conv_bias/quint8/algos.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "src/aarch64/conv_bias/quint8/algos.h"
  12. #include "src/aarch64/conv_bias/quint8/strategy.h"
  13. #include "src/aarch64/matrix_mul/quint8_dot/gemv.h"
  14. #include "src/aarch64/matrix_mul/quint8_dot/strategy.h"
  15. #include "src/arm_common/convolution/img2col_helper.h"
  16. #include "src/arm_common/elemwise_op.h"
  17. #include "src/common/opr_delegate.h"
  18. #include "src/fallback/conv_bias/common.h"
  19. #include "src/fallback/matrix_mul/gemm_impl.h"
  20. #include "midout.h"
  21. MIDOUT_DECL(megdnn_aarch64_conv_bias_quint8_gemm)
  22. using namespace megdnn;
  23. using namespace aarch64;
  24. using megdnn::arm_common::HSwishOp;
  25. using megdnn::arm_common::ReluOp;
  26. using megdnn::arm_common::TypeCvtOp;
  27. /* ===================== matrix mul algo ===================== */
  28. bool ConvBiasImpl::AlgoQU8MatrixMul::usable(
  29. const NCBKernSizeParam& param,
  30. AlgoSelectionStrategy /*algo_selection_strategy*/) const {
  31. auto&& fm = param.filter_meta;
  32. return param.src_type.enumv() == DTypeEnum::Quantized8Asymm &&
  33. param.dst_type.enumv() == DTypeEnum::Quantized8Asymm &&
  34. fm.format == param::ConvBias::Format::NCHW && fm.spatial_ndim == 2 &&
  35. fm.dilation[0] == 1 && fm.dilation[1] == 1 &&
  36. //! As postprocess, the bias is not contigous read, make the
  37. //! performance bad, so we do not process it in fused kernel
  38. param.bias_mode != BiasMode::BIAS &&
  39. //! This algo is only support single thread
  40. param.nr_threads == 1_z;
  41. }
  42. WorkspaceBundle ConvBiasImpl::AlgoQU8MatrixMul::get_bundle(
  43. const NCBKernSizeParam& param) {
  44. UNPACK_CONV_NCB_KERN_SIZES(param);
  45. MEGDNN_MARK_USED_VAR(N);
  46. auto IW2 = IH + 2 * PH;
  47. auto IH2 = IW + 2 * PW;
  48. bool can_matrix_mul_direct =
  49. (FH == 1 && FW == 1 && SH == 1 && SW == 1 && PH == 0 && PW == 0);
  50. // temp space to store padding-free src (with 16 extra int8)
  51. // temp space to store unrolled matrix (with 16 extra int8)
  52. // workspace for matrix mul opr
  53. size_t part0, part1, part2;
  54. if (can_matrix_mul_direct) {
  55. part0 = part1 = 0;
  56. } else {
  57. part0 = (IC * IH2 * IW2 + 16) * sizeof(uint8_t);
  58. part1 = (IC * FH * FW * OH * OW + 16) * sizeof(uint8_t);
  59. }
  60. {
  61. size_t M = OC;
  62. size_t K = IC * FH * FW;
  63. size_t N = OH * OW;
  64. #if MGB_ENABLE_DOT
  65. #define DISPATCH_GEMM_STRATEGY( \
  66. _gemm, _gemm_midout_enum, _bias, _bias_midout_enum, _nonline, \
  67. _nonline_midout_enum) \
  68. matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \
  69. M, N, K, param.filter_type, param.src_type, param.dst_type); \
  70. part2 = megdnn::matmul::GemmInterleaved< \
  71. matmul::gemm_##_gemm##_##_bias##_##_nonline>( \
  72. M, N, K, false, false, strategy) \
  73. .get_workspace_size();
  74. if (cpuinfo_has_arm_neon_dot()) {
  75. DISPATCH_GEMM_BIAS(u8_8x8_dot, 1);
  76. } else {
  77. DISPATCH_GEMM_BIAS(u8_8x8_nodot, 0);
  78. }
  79. #else
  80. #define DISPATCH_GEMM_STRATEGY( \
  81. _gemm, _gemm_midout_enum, _bias, _bias_midout_enum, _nonline, \
  82. _nonline_midout_enum) \
  83. MIDOUT_BEGIN( \
  84. megdnn_aarch64_conv_bias_quint8_gemm, 0, _gemm_midout_enum, \
  85. _bias_midout_enum, _nonline_midout_enum) { \
  86. matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \
  87. M, N, K, param.filter_type, param.src_type, param.dst_type); \
  88. part2 = megdnn::matmul::GemmInterleaved< \
  89. matmul::gemm_##_gemm##_##_bias##_##_nonline>( \
  90. M, N, K, false, false, strategy) \
  91. .get_workspace_size(); \
  92. } \
  93. MIDOUT_END()
  94. DISPATCH_GEMM_BIAS(u8_8x8_nodot, 0)
  95. #endif
  96. #undef DISPATCH_GEMM_STRATEGY
  97. }
  98. return {nullptr, {part0, part1, part2}};
  99. }
  100. void ConvBiasImpl::AlgoQU8MatrixMul::kimpl(
  101. const NCBKernParam& param, const NCBKernIndex& ncb_index) {
  102. auto is_xcorr = !param.filter_meta.should_flip;
  103. UNPACK_CONV_NCB_KERN_SIZES(param);
  104. auto bundle = get_bundle(param);
  105. bundle.set(param.workspace_ptr);
  106. auto IH2 = IH + 2 * PH;
  107. auto IW2 = IW + 2 * PW;
  108. size_t group_id = ncb_index.ndrange_id[0];
  109. uint8_t src_zp = param.src_type.param<dtype::Quantized8Asymm>().zero_point;
  110. // workspace = tmp..src2
  111. for (size_t n = 0; n < N; ++n) {
  112. uint8_t* src = const_cast<uint8_t*>(param.src<uint8_t>(n, group_id));
  113. uint8_t* filter = const_cast<uint8_t*>(param.filter<uint8_t>(group_id));
  114. uint8_t* dst = static_cast<uint8_t*>(param.dst<uint8_t>(n, group_id));
  115. int32_t* bias = const_cast<int32_t*>(param.bias<int32_t>(n, group_id));
  116. uint8_t *B, *src2;
  117. if (FH == 1 && FW == 1 && SH == 1 && SW == 1 && PH == 0 && PW == 0) {
  118. // special case: 1x1
  119. B = const_cast<uint8_t*>(src);
  120. } else {
  121. src2 = static_cast<uint8_t*>(bundle.get(0));
  122. // copy src to src2;
  123. uint8_t* src2_ptr = src2;
  124. const uint8_t* src_ptr = src;
  125. rep(ic, IC) {
  126. if (PH != 0) {
  127. std::memset(src2_ptr, src_zp, sizeof(uint8_t) * PH * IW2);
  128. src2_ptr += PH * IW2;
  129. }
  130. rep(ih, IH) {
  131. if (PW != 0)
  132. rep(pw, PW) { *(src2_ptr++) = src_zp; }
  133. std::memcpy(src2_ptr, src_ptr, sizeof(uint8_t) * IW);
  134. src2_ptr += IW;
  135. src_ptr += IW;
  136. if (PW != 0)
  137. rep(pw, PW) { *(src2_ptr++) = src_zp; }
  138. }
  139. if (PH != 0) {
  140. std::memset(src2_ptr, src_zp, sizeof(uint8_t) * PH * IW2);
  141. src2_ptr += PH * IW2;
  142. }
  143. }
  144. B = static_cast<uint8_t*>(bundle.get(1));
  145. if (SH == 1 && SW == 1) {
  146. if (is_xcorr)
  147. img2col<true>(src2, B, OC, OH, OW, IC, IH2, IW2, FH, FW);
  148. else
  149. img2col<false>(src2, B, OC, OH, OW, IC, IH2, IW2, FH, FW);
  150. } else {
  151. if (is_xcorr)
  152. img2col_stride<true>(
  153. src2, B, OC, OH, OW, IC, IH2, IW2, FH, FW, SH, SW);
  154. else
  155. img2col_stride<false>(
  156. src2, B, OC, OH, OW, IC, IH2, IW2, FH, FW, SH, SW);
  157. }
  158. }
  159. {
  160. Workspace workspace(
  161. static_cast<dt_byte*>(bundle.get(2)), bundle.get_size(2));
  162. size_t M = OC;
  163. size_t K = IC * FH * FW;
  164. size_t N = OH * OW;
  165. #if MGB_ENABLE_DOT
  166. #define DISPATCH_GEMM_STRATEGY( \
  167. _gemm, _gemm_midout_enum, _bias, _bias_midout_enum, _nonline, \
  168. _nonline_midout_enum) \
  169. matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \
  170. M, N, K, param.filter_type, param.src_type, param.dst_type); \
  171. megdnn::matmul::GemmInterleaved<matmul::gemm_##_gemm##_##_bias##_##_nonline> \
  172. gemm_interleaved(M, N, K, false, false, strategy); \
  173. gemm_interleaved.execute(filter, K, B, N, dst, N, workspace.raw_ptr, bias);
  174. if (cpuinfo_has_arm_neon_dot()) {
  175. DISPATCH_GEMM_BIAS(u8_8x8_dot, 1)
  176. } else {
  177. DISPATCH_GEMM_BIAS(u8_8x8_nodot, 0)
  178. }
  179. #else
  180. #define DISPATCH_GEMM_STRATEGY( \
  181. _gemm, _gemm_midout_enum, _bias, _bias_midout_enum, _nonline, \
  182. _nonline_midout_enum) \
  183. MIDOUT_BEGIN( \
  184. megdnn_aarch64_conv_bias_quint8_gemm, 1, _gemm_midout_enum, \
  185. _bias_midout_enum, _nonline_midout_enum) { \
  186. matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \
  187. M, N, K, param.filter_type, param.src_type, param.dst_type); \
  188. megdnn::matmul::GemmInterleaved<matmul::gemm_##_gemm##_##_bias##_##_nonline> \
  189. gemm_interleaved(M, N, K, false, false, strategy); \
  190. gemm_interleaved.execute(filter, K, B, N, dst, N, workspace.raw_ptr, bias); \
  191. } \
  192. MIDOUT_END()
  193. DISPATCH_GEMM_BIAS(u8_8x8_nodot, 0)
  194. #endif
  195. #undef DISPATCH_GEMM_STRATEGY
  196. }
  197. }
  198. }
  199. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台