You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

opr_impl.cpp 6.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189
  1. /**
  2. * \file dnn/src/x86/conv_bias/opr_impl.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "src/x86/conv_bias/opr_impl.h"
  13. #include <algorithm>
  14. #include <memory>
  15. #include "src/common/metahelper.h"
  16. #include "src/common/opr_delegate.h"
  17. #include "src/x86/conv_bias/f32/algos.h"
  18. #include "src/x86/conv_bias/int8/algo_usable_preferred.h"
  19. #include "src/x86/conv_bias/int8/algos.h"
  20. #include "src/x86/matrix_mul/opr_impl.h"
  21. using namespace megdnn;
  22. using namespace x86;
  23. namespace {
  24. uint8_t x86_algo_type_storage;
  25. void* x86_algo_type = &x86_algo_type_storage;
  26. } // anonymous namespace
  27. #if MEGDNN_X86_WITH_MKL_DNN
  28. void* ConvBiasImpl::AlgoMkldnnQint8::type() const {
  29. return x86_algo_type;
  30. }
  31. void* ConvBiasImpl::AlgoMkldnnMatmulQint8::type() const {
  32. return x86_algo_type;
  33. }
  34. void* ConvBiasImpl::AlgoMkldnnConv::type() const {
  35. return x86_algo_type;
  36. }
  37. #endif
  38. void* ConvBiasImpl::AlgoDirect::type() const {
  39. return x86_algo_type;
  40. }
  41. void* ConvBiasImpl::AlgoDirectStride2::type() const {
  42. return x86_algo_type;
  43. }
  44. void* ConvBiasImpl::AlgoMatrixMul::type() const {
  45. return x86_algo_type;
  46. }
  47. void* ConvBiasImpl::AlgoDirectAvx2Stride1Int8::type() const {
  48. return x86_algo_type;
  49. }
  50. void* ConvBiasImpl::AlgoFP32WinogradF63_8x8::type() const {
  51. return x86_algo_type;
  52. }
  53. void* ConvBiasImpl::AlgoFP32WinogradF23_8x8::type() const {
  54. return x86_algo_type;
  55. }
  56. void* ConvBiasImpl::AlgoAVX2DirectConvStride2::type() const {
  57. return x86_algo_type;
  58. }
  59. void* ConvBiasImpl::AlgoChanWiseAvx2Stride1Qint8::type() const {
  60. return x86_algo_type;
  61. }
  62. void* ConvBiasImpl::AlgoChanWiseAvx2Stride2Qint8::type() const {
  63. return x86_algo_type;
  64. }
  65. class ConvBiasImpl::AlgoPack : NonCopyableObj {
  66. AlgoDirect stride1_direct_large_group{true};
  67. AlgoDirect stride1_direct_small_group{false};
  68. AlgoDirectStride2 stride2_direct_large_group{true};
  69. AlgoDirectStride2 stride2_direct_small_group{false};
  70. AlgoDirectAvx2Stride1Int8 avx2_stride1_direct_int8;
  71. AlgoAVX2DirectConvStride2 avx2_stride2_direct;
  72. AlgoChanWiseAvx2Stride1Qint8 avx2_stride1_chanwsie_qint8;
  73. AlgoChanWiseAvx2Stride2Qint8 avx2_stride2_chanwsie_qint8;
  74. AlgoMatrixMul matmul;
  75. #if MEGDNN_X86_WITH_MKL_DNN
  76. AlgoMkldnnMatmulQint8 mkldnn_matmul_qint8;
  77. //! Because the mkldnnconv need handle
  78. AlgoMkldnnQint8 mkldnn_qint8;
  79. AlgoMkldnnConv mkldnn_conv_fp32;
  80. #endif
  81. SmallVector<std::unique_ptr<AlgoBase>> refhold;
  82. public:
  83. AlgoPack() {
  84. //! FIXME: preference to use mkldnn algo on VNNI devices
  85. //! But now mkldnn algo preference issue with NCHW->NHWC->NCHW
  86. #if MEGDNN_X86_WITH_MKL_DNN
  87. //! Create the mkldnn algo
  88. all_algos.emplace_back(&mkldnn_conv_fp32);
  89. all_algos.emplace_back(&mkldnn_matmul_qint8);
  90. all_algos.emplace_back(&mkldnn_qint8);
  91. #endif
  92. all_algos.emplace_back(&stride1_direct_large_group);
  93. all_algos.emplace_back(&stride1_direct_small_group);
  94. all_algos.emplace_back(&stride2_direct_large_group);
  95. all_algos.emplace_back(&stride2_direct_small_group);
  96. all_algos.emplace_back(&avx2_stride1_direct_int8);
  97. all_algos.emplace_back(&avx2_stride2_direct);
  98. all_algos.emplace_back(&avx2_stride1_chanwsie_qint8);
  99. all_algos.emplace_back(&avx2_stride2_chanwsie_qint8);
  100. all_algos.emplace_back(&matmul);
  101. static CpuOprDelegationStorage<> storage;
  102. auto matmul_opr = storage.get<MatrixMul>();
  103. auto&& matmul_algos =
  104. static_cast<MatrixMulImpl*>(matmul_opr)->algo_pack();
  105. for (auto&& algo : matmul_algos) {
  106. if (algo->type() == nullptr)
  107. continue;
  108. for (uint32_t tile_size : {8, 16, 24}) {
  109. refhold.emplace_back(new AlgoFP32WinogradF63_8x8(
  110. static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
  111. tile_size));
  112. winograd_algos.emplace_back(refhold.back().get());
  113. refhold.emplace_back(new AlgoFP32WinogradF23_8x8(
  114. static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
  115. tile_size));
  116. winograd_algos.emplace_back(refhold.back().get());
  117. }
  118. }
  119. }
  120. SmallVector<AlgoBase*> all_algos;
  121. SmallVector<AlgoBase*> winograd_algos;
  122. };
  123. SmallVector<ConvBiasImpl::AlgoBase*> ConvBiasImpl::algo_pack() {
  124. static AlgoPack sl_algo_pack;
  125. auto&& algos = fallback::ConvBiasImpl::algo_pack();
  126. algos.insert(algos.begin(), sl_algo_pack.all_algos.begin(),
  127. sl_algo_pack.all_algos.end());
  128. algos.insert(algos.end(), sl_algo_pack.winograd_algos.begin(),
  129. sl_algo_pack.winograd_algos.end());
  130. return std::move(algos);
  131. }
  132. void ConvBiasImpl::get_rectified_img_size(size_t IH, size_t IW, size_t FH,
  133. size_t FW, size_t OH, size_t OW,
  134. size_t PH, size_t PW, size_t& IH2,
  135. size_t& IW2, size_t& OH2,
  136. size_t& OW2) {
  137. OW2 = (OW + 7) >> 3 << 3;
  138. OH2 = OH;
  139. IH2 = std::max(IH, OH2 + FH - 1 + 2 * PH);
  140. IW2 = std::max(IW, OW2 + FW - 1 + 2 * PW);
  141. }
  142. const char* ConvBiasImpl::get_algorithm_set_name() const {
  143. // x86 version 0
  144. return "X0";
  145. }
  146. bool ConvBiasImpl::is_matmul_quantized_prefer(
  147. const ConvBiasImpl::NCBKernSizeParam& param) const {
  148. bool conv_direct_chanwise_mkldnn_usable = true;
  149. if (param.dst_type.enumv() == DTypeEnum::QuantizedS8 ||
  150. param.dst_type.enumv() == DTypeEnum::QuantizedS32) {
  151. conv_direct_chanwise_mkldnn_usable =
  152. chanwise_avx2_stride1_qint8_usable_preferred(param) ||
  153. chanwise_avx2_stride2_qint8_usable_preferred(param) ||
  154. direct_avx2_stride1_int8_usable_preferred(param) ||
  155. direct_avx2_stride2_int8_usable_preferred(param);
  156. #if MEGDNN_X86_WITH_MKL_DNN
  157. conv_direct_chanwise_mkldnn_usable =
  158. conv_direct_chanwise_mkldnn_usable ||
  159. mkldnn_qint8_usable_preferred(param) ||
  160. mkldnn_matmul_qint8_usable_preferred(param);
  161. #endif
  162. }
  163. return !conv_direct_chanwise_mkldnn_usable ||
  164. (is_supported(SIMDType::VNNI) &&
  165. !chanwise_avx2_stride1_qint8_usable_preferred(param) &&
  166. !chanwise_avx2_stride2_qint8_usable_preferred(param));
  167. }
  168. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台