You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

opr_impl.cpp 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246
  1. /**
  2. * \file dnn/src/arm_common/conv_bias/opr_impl.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "src/arm_common/conv_bias/int8/algos.h"
  13. #include "src/arm_common/conv_bias/int8x8x16/algos.h"
  14. #include "src/arm_common/conv_bias/quint8/algos.h"
  15. #include "src/arm_common/conv_bias/opr_impl.h"
  16. #include "src/common/metahelper.h"
  17. #include "src/common/utils.h"
  18. #include "src/naive/handle.h"
  19. #include "src/arm_common/convolution/opr_impl.h"
  20. #include "src/arm_common/matrix_mul/opr_impl.h"
  21. #include "src/common/opr_delegate.h"
  22. #include "include/megdnn/oprs/nn.h"
  23. #include "src/arm_common/conv_bias/f16/algos.h"
  24. #include "src/arm_common/conv_bias/fp32/algos.h"
  25. #include "src/arm_common/conv_bias/int8/stride1.h"
  26. #include "src/arm_common/conv_bias/int8/stride2.h"
  27. #include "src/arm_common/conv_bias/quint8/stride1.h"
  28. #include "src/arm_common/conv_bias/quint8/stride2.h"
  29. #include "src/arm_common/convolution/opr_impl.h"
  30. using namespace megdnn;
  31. using namespace arm_common;
  32. namespace {
  33. uint8_t arm_common_algo_type_storage;
  34. } // anonymous namespace
  35. class ConvBiasImpl::AlgoPack : NonCopyableObj {
  36. AlgoQU8DirectStride2 qu8_direct_stride2;
  37. AlgoQU8DirectStride1 qu8_direct_stride1;
  38. AlgoS8DirectStride2 s8_direct_stride2;
  39. AlgoS8DirectNCHW44 s8_direct_nchw44;
  40. AlgoS8DirectNCHWNCHW44 s8_direct_nchw_nchw44;
  41. AlgoS8DirectStride1 s8_direct_stride1;
  42. AlgoS8ChanWiseStride1NCHW44 s8_channel_wise_stride1_nchw44;
  43. AlgoS8ChanWiseStride2NCHW44 s8_channel_wise_stride2_nchw44;
  44. AlgoS8x8x16ChanWiseStride1Stride2NCHW44 s8x8x16_channel_wise_stride1_stride2_nchw44;
  45. #if __ARM_FEATURE_DOTPROD
  46. AlgoDotS8DirectStride1 ds8_direct_stride1;
  47. AlgoDotS8DirectStride2 ds8_direct_stride2;
  48. AlgoDotU8DirectStride1 du8_direct_stride1;
  49. AlgoDotU8DirectStride2 du8_direct_stride2;
  50. AlgoDotS8Direct_NCHW44 ds8_direct_nchw44;
  51. AlgoDotS8DirectNCHWNCHW44 ds8_direct_nchw_nchw44;
  52. #endif
  53. AlgoF32DirectNCHWNCHW44 f32_direct_stride2_nchw_nchw44;
  54. AlgoF32ChannelWiseNCHW44 f32_chanel_wise_nchw44;
  55. AlgoF32DirectNCHW44 f32_direct_nchw44;
  56. AlgoF32Direct f32_direct;
  57. AlgoF32DirectStride2 f32_direct_stride2;
  58. AlgoF32DirectStride1 f32_direct_stride1;
  59. AlgoI8x8x16Direct i8x8x16_direct;
  60. AlgoI8x8x16Stride2 i8x8x16_stride2;
  61. AlgoI8x8x16Stride2Filter2 i8x8x16_stride2_filter2;
  62. AlgoI8x8x16DirectNCHWNCHW44 i8x8x16_nchw_nchw44;
  63. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  64. AlgoF16Direct f16_direct;
  65. AlgoF16DirectStride1 f16_direct_stride1;
  66. #endif
  67. SmallVector<std::unique_ptr<AlgoBase>> refhold;
  68. public:
  69. AlgoPack() {
  70. #if __ARM_FEATURE_DOTPROD
  71. direct_algos.emplace_back(&ds8_direct_stride1);
  72. direct_algos.emplace_back(&ds8_direct_stride2);
  73. direct_algos.emplace_back(&du8_direct_stride1);
  74. direct_algos.emplace_back(&du8_direct_stride2);
  75. direct_algos.emplace_back(&ds8_direct_nchw44);
  76. direct_algos.emplace_back(&ds8_direct_nchw_nchw44);
  77. #endif
  78. direct_algos.emplace_back(&qu8_direct_stride2);
  79. direct_algos.emplace_back(&qu8_direct_stride1);
  80. direct_algos.emplace_back(&s8_direct_stride2);
  81. direct_algos.emplace_back(&s8_direct_nchw44);
  82. direct_algos.emplace_back(&s8_direct_nchw_nchw44);
  83. direct_algos.emplace_back(&s8_direct_stride1);
  84. direct_algos.emplace_back(&s8x8x16_channel_wise_stride1_stride2_nchw44);
  85. direct_algos.emplace_back(&s8_channel_wise_stride1_nchw44);
  86. direct_algos.emplace_back(&s8_channel_wise_stride2_nchw44);
  87. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  88. direct_algos.emplace_back(&f16_direct_stride1);
  89. direct_algos.emplace_back(&f16_direct);
  90. #endif
  91. direct_algos.emplace_back(&i8x8x16_direct);
  92. direct_algos.emplace_back(&i8x8x16_stride2_filter2);
  93. direct_algos.emplace_back(&i8x8x16_stride2);
  94. direct_algos.emplace_back(&i8x8x16_nchw_nchw44);
  95. direct_algos.emplace_back(&f32_direct_stride2_nchw_nchw44);
  96. direct_algos.emplace_back(&f32_chanel_wise_nchw44);
  97. direct_algos.emplace_back(&f32_direct_nchw44);
  98. direct_algos.emplace_back(&f32_direct_stride1);
  99. direct_algos.emplace_back(&f32_direct_stride2);
  100. direct_algos.emplace_back(&f32_direct);
  101. static CpuOprDelegationStorage<2> storage;
  102. auto matmul_opr = storage.get<MatrixMul, 0>();
  103. auto&& matmul_algos =
  104. static_cast<arm_common::MatrixMulImpl*>(matmul_opr)
  105. ->algo_pack();
  106. for (auto&& algo : matmul_algos) {
  107. if (algo->type() == nullptr)
  108. continue;
  109. for (uint32_t tile_size : {16, 8, 24, 32}) {
  110. refhold.emplace_back(new AlgoFP32WinogradF23_4x4(
  111. static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
  112. tile_size));
  113. winograd_algos.emplace_back(refhold.back().get());
  114. refhold.emplace_back(new AlgoFP32WinogradF63(
  115. static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
  116. tile_size));
  117. winograd_algos.emplace_back(refhold.back().get());
  118. refhold.emplace_back(new AlgoFP32WinogradF63_4x4(
  119. static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
  120. tile_size));
  121. winograd_algos.emplace_back(refhold.back().get());
  122. refhold.emplace_back(new AlgoFP32WinogradF54(
  123. static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
  124. tile_size));
  125. winograd_algos.emplace_back(refhold.back().get());
  126. refhold.emplace_back(new AlgoFP32WinogradF45(
  127. static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
  128. tile_size));
  129. winograd_algos.emplace_back(refhold.back().get());
  130. refhold.emplace_back(new AlgoFP32WinogradF23_4x4_NCHW44(
  131. static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
  132. tile_size));
  133. winograd_algos.emplace_back(refhold.back().get());
  134. refhold.emplace_back(new AlgoFP32WinogradF63_4x4_NCHW44(
  135. static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
  136. tile_size));
  137. winograd_algos.emplace_back(refhold.back().get());
  138. //! uncomment this when low precision mode is done
  139. #if 0
  140. refhold.emplace_back(new AlgoFP32WinogradF73_4x4_NCHW44(
  141. static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
  142. tile_size));
  143. winograd_algos.emplace_back(refhold.back().get());
  144. #endif
  145. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  146. refhold.emplace_back(new AlgoFP16WinogradF23(
  147. static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
  148. tile_size));
  149. winograd_algos.emplace_back(refhold.back().get());
  150. refhold.emplace_back(new AlgoFP16WinogradF45(
  151. static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
  152. tile_size));
  153. winograd_algos.emplace_back(refhold.back().get());
  154. refhold.emplace_back(new AlgoFP16WinogradF63(
  155. static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
  156. tile_size));
  157. winograd_algos.emplace_back(refhold.back().get());
  158. refhold.emplace_back(new AlgoFP16WinogradF23_8x8(
  159. static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
  160. tile_size));
  161. winograd_algos.emplace_back(refhold.back().get());
  162. #endif
  163. refhold.emplace_back(new AlgoS8WinogradF23_8x8(
  164. static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
  165. tile_size));
  166. winograd_algos.emplace_back(refhold.back().get());
  167. refhold.emplace_back(new AlgoS8CF32WinogradF23_4x4_NCHW44(
  168. static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
  169. tile_size));
  170. winograd_algos.emplace_back(refhold.back().get());
  171. refhold.emplace_back(new AlgoS8WinogradF23_8x8_NCHW44(
  172. static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
  173. tile_size));
  174. winograd_algos.emplace_back(refhold.back().get());
  175. }
  176. }
  177. }
  178. SmallVector<AlgoBase*> direct_algos;
  179. SmallVector<AlgoBase*> winograd_algos;
  180. };
  181. SmallVector<ConvBiasImpl::AlgoBase*> ConvBiasImpl::algo_pack() {
  182. static AlgoPack sl_algo_pack;
  183. auto&& algos = fallback::ConvBiasImpl::algo_pack();
  184. algos.insert(algos.begin(), sl_algo_pack.direct_algos.begin(),
  185. sl_algo_pack.direct_algos.end());
  186. algos.insert(algos.end(), sl_algo_pack.winograd_algos.begin(),
  187. sl_algo_pack.winograd_algos.end());
  188. return std::move(algos);
  189. }
  190. void* const ConvBiasImpl::sm_arm_common_algo_type =
  191. &arm_common_algo_type_storage;
  192. bool ConvBiasImpl::is_matmul_quantized_prefer(
  193. const ConvBiasImpl::NCBKernSizeParam& param) const {
  194. fallback::ConvBiasImpl::NCBKernSizeParam conv_ncb_param(
  195. param, 0, param::MatrixMul::Format::DEFAULT, {}, 0,
  196. BiasMode::NO_BIAS, param::ConvBias::NonlineMode::IDENTITY);
  197. conv_ncb_param.dst_type = param.bias_type;
  198. conv_ncb_param.filter_meta.group = 1;
  199. bool conv_direct_unusable = false;
  200. if (param.dst_type.enumv() == DTypeEnum::QuantizedS8 ||
  201. param.dst_type.enumv() == DTypeEnum::QuantizedS32) {
  202. conv_direct_unusable =
  203. !arm_common::direct_int8_stride1::can_conv_direct_stride1_int8(
  204. conv_ncb_param) &&
  205. !arm_common::direct_int8_stride2::can_conv_direct_stride2_int8(
  206. conv_ncb_param);
  207. } else if (param.dst_type.enumv() == DTypeEnum::Quantized8Asymm) {
  208. conv_direct_unusable =
  209. !arm_common::direct_quint8_stride1::
  210. can_conv_direct_stride1_quint8(conv_ncb_param) &&
  211. !arm_common::direct_quint8_stride2::
  212. can_conv_direct_stride2_quint8(conv_ncb_param);
  213. }
  214. return conv_direct_unusable;
  215. }
  216. const char* ConvBiasImpl::get_algorithm_set_name() const {
  217. // arm common version 0
  218. return "AC0";
  219. }
  220. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台