You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

factory.h 26 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563
  1. /**
  2. * \file dnn/src/fallback/conv_bias/im2col/factory.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #pragma once
  12. #include <unordered_map>
  13. #include "src/fallback/conv_bias/im2col/strategy_base.h"
  14. #include "src/fallback/conv_bias/opr_impl.h"
  15. #include "midout.h"
  16. MIDOUT_DECL(megdnn_fallback_im2col_factory_make_strategy)
  17. namespace megdnn {
  18. namespace fallback {
  19. namespace im2col {
  20. enum class StrategyType : uint32_t {
  21. FLOAT = 0,
  22. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  23. FLOAT_FP16 = 1,
  24. #else
  25. #if !MEGDNN_DISABLE_FLOAT16
  26. FLOAT16_FLOAT16 = 2,
  27. #endif
  28. #endif
  29. INT8x8x32 = 3,
  30. INT8x8x16 = 4,
  31. #if MEGDNN_AARCH64 || MEGDNN_ARMV7
  32. QUINT8x8x32 = 5,
  33. QUINT8x8x32x8 = 6,
  34. #endif
  35. QINT8x8x32 = 7,
  36. QINT8x8x32x8 = 8
  37. };
  38. struct StrategyHashParam {
  39. bool is_xcorr;
  40. bool is_square; //! kernel_h == kernel_w, stride_h = stride_w
  41. size_t block_m;
  42. size_t block_n;
  43. size_t block_k;
  44. size_t kernel;
  45. size_t stride;
  46. fallback::ConvBiasImpl::NCBKernSizeParam param;
  47. param::ConvBias::Format format;
  48. fallback::MatrixMulImpl::AlgoBase::PackMode packmode;
  49. };
  50. struct StrategyHashParamHash {
  51. uint64_t operator()(const StrategyHashParam& sparam) const {
  52. constexpr uint64_t base = 1; //! avoid hashkey is zero
  53. uint64_t result =
  54. static_cast<uint64_t>(sparam.param.src_type.enumv()) + base;
  55. result = result ^
  56. ((static_cast<uint64_t>(sparam.param.dst_type.enumv()) + base)
  57. << 3);
  58. result = result ^
  59. ((static_cast<uint64_t>(sparam.param.filter_type.enumv()) +
  60. base)
  61. << 6);
  62. result = result ^
  63. ((static_cast<uint64_t>(sparam.param.bias_type.enumv()) + base)
  64. << 9);
  65. result = result ^ ((static_cast<uint64_t>(sparam.format) + base) << 12);
  66. result = result ^
  67. ((static_cast<uint64_t>(sparam.packmode) + base) << 15);
  68. result =
  69. result ^ ((static_cast<uint64_t>(sparam.block_m) + base) << 18);
  70. result =
  71. result ^ ((static_cast<uint64_t>(sparam.block_n) + base) << 22);
  72. result =
  73. result ^ ((static_cast<uint64_t>(sparam.block_k) + base) << 26);
  74. result = result ^ ((static_cast<uint64_t>(sparam.kernel) + base) << 30);
  75. result = result ^ ((static_cast<uint64_t>(sparam.stride) + base) << 34);
  76. result = result ^
  77. ((static_cast<uint64_t>(sparam.is_square) + base) << 35);
  78. result = result ^
  79. ((static_cast<uint64_t>(sparam.is_xcorr) + base) << 36);
  80. return result;
  81. };
  82. };
  83. struct StrategyHashParamEqual {
  84. bool operator()(const StrategyHashParam& param1,
  85. const StrategyHashParam& param2) const {
  86. bool flags = true;
  87. flags = param1.param.src_type == param2.param.src_type && flags;
  88. flags = param1.param.filter_type == param2.param.filter_type && flags;
  89. flags = param1.param.bias_type == param2.param.bias_type && flags;
  90. flags = param1.param.dst_type == param2.param.dst_type && flags;
  91. flags = param1.format == param2.format && flags;
  92. flags = param1.packmode == param2.packmode && flags;
  93. flags = param1.block_m == param2.block_m && flags;
  94. flags = param1.block_n == param2.block_n && flags;
  95. flags = param1.block_k == param2.block_k && flags;
  96. flags = param1.kernel == param2.kernel && flags;
  97. flags = param1.stride == param2.stride && flags;
  98. flags = param1.is_square == param2.is_square && flags;
  99. flags = param1.is_xcorr == param2.is_xcorr && flags;
  100. return flags;
  101. };
  102. };
  103. class StrategyDelegationStorage {
  104. std::mutex m_mtx;
  105. std::unordered_map<StrategyHashParam, std::unique_ptr<StrategyBase>,
  106. StrategyHashParamHash, StrategyHashParamEqual>
  107. map_strategys;
  108. public:
  109. ~StrategyDelegationStorage() = default;
  110. template <typename Strategy>
  111. Strategy* get(fallback::MatrixMulImpl::AlgoBase* matmul_algo,
  112. const fallback::ConvBiasImpl::NCBKernSizeParam& param,
  113. StrategyType stype);
  114. };
  115. class Factory {
  116. public:
  117. static StrategyBase* get_im2col_strategy(
  118. const fallback::ConvBiasImpl::NCBKernSizeParam& param,
  119. fallback::MatrixMulImpl::AlgoBase* matmul_algo) {
  120. static StrategyDelegationStorage storage;
  121. StrategyType strategytype = get_strategy_type(param);
  122. return storage.get<StrategyBase>(matmul_algo, param, strategytype);
  123. }
  124. static StrategyType get_strategy_type(
  125. const fallback::ConvBiasImpl::NCBKernSizeParam& param) {
  126. #define cb1(_dt, _post_ctype, _strategytype) \
  127. if (param.filter_type.enumv() == DTypeTrait<_dt>::enumv) { \
  128. return _strategytype; \
  129. }
  130. #define cb2(_i_src_type, _i_bias_type, _i_dst_type, _src_ctype, _bias_ctype, \
  131. _dst_ctype, _strategytype) \
  132. if (param.filter_type.enumv() == param.src_type.enumv() && \
  133. param.src_type.enumv() == DTypeTrait<_i_src_type>::enumv && \
  134. param.dst_type.enumv() == DTypeTrait<_i_dst_type>::enumv) { \
  135. return _strategytype; \
  136. }
  137. cb1(dt_float32, dt_float32, StrategyType::FLOAT);
  138. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  139. cb1(dt_float16, __fp16, StrategyType::FLOAT_FP16);
  140. #else
  141. #if !MEGDNN_DISABLE_FLOAT16
  142. cb1(dt_float16, dt_float16, StrategyType::FLOAT16_FLOAT16);
  143. #endif
  144. #endif
  145. cb2(dt_int8, dt_int32, dt_int32, dt_int8, dt_int32, dt_int32,
  146. StrategyType::INT8x8x32);
  147. cb2(dt_int8, dt_int16, dt_int16, dt_int8, dt_int16, dt_int16,
  148. StrategyType::INT8x8x16);
  149. #if MEGDNN_AARCH64 || MEGDNN_ARMV7
  150. cb2(dtype::Quantized8Asymm, dtype::QuantizedS32, dtype::QuantizedS32,
  151. dt_uint8, dt_int32, dt_int32, StrategyType::QUINT8x8x32);
  152. cb2(dtype::Quantized8Asymm, dtype::QuantizedS32, dtype::Quantized8Asymm,
  153. dt_uint8, dt_int32, dt_uint8, StrategyType::QUINT8x8x32x8);
  154. #endif
  155. cb2(dtype::QuantizedS8, dtype::QuantizedS32, dtype::QuantizedS32,
  156. dt_int8, dt_int32, dt_int32, StrategyType::QINT8x8x32);
  157. cb2(dtype::QuantizedS8, dtype::QuantizedS32, dtype::QuantizedS8,
  158. dt_int8, dt_int32, dt_int8, StrategyType::QINT8x8x32x8);
  159. #undef cb1
  160. #undef cb2
  161. megdnn_throw("not support datatype in im2col strategy\n");
  162. }
  163. #define cb1(_format, _packmode, _dt, _post_ctype, _postprocess_mode, \
  164. _midout_tag) \
  165. MIDOUT_BEGIN(megdnn_fallback_im2col_factory_make_strategy, \
  166. midout_iv(_midout_tag)) { \
  167. if (param.filter_type.enumv() == DTypeTrait<_dt>::enumv) { \
  168. return std::make_unique< \
  169. Strategy<_dt, _dt, _dt, _post_ctype, _post_ctype, \
  170. _postprocess_mode, PackMode::_packmode, \
  171. FormatMode::_format>>(); \
  172. } \
  173. } \
  174. MIDOUT_END(); \
  175. return {};
  176. #define cb2(_format, _packmode, _i_src_type, _i_bias_type, _i_dst_type, \
  177. _src_ctype, _bias_ctype, _dst_ctype, _postprocess_mode, \
  178. _midout_tag) \
  179. MIDOUT_BEGIN(megdnn_fallback_im2col_factory_make_strategy, \
  180. midout_iv(_midout_tag)) { \
  181. if (param.filter_type.enumv() == param.src_type.enumv() && \
  182. param.src_type.enumv() == DTypeTrait<_i_src_type>::enumv && \
  183. param.dst_type.enumv() == DTypeTrait<_i_dst_type>::enumv) { \
  184. return std::make_unique<Strategy< \
  185. _src_ctype, _bias_ctype, _dst_ctype, \
  186. DTypeTrait<_i_bias_type>::ctype, \
  187. DTypeTrait<_i_dst_type>::ctype, _postprocess_mode, \
  188. PackMode::_packmode, FormatMode::_format>>(); \
  189. } \
  190. } \
  191. MIDOUT_END(); \
  192. return {};
  193. static std::unique_ptr<StrategyBase> make_default_strategy(
  194. fallback::MatrixMulImpl::AlgoBase* matmul_algo,
  195. const fallback::ConvBiasImpl::NCBKernSizeParam& param,
  196. StrategyType strategytype) {
  197. MEGDNN_MARK_USED_VAR(matmul_algo);
  198. param::ConvBias::Format format = param.filter_meta.format;
  199. switch (strategytype) {
  200. case StrategyType::FLOAT:
  201. if (format == param::ConvBias::Format::NCHW) {
  202. cb1(NCHW, DEFAULT, dt_float32, dt_float32,
  203. PostprocessMode::FLOAT,
  204. "DefaultStrategyType::FLOAT"_hash);
  205. } else if (format == param::ConvBias::Format::NCHW44) {
  206. #if MEGDNN_AARCH64 || MEGDNN_ARMV7
  207. auto matmul_block = matmul_algo->get_inner_block_size();
  208. //! Optimize NCHW44 3x3s2 aarch64 8X12X1 and armv7 4x12x1 im2col+pack fuse
  209. if ((matmul_block.m == 8 || matmul_block.m == 4) &&
  210. matmul_block.n == 12 && matmul_block.k == 1 &&
  211. param.filter_meta.spatial[0] == 3 &&
  212. param.filter_meta.spatial[1] == 3 &&
  213. param.filter_meta.stride[0] == 2 &&
  214. param.filter_meta.stride[1] == 2 &&
  215. !param.filter_meta.should_flip) {
  216. MIDOUT_BEGIN(
  217. megdnn_fallback_im2col_factory_make_strategy,
  218. midout_iv(
  219. "DefaultStrategyType::8x12x1_fuse_packb_s2_nchw44"_hash)) {
  220. return std::make_unique<
  221. StrategyFuseXx12x1Nchw44K3x3S2<
  222. float, float,
  223. PostprocessMode::FLOAT>>();
  224. }
  225. MIDOUT_END();
  226. return {};
  227. }
  228. #endif
  229. cb1(NCHW44, DEFAULT, dt_float32, dt_float32,
  230. PostprocessMode::FLOAT,
  231. "DefaultStrategyTypeNCHW44::FLOAT"_hash);
  232. } else {
  233. megdnn_throw(
  234. ssprintf("Current only support layout "
  235. "NCHW44/NCHW for im2col "
  236. "algo, but got %d\n",
  237. uint32_t(format)));
  238. }
  239. break;
  240. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  241. case StrategyType::FLOAT_FP16:
  242. cb1(NCHW, DEFAULT, dt_float16, __fp16, PostprocessMode::FLOAT,
  243. "DefaultStrategyType::FLOAT_FP16"_hash);
  244. break;
  245. #else
  246. #if !MEGDNN_DISABLE_FLOAT16
  247. case StrategyType::FLOAT16_FLOAT16:
  248. cb1(NCHW, DEFAULT, dt_float16, dt_float16,
  249. PostprocessMode::NO_PROCESS,
  250. "DefaultStrategyType::FLOAT16_FLOAT16"_hash);
  251. break;
  252. #endif
  253. #endif
  254. case StrategyType::INT8x8x32:
  255. if (format == param::ConvBias::Format::NCHW) {
  256. cb2(NCHW, DEFAULT, dt_int8, dt_int32, dt_int32, dt_int8,
  257. dt_int32, dt_int32, PostprocessMode::NO_PROCESS,
  258. "DefaultStrategyType::INT8x8x32"_hash);
  259. } else if (format == param::ConvBias::Format::NCHW44 ||
  260. format == param::ConvBias::Format::NCHW44_DOT) {
  261. cb2(NCHW44, DEFAULT, dt_int8, dt_int32, dt_int32, dt_int8,
  262. dt_int32, dt_int32, PostprocessMode::NO_PROCESS,
  263. "DefaultStrategyType::INT8x8x32"_hash);
  264. } else {
  265. megdnn_throw(
  266. ssprintf("Current only support layout "
  267. "NCHW44/NCHW/NCHW_DOT for im2col "
  268. "algo, but got %d\n",
  269. uint32_t(format)));
  270. }
  271. break;
  272. case StrategyType::INT8x8x16:
  273. cb2(NCHW, DEFAULT, dt_int8, dt_int16, dt_int16, dt_int8,
  274. dt_int16, dt_int16, PostprocessMode::NO_PROCESS,
  275. "DefaultStrategyType::INT8x8x16"_hash);
  276. break;
  277. #if MEGDNN_AARCH64 || MEGDNN_ARMV7
  278. case StrategyType::QUINT8x8x32:
  279. cb2(NCHW, DEFAULT, dtype::Quantized8Asymm, dtype::QuantizedS32,
  280. dtype::QuantizedS32, dt_uint8, dt_int32, dt_int32,
  281. PostprocessMode::NO_PROCESS,
  282. "DefaultStrategyType::QUINT8x8x32"_hash);
  283. break;
  284. case StrategyType::QUINT8x8x32x8:
  285. cb2(NCHW, DEFAULT, dtype::Quantized8Asymm, dtype::QuantizedS32,
  286. dtype::Quantized8Asymm, dt_uint8, dt_int32, dt_uint8,
  287. PostprocessMode::QUANTIZED,
  288. "DefaultStrategyType::QUINT8x8x32x8"_hash);
  289. break;
  290. #endif
  291. case StrategyType::QINT8x8x32:
  292. if (format == param::ConvBias::Format::NCHW) {
  293. cb2(NCHW, DEFAULT, dtype::QuantizedS8, dtype::QuantizedS32,
  294. dtype::QuantizedS32, dt_int8, dt_int32, dt_int32,
  295. PostprocessMode::NO_PROCESS,
  296. "DefaultStrategyTypeNCHW::QINT8x8x32"_hash);
  297. } else if (format == param::ConvBias::Format::NCHW44 ||
  298. format == param::ConvBias::Format::NCHW44_DOT) {
  299. cb2(NCHW44, DEFAULT, dtype::QuantizedS8,
  300. dtype::QuantizedS32, dtype::QuantizedS32, dt_int8,
  301. dt_int32, dt_int32, PostprocessMode::NO_PROCESS,
  302. "DefaultStrategyTypeHCHW44::QINT8x8x32"_hash);
  303. } else {
  304. megdnn_throw(
  305. ssprintf("Current only support layout "
  306. "NCHW44/NCHW/NCHW_DOT for im2col "
  307. "algo, but got %d\n",
  308. uint32_t(format)));
  309. }
  310. break;
  311. case StrategyType::QINT8x8x32x8:
  312. if (format == param::ConvBias::Format::NCHW) {
  313. cb2(NCHW, DEFAULT, dtype::QuantizedS8, dtype::QuantizedS32,
  314. dtype::QuantizedS8, dt_int8, dt_int32, dt_int8,
  315. PostprocessMode::QUANTIZED,
  316. "DefaultStrategyType::QINT8x8x32x8"_hash);
  317. } else if (format == param::ConvBias::Format::NCHW44 ||
  318. format == param::ConvBias::Format::NCHW44_DOT) {
  319. if (format == param::ConvBias::Format::NCHW44) {
  320. //! Optimize NCHW44 3x3s1 4X4X16 im2col+pack fuse
  321. #if MEGDNN_AARCH64
  322. auto matmul_block = matmul_algo->get_inner_block_size();
  323. if (matmul_block.m == 4 && matmul_block.n == 4 &&
  324. matmul_block.k == 16 &&
  325. param.filter_meta.spatial[0] == 3 &&
  326. param.filter_meta.spatial[1] == 3 &&
  327. param.filter_meta.stride[0] == 1 &&
  328. param.filter_meta.stride[1] == 1 &&
  329. !param.filter_meta.should_flip) {
  330. MIDOUT_BEGIN(
  331. megdnn_fallback_im2col_factory_make_strategy,
  332. midout_iv(
  333. "DefaultStrategyType::INT8x8x32_4x4x16"_hash)) {
  334. return std::make_unique<
  335. StrategyFuse4x4x16Nchw44<
  336. dt_qint32, dt_qint8,
  337. PostprocessMode::QUANTIZED>>();
  338. }
  339. MIDOUT_END();
  340. return {};
  341. }
  342. #endif
  343. } else {
  344. #if MEGDNN_AARCH64
  345. auto matmul_block = matmul_algo->get_inner_block_size();
  346. //! Optimize NCHW44_DOT 3x3s1 8X12X4 im2col+pack fuse
  347. if (matmul_block.m == 8 && matmul_block.n == 12 &&
  348. matmul_block.k == 4 &&
  349. param.filter_meta.spatial[0] == 3 &&
  350. param.filter_meta.spatial[1] == 3 &&
  351. param.filter_meta.stride[0] == 1 &&
  352. param.filter_meta.stride[1] == 1 &&
  353. !param.filter_meta.should_flip) {
  354. MIDOUT_BEGIN(
  355. megdnn_fallback_im2col_factory_make_strategy,
  356. midout_iv(
  357. "DefaultStrategyType::INT8x8x32_8x12x4"_hash)) {
  358. return std::make_unique<
  359. StrategyFuse8x12x4Nchw44Dot<
  360. dt_qint32, dt_qint8,
  361. PostprocessMode::QUANTIZED>>();
  362. }
  363. MIDOUT_END();
  364. return {};
  365. }
  366. #endif
  367. #if MEGDNN_ARMV7
  368. auto matmul_block = matmul_algo->get_inner_block_size();
  369. if (matmul_block.m == 8 && matmul_block.n == 4 &&
  370. matmul_block.k == 4 &&
  371. param.filter_meta.spatial[0] == 3 &&
  372. param.filter_meta.spatial[1] == 3 &&
  373. param.filter_meta.stride[0] == 2 &&
  374. param.filter_meta.stride[1] == 2 &&
  375. !param.filter_meta.should_flip) {
  376. MIDOUT_BEGIN(
  377. megdnn_fallback_im2col_factory_make_strategy,
  378. midout_iv(
  379. "DefaultStrategyType::INT8x8x32_8x4x4_s2"_hash)) {
  380. return std::make_unique<
  381. StrategyFuse8x4x4Nchw44DotK3x3S2<
  382. dt_qint32, dt_qint8,
  383. PostprocessMode::QUANTIZED>>();
  384. }
  385. MIDOUT_END();
  386. return {};
  387. }
  388. #endif
  389. }
  390. cb2(NCHW44, DEFAULT, dtype::QuantizedS8,
  391. dtype::QuantizedS32, dtype::QuantizedS8, dt_int8,
  392. dt_int32, dt_int8, PostprocessMode::QUANTIZED,
  393. "DefaultStrategyTypeNCHW44::QINT8x8x32x8"_hash);
  394. } else {
  395. megdnn_throw(ssprintf("Current only support layout "
  396. "NCHW44/NCHW/NCHW_DOT for im2col "
  397. "algo, but got %d\n",
  398. uint32_t(format)));
  399. }
  400. break;
  401. }
  402. megdnn_throw(ssprintf("Unsupported strategy type %u in default mode",
  403. uint32_t(strategytype)));
  404. }
  405. static std::unique_ptr<StrategyBase> make_nopack_strategy(
  406. fallback::MatrixMulImpl::AlgoBase* matmul_algo,
  407. const fallback::ConvBiasImpl::NCBKernSizeParam& param,
  408. StrategyType strategytype) {
  409. MEGDNN_MARK_USED_VAR(matmul_algo);
  410. switch (strategytype) {
  411. case StrategyType::FLOAT:
  412. cb1(NCHW, NO_PACK, dt_float32, dt_float32,
  413. PostprocessMode::FLOAT, "NoPackStrategyType::FLOAT"_hash);
  414. break;
  415. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  416. #else
  417. #if !MEGDNN_DISABLE_FLOAT16
  418. case StrategyType::FLOAT16_FLOAT16:
  419. cb1(NCHW, NO_PACK, dt_float16, dt_float16,
  420. PostprocessMode::NO_PROCESS,
  421. "NoPackStrategyType::FLOAT16_FLOAT16"_hash);
  422. break;
  423. #endif
  424. #endif
  425. case StrategyType::INT8x8x16:
  426. cb2(NCHW, NO_PACK, dt_int8, dt_int16, dt_int16, dt_int8,
  427. dt_int16, dt_int16, PostprocessMode::NO_PROCESS,
  428. "NoPackStrategyType::INT8x8x16"_hash);
  429. break;
  430. case StrategyType::INT8x8x32:
  431. cb2(NCHW, NO_PACK, dt_int8, dt_int32, dt_int32, dt_int8,
  432. dt_int32, dt_int32, PostprocessMode::NO_PROCESS,
  433. "NoPackStrategyType::INT8x8x32"_hash);
  434. break;
  435. default:
  436. megdnn_throw(
  437. ssprintf("Unsupported strategy type %u in no_pack mode",
  438. uint32_t(strategytype)));
  439. break;
  440. }
  441. megdnn_throw(ssprintf("Unsupported strategy type %u in no_pack mode",
  442. uint32_t(strategytype)));
  443. }
  444. static std::unique_ptr<StrategyBase> make_onlypacka_strategy(
  445. fallback::MatrixMulImpl::AlgoBase* matmul_algo,
  446. const fallback::ConvBiasImpl::NCBKernSizeParam& param,
  447. StrategyType strategytype) {
  448. MEGDNN_MARK_USED_VAR(matmul_algo);
  449. switch (strategytype) {
  450. case StrategyType::FLOAT:
  451. cb1(NCHW, ONLY_PACKA, dt_float32, dt_float32,
  452. PostprocessMode::FLOAT,
  453. "OnlyPackaStrategyType::FLOAT"_hash);
  454. break;
  455. default:
  456. megdnn_throw(ssprintf(
  457. "Unsupported strategy type %u in onlypacka mode",
  458. uint32_t(strategytype)));
  459. break;
  460. }
  461. megdnn_throw(ssprintf("Unsupported strategy type %u in onlypacka mode",
  462. uint32_t(strategytype)));
  463. }
  464. #undef cb1
  465. #undef cb2
  466. static std::unique_ptr<StrategyBase> make_strategy(
  467. fallback::MatrixMulImpl::AlgoBase* matmul_algo,
  468. fallback::MatrixMulImpl::AlgoBase::PackMode packmode,
  469. const fallback::ConvBiasImpl::NCBKernSizeParam& param,
  470. StrategyType stype) {
  471. switch (packmode) {
  472. case MatrixMulImpl::AlgoBase::PackMode::DEFAULT:
  473. return make_default_strategy(matmul_algo, param, stype);
  474. break;
  475. case MatrixMulImpl::AlgoBase::PackMode::ONLY_PACKA:
  476. return make_onlypacka_strategy(matmul_algo, param, stype);
  477. break;
  478. case MatrixMulImpl::AlgoBase::PackMode::NO_PACK:
  479. return make_nopack_strategy(matmul_algo, param, stype);
  480. break;
  481. default:
  482. megdnn_throw(
  483. "not support packmode except default onlypackA "
  484. "nopack");
  485. break;
  486. }
  487. megdnn_throw("factory make Strategy error please check your code");
  488. }
  489. };
  490. template <typename Strategy>
  491. Strategy* StrategyDelegationStorage::get(
  492. fallback::MatrixMulImpl::AlgoBase* matmul_algo,
  493. const fallback::ConvBiasImpl::NCBKernSizeParam& param,
  494. StrategyType stype) {
  495. fallback::MatrixMulImpl::AlgoBase::PackMode packmode =
  496. matmul_algo->packmode();
  497. //! nopack mode block_m block_n block_k is zero
  498. size_t block_m = 0, block_n = 0, block_k = 0;
  499. if (packmode == fallback::MatrixMulImpl::AlgoBase::PackMode::DEFAULT ||
  500. packmode == fallback::MatrixMulImpl::AlgoBase::PackMode::ONLY_PACKA) {
  501. block_m = matmul_algo->get_inner_block_size().m;
  502. block_n = matmul_algo->get_inner_block_size().n;
  503. block_k = matmul_algo->get_inner_block_size().k;
  504. }
  505. StrategyHashParam sparam;
  506. sparam.param = param;
  507. sparam.format = param.filter_meta.format;
  508. sparam.packmode = packmode;
  509. sparam.block_m = block_m;
  510. sparam.block_n = block_n;
  511. sparam.block_k = block_k;
  512. sparam.kernel = param.filter_meta.spatial[0];
  513. sparam.stride = param.filter_meta.stride[0];
  514. sparam.is_square =
  515. param.filter_meta.spatial[0] == param.filter_meta.spatial[1];
  516. sparam.is_xcorr = param.filter_meta.should_flip;
  517. MEGDNN_LOCK_GUARD(m_mtx);
  518. if (map_strategys.find(sparam) == map_strategys.end()) {
  519. auto strategy =
  520. Factory::make_strategy(matmul_algo, packmode, param, stype);
  521. map_strategys[sparam] = std::move(strategy);
  522. }
  523. return static_cast<Strategy*>(map_strategys[sparam].get());
  524. }
  525. } // namespace im2col
  526. } // namespace fallback
  527. } // namespace megdnn
  528. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台