You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

algos_conv1x1_gemv.cpp 23 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524
  1. /**
  2. * \file dnn/src/fallback/conv_bias/conv1x1/algos_conv1x1_gemv.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "src/fallback/conv_bias/conv1x1/algos_conv1x1_gemv.h"
  13. #include "src/fallback/conv_bias/conv1x1/conv1x1_utils.h"
  14. #include "src/common/opr_delegate.h"
  15. #include "src/fallback/conv_bias/common.h"
  16. #include "src/fallback/conv_bias/opr_impl.h"
  17. #include "megdnn/opr_param_defs.h"
  18. #include "src/naive/convolution/helper.h"
  19. #include "src/fallback/matrix_mul/gemv.h"
  20. #if MEGDNN_X86
  21. #include "src/x86/conv_bias/postprocess_helper.h"
  22. #elif (MEGDNN_ARMV7 || MEGDNN_AARCH64)
  23. #include "src/arm_common/conv_bias/postprocess_helper.h"
  24. #include "src/arm_common/matrix_mul/fp32/exec_sgemv.h"
  25. #include "src/arm_common/matrix_mul/fp16/hgemv.h"
  26. #include "src/arm_common/matrix_mul/int8/gemv.h"
  27. #endif
  28. #include "midout.h"
  29. MIDOUT_DECL(megdnn_fallback_conv1x1_gemv)
  30. using namespace megdnn;
  31. using namespace fallback;
  32. #if MEGDNN_X86
  33. using namespace x86;
  34. #endif
  35. using namespace conv1x1;
  36. namespace {
  37. template <typename stype, typename btype, param::ConvBias::Format F>
  38. struct GemvLike {
  39. inline static void do_gemv(const stype* A, const stype* B, btype* C,
  40. size_t M, size_t N, size_t K, size_t LDA,
  41. size_t LDB, size_t LDC, DType src,
  42. DType filter) {
  43. MEGDNN_MARK_USED_VAR(A);
  44. MEGDNN_MARK_USED_VAR(B);
  45. MEGDNN_MARK_USED_VAR(C);
  46. MEGDNN_MARK_USED_VAR(M);
  47. MEGDNN_MARK_USED_VAR(N);
  48. MEGDNN_MARK_USED_VAR(K);
  49. MEGDNN_MARK_USED_VAR(LDA);
  50. MEGDNN_MARK_USED_VAR(LDB);
  51. MEGDNN_MARK_USED_VAR(LDC);
  52. MEGDNN_MARK_USED_VAR(src);
  53. MEGDNN_MARK_USED_VAR(filter);
  54. megdnn_assert(false,
  55. "unspported conv1x1 gemv : \nsrc_type : "
  56. "%s\nfilter_type : %s\n",
  57. src.name(), filter.name());
  58. }
  59. };
  60. template <typename stype, typename btype>
  61. struct GemvLike<stype, btype, param::ConvBias::Format::NCHW> {
  62. inline static void do_gemv(const stype* A, const stype* B, btype* C,
  63. size_t M, size_t N, size_t K, size_t LDA,
  64. size_t LDB, size_t LDC, DType src,
  65. DType filter) {
  66. MEGDNN_MARK_USED_VAR(src);
  67. MEGDNN_MARK_USED_VAR(filter);
  68. megdnn::fallback::gemv_like<stype, btype>(A, B, C, M, N, K, LDA, LDB,
  69. LDC);
  70. }
  71. };
  72. template <>
  73. struct GemvLike<dt_uint8, dt_int32, param::ConvBias::Format::NCHW> {
  74. inline static void do_gemv(const dt_uint8* A, const dt_uint8* B,
  75. dt_int32* C, size_t M, size_t N, size_t K,
  76. size_t LDA, size_t LDB, size_t LDC, DType src,
  77. DType filter) {
  78. uint8_t zp0 = src.param<dtype::Quantized8Asymm>().zero_point;
  79. uint8_t zp1 = filter.param<dtype::Quantized8Asymm>().zero_point;
  80. megdnn::fallback::gemv_like<dt_uint8, dt_int32>(A, B, C, M, N, K, LDA,
  81. LDB, LDC, zp0, zp1);
  82. }
  83. };
  84. #if MEGDNN_AARCH64 || MEGDNN_ARMV7
  85. template <>
  86. struct GemvLike<dt_float32, dt_float32, param::ConvBias::Format::NCHW> {
  87. inline static void do_gemv(const dt_float32* A, const dt_float32* B,
  88. dt_float32* C, size_t M, size_t N, size_t K,
  89. size_t LDA, size_t LDB, size_t LDC, DType src,
  90. DType filter) {
  91. MEGDNN_MARK_USED_VAR(src);
  92. MEGDNN_MARK_USED_VAR(filter);
  93. megdnn::arm_common::gemv_like(A, B, C, M, N, K, LDA, LDB, LDC);
  94. }
  95. };
  96. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  97. template <>
  98. struct GemvLike<dt_float16, dt_float16, param::ConvBias::Format::NCHW> {
  99. inline static void do_gemv(const dt_float16* A, const dt_float16* B,
  100. dt_float16* C, size_t M, size_t N, size_t K,
  101. size_t LDA, size_t LDB, size_t LDC, DType src,
  102. DType filter) {
  103. MEGDNN_MARK_USED_VAR(src);
  104. MEGDNN_MARK_USED_VAR(filter);
  105. megdnn::arm_common::gemv_like(reinterpret_cast<const __fp16*>(A),
  106. reinterpret_cast<const __fp16*>(B),
  107. reinterpret_cast<__fp16*>(C), M, N, K,
  108. LDA, LDB, LDC);
  109. }
  110. };
  111. #endif
  112. template <>
  113. struct GemvLike<dt_int8, dt_int32, param::ConvBias::Format::NCHW> {
  114. inline static void do_gemv(const dt_int8* A, const dt_int8* B, dt_int32* C,
  115. size_t M, size_t N, size_t K, size_t LDA,
  116. size_t LDB, size_t LDC, DType src,
  117. DType filter) {
  118. MEGDNN_MARK_USED_VAR(src);
  119. MEGDNN_MARK_USED_VAR(filter);
  120. megdnn::arm_common::gemv_like(A, B, C, M, N, K, LDA, LDB, LDC);
  121. }
  122. };
  123. template <typename stype, typename btype>
  124. struct GemvLike<stype, btype, param::ConvBias::Format::NCHW44> {
  125. inline static void do_gemv(const stype* A, const stype* B, btype* C,
  126. size_t M, size_t N, size_t K, size_t LDA,
  127. size_t LDB, size_t LDC, DType src,
  128. DType filter) {
  129. MEGDNN_MARK_USED_VAR(src);
  130. MEGDNN_MARK_USED_VAR(filter);
  131. megdnn::arm_common::gemv_like_mk4(A, B, C, M, N, K, LDA, LDB, LDC);
  132. }
  133. };
  134. #if __ARM_FEATURE_DOTPROD
  135. template <typename stype, typename btype>
  136. struct GemvLike<stype, btype, param::ConvBias::Format::NCHW44_DOT> {
  137. inline static void do_gemv(const stype* A, const stype* B, btype* C,
  138. size_t M, size_t N, size_t K, size_t LDA,
  139. size_t LDB, size_t LDC, DType src,
  140. DType filter) {
  141. MEGDNN_MARK_USED_VAR(src);
  142. MEGDNN_MARK_USED_VAR(filter);
  143. megdnn::arm_common::gemv_like_mk4_dot(A, B, C, M, N, K, LDA, LDB, LDC);
  144. }
  145. };
  146. #endif
  147. #endif
  148. template <typename src_ctype, typename bias_ctype, typename dst_ctype,
  149. typename op_ctype, typename op_dtype,
  150. megdnn::PostprocessMode postprocess_mode,
  151. param::ConvBias::Format format>
  152. struct Conv1x1GemvWorker {
  153. static void exec(WorkspaceBundle& whole_bundle,
  154. WorkspaceBundle& thread_bundle, size_t oc_tile_size,
  155. const ConvBiasImpl::NCBKernSizeParam& param,
  156. const ConvBiasImpl::NCBKernParam& ncb_param,
  157. const ConvBiasImpl::NCBKernIndex& ncb_index) {
  158. whole_bundle.set(ncb_param.workspace_ptr);
  159. size_t OC = param.filter_meta.ocpg;
  160. size_t IC = param.filter_meta.icpg;
  161. size_t batch_id = ncb_index.ndrange_id[0];
  162. size_t group_id = ncb_index.ndrange_id[1];
  163. size_t oc_tile_id_in_group = ncb_index.ndrange_id[2];
  164. size_t thread_id = ncb_index.thread_id;
  165. size_t oc_start = oc_tile_size * oc_tile_id_in_group;
  166. size_t oc_end = oc_start + oc_tile_size;
  167. oc_end = (oc_end <= OC ? oc_end : OC);
  168. size_t numbers_of_ncb_filter_offset =
  169. oc_tile_size * IC * oc_tile_id_in_group;
  170. const src_ctype* Aptr = ncb_param.filter<src_ctype>(group_id) +
  171. numbers_of_ncb_filter_offset;
  172. const src_ctype* Bptr = ncb_param.src<src_ctype>(batch_id, group_id);
  173. size_t thread_offset = thread_bundle.total_size_in_bytes() * thread_id;
  174. size_t bytes_offset_of_matmul_dst_this_thread =
  175. thread_offset + thread_bundle.get_size(0);
  176. bias_ctype* matmul_temp_dst = reinterpret_cast<bias_ctype*>(
  177. reinterpret_cast<int8_t*>(whole_bundle.get(0)) +
  178. bytes_offset_of_matmul_dst_this_thread);
  179. size_t numbers_of_ncb_dst_offset = oc_tile_size * oc_tile_id_in_group;
  180. dst_ctype* conv_bias_dst =
  181. ncb_param.dst<dst_ctype>(batch_id, group_id) +
  182. numbers_of_ncb_dst_offset;
  183. bool is_dst_8bit =
  184. (param.src_type.enumv() == DTypeEnum::QuantizedS8 &&
  185. param.dst_type.enumv() == DTypeEnum::QuantizedS8) ||
  186. (param.src_type.enumv() == DTypeEnum::Quantized8Asymm &&
  187. param.dst_type.enumv() == DTypeEnum::Quantized8Asymm);
  188. bias_ctype* gemv_dst =
  189. is_dst_8bit ? matmul_temp_dst
  190. : reinterpret_cast<bias_ctype*>(conv_bias_dst);
  191. size_t pack_size = megdnn::fallback::pack_size(format);
  192. GemvLike<src_ctype, bias_ctype, format>::do_gemv(
  193. Aptr, Bptr, gemv_dst, oc_end - oc_start, 1, IC, IC * pack_size,
  194. pack_size, pack_size, ncb_param.filter_type,
  195. ncb_param.src_type);
  196. //! do postprocess
  197. void* bias_ptr = nullptr;
  198. if (param.bias_mode != megdnn::BiasMode::NO_BIAS) {
  199. bias_ptr = static_cast<void*>(const_cast<bias_ctype*>(
  200. ncb_param.bias<bias_ctype>(batch_id, group_id) +
  201. numbers_of_ncb_dst_offset));
  202. }
  203. PostProcess<op_ctype, op_dtype, postprocess_mode>::run(
  204. gemv_dst, bias_ptr, conv_bias_dst, param.bias_mode,
  205. param.nonlineMode, param.bias_type, param.dst_type, 1_z,
  206. oc_end - oc_start, 1, 1, 1);
  207. }
  208. };
  209. } // namespace
  210. size_t ConvBiasImpl::AlgoConv1x1Gemv::get_oc_tile_size_heuristic(
  211. const NCBKernSizeParam& param) const {
  212. MIDOUT_BEGIN(megdnn_fallback_conv1x1_gemv,
  213. midout_iv("AlgoConv1x1Gemv::get_oc_tile"_hash)) {
  214. size_t OC = param.filter_meta.ocpg;
  215. size_t oc_block_size_one_thread = div_ceil(OC, param.nr_threads);
  216. return round_up<size_t>(oc_block_size_one_thread, 16);
  217. }
  218. MIDOUT_END();
  219. }
  220. size_t ConvBiasImpl::AlgoConv1x1Gemv::get_workspace(
  221. const NCBKernSizeParam& param) const {
  222. MIDOUT_BEGIN(megdnn_fallback_conv1x1_gemv,
  223. midout_iv("AlgoConv1x1Gemv::get_workspace"_hash)) {
  224. size_t compt_oc_block_size = get_oc_tile_size_heuristic(param);
  225. auto thread_bundle =
  226. utils::get_thread_bundle(param, 0, compt_oc_block_size);
  227. return WorkspaceBundle{
  228. nullptr,
  229. {thread_bundle.total_size_in_bytes() * param.nr_threads}}
  230. .total_size_in_bytes();
  231. }
  232. MIDOUT_END();
  233. }
  234. SmallVector<ConvBiasImpl::NCBKern>
  235. ConvBiasImpl::AlgoConv1x1Gemv::dispatch_kerns(
  236. const NCBKernSizeParam& param) const {
  237. SmallVector<ConvBiasImpl::NCBKern> ret_kern;
  238. size_t OC = param.filter_meta.ocpg;
  239. size_t compt_oc_block_size = get_oc_tile_size_heuristic(param);
  240. size_t GROUP = param.filter_meta.group;
  241. size_t BATCH = param.n;
  242. size_t oc_blocks_per_group = div_ceil(OC, compt_oc_block_size);
  243. //! get thread bundle
  244. auto thread_bundle =
  245. utils::get_thread_bundle(param, 0, compt_oc_block_size);
  246. auto whole_bundle = WorkspaceBundle{
  247. nullptr, {thread_bundle.total_size_in_bytes() * param.nr_threads}};
  248. using conv1x1_gemv_kern =
  249. std::function<void(WorkspaceBundle&, WorkspaceBundle&, size_t,
  250. const ConvBiasImpl::NCBKernSizeParam&,
  251. const ConvBiasImpl::NCBKernParam&,
  252. const ConvBiasImpl::NCBKernIndex&)>;
  253. conv1x1_gemv_kern conv1x1_gemv_worker = nullptr;
  254. #define cb1(_format, _dt, _post_ctype, _postprocess_mode, _midout_tag) \
  255. MIDOUT_BEGIN(megdnn_fallback_conv1x1_gemv, midout_iv(_midout_tag)) { \
  256. if (param.filter_type.enumv() == DTypeTrait<_dt>::enumv) { \
  257. conv1x1_gemv_worker = \
  258. Conv1x1GemvWorker<_dt, _dt, _dt, _post_ctype, _post_ctype, \
  259. _postprocess_mode, _format>::exec; \
  260. } \
  261. } \
  262. MIDOUT_END()
  263. #define cb2(_format, _i_src_type, _i_bias_type, _i_dst_type, _src_ctype, \
  264. _bias_ctype, _dst_ctype, _postprocess_mode, _midout_tag) \
  265. MIDOUT_BEGIN(megdnn_fallback_conv1x1_gemv, midout_iv(_midout_tag)) { \
  266. if (param.filter_type.enumv() == param.src_type.enumv() && \
  267. param.src_type.enumv() == DTypeTrait<_i_src_type>::enumv && \
  268. param.dst_type.enumv() == DTypeTrait<_i_dst_type>::enumv) { \
  269. conv1x1_gemv_worker = \
  270. Conv1x1GemvWorker<_src_ctype, _bias_ctype, _dst_ctype, \
  271. DTypeTrait<_i_bias_type>::ctype, \
  272. DTypeTrait<_i_dst_type>::ctype, \
  273. _postprocess_mode, _format>::exec; \
  274. } \
  275. } \
  276. MIDOUT_END()
  277. #define cb3(_format, _i_src_type, _i_bias_type, _i_dst_type, _src_ctype, \
  278. _bias_ctype, _dst_ctype, _postprocess_mode, _midout_tag) \
  279. MIDOUT_BEGIN(megdnn_fallback_conv1x1_gemv, midout_iv(_midout_tag)) { \
  280. if (param.filter_type.enumv() == param.src_type.enumv() && \
  281. param.src_type.enumv() == DTypeTrait<_i_src_type>::enumv && \
  282. param.dst_type.enumv() == DTypeTrait<_i_dst_type>::enumv) { \
  283. conv1x1_gemv_worker = \
  284. Conv1x1GemvWorker<_src_ctype, _bias_ctype, _dst_ctype, \
  285. _bias_ctype, _dst_ctype, \
  286. _postprocess_mode, _format>::exec; \
  287. } \
  288. } \
  289. MIDOUT_END()
  290. switch (param.filter_meta.format) {
  291. case param::ConvBias::Format::NCHW:
  292. cb1(param::ConvBias::Format::NCHW, dt_float32, dt_float32,
  293. PostprocessMode::FLOAT, "NCHW::GEMV::FLOAT"_hash);
  294. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  295. cb1(param::ConvBias::Format::NCHW, dt_float16, __fp16,
  296. PostprocessMode::FLOAT, "NCHW::GEMV::FLOAT16_FP16"_hash);
  297. #else
  298. #if !MEGDNN_DISABLE_FLOAT16
  299. cb1(param::ConvBias::Format::NCHW, dt_float16, dt_float16,
  300. PostprocessMode::NO_PROCESS, "NCHW::GEMV::FLOAT16_FLOAT16"_hash);
  301. #endif
  302. #endif
  303. cb3(param::ConvBias::Format::NCHW, dt_int8, dt_int32, dt_int32,
  304. dt_int8, dt_int32, dt_int32, PostprocessMode::ADD_BIAS,
  305. "NCHW::GEMV::INT8x8x32_INT32"_hash);
  306. cb3(param::ConvBias::Format::NCHW, dt_int8, dt_int16, dt_int16,
  307. dt_int8, dt_int16, dt_int16, PostprocessMode::ADD_BIAS,
  308. "NCHW::GEMV::INT8x8x16_INT16"_hash);
  309. cb3(param::ConvBias::Format::NCHW, dtype::QuantizedS8,
  310. dtype::QuantizedS32, dtype::QuantizedS32, dt_int8, dt_int32,
  311. dt_int32, PostprocessMode::ADD_BIAS,
  312. "NCHW::GEMV::QINT8x8x32_QINT32"_hash);
  313. cb2(param::ConvBias::Format::NCHW, dtype::QuantizedS8,
  314. dtype::QuantizedS32, dtype::QuantizedS8, dt_int8, dt_int32,
  315. dt_int8, PostprocessMode::QUANTIZED,
  316. "NCHW::GEMV::QINT8x8x32_QINT8"_hash);
  317. cb3(param::ConvBias::Format::NCHW, dtype::Quantized8Asymm,
  318. dtype::QuantizedS32, dtype::QuantizedS32, dt_uint8, dt_int32,
  319. dt_int32, PostprocessMode::ADD_BIAS,
  320. "NCHW::GEMV::QUINT8x8x32_QINT32"_hash);
  321. cb2(param::ConvBias::Format::NCHW, dtype::Quantized8Asymm,
  322. dtype::QuantizedS32, dtype::Quantized8Asymm, dt_uint8, dt_int32,
  323. dt_uint8, PostprocessMode::QUANTIZED,
  324. "NCHW::GEMV::QUINT8x8x32_QUINT8"_hash);
  325. break;
  326. //!no support nchw44 8x8x16
  327. case param::ConvBias::Format::NCHW44:
  328. cb1(param::ConvBias::Format::NCHW44, dt_float32, dt_float32,
  329. PostprocessMode::FLOAT, "NCHW44::GEMV::FLOAT"_hash);
  330. cb3(param::ConvBias::Format::NCHW44, dt_int8, dt_int32, dt_int32,
  331. dt_int8, dt_int32, dt_int32, PostprocessMode::ADD_BIAS,
  332. "NCHW44::GEMV::INT8x8x32_INT32"_hash);
  333. cb3(param::ConvBias::Format::NCHW44, dtype::QuantizedS8,
  334. dtype::QuantizedS32, dtype::QuantizedS32, dt_int8, dt_int32,
  335. dt_int32, PostprocessMode::ADD_BIAS,
  336. "NCHW44::GEMV::QINT8x8x32_QINT32"_hash);
  337. cb2(param::ConvBias::Format::NCHW44, dtype::QuantizedS8,
  338. dtype::QuantizedS32, dtype::QuantizedS8, dt_int8, dt_int32,
  339. dt_int8, PostprocessMode::QUANTIZED,
  340. "NCHW44::GEMV::QINT8x8x32_QINT8"_hash);
  341. break;
  342. //!no support nchw44-dot 8x8x16
  343. case param::ConvBias::Format::NCHW44_DOT:
  344. cb3(param::ConvBias::Format::NCHW44_DOT, dt_int8, dt_int32,
  345. dt_int32, dt_int8, dt_int32, dt_int32,
  346. PostprocessMode::ADD_BIAS,
  347. "NCHW44_DOT::GEMV::INT8x8x32_INT32"_hash);
  348. cb3(param::ConvBias::Format::NCHW44_DOT, dtype::QuantizedS8,
  349. dtype::QuantizedS32, dtype::QuantizedS32, dt_int8, dt_int32,
  350. dt_int32, PostprocessMode::ADD_BIAS,
  351. "NCHW44_DOT::GEMV::QINT8x8x32_QINT32"_hash);
  352. cb2(param::ConvBias::Format::NCHW44_DOT, dtype::QuantizedS8,
  353. dtype::QuantizedS32, dtype::QuantizedS8, dt_int8, dt_int32,
  354. dt_int8, PostprocessMode::QUANTIZED,
  355. "NCHW44_DOT::GEMV::QINT8x8x32_QINT8"_hash);
  356. break;
  357. default:
  358. megdnn_throw("Invalid Format");
  359. break;
  360. }
  361. #undef cb1
  362. #undef cb2
  363. #undef cb3
  364. megdnn_assert(conv1x1_gemv_worker, "No suitable gemv worker");
  365. auto kern_compt =
  366. [compt_oc_block_size, param, conv1x1_gemv_worker, whole_bundle,
  367. thread_bundle](
  368. const ConvBiasImpl::NCBKernParam& ncb_param,
  369. const ConvBiasImpl::NCBKernIndex& ncb_index) mutable {
  370. conv1x1_gemv_worker(whole_bundle, thread_bundle,
  371. compt_oc_block_size, param, ncb_param,
  372. std::move(ncb_index));
  373. };
  374. ret_kern.push_back({kern_compt, {BATCH, GROUP, oc_blocks_per_group}});
  375. return ret_kern;
  376. }
  377. bool ConvBiasImpl::AlgoConv1x1Gemv::usable(const NCBKernSizeParam& param,
  378. AlgoSelectionStrategy) const {
  379. MIDOUT_BEGIN(megdnn_fallback_conv1x1_gemv,
  380. midout_iv("AlgoConv1x1Gemv::usable"_hash)) {
  381. auto format = param.filter_meta.format;
  382. size_t FH = param.filter_meta.spatial[0],
  383. FW = param.filter_meta.spatial[1];
  384. size_t PH = param.filter_meta.padding[0],
  385. PW = param.filter_meta.padding[1];
  386. size_t SH = param.filter_meta.stride[0],
  387. SW = param.filter_meta.stride[1];
  388. size_t OH = param.osz[0];
  389. size_t OW = param.osz[1];
  390. //! whether gemv and 1x1
  391. if (OH * OW != 1 || FH != 1 || FW != 1 || PH || PW || SH != 1 ||
  392. SW != 1) {
  393. return false;
  394. }
  395. #if MEGDNN_AARCH64 || MEGDNN_ARMV7
  396. if (format != param::ConvBias::Format::NCHW &&
  397. format != param::ConvBias::Format::NCHW44 &&
  398. format != param::ConvBias::Format::NCHW44_DOT) {
  399. return false;
  400. }
  401. #else
  402. if (format != param::ConvBias::Format::NCHW) {
  403. return false;
  404. }
  405. #endif
  406. //! supports a few dtypes
  407. if (param.src_type.enumv() != param.filter_type.enumv() ||
  408. (param.src_type.enumv() != DTypeEnum::Int8 &&
  409. param.src_type.enumv() != DTypeEnum::QuantizedS8 &&
  410. param.src_type.enumv() != DTypeEnum::Quantized8Asymm &&
  411. #if !MEGDNN_DISABLE_FLOAT16
  412. param.src_type.enumv() != DTypeEnum::Float16 &&
  413. #endif
  414. param.src_type.enumv() != DTypeEnum::Float32)) {
  415. return false;
  416. }
  417. //! x86 disable Quntized8Asymm
  418. #if MEGDNN_X86
  419. if (param.src_type.enumv() == DTypeEnum::Quantized8Asymm) {
  420. return false;
  421. }
  422. #endif
  423. if (format == param::ConvBias::Format::NCHW44) {
  424. if (param.src_type.enumv() != DTypeEnum::Float32 &&
  425. param.src_type.enumv() != DTypeEnum::Int8 &&
  426. param.src_type.enumv() != DTypeEnum::QuantizedS8) {
  427. return false;
  428. }
  429. //! 8x8x16 is not support nchw44
  430. if (param.src_type.enumv() == DTypeEnum::Int8 &&
  431. param.dst_type.enumv() == DTypeEnum::Int16) {
  432. return false;
  433. }
  434. } else if (format == param::ConvBias::Format::NCHW44_DOT) {
  435. if ((param.src_type.enumv() != DTypeEnum::Int8 &&
  436. param.src_type.enumv() != DTypeEnum::QuantizedS8) ||
  437. param.dst_type.enumv() == DTypeEnum::Int16) {
  438. return false;
  439. }
  440. }
  441. //! make sure 8x8x16 and 8x8x32 biasmode nonlineMode is identity
  442. //! otherwise return false
  443. if (param.dst_type.enumv() == DTypeEnum::Int16 ||
  444. param.dst_type.enumv() == DTypeEnum::Int32 ||
  445. param.dst_type.enumv() == DTypeEnum::QuantizedS32) {
  446. if (param.nonlineMode != megdnn::NonlineMode::IDENTITY) {
  447. return false;
  448. }
  449. }
  450. //! even no naive support in gemv
  451. if ((param.src_type.enumv() == param.filter_type.enumv() &&
  452. param.src_type.enumv() == DTypeEnum::Int16) &&
  453. param.dst_type.enumv() == DTypeEnum::Int32) {
  454. return false;
  455. }
  456. return (param.filter_meta.dilation[0] ==
  457. param.filter_meta.dilation[1] &&
  458. param.filter_meta.dilation[0] == 1) &&
  459. param.compute_mode == param::ConvBias::ComputeMode::DEFAULT;
  460. }
  461. MIDOUT_END();
  462. return false;
  463. }
  464. bool ConvBiasImpl::AlgoConv1x1Gemv::is_preferred(
  465. const NCBKernSizeParam& param) const {
  466. MIDOUT_BEGIN(megdnn_fallback_conv1x1_gemv,
  467. midout_iv("AlgoConv1x1Gemv::is_preferred"_hash)) {
  468. #if (MEGDNN_ARMV7 || MEGDNN_AARCH64)
  469. if (param.filter_meta.format == param::ConvBias::Format::NCHW &&
  470. param.src_type.enumv() == DTypeEnum::Quantized8Asymm) {
  471. return false;
  472. }
  473. #endif
  474. return true;
  475. }
  476. MIDOUT_END();
  477. return false;
  478. }
  479. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台