You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

algos.cpp 7.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. /**
  2. * \file dnn/src/armv7/conv_bias/int8/algos.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "src/armv7/conv_bias/int8/algos.h"
  12. #include "src/arm_common/convolution/img2col_helper.h"
  13. #include "src/armv7/conv_bias/int8/strategy.h"
  14. #include "src/common/opr_delegate.h"
  15. #include "src/fallback/conv_bias/common.h"
  16. #include "src/fallback/matrix_mul/gemm_impl.h"
  17. #include "midout.h"
  18. MIDOUT_DECL(megdnn_armv7_conv_bias_int8)
  19. using namespace megdnn;
  20. using namespace armv7;
  21. /* ===================== matrix mul algo ===================== */
  22. bool ConvBiasImpl::AlgoS8MatrixMul::usable(
  23. const NCBKernSizeParam& param,
  24. AlgoSelectionStrategy /*algo_selection_strategy*/) const {
  25. auto&& fm = param.filter_meta;
  26. return param.src_type.enumv() == DTypeEnum::QuantizedS8 &&
  27. param.dst_type.enumv() == DTypeEnum::QuantizedS8 &&
  28. fm.format == param::ConvBias::Format::NCHW && fm.spatial_ndim == 2 &&
  29. fm.dilation[0] == 1 && fm.dilation[1] == 1 &&
  30. //! As postprocess, the bias is not contigous read, make the
  31. //! performance bad, so we do not process it in fused kernel
  32. param.bias_mode != BiasMode::BIAS &&
  33. //! This algo is only support single thread
  34. param.nr_threads == 1_z;
  35. }
  36. WorkspaceBundle ConvBiasImpl::AlgoS8MatrixMul::get_bundle(
  37. const NCBKernSizeParam& param) {
  38. UNPACK_CONV_NCB_KERN_SIZES(param);
  39. MEGDNN_MARK_USED_VAR(N);
  40. auto IW2 = IH + 2 * PH;
  41. auto IH2 = IW + 2 * PW;
  42. bool can_matrix_mul_direct =
  43. (FH == 1 && FW == 1 && SH == 1 && SW == 1 && PH == 0 && PW == 0);
  44. // temp space to store padding-free src (with 16 extra int8)
  45. // temp space to store unrolled matrix (with 16 extra int8)
  46. // workspace for matrix mul opr
  47. size_t part0, part1, part2;
  48. if (can_matrix_mul_direct) {
  49. part0 = part1 = 0;
  50. } else {
  51. part0 = (IC * IH2 * IW2 + 16) * sizeof(int8_t);
  52. part1 = (IC * FH * FW * OH * OW + 16) * sizeof(int8_t);
  53. }
  54. {
  55. size_t M = OC;
  56. size_t K = IC * FH * FW;
  57. size_t N = OH * OW;
  58. #define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \
  59. _bias_midout_enum, _nonline, \
  60. _nonline_midout_enum) \
  61. MIDOUT_BEGIN(megdnn_armv7_conv_bias_int8, 0, _gemm_midout_enum, \
  62. _bias_midout_enum, _nonline_midout_enum) { \
  63. matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \
  64. M, N, K, param.filter_type, param.src_type, param.dst_type); \
  65. part2 = megdnn::matmul::GemmInterleaved< \
  66. matmul::gemm_##_gemm##_##_bias##_##_nonline>( \
  67. M, N, K, false, false, strategy) \
  68. .get_workspace_size(); \
  69. } \
  70. MIDOUT_END()
  71. DISPATCH_GEMM_BIAS(s8_4x2, 0)
  72. #undef DISPATCH_GEMM_STRATEGY
  73. }
  74. return {nullptr, {part0, part1, part2}};
  75. }
  76. void ConvBiasImpl::AlgoS8MatrixMul::kimpl(const NCBKernParam& param,
  77. const NCBKernIndex& ncb_index) {
  78. auto is_xcorr = !param.filter_meta.should_flip;
  79. UNPACK_CONV_NCB_KERN_SIZES(param);
  80. auto bundle = get_bundle(param);
  81. bundle.set(param.workspace_ptr);
  82. auto IH2 = IH + 2 * PH;
  83. auto IW2 = IW + 2 * PW;
  84. size_t group_id = ncb_index.ndrange_id[0];
  85. // workspace = tmp..src2
  86. for (size_t n = 0; n < N; ++n) {
  87. dt_int8* src = const_cast<dt_int8*>(param.src<dt_int8>(n, group_id));
  88. dt_int8* filter = const_cast<dt_int8*>(param.filter<dt_int8>(group_id));
  89. dt_int8* dst = static_cast<dt_int8*>(param.dst<dt_int8>(n, group_id));
  90. dt_int32* bias = const_cast<dt_int32*>(param.bias<dt_int32>(n, group_id));
  91. dt_int8 *B, *src2;
  92. if (FH == 1 && FW == 1 && SH == 1 && SW == 1 && PH == 0 && PW == 0) {
  93. // special case: 1x1
  94. B = const_cast<dt_int8*>(src);
  95. } else {
  96. src2 = static_cast<dt_int8*>(bundle.get(0));
  97. // copy src to src2;
  98. dt_int8* src2_ptr = src2;
  99. const dt_int8* src_ptr = src;
  100. rep(ic, IC) {
  101. if (PH != 0) {
  102. std::memset(src2_ptr, 0, sizeof(dt_int8) * PH * IW2);
  103. src2_ptr += PH * IW2;
  104. }
  105. rep(ih, IH) {
  106. if (PW != 0)
  107. rep(pw, PW) { *(src2_ptr++) = 0.0f; }
  108. std::memcpy(src2_ptr, src_ptr, sizeof(dt_int8) * IW);
  109. src2_ptr += IW;
  110. src_ptr += IW;
  111. if (PW != 0)
  112. rep(pw, PW) { *(src2_ptr++) = 0.0f; }
  113. }
  114. if (PH != 0) {
  115. std::memset(src2_ptr, 0, sizeof(dt_int8) * PH * IW2);
  116. src2_ptr += PH * IW2;
  117. }
  118. }
  119. B = static_cast<dt_int8*>(bundle.get(1));
  120. if (SH == 1 && SW == 1) {
  121. if (is_xcorr)
  122. img2col<true>(src2, B, OC, OH, OW, IC, IH2, IW2, FH, FW);
  123. else
  124. img2col<false>(src2, B, OC, OH, OW, IC, IH2, IW2, FH, FW);
  125. } else {
  126. if (is_xcorr)
  127. img2col_stride<true>(src2, B, OC, OH, OW, IC, IH2, IW2, FH,
  128. FW, SH, SW);
  129. else
  130. img2col_stride<false>(src2, B, OC, OH, OW, IC, IH2, IW2, FH,
  131. FW, SH, SW);
  132. }
  133. }
  134. {
  135. Workspace workspace(static_cast<dt_byte*>(bundle.get(2)),
  136. bundle.get_size(2));
  137. size_t M = OC;
  138. size_t K = IC * FH * FW;
  139. size_t N = OH * OW;
  140. #define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \
  141. _bias_midout_enum, _nonline, \
  142. _nonline_midout_enum) \
  143. MIDOUT_BEGIN(megdnn_armv7_conv_bias_int8, 1, _gemm_midout_enum, \
  144. _bias_midout_enum, _nonline_midout_enum) { \
  145. matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \
  146. M, N, K, param.filter_type, param.src_type, param.dst_type); \
  147. megdnn::matmul::GemmInterleaved< \
  148. matmul::gemm_##_gemm##_##_bias##_##_nonline> \
  149. gemm_interleaved(M, N, K, false, false, strategy); \
  150. gemm_interleaved.execute(filter, K, B, N, dst, N, workspace.raw_ptr, \
  151. bias); \
  152. } \
  153. MIDOUT_END()
  154. DISPATCH_GEMM_BIAS(s8_4x2, 0)
  155. #undef DISPATCH_GEMM_STRATEGY
  156. }
  157. }
  158. }
  159. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台