You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

opr_impl.cpp 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256
  1. /**
  2. * \file dnn/src/arm_common/conv_bias/opr_impl.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "src/arm_common/conv_bias/int8/algos.h"
  13. #include "src/arm_common/conv_bias/int8x8x16/algos.h"
  14. #include "src/arm_common/conv_bias/quint8/algos.h"
  15. #include "src/arm_common/conv_bias/opr_impl.h"
  16. #include "src/common/metahelper.h"
  17. #include "src/common/utils.h"
  18. #include "src/naive/handle.h"
  19. #include "src/arm_common/convolution/opr_impl.h"
  20. #include "src/arm_common/matrix_mul/opr_impl.h"
  21. #include "src/common/opr_delegate.h"
  22. #include "include/megdnn/oprs/nn.h"
  23. #include "src/arm_common/conv_bias/f16/algos.h"
  24. #include "src/arm_common/conv_bias/fp32/algos.h"
  25. #include "src/arm_common/conv_bias/int8/stride1.h"
  26. #include "src/arm_common/conv_bias/int8/stride2.h"
  27. #include "src/arm_common/conv_bias/quint8/stride1.h"
  28. #include "src/arm_common/conv_bias/quint8/stride2.h"
  29. #include "src/arm_common/convolution/opr_impl.h"
  30. using namespace megdnn;
  31. using namespace arm_common;
  32. namespace {
  33. uint8_t arm_common_algo_type_storage;
  34. } // anonymous namespace
  35. class ConvBiasImpl::AlgoPack : NonCopyableObj {
  36. AlgoQU8DirectStride2 qu8_direct_stride2_large_group{true};
  37. AlgoQU8DirectStride2 qu8_direct_stride2_small_group{false};
  38. AlgoQU8DirectStride1 qu8_direct_stride1_large_group{true};
  39. AlgoQU8DirectStride1 qu8_direct_stride1_small_group{false};
  40. AlgoS8DirectStride2 s8_direct_stride2_large_group{true};
  41. AlgoS8DirectStride2 s8_direct_stride2_small_group{false};
  42. AlgoS8DirectStride2NCHW44 s8_direct_stride2_nchw44;
  43. AlgoS8DirectStride2NCHWNCHW44 s8_direct_stride2_nchw_nchw44;
  44. AlgoS8DirectStride1 s8_direct_stride1_large_group{true};
  45. AlgoS8DirectStride1 s8_direct_stride1_small_group{false};
  46. AlgoS8DirectStride1NCHW44 s8_direct_stride1_nchw44;
  47. AlgoS8ChanWiseStride1NCHW44 s8_channel_wise_stride1_nchw44;
  48. AlgoS8ChanWiseStride2NCHW44 s8_channel_wise_stride2_nchw44;
  49. #if __ARM_FEATURE_DOTPROD
  50. AlgoDotS8DirectNCHWNCHW44 ds8_direct_stride2_nchw_nchw44;
  51. AlgoDotS8DirectStride1 ds8_direct_stride1_large_group{true};
  52. AlgoDotS8DirectStride1 ds8_direct_stride1_small_group{false};
  53. AlgoDotS8DirectStride2 ds8_direct_stride2_large_group{true};
  54. AlgoDotS8DirectStride2 ds8_direct_stride2_small_group{false};
  55. AlgoDotU8DirectStride1 du8_direct_stride1_large_group{true};
  56. AlgoDotU8DirectStride1 du8_direct_stride1_small_group{false};
  57. AlgoDotU8DirectStride2 du8_direct_stride2_large_group{true};
  58. AlgoDotU8DirectStride2 du8_direct_stride2_small_group{false};
  59. #endif
  60. AlgoF32DirectStride2NCHWNCHW44 f32_direct_stride2_nchw_nchw44;
  61. AlgoF32ChannelWiseNCHW44 f32_chanel_wise_nchw44;
  62. AlgoF32DirectNCHW44 f32_direct_nchw44;
  63. AlgoF32Direct f32_direct_large_group{true};
  64. AlgoF32Direct f32_direct_small_group{false};
  65. AlgoF32DirectStride2 f32_direct_stride2_large_group{true};
  66. AlgoF32DirectStride2 f32_direct_stride2_small_group{false};
  67. AlgoF32DirectStride1 f32_direct_stride1_large_group{true};
  68. AlgoF32DirectStride1 f32_direct_stride1_small_group{false};
  69. AlgoI8x8x16Direct i8x8x16_direct_large_group{true};
  70. AlgoI8x8x16Direct i8x8x16_direct_small_group{false};
  71. AlgoI8x8x16Stride2 i8x8x16_stride2_large_group{true};
  72. AlgoI8x8x16Stride2 i8x8x16_stride2_small_group{false};
  73. AlgoI8x8x16Stride2Filter2 i8x8x16_stride2_filter2;
  74. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  75. AlgoF16Direct f16_direct_large_group{true};
  76. AlgoF16Direct f16_direct_small_group{false};
  77. AlgoF16DirectStride1 f16_direct_stride1_large_group{true};
  78. AlgoF16DirectStride1 f16_direct_stride1_small_group{false};
  79. #endif
  80. SmallVector<std::unique_ptr<AlgoBase>> refhold;
  81. public:
  82. AlgoPack() {
  83. #if __ARM_FEATURE_DOTPROD
  84. direct_algos.emplace_back(&ds8_direct_stride2_nchw_nchw44);
  85. direct_algos.emplace_back(&ds8_direct_stride1_large_group);
  86. direct_algos.emplace_back(&ds8_direct_stride1_small_group);
  87. direct_algos.emplace_back(&ds8_direct_stride2_large_group);
  88. direct_algos.emplace_back(&ds8_direct_stride2_small_group);
  89. direct_algos.emplace_back(&du8_direct_stride1_large_group);
  90. direct_algos.emplace_back(&du8_direct_stride1_small_group);
  91. direct_algos.emplace_back(&du8_direct_stride2_large_group);
  92. direct_algos.emplace_back(&du8_direct_stride2_small_group);
  93. #endif
  94. direct_algos.emplace_back(&qu8_direct_stride2_large_group);
  95. direct_algos.emplace_back(&qu8_direct_stride2_small_group);
  96. direct_algos.emplace_back(&qu8_direct_stride1_large_group);
  97. direct_algos.emplace_back(&qu8_direct_stride1_small_group);
  98. direct_algos.emplace_back(&s8_direct_stride2_large_group);
  99. direct_algos.emplace_back(&s8_direct_stride2_small_group);
  100. direct_algos.emplace_back(&s8_direct_stride2_nchw44);
  101. direct_algos.emplace_back(&s8_direct_stride2_nchw_nchw44);
  102. direct_algos.emplace_back(&s8_direct_stride1_large_group);
  103. direct_algos.emplace_back(&s8_direct_stride1_small_group);
  104. direct_algos.emplace_back(&s8_direct_stride1_nchw44);
  105. direct_algos.emplace_back(&s8_channel_wise_stride1_nchw44);
  106. direct_algos.emplace_back(&s8_channel_wise_stride2_nchw44);
  107. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  108. direct_algos.emplace_back(&f16_direct_stride1_large_group);
  109. direct_algos.emplace_back(&f16_direct_stride1_small_group);
  110. direct_algos.emplace_back(&f16_direct_large_group);
  111. direct_algos.emplace_back(&f16_direct_small_group);
  112. #endif
  113. direct_algos.emplace_back(&i8x8x16_direct_large_group);
  114. direct_algos.emplace_back(&i8x8x16_direct_small_group);
  115. direct_algos.emplace_back(&i8x8x16_stride2_filter2);
  116. direct_algos.emplace_back(&i8x8x16_stride2_large_group);
  117. direct_algos.emplace_back(&i8x8x16_stride2_small_group);
  118. direct_algos.emplace_back(&f32_direct_stride2_nchw_nchw44);
  119. direct_algos.emplace_back(&f32_chanel_wise_nchw44);
  120. direct_algos.emplace_back(&f32_direct_nchw44);
  121. direct_algos.emplace_back(&f32_direct_stride1_large_group);
  122. direct_algos.emplace_back(&f32_direct_stride1_small_group);
  123. direct_algos.emplace_back(&f32_direct_stride2_large_group);
  124. direct_algos.emplace_back(&f32_direct_stride2_small_group);
  125. direct_algos.emplace_back(&f32_direct_large_group);
  126. direct_algos.emplace_back(&f32_direct_small_group);
  127. static CpuOprDelegationStorage<2> storage;
  128. auto matmul_opr = storage.get<MatrixMul, 0>();
  129. auto&& matmul_algos =
  130. static_cast<arm_common::MatrixMulImpl*>(matmul_opr)
  131. ->algo_pack();
  132. for (auto&& algo : matmul_algos) {
  133. if (algo->type() == nullptr)
  134. continue;
  135. for (uint32_t tile_size : {8, 16, 24, 32, 40, 48, 64, 80}) {
  136. refhold.emplace_back(new AlgoFP32WinogradF23_4x4(
  137. static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
  138. tile_size));
  139. winograd_algos.emplace_back(refhold.back().get());
  140. refhold.emplace_back(new AlgoFP32WinogradF63(
  141. static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
  142. tile_size));
  143. winograd_algos.emplace_back(refhold.back().get());
  144. refhold.emplace_back(new AlgoFP32WinogradF63_4x4(
  145. static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
  146. tile_size));
  147. winograd_algos.emplace_back(refhold.back().get());
  148. refhold.emplace_back(new AlgoFP32WinogradF54(
  149. static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
  150. tile_size));
  151. winograd_algos.emplace_back(refhold.back().get());
  152. refhold.emplace_back(new AlgoFP32WinogradF45(
  153. static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
  154. tile_size));
  155. winograd_algos.emplace_back(refhold.back().get());
  156. refhold.emplace_back(new AlgoFP32WinogradF23_4x4_NCHW44(
  157. static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
  158. tile_size));
  159. winograd_algos.emplace_back(refhold.back().get());
  160. refhold.emplace_back(new AlgoFP32WinogradF63_4x4_NCHW44(
  161. static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
  162. tile_size));
  163. winograd_algos.emplace_back(refhold.back().get());
  164. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  165. refhold.emplace_back(new AlgoFP16WinogradF23(
  166. static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
  167. tile_size));
  168. winograd_algos.emplace_back(refhold.back().get());
  169. refhold.emplace_back(new AlgoFP16WinogradF45(
  170. static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
  171. tile_size));
  172. winograd_algos.emplace_back(refhold.back().get());
  173. refhold.emplace_back(new AlgoFP16WinogradF63(
  174. static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
  175. tile_size));
  176. winograd_algos.emplace_back(refhold.back().get());
  177. refhold.emplace_back(new AlgoFP16WinogradF23_8x8(
  178. static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
  179. tile_size));
  180. winograd_algos.emplace_back(refhold.back().get());
  181. #endif
  182. refhold.emplace_back(new AlgoS8WinogradF23_8x8(
  183. static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
  184. tile_size));
  185. winograd_algos.emplace_back(refhold.back().get());
  186. }
  187. }
  188. }
  189. SmallVector<AlgoBase*> direct_algos;
  190. SmallVector<AlgoBase*> winograd_algos;
  191. };
  192. SmallVector<ConvBiasImpl::AlgoBase*> ConvBiasImpl::algo_pack() {
  193. static AlgoPack sl_algo_pack;
  194. auto&& algos = fallback::ConvBiasImpl::algo_pack();
  195. algos.insert(algos.begin(), sl_algo_pack.direct_algos.begin(),
  196. sl_algo_pack.direct_algos.end());
  197. algos.insert(algos.end(), sl_algo_pack.winograd_algos.begin(),
  198. sl_algo_pack.winograd_algos.end());
  199. return std::move(algos);
  200. }
  201. void* const ConvBiasImpl::sm_arm_common_algo_type =
  202. &arm_common_algo_type_storage;
  203. bool ConvBiasImpl::is_matmul_quantized_prefer(
  204. const ConvBiasImpl::NCBKernSizeParam& param) {
  205. // fallback::ConvBiasImpl::NCBKernParam conv_ncb_param;
  206. fallback::ConvBiasImpl::NCBKernSizeParam conv_ncb_param(
  207. param, 0, param::MatrixMul::Format::DEFAULT, {}, 0,
  208. BiasMode::NO_BIAS, param::ConvBias::NonlineMode::IDENTITY);
  209. conv_ncb_param.dst_type = param.bias_type;
  210. conv_ncb_param.filter_meta.group = 1;
  211. bool conv_direct_unusable = false;
  212. if (param.dst_type.enumv() == DTypeEnum::QuantizedS8 ||
  213. param.dst_type.enumv() == DTypeEnum::QuantizedS32) {
  214. conv_direct_unusable =
  215. !arm_common::direct_int8_stride1::can_conv_direct_stride1_int8(
  216. conv_ncb_param) &&
  217. !arm_common::direct_int8_stride2::can_conv_direct_stride2_int8(
  218. conv_ncb_param);
  219. } else if (param.dst_type.enumv() == DTypeEnum::Quantized8Asymm) {
  220. conv_direct_unusable =
  221. !arm_common::direct_quint8_stride1::
  222. can_conv_direct_stride1_quint8(conv_ncb_param) &&
  223. !arm_common::direct_quint8_stride2::
  224. can_conv_direct_stride2_quint8(conv_ncb_param);
  225. }
  226. return conv_direct_unusable;
  227. }
  228. const char* ConvBiasImpl::get_algorithm_set_name() const {
  229. // arm common version 0
  230. return "AC0";
  231. }
  232. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台