You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

opr_impl.cpp 8.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201
  1. /**
  2. * \file dnn/src/x86/conv_bias/opr_impl.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "src/x86/conv_bias/opr_impl.h"
  13. #include <algorithm>
  14. #include <memory>
  15. #include "src/common/metahelper.h"
  16. #include "src/common/opr_delegate.h"
  17. #include "src/x86/conv_bias/f32/algos.h"
  18. #include "src/x86/conv_bias/int8/algo_usable_preferred.h"
  19. #include "src/x86/conv_bias/int8/algos.h"
  20. #include "src/x86/matrix_mul/opr_impl.h"
  21. using namespace megdnn;
  22. using namespace x86;
  23. namespace {
  24. bool is_fallback_or_naive(const detail::Algorithm* algo) {
  25. return algo->handle_type() == Handle::HandleType::NAIVE ||
  26. algo->handle_type() == Handle::HandleType::FALLBACK;
  27. }
  28. } // anonymous namespace
  29. class ConvBiasImpl::AlgoPack : NonCopyableObj {
  30. AlgoDirect stride1_direct;
  31. AlgoDirectStride2 stride2_direct;
  32. AlgoDirectAvx2Stride1Int8 avx2_stride1_direct_int8;
  33. AlgoAVX2DirectConvStride2 avx2_stride2_direct;
  34. AlgoChanWiseAvx2Stride1Qint8 avx2_stride1_chanwsie_qint8;
  35. AlgoChanWiseAvx2Stride2Qint8 avx2_stride2_chanwsie_qint8;
  36. #if MEGDNN_X86_WITH_MKL_DNN
  37. AlgoMkldnnMatmulQint8 mkldnn_matmul_qint8;
  38. //! Because the mkldnnconv need handle
  39. AlgoMkldnnQint8 mkldnn_qint8;
  40. AlgoMkldnnConv mkldnn_conv_fp32;
  41. #endif
  42. SmallVector<std::unique_ptr<AlgoBase>> refhold;
  43. SmallVector<fallback::ConvBiasImpl::AlgoBase*> m_all_no_winograd_algo;
  44. SmallVector<fallback::ConvBiasImpl::AlgoBase*> m_winograd_algos;
  45. fallback::ConvBiasImpl::AlgoBase::Mapper m_all_algos_map;
  46. public:
  47. AlgoPack() {
  48. //! FIXME: preference to use mkldnn algo on VNNI devices
  49. //! But now mkldnn algo preference issue with NCHW->NHWC->NCHW
  50. #if MEGDNN_X86_WITH_MKL_DNN
  51. //! Create the mkldnn algo
  52. m_all_no_winograd_algo.emplace_back(&mkldnn_conv_fp32);
  53. m_all_no_winograd_algo.emplace_back(&mkldnn_matmul_qint8);
  54. m_all_no_winograd_algo.emplace_back(&mkldnn_qint8);
  55. #endif
  56. m_all_no_winograd_algo.emplace_back(&stride1_direct);
  57. m_all_no_winograd_algo.emplace_back(&stride2_direct);
  58. m_all_no_winograd_algo.emplace_back(&avx2_stride1_chanwsie_qint8);
  59. m_all_no_winograd_algo.emplace_back(&avx2_stride2_chanwsie_qint8);
  60. m_all_no_winograd_algo.emplace_back(&avx2_stride1_direct_int8);
  61. m_all_no_winograd_algo.emplace_back(&avx2_stride2_direct);
  62. static CpuOprDelegationStorage<> storage;
  63. auto matmul_opr = storage.get<MatrixMul>();
  64. auto&& matmul_algos =
  65. static_cast<MatrixMulImpl*>(matmul_opr)->get_all_packed_algo();
  66. for (auto&& algo : matmul_algos) {
  67. if (is_fallback_or_naive(algo))
  68. continue;
  69. for (uint32_t tile_size : {8, 16, 24}) {
  70. refhold.emplace_back(new AlgoFP32WinogradF63_8x8(
  71. static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
  72. tile_size));
  73. m_winograd_algos.emplace_back(refhold.back().get());
  74. refhold.emplace_back(new AlgoFP32WinogradF23_8x8(
  75. static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
  76. tile_size));
  77. m_winograd_algos.emplace_back(refhold.back().get());
  78. }
  79. }
  80. for (auto&& algo : m_all_no_winograd_algo) {
  81. m_all_algos_map.emplace(algo->info().desc, algo);
  82. }
  83. for (auto&& algo : m_winograd_algos) {
  84. m_all_algos_map.emplace(algo->info().desc, algo);
  85. }
  86. }
  87. const SmallVector<fallback::ConvBiasImpl::AlgoBase*>& all_no_winograd_algo()
  88. const {
  89. return m_all_no_winograd_algo;
  90. }
  91. const SmallVector<fallback::ConvBiasImpl::AlgoBase*>& winograd_algos()
  92. const {
  93. return m_winograd_algos;
  94. }
  95. const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; }
  96. };
  97. const ConvBiasImpl::AlgoPack& ConvBiasImpl::algo_pack() {
  98. static AlgoPack algo_pack;
  99. return algo_pack;
  100. }
  101. fallback::ConvBiasImpl::AlgoBase* ConvBiasImpl::get_algo_from_desc(
  102. const AlgorithmDesc& desc) {
  103. megdnn_assert(algo_pack().all_algos_map().find(desc) !=
  104. algo_pack().all_algos_map().end());
  105. return algo_pack().all_algos_map().at(desc);
  106. }
  107. SmallVector<fallback::ConvBiasImpl::AlgoBase*>
  108. ConvBiasImpl::get_all_packed_algo() {
  109. auto&& algos = fallback::ConvBiasImpl::get_all_packed_algo();
  110. algos.insert(algos.begin(), algo_pack().all_no_winograd_algo().begin(),
  111. algo_pack().all_no_winograd_algo().end());
  112. algos.insert(algos.end(), algo_pack().winograd_algos().begin(),
  113. algo_pack().winograd_algos().end());
  114. return std::move(algos);
  115. }
  116. void ConvBiasImpl::get_rectified_img_size(size_t IH, size_t IW, size_t FH,
  117. size_t FW, size_t OH, size_t OW,
  118. size_t PH, size_t PW, size_t& IH2,
  119. size_t& IW2, size_t& OH2,
  120. size_t& OW2) {
  121. OW2 = (OW + 7) >> 3 << 3;
  122. OH2 = OH;
  123. IH2 = std::max(IH, OH2 + FH - 1 + 2 * PH);
  124. IW2 = std::max(IW, OW2 + FW - 1 + 2 * PW);
  125. }
  126. const char* ConvBiasImpl::get_algorithm_set_name() const {
  127. // x86 version 0
  128. return "X0";
  129. }
  130. bool ConvBiasImpl::is_matmul_quantized_prefer(
  131. const ConvBiasImpl::NCBKernSizeParam& param) const {
  132. bool conv_direct_chanwise_mkldnn_usable = true;
  133. if (param.dst_type.enumv() == DTypeEnum::QuantizedS8 ||
  134. param.dst_type.enumv() == DTypeEnum::QuantizedS32) {
  135. conv_direct_chanwise_mkldnn_usable =
  136. chanwise_avx2_stride1_qint8_usable_preferred(param) ||
  137. chanwise_avx2_stride2_qint8_usable_preferred(param) ||
  138. direct_avx2_stride1_int8_usable_preferred(param) ||
  139. direct_avx2_stride2_int8_usable_preferred(param);
  140. #if MEGDNN_X86_WITH_MKL_DNN
  141. conv_direct_chanwise_mkldnn_usable =
  142. conv_direct_chanwise_mkldnn_usable ||
  143. mkldnn_qint8_usable_preferred(param) ||
  144. mkldnn_matmul_qint8_usable_preferred(param);
  145. #endif
  146. }
  147. return !conv_direct_chanwise_mkldnn_usable ||
  148. (is_supported(SIMDType::VNNI) &&
  149. !chanwise_avx2_stride1_qint8_usable_preferred(param) &&
  150. !chanwise_avx2_stride2_qint8_usable_preferred(param));
  151. }
  152. SmallVector<AlgoCategory> ConvBiasImpl::suggest_algo_category_order(
  153. const NCBKernSizeParam& param) const {
  154. auto IC = param.filter_meta.icpg;
  155. auto OC = param.filter_meta.ocpg;
  156. auto FH = param.filter_meta.spatial[0];
  157. auto FW = param.filter_meta.spatial[1];
  158. //! TODO: now winograd only support fast-run
  159. //! nchw88 use mkl-dnn which algo is direct
  160. if (param.filter_meta.format == param::ConvBias::Format::NCHW88) {
  161. return {AlgoCategory::DIRECT, AlgoCategory::IM2COL};
  162. }
  163. //! im2col + matmul
  164. bool im2col_prefer = (IC >= 32 || OC >= 32);
  165. //! quantized algo use matmul when direct algo is unusable
  166. if (param.src_type.category() == DTypeCategory::QUANTIZED) {
  167. im2col_prefer = is_matmul_quantized_prefer(param);
  168. }
  169. //! conv1x1
  170. im2col_prefer |= (FH == 1 && FW == 1);
  171. //! x86 8x8x16 not optimized, so it will use fallback im2col+matmul
  172. if (param.deduce_algo_data_type() == AlgoDataType::INT8X8X16) {
  173. im2col_prefer = true;
  174. }
  175. if (im2col_prefer) {
  176. return {AlgoCategory::IM2COL, AlgoCategory::DIRECT,
  177. AlgoCategory::NAIVE};
  178. } else {
  179. return {AlgoCategory::DIRECT, AlgoCategory::IM2COL,
  180. AlgoCategory::NAIVE};
  181. }
  182. }
  183. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台