You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

algos.cpp 51 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991
  1. /**
  2. * \file dnn/src/fallback/conv_bias/im2col/algos.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "src/fallback/conv_bias/im2col/algos.h"
  12. #include "megdnn/opr_param_defs.h"
  13. #include "src/common/opr_delegate.h"
  14. #include "src/fallback/conv_bias/common.h"
  15. #include "src/fallback/conv_bias/opr_impl.h"
  16. #include "src/fallback/conv_bias/winograd/strategy.h"
  17. #include "src/fallback/convolution/img2col_helper.h"
  18. #include "src/naive/convolution/helper.h"
  19. #if MEGDNN_X86
  20. #include "src/x86/conv_bias/postprocess_helper.h"
  21. #endif
  22. #include "midout.h"
  23. MIDOUT_DECL(megdnn_fallback_im2col)
  24. using namespace megdnn;
  25. using namespace fallback;
  26. #if MEGDNN_X86
  27. using namespace x86;
  28. #endif
  29. /*======================== AlgoIm2col=======================*/
  30. /*!
  31. * *\brief The index of all parts workspace in im2col workspace bundel
  32. * *Through witch can convenient get the needed ptr
  33. */
  34. struct Im2colBundelIndex {
  35. static constexpr size_t BUNDLE_PADDING_INDEX = 0_z;
  36. static constexpr size_t BUNDLE_PACKA_INDEX = 1_z;
  37. static constexpr size_t BUNDLE_THREAD_INDEX = 2_z;
  38. static constexpr size_t THREAD_BUNDLE_PACKB_INDEX = 0_z;
  39. static constexpr size_t THREAD_BUNDLE_IM2COL_INDEX = 1_z;
  40. static constexpr size_t THREAD_BUNDLE_MATMUL_DST_INDEX = 2_z;
  41. static constexpr size_t THREAD_BUNDLE_BIAS_INDEX = 3_z;
  42. static constexpr size_t THREAD_BUNDLE_COMPUTE_INDEX = 4_z;
  43. };
  44. /*!
  45. * *\brief PtrGetter is get the im2col needed ptr according to the provided
  46. * *conditions
  47. */
  48. class PtrGetter {
  49. public:
  50. template <typename dtype>
  51. static inline dtype* get_matmul_dst_ptr(
  52. const ConvBiasImpl::NCBKernParam& param,
  53. const WorkspaceBundle& bundle_thread, size_t bundle_id,
  54. size_t oc_cur_index, size_t OHW, bool is_dst_8bit,
  55. bool ohw_bigger_ohwblock, size_t batch_id, size_t group_id) {
  56. if (is_dst_8bit || !ohw_bigger_ohwblock) {
  57. return static_cast<dtype*>(bundle_thread.get(bundle_id));
  58. } else {
  59. dtype* dst =
  60. param.dst<dtype>(batch_id, group_id) + oc_cur_index * OHW;
  61. return static_cast<dtype*>(dst);
  62. }
  63. }
  64. template <typename bias_ctype>
  65. static inline bias_ctype* get_bias_temp_ptr(
  66. const ConvBiasImpl::NCBKernParam& param,
  67. const WorkspaceBundle& bundle_thread) {
  68. bias_ctype* bias_tmp_ptr =
  69. param.bias_mode == megdnn::BiasMode::BIAS
  70. ? static_cast<bias_ctype*>(bundle_thread.get(
  71. Im2colBundelIndex::THREAD_BUNDLE_BIAS_INDEX))
  72. : nullptr;
  73. return bias_tmp_ptr;
  74. }
  75. template <typename dtype>
  76. static inline dtype* get_bundle_offset_byte_ptr(
  77. const WorkspaceBundle& bundle, size_t bundle_id, size_t offset) {
  78. return reinterpret_cast<dtype*>(
  79. reinterpret_cast<uintptr_t>(bundle.get(bundle_id)) + offset);
  80. }
  81. };
  82. using Pack_Mode=fallback::MatrixMulImpl::AlgoBase::PackMode;
  83. //! Process one input channel copy padding
  84. template <typename src_ctype>
  85. static void copy_padding_kern(WorkspaceBundle bundle,
  86. const ConvBiasImpl::NCBKernParam& param,
  87. ConvBiasImpl::NCBKernIndex ncb_index) {
  88. UNPACK_CONV_F32_NCB_KERN_SIZES(param);
  89. MEGDNN_MARK_USED_VAR(N);
  90. MEGDNN_MARK_USED_VAR(OC);
  91. MEGDNN_MARK_USED_VAR(OH);
  92. MEGDNN_MARK_USED_VAR(OW);
  93. MEGDNN_MARK_USED_VAR(FH);
  94. MEGDNN_MARK_USED_VAR(FW);
  95. MEGDNN_MARK_USED_VAR(SH);
  96. MEGDNN_MARK_USED_VAR(SW);
  97. size_t IW2 = IW + 2 * PW;
  98. size_t IH2 = IH + 2 * PH;
  99. size_t group_id = ncb_index.ndrange_id[0];
  100. size_t batch_id = ncb_index.ndrange_id[1];
  101. size_t channel_id = ncb_index.ndrange_id[2];
  102. size_t padding_group_size = IH2 * IW2 * IC;
  103. size_t workspace_channel_offset = IH2 * IW2 * channel_id;
  104. size_t workspace_group_offset = group_id * padding_group_size;
  105. size_t workspace_batch_offset =
  106. param.filter_meta.group * batch_id * padding_group_size;
  107. bundle.set(param.workspace_ptr);
  108. src_ctype src_zp = static_cast<src_ctype>(0);
  109. if (param.src_type.enumv() == DTypeEnum::Quantized8Asymm) {
  110. src_zp = param.src_type.param<dtype::Quantized8Asymm>().zero_point;
  111. }
  112. src_ctype* src = const_cast<src_ctype*>(
  113. param.src<src_ctype>(batch_id, group_id, channel_id));
  114. src_ctype* src2;
  115. src2 = static_cast<src_ctype*>(
  116. bundle.get(Im2colBundelIndex::BUNDLE_PADDING_INDEX)) +
  117. workspace_group_offset + workspace_batch_offset +
  118. workspace_channel_offset;
  119. src_ctype* src2_ptr = src2;
  120. const src_ctype* src_ptr = src;
  121. if (PH != 0) {
  122. std::memset(src2_ptr, src_zp, sizeof(src_ctype) * PH * IW2);
  123. src2_ptr += PH * IW2;
  124. }
  125. rep(ih, IH) {
  126. if (PW != 0)
  127. rep(pw, PW) * (src2_ptr++) = src_zp;
  128. std::memcpy(src2_ptr, src_ptr, sizeof(src_ctype) * IW);
  129. src2_ptr += IW;
  130. src_ptr += IW;
  131. if (PW != 0)
  132. rep(pw, PW) * (src2_ptr++) = src_zp;
  133. }
  134. if (PH != 0) {
  135. std::memset(src2_ptr, src_zp, sizeof(src_ctype) * PH * IW2);
  136. src2_ptr += PH * IW2;
  137. }
  138. };
  139. /*!
  140. * *\brief Im2colKerns collects all the im2col kerns in it
  141. */
  142. #define COPY_BIAS() \
  143. const bias_ctype* bias_ptr = static_cast<const bias_ctype*>( \
  144. param.bias<bias_ctype>(batch_id, group_id)); \
  145. bias_ctype* bias_temp_ptr = \
  146. PtrGetter::get_bias_temp_ptr<bias_ctype>(param, bundle_thread); \
  147. if (param.bias_mode == megdnn::BiasMode::BIAS) { \
  148. bias_ctype* copy_dst = bias_temp_ptr; \
  149. const bias_ctype* copy_src = \
  150. bias_ptr + oc_cur_index * OH * OW + ohw_cur_index; \
  151. for (size_t oc = oc_cur_index; oc < oc_end_index; oc++) { \
  152. std::memcpy(copy_dst, copy_src, \
  153. sizeof(bias_ctype) * output_block_size); \
  154. copy_dst += output_block_size; \
  155. copy_src += OH * OW; \
  156. } \
  157. }
  158. #define IM2COL() \
  159. src_ctype* im2col_dst = nullptr; \
  160. src_ctype* no_padding_src = \
  161. const_cast<src_ctype*>(param.src<src_ctype>(batch_id, group_id)) + \
  162. ohw_cur_index; \
  163. if (!special_1x1) { \
  164. size_t padding_group_size = IH2 * IW2 * IC * sizeof(src_ctype); \
  165. src_ctype* src2 = PtrGetter::get_bundle_offset_byte_ptr<src_ctype>( \
  166. bundle, Im2colBundelIndex::BUNDLE_PADDING_INDEX, \
  167. (ncb_index.ndrange_id[0] + \
  168. param.filter_meta.group * ncb_index.ndrange_id[1]) * \
  169. padding_group_size); \
  170. if (PH == 0 && PW == 0) { \
  171. src2 = const_cast<src_ctype*>( \
  172. param.src<src_ctype>(batch_id, group_id)); \
  173. } \
  174. im2col_dst = static_cast<src_ctype*>(bundle_thread.get( \
  175. Im2colBundelIndex::THREAD_BUNDLE_IM2COL_INDEX)); \
  176. if (SH == 1 && SW == 1) { \
  177. if (is_xcorr) { \
  178. img2col<true>(src2, im2col_dst, OC, OH, OW, IC, IH2, IW2, FH, \
  179. FW, ohw_cur_index, output_block_size); \
  180. } else { \
  181. img2col<false>(src2, im2col_dst, OC, OH, OW, IC, IH2, IW2, FH, \
  182. FW, ohw_cur_index, output_block_size); \
  183. } \
  184. } else { \
  185. if (is_xcorr) { \
  186. img2col_stride<true>(src2, im2col_dst, OC, OH, OW, IC, IH2, \
  187. IW2, FH, FW, SH, SW, ohw_cur_index, \
  188. output_block_size); \
  189. } else { \
  190. img2col_stride<false>(src2, im2col_dst, OC, OH, OW, IC, IH2, \
  191. IW2, FH, FW, SH, SW, ohw_cur_index, \
  192. output_block_size); \
  193. } \
  194. } \
  195. }
  196. #define POSTPROCESS_AND_COPYDST() \
  197. PostProcess<op_ctype, op_dtype, postprocess_mode>::run( \
  198. matmul_dst, \
  199. param.bias_mode == megdnn::BiasMode::BIAS \
  200. ? bias_temp_ptr \
  201. : const_cast<bias_ctype*>(bias_ptr + oc_cur_index), \
  202. matmul_dst, param.bias_mode, param.nonlineMode, param.bias_type, \
  203. param.dst_type, 1_z, output_block_oc_size, 1_z, \
  204. output_block_size); \
  205. if (!skip_copy_dst) { \
  206. dst_ctype* dst_tmp_ptr = reinterpret_cast<dst_ctype*>(matmul_dst); \
  207. dst_ctype* dst = param.dst<dst_ctype>(batch_id, group_id) + \
  208. oc_cur_index * OHW + ohw_cur_index; \
  209. for (size_t oc = 0; oc < output_block_oc_size; oc++) { \
  210. std::memcpy(dst, dst_tmp_ptr, \
  211. sizeof(dst_ctype) * output_block_size); \
  212. dst_tmp_ptr += output_block_size; \
  213. dst += OHW; \
  214. } \
  215. }
  216. #define PREPAR_MATMUL_DATA() \
  217. size_t packA_per_oc_block_size = \
  218. round_up(matmul_param.K, matmul_algo->get_inner_block_size().k) * \
  219. oc_tile_size * matmul_algo->get_packA_type_size(); \
  220. size_t packA_group_size = \
  221. matmul_algo->get_bundle(matmul_param).get_size(0); \
  222. src_ctype* a_panel = PtrGetter::get_bundle_offset_byte_ptr<src_ctype>( \
  223. bundle, Im2colBundelIndex::BUNDLE_PACKA_INDEX, \
  224. ncb_index.ndrange_id[0] * packA_group_size + \
  225. ncb_index.ndrange_id[3] * packA_per_oc_block_size); \
  226. src_ctype* b_panel = PtrGetter::get_bundle_offset_byte_ptr<src_ctype>( \
  227. bundle_thread, Im2colBundelIndex::THREAD_BUNDLE_PACKB_INDEX, 0); \
  228. /*In pack mode, the matmul dst and im2col dst is the same workspace*/ \
  229. bias_ctype* matmul_dst = PtrGetter::get_matmul_dst_ptr<bias_ctype>( \
  230. param, bundle_thread, \
  231. Im2colBundelIndex::THREAD_BUNDLE_IM2COL_INDEX, oc_cur_index, OHW, \
  232. is_dst_8bit, is_ohw_size_bigger, batch_id, group_id);
  233. #define MATMUL_COMPUTE() \
  234. auto matmul_kern_naked = matmul_algo->get_kern_naked(matmul_param); \
  235. matmul_param.M = output_block_oc_size; \
  236. matmul_param.N = output_block_size; \
  237. matmul_param.LDB = special_1x1 ? OH * OW : output_block_size; \
  238. matmul_param.LDC = output_block_size; \
  239. matmul_param.A_ptr = a_panel; \
  240. matmul_param.B_ptr = im2col_dst ? im2col_dst : no_padding_src; \
  241. matmul_param.C_ptr = matmul_dst; \
  242. matmul_algo->pack_B(matmul_param, b_panel, 0, output_block_size); \
  243. matmul_kern_naked(matmul_param, a_panel, b_panel);
  244. template <Pack_Mode packmode>
  245. class Im2colKerns;
  246. template <>
  247. class Im2colKerns<Pack_Mode::DEFAULT> {
  248. public:
  249. //! packA kern
  250. template <typename src_ctype>
  251. static void packA_kern(WorkspaceBundle bundle,
  252. const ConvBiasImpl::NCBKernParam& param,
  253. fallback::MatrixMulImpl::KernSizeParam matmulparam,
  254. fallback::MatrixMulImpl::AlgoBase* matmul_algo,
  255. ConvBiasImpl::NCBKernIndex ncb_index) {
  256. bundle.set(param.workspace_ptr);
  257. fallback::MatrixMulImpl::KernParam matmul_param;
  258. size_t group_id = ncb_index.ndrange_id[0];
  259. static_cast<fallback::MatrixMulImpl::KernSizeParam&>(matmul_param) =
  260. matmulparam;
  261. size_t packA_group_size =
  262. matmul_algo->get_bundle(matmul_param).get_size(0);
  263. size_t packed_per_oc_block_size =
  264. round_up(matmul_param.K,
  265. matmul_algo->get_inner_block_size().k) *
  266. matmul_algo->get_inner_block_size().m *
  267. matmul_algo->get_packA_type_size();
  268. size_t a_panel_offset =
  269. ncb_index.ndrange_id[2] * packed_per_oc_block_size;
  270. int8_t* a_panel = static_cast<int8_t*>(bundle.get(
  271. Im2colBundelIndex::BUNDLE_PACKA_INDEX)) +
  272. group_id * packA_group_size + a_panel_offset;
  273. matmul_param.A_ptr =
  274. const_cast<src_ctype*>(param.filter<src_ctype>(group_id));
  275. matmul_algo->pack_A(matmul_param, a_panel, ncb_index.ndrange_id[2],
  276. matmul_algo->get_inner_block_size().m);
  277. };
  278. //! conv kernel
  279. template <typename src_ctype, typename bias_ctype, typename dst_ctype,
  280. typename op_ctype, typename op_dtype,
  281. PostprocessMode postprocess_mode>
  282. static void kerns(
  283. WorkspaceBundle bundle, WorkspaceBundle bundle_thread,
  284. const ConvBiasImpl::NCBKernParam& param,
  285. fallback::MatrixMulImpl::KernSizeParam matmul_kernsize_param,
  286. fallback::MatrixMulImpl::AlgoBase* matmul_algo,
  287. fallback::ConvBiasImpl::NCBKernIndex ncb_index,
  288. size_t ohw_tile_size, size_t oc_tile_size) {
  289. auto is_xcorr = !param.filter_meta.should_flip;
  290. UNPACK_CONV_F32_NCB_KERN_SIZES(param);
  291. MEGDNN_MARK_USED_VAR(N);
  292. auto IH2 = IH + 2 * PH;
  293. auto IW2 = IW + 2 * PW;
  294. size_t OHW = OH * OW;
  295. size_t group_id = ncb_index.ndrange_id[0];
  296. size_t batch_id = ncb_index.ndrange_id[1];
  297. size_t output_block_size = std::min(
  298. ohw_tile_size, OHW - ncb_index.ndrange_id[2] * ohw_tile_size);
  299. size_t output_block_oc_size = std::min(
  300. oc_tile_size, OC - ncb_index.ndrange_id[3] * oc_tile_size);
  301. //! misc flags
  302. bool special_1x1 = (FH == 1 && FW == 1 && SH == 1 && SW == 1 &&
  303. PH == 0 && PW == 0);
  304. bool is_dst_8bit =
  305. (param.src_type.enumv() == DTypeEnum::QuantizedS8 &&
  306. param.dst_type.enumv() == DTypeEnum::QuantizedS8) ||
  307. (param.src_type.enumv() == DTypeEnum::Quantized8Asymm &&
  308. param.dst_type.enumv() == DTypeEnum::Quantized8Asymm);
  309. bool is_ohw_size_bigger = (ohw_tile_size >= OHW);
  310. bool skip_copy_dst = is_ohw_size_bigger && !is_dst_8bit;
  311. //! misc index
  312. size_t ohw_cur_index = ncb_index.ndrange_id[2] * ohw_tile_size;
  313. size_t oc_cur_index = ncb_index.ndrange_id[3] * oc_tile_size;
  314. size_t oc_end_index = oc_cur_index + output_block_oc_size;
  315. bundle.set(param.workspace_ptr);
  316. bundle_thread.set(PtrGetter::get_bundle_offset_byte_ptr<int8_t>(
  317. bundle, Im2colBundelIndex::BUNDLE_THREAD_INDEX,
  318. bundle_thread.total_size_in_bytes() * ncb_index.thread_id));
  319. fallback::MatrixMulImpl::KernParam matmul_param;
  320. static_cast<fallback::MatrixMulImpl::KernSizeParam&>(matmul_param) =
  321. matmul_kernsize_param;
  322. matmul_param.workspace_ptr = bundle_thread.get(
  323. Im2colBundelIndex::THREAD_BUNDLE_COMPUTE_INDEX);
  324. //! 1.Copy bias if need
  325. COPY_BIAS();
  326. //! 2.Im2col
  327. IM2COL();
  328. //! 3.packb and matmul compute
  329. PREPAR_MATMUL_DATA();
  330. MATMUL_COMPUTE();
  331. //! 4.postprocess and copy dst if need
  332. POSTPROCESS_AND_COPYDST();
  333. #undef PREPAR_MATMUL_DATA
  334. #undef MATMUL_COMPUTE
  335. }
  336. };
  337. #define PREPAR_MATMUL_DATA() \
  338. bias_ctype* matmul_dst = nullptr; \
  339. src_ctype* b_panel = nullptr; \
  340. size_t packA_group_size = \
  341. bundle.get_size(Im2colBundelIndex::BUNDLE_PACKA_INDEX) / \
  342. param.filter_meta.group; \
  343. size_t a_panel_offset = ncb_index.ndrange_id[3] * \
  344. matmul_algo->get_bundle(matmul_param).get_size(0); \
  345. \
  346. src_ctype* a_panel = PtrGetter::get_bundle_offset_byte_ptr<src_ctype>( \
  347. bundle, Im2colBundelIndex::BUNDLE_PACKA_INDEX, \
  348. group_id * packA_group_size + a_panel_offset); \
  349. matmul_dst = PtrGetter::get_matmul_dst_ptr<bias_ctype>( \
  350. param, bundle_thread, \
  351. Im2colBundelIndex::THREAD_BUNDLE_MATMUL_DST_INDEX, oc_cur_index, \
  352. OHW, is_dst_8bit, is_ohw_size_bigger, batch_id, group_id);
  353. #define MATMUL_COMPUTE() \
  354. auto matmul_kern_naked = matmul_algo->get_kern_naked(matmul_param); \
  355. matmul_param.M = output_block_oc_size; \
  356. matmul_param.N = output_block_size; \
  357. matmul_param.LDB = special_1x1 ? OH * OW : output_block_size; \
  358. matmul_param.LDC = output_block_size; \
  359. matmul_param.A_ptr = a_panel; \
  360. matmul_param.B_ptr = im2col_dst ? im2col_dst : no_padding_src; \
  361. matmul_param.C_ptr = matmul_dst; \
  362. matmul_kern_naked(matmul_param, a_panel, b_panel);
  363. template <>
  364. class Im2colKerns<Pack_Mode::ONLY_PACKA> {
  365. public:
  366. //! packA kern
  367. template <typename src_ctype>
  368. static void packA_kern(WorkspaceBundle bundle,
  369. const ConvBiasImpl::NCBKernParam& param,
  370. fallback::MatrixMulImpl::KernSizeParam matmulparam,
  371. fallback::MatrixMulImpl::AlgoBase* matmul_algo,
  372. ConvBiasImpl::NCBKernIndex ncb_index) {
  373. bundle.set(param.workspace_ptr);
  374. fallback::MatrixMulImpl::KernParam matmul_param;
  375. static_cast<fallback::MatrixMulImpl::KernSizeParam&>(matmul_param) =
  376. matmulparam;
  377. size_t OC = param.filter_meta.ocpg;
  378. size_t oc_tile_size = matmul_param.M;
  379. size_t group_id = ncb_index.ndrange_id[0];
  380. size_t output_block_oc_size = std::min(
  381. oc_tile_size, OC - ncb_index.ndrange_id[2] * oc_tile_size);
  382. size_t oc_cur_index = ncb_index.ndrange_id[2] * oc_tile_size;
  383. size_t packA_group_size =
  384. bundle.get_size(Im2colBundelIndex::BUNDLE_PACKA_INDEX) /
  385. param.filter_meta.group;
  386. size_t a_panel_offset =
  387. ncb_index.ndrange_id[2] *
  388. matmul_algo->get_bundle(matmul_param).get_size(0);
  389. int8_t* a_panel = static_cast<int8_t*>(bundle.get(
  390. Im2colBundelIndex::BUNDLE_PACKA_INDEX)) +
  391. group_id * packA_group_size + a_panel_offset;
  392. matmul_param.A_ptr =
  393. const_cast<src_ctype*>(param.filter<src_ctype>(group_id)) +
  394. oc_cur_index * matmul_param.K;
  395. matmul_param.M = output_block_oc_size;
  396. matmul_algo->pack_A(matmul_param, a_panel, 0_z, 0_z);
  397. };
  398. //! conv kernel
  399. template <typename src_ctype, typename bias_ctype, typename dst_ctype,
  400. typename op_ctype, typename op_dtype,
  401. PostprocessMode postprocess_mode>
  402. static void kerns(
  403. WorkspaceBundle bundle, WorkspaceBundle bundle_thread,
  404. const ConvBiasImpl::NCBKernParam& param,
  405. fallback::MatrixMulImpl::KernSizeParam matmul_kernsize_param,
  406. fallback::MatrixMulImpl::AlgoBase* matmul_algo,
  407. fallback::ConvBiasImpl::NCBKernIndex ncb_index,
  408. size_t ohw_tile_size, size_t oc_tile_size) {
  409. auto is_xcorr = !param.filter_meta.should_flip;
  410. UNPACK_CONV_F32_NCB_KERN_SIZES(param);
  411. MEGDNN_MARK_USED_VAR(N);
  412. auto IH2 = IH + 2 * PH;
  413. auto IW2 = IW + 2 * PW;
  414. size_t group_id = ncb_index.ndrange_id[0];
  415. size_t batch_id = ncb_index.ndrange_id[1];
  416. size_t OHW = OH * OW;
  417. size_t output_block_size = std::min(
  418. ohw_tile_size, OHW - ncb_index.ndrange_id[2] * ohw_tile_size);
  419. size_t output_block_oc_size = std::min(
  420. oc_tile_size, OC - ncb_index.ndrange_id[3] * oc_tile_size);
  421. //! misc flags
  422. bool special_1x1 = (FH == 1 && FW == 1 && SH == 1 && SW == 1 &&
  423. PH == 0 && PW == 0);
  424. bool is_dst_8bit =
  425. (param.src_type.enumv() == DTypeEnum::QuantizedS8 &&
  426. param.dst_type.enumv() == DTypeEnum::QuantizedS8) ||
  427. (param.src_type.enumv() == DTypeEnum::Quantized8Asymm &&
  428. param.dst_type.enumv() == DTypeEnum::Quantized8Asymm);
  429. bool is_ohw_size_bigger = (ohw_tile_size >= OHW);
  430. bool skip_copy_dst = is_ohw_size_bigger && !is_dst_8bit;
  431. //! misc index
  432. size_t ohw_cur_index = ncb_index.ndrange_id[2] * ohw_tile_size;
  433. size_t oc_cur_index = ncb_index.ndrange_id[3] * oc_tile_size;
  434. size_t oc_end_index = oc_cur_index + output_block_oc_size;
  435. bundle.set(param.workspace_ptr);
  436. bundle_thread.set(PtrGetter::get_bundle_offset_byte_ptr<int8_t>(
  437. bundle, Im2colBundelIndex::BUNDLE_THREAD_INDEX,
  438. bundle_thread.total_size_in_bytes() * ncb_index.thread_id));
  439. fallback::MatrixMulImpl::KernParam matmul_param;
  440. static_cast<fallback::MatrixMulImpl::KernSizeParam&>(matmul_param) =
  441. matmul_kernsize_param;
  442. matmul_param.workspace_ptr = bundle_thread.get(
  443. Im2colBundelIndex::THREAD_BUNDLE_COMPUTE_INDEX);
  444. //! 1.Copy bias if need
  445. COPY_BIAS();
  446. //! 2.Im2col
  447. IM2COL();
  448. //! 3.packb and matmul compute
  449. PREPAR_MATMUL_DATA();
  450. MATMUL_COMPUTE();
  451. //! 4.postprocess and copy dst if need
  452. POSTPROCESS_AND_COPYDST();
  453. #undef PREPAR_MATMUL_DATA
  454. #undef MATMUL_COMPUTE
  455. }
  456. };
  457. #define PREPAR_MATMUL_DATA() \
  458. bias_ctype* matmul_dst = nullptr; \
  459. const src_ctype* filter = \
  460. param.filter<src_ctype>(group_id) + oc_cur_index * IC * FH * FW; \
  461. matmul_dst = PtrGetter::get_matmul_dst_ptr<bias_ctype>( \
  462. param, bundle_thread, \
  463. Im2colBundelIndex::THREAD_BUNDLE_MATMUL_DST_INDEX, oc_cur_index, \
  464. OHW, is_dst_8bit, is_ohw_size_bigger, batch_id, group_id);
  465. #define MATMUL_COMPUTE() \
  466. matmul_param.M = output_block_oc_size; \
  467. matmul_param.N = output_block_size; \
  468. matmul_param.LDB = special_1x1 ? OH * OW : output_block_size; \
  469. matmul_param.LDC = output_block_size; \
  470. matmul_param.A_ptr = filter; \
  471. matmul_param.B_ptr = im2col_dst ? im2col_dst : no_padding_src; \
  472. matmul_param.C_ptr = matmul_dst; \
  473. auto matmul_kern_t = matmul_algo->get_kern(matmul_param); \
  474. matmul_kern_t(matmul_param);
  475. template <>
  476. class Im2colKerns<Pack_Mode::NO_PACK> {
  477. public:
  478. //! conv kernel
  479. template <typename src_ctype, typename bias_ctype, typename dst_ctype,
  480. typename op_ctype, typename op_dtype,
  481. PostprocessMode postprocess_mode>
  482. static void kerns(
  483. WorkspaceBundle bundle, WorkspaceBundle bundle_thread,
  484. const ConvBiasImpl::NCBKernParam& param,
  485. fallback::MatrixMulImpl::KernSizeParam matmul_kernsize_param,
  486. fallback::MatrixMulImpl::AlgoBase* matmul_algo,
  487. fallback::ConvBiasImpl::NCBKernIndex ncb_index,
  488. size_t ohw_tile_size, size_t oc_tile_size) {
  489. auto is_xcorr = !param.filter_meta.should_flip;
  490. UNPACK_CONV_F32_NCB_KERN_SIZES(param);
  491. MEGDNN_MARK_USED_VAR(N);
  492. auto IH2 = IH + 2 * PH;
  493. auto IW2 = IW + 2 * PW;
  494. size_t group_id = ncb_index.ndrange_id[0];
  495. size_t batch_id = ncb_index.ndrange_id[1];
  496. size_t OHW = OH * OW;
  497. size_t output_block_size = std::min(
  498. ohw_tile_size, OHW - ncb_index.ndrange_id[2] * ohw_tile_size);
  499. size_t output_block_oc_size = std::min(
  500. oc_tile_size, OC - ncb_index.ndrange_id[3] * oc_tile_size);
  501. //! misc flags
  502. bool special_1x1 = (FH == 1 && FW == 1 && SH == 1 && SW == 1 &&
  503. PH == 0 && PW == 0);
  504. bool is_dst_8bit =
  505. (param.src_type.enumv() == DTypeEnum::QuantizedS8 &&
  506. param.dst_type.enumv() == DTypeEnum::QuantizedS8) ||
  507. (param.src_type.enumv() == DTypeEnum::Quantized8Asymm &&
  508. param.dst_type.enumv() == DTypeEnum::Quantized8Asymm);
  509. bool is_ohw_size_bigger = (ohw_tile_size >= OHW);
  510. bool skip_copy_dst = is_ohw_size_bigger && !is_dst_8bit;
  511. //! misc index
  512. size_t ohw_cur_index = ncb_index.ndrange_id[2] * ohw_tile_size;
  513. size_t oc_cur_index = ncb_index.ndrange_id[3] * oc_tile_size;
  514. size_t oc_end_index = oc_cur_index + output_block_oc_size;
  515. bundle.set(param.workspace_ptr);
  516. bundle_thread.set(PtrGetter::get_bundle_offset_byte_ptr<int8_t>(
  517. bundle, Im2colBundelIndex::BUNDLE_THREAD_INDEX,
  518. bundle_thread.total_size_in_bytes() * ncb_index.thread_id));
  519. fallback::MatrixMulImpl::KernParam matmul_param;
  520. static_cast<fallback::MatrixMulImpl::KernSizeParam&>(matmul_param) =
  521. matmul_kernsize_param;
  522. matmul_param.workspace_ptr = bundle_thread.get(
  523. Im2colBundelIndex::THREAD_BUNDLE_COMPUTE_INDEX);
  524. //! 1.Copy bias if need
  525. COPY_BIAS();
  526. //! 2.Im2col
  527. IM2COL();
  528. //! 3.packb and matmul compute
  529. PREPAR_MATMUL_DATA();
  530. MATMUL_COMPUTE();
  531. //! 4.postprocess and copy dst if need
  532. POSTPROCESS_AND_COPYDST();
  533. #undef PREPAR_MATMUL_DATA
  534. #undef MATMUL_COMPUTE
  535. }
  536. };
  537. #undef COPY_BIAS
  538. #undef IM2COL
  539. #undef POSTPROCESS_AND_COPYDST
  540. fallback::MatrixMulImpl::KernSizeParam
  541. ConvBiasImpl::AlgoIm2col ::get_matmul_kern_param(const NCBKernSizeParam& param,
  542. size_t ohw_tile_size,
  543. size_t oc_tile_size) const {
  544. size_t M = oc_tile_size;
  545. size_t N = ohw_tile_size;
  546. size_t K = param.filter_meta.icpg * param.filter_meta.spatial[0] *
  547. param.filter_meta.spatial[1];
  548. size_t LDA = K, LDB = N, LDC = N;
  549. bool is_dst_8bit = (param.src_type.enumv() == DTypeEnum::QuantizedS8 &&
  550. param.dst_type.enumv() == DTypeEnum::QuantizedS8) ||
  551. (param.src_type.enumv() == DTypeEnum::Quantized8Asymm &&
  552. param.dst_type.enumv() == DTypeEnum::Quantized8Asymm);
  553. return {param.filter_type,
  554. param.src_type,
  555. is_dst_8bit ? param.bias_type : param.dst_type,
  556. M,
  557. N,
  558. K,
  559. LDA,
  560. LDB,
  561. LDC,
  562. false,
  563. false,
  564. param::MatrixMul::ComputeMode::DEFAULT,
  565. param::MatrixMul::Format::DEFAULT};
  566. }
  567. void ConvBiasImpl::AlgoIm2col::choice_ohw_oc_block(
  568. const NCBKernSizeParam& param, size_t block_m, size_t block_n,
  569. bool need_pack) const {
  570. size_t nr_threads = param.nr_threads;
  571. size_t OC = param.filter_meta.ocpg;
  572. size_t ohw = param.osz[0] * param.osz[1];
  573. //! pay attention please, should not change the 2 line code,
  574. //! the opr use the same im2col algo, via choice_ohw_oc_block may change the
  575. //! m_ohw_tile_size and m_oc_tile_size, if the two value changed, the
  576. //! workspace size may change, will ocur workspace not match problem, so
  577. //! should use the original data init them to avoid the problem
  578. m_oc_tile_size = DEFAULT_OC_TILE_SIZE;
  579. m_ohw_tile_size = m_ohw_tile_origin;
  580. m_oc_tile_size = std::min(m_oc_tile_size, OC);
  581. m_ohw_tile_size = std::min(m_ohw_tile_size, ohw);
  582. if (nr_threads > 1) {
  583. if (ohw / m_ohw_tile_size < nr_threads) {
  584. m_ohw_tile_size = round_up(div_ceil(ohw, nr_threads), block_n);
  585. if (m_ohw_tile_size < DEFAULT_OHW_MIN_TILE_SIZE) {
  586. m_ohw_tile_size = ohw;
  587. m_oc_tile_size = round_up(div_ceil(OC, nr_threads), block_m);
  588. if (m_oc_tile_size > DEFAULT_OC_MAX_TILE_SIZE) {
  589. m_oc_tile_size = DEFAULT_OC_MAX_TILE_SIZE;
  590. } else if (m_oc_tile_size < DEFAULT_OC_MIN_TILE_SIZE) {
  591. m_oc_tile_size = DEFAULT_OC_MIN_TILE_SIZE;
  592. }
  593. }
  594. }
  595. } else {
  596. if (!need_pack) { //! no pack ,usually in x86 save memroy
  597. m_ohw_tile_size = ohw;
  598. m_oc_tile_size = OC;
  599. }
  600. }
  601. }
  602. WorkspaceBundle ConvBiasImpl::AlgoIm2col::get_bundle(
  603. const NCBKernSizeParam& param) const {
  604. UNPACK_CONV_F32_NCB_KERN_SIZES(param);
  605. MEGDNN_MARK_USED_VAR(OC);
  606. MEGDNN_MARK_USED_VAR(OH);
  607. MEGDNN_MARK_USED_VAR(OW);
  608. MEGDNN_MARK_USED_VAR(FH);
  609. MEGDNN_MARK_USED_VAR(FW);
  610. MEGDNN_MARK_USED_VAR(SW);
  611. MEGDNN_MARK_USED_VAR(SH);
  612. auto IW2 = IH + 2 * PH;
  613. auto IH2 = IW + 2 * PW;
  614. bool no_need_pading = (PH == 0 && PW == 0);
  615. size_t padding = 0, packa_size = 0, packa_group_size = 0;
  616. size_t nr_threads = param.nr_threads;
  617. size_t GROUP = param.filter_meta.group;
  618. bool need_pack = m_matmul_algo->packmode() == Pack_Mode::DEFAULT;
  619. bool only_packA = m_matmul_algo->packmode() == Pack_Mode::ONLY_PACKA;
  620. if (need_pack || only_packA) {
  621. auto inner_block = m_matmul_algo->get_inner_block_size();
  622. choice_ohw_oc_block(param, inner_block.m, inner_block.n, need_pack);
  623. auto im2col_kern_param = get_matmul_kern_param(
  624. param, m_ohw_tile_size, only_packA ? m_oc_tile_size : OC);
  625. size_t oc_parallel_times = div_ceil<size_t>(OC, m_oc_tile_size);
  626. WorkspaceBundle wb = m_matmul_algo->get_bundle(im2col_kern_param);
  627. packa_group_size = only_packA ? oc_parallel_times * wb.get_size(0)
  628. : wb.get_size(0);
  629. } else { //! not support pack,not need pack
  630. size_t nopack_default_blockm = 8;
  631. size_t nopack_default_blockn = 16;
  632. choice_ohw_oc_block(param, nopack_default_blockm, nopack_default_blockn,
  633. need_pack);
  634. packa_group_size = 0;
  635. }
  636. if (no_need_pading) {
  637. padding = 0; //! not need padding
  638. } else {
  639. padding = (GROUP * N * IC * IH2 * IW2) *
  640. sizeof(param.src_type); //! for padding
  641. }
  642. packa_size = GROUP * packa_group_size; //! for packA size = GROUP * a_size
  643. WorkspaceBundle ws = get_thread_bundle(param);
  644. return {nullptr,
  645. {padding, packa_size, ws.total_size_in_bytes() * nr_threads}};
  646. }
  647. WorkspaceBundle ConvBiasImpl::AlgoIm2col::get_thread_bundle(
  648. const NCBKernSizeParam& param) const {
  649. size_t IC = param.filter_meta.icpg, FH = param.filter_meta.spatial[0],
  650. FW = param.filter_meta.spatial[1];
  651. size_t ohw = param.osz[0] * param.osz[1];
  652. size_t im2col = 0, packb = 0, matmul_dst = 0, bias_temp = 0,
  653. matmul_compute = 0;
  654. auto im2col_kern_param =
  655. get_matmul_kern_param(param, m_ohw_tile_size, m_oc_tile_size);
  656. bool default_pack = m_matmul_algo->packmode() == Pack_Mode::DEFAULT;
  657. bool only_packA = m_matmul_algo->packmode() == Pack_Mode::ONLY_PACKA;
  658. bool is_dst_8bit = (param.src_type.enumv() == DTypeEnum::QuantizedS8 &&
  659. param.dst_type.enumv() == DTypeEnum::QuantizedS8) ||
  660. (param.src_type.enumv() == DTypeEnum::Quantized8Asymm &&
  661. param.dst_type.enumv() == DTypeEnum::Quantized8Asymm);
  662. size_t im2col_dst_size =
  663. IC * FH * FW * m_ohw_tile_size * sizeof(param.src_type);
  664. size_t matmul_dst_size =
  665. m_oc_tile_size * m_ohw_tile_size * sizeof(param.bias_type);
  666. if (default_pack || only_packA) {
  667. //! matmul_dst and im2col_dst use the same memory
  668. WorkspaceBundle wb = m_matmul_algo->get_bundle(im2col_kern_param);
  669. packb = wb.get_size(1);
  670. im2col = only_packA ? im2col_dst_size
  671. : std::max(im2col_dst_size, matmul_dst_size);
  672. matmul_dst = only_packA ? matmul_dst_size : 0;
  673. } else {
  674. im2col = im2col_dst_size;
  675. if (is_dst_8bit) {
  676. matmul_dst = matmul_dst_size;
  677. } else {
  678. matmul_dst = m_ohw_tile_size >= ohw ? 0 : matmul_dst_size;
  679. }
  680. matmul_compute = m_matmul_algo->get_workspace(im2col_kern_param);
  681. }
  682. if (param.bias_mode == megdnn::BiasMode::BIAS) {
  683. bias_temp = m_oc_tile_size * m_ohw_tile_size * sizeof(param.bias_type);
  684. }
  685. return {nullptr, {packb, im2col, matmul_dst, bias_temp, matmul_compute}};
  686. }
  687. size_t ConvBiasImpl::AlgoIm2col::get_workspace(
  688. ConvBiasImpl*, const NCBKernSizeParam& p) const {
  689. MIDOUT_BEGIN(megdnn_fallback_im2col, 0, 0) {
  690. return get_bundle(p).total_size_in_bytes();
  691. }
  692. MIDOUT_END();
  693. return 0;
  694. }
  695. SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoIm2col::dispatch_kerns(
  696. ConvBiasImpl*, const NCBKernSizeParam& param) const {
  697. MIDOUT_BEGIN(megdnn_fallback_im2col, 0, 1) {
  698. size_t ohw = param.osz[0] * param.osz[1];
  699. size_t ohw_parallel_times = div_ceil(ohw, m_ohw_tile_size);
  700. size_t GROUP = param.filter_meta.group;
  701. size_t IC = param.filter_meta.icpg;
  702. size_t OC = param.filter_meta.ocpg;
  703. size_t PH = param.filter_meta.padding[0];
  704. size_t PW = param.filter_meta.padding[1];
  705. WorkspaceBundle bundle = get_bundle(param);
  706. WorkspaceBundle bundle_thread = get_thread_bundle(param);
  707. size_t oc_parallel_times = div_ceil(OC, m_oc_tile_size);
  708. bool need_padding = (PH != 0 || PW != 0);
  709. bool default_pack = m_matmul_algo->packmode() == Pack_Mode::DEFAULT;
  710. bool no_pack = m_matmul_algo->packmode() == Pack_Mode::NO_PACK;
  711. bool only_packA = m_matmul_algo->packmode() == Pack_Mode::ONLY_PACKA;
  712. size_t packa_parallel_times = 0;
  713. if (only_packA) {
  714. packa_parallel_times = div_ceil(OC, m_oc_tile_size);
  715. } else if (default_pack) {
  716. packa_parallel_times =
  717. div_ceil(OC, m_matmul_algo->get_inner_block_size().m);
  718. }
  719. auto matmul_param = get_matmul_kern_param(
  720. param, m_ohw_tile_size, only_packA ? m_oc_tile_size : OC);
  721. SmallVector<ConvBiasImpl::NCBKern> ret_kern;
  722. #define RETURN_KERNS() \
  723. if (default_pack) { \
  724. ret_kern.push_back( \
  725. {kern_default_packA, {GROUP, 1_z, packa_parallel_times}}); \
  726. } \
  727. if (only_packA) { \
  728. ret_kern.push_back( \
  729. {kern_only_packA, {GROUP, 1_z, packa_parallel_times}}); \
  730. } \
  731. if (need_padding) { \
  732. ret_kern.push_back({kern_padding, {GROUP, param.n, IC}}); \
  733. } \
  734. if (default_pack) { \
  735. ret_kern.push_back( \
  736. {kern_compute_default, \
  737. {GROUP, param.n, ohw_parallel_times, oc_parallel_times}}); \
  738. } \
  739. if (no_pack) { \
  740. ret_kern.push_back( \
  741. {kern_compute_nopack, \
  742. {GROUP, param.n, ohw_parallel_times, oc_parallel_times}}); \
  743. } \
  744. if (only_packA) { \
  745. ret_kern.push_back( \
  746. {kern_compute_onlypackA, \
  747. {GROUP, param.n, ohw_parallel_times, oc_parallel_times}}); \
  748. } \
  749. return ret_kern;
  750. #define COMPUTE_KERN(_name, _pack_mode, _dt, _post_ctype, _postprocess_mode) \
  751. auto kern_compute_##_name = [bundle, bundle_thread, matmul_param, \
  752. matmul_algo = m_matmul_algo, \
  753. ohw_tile_size = m_ohw_tile_size, \
  754. oc_tile_size = m_oc_tile_size]( \
  755. const NCBKernParam& param, \
  756. const NCBKernIndex& ncb_index) { \
  757. Im2colKerns<_pack_mode>::kerns<_dt, _dt, _dt, _post_ctype, \
  758. _post_ctype, _postprocess_mode>( \
  759. bundle, bundle_thread, param, matmul_param, matmul_algo, \
  760. ncb_index, ohw_tile_size, oc_tile_size); \
  761. };
  762. #define cb(_dt, _post_ctype, _postprocess_mode, _midout_tags) \
  763. do { \
  764. if (param.filter_type.enumv() == DTypeTrait<_dt>::enumv) { \
  765. MIDOUT_BEGIN(megdnn_fallback_im2col, 0, 1, _midout_tags) { \
  766. auto kern_padding = [bundle](const NCBKernParam& param, \
  767. const NCBKernIndex& ncb_index) { \
  768. copy_padding_kern<_dt>(bundle, param, ncb_index); \
  769. }; \
  770. auto kern_default_packA = \
  771. [bundle, matmul_algo = m_matmul_algo, matmul_param]( \
  772. const NCBKernParam& param, \
  773. const NCBKernIndex& ncb_index) { \
  774. Im2colKerns<Pack_Mode::DEFAULT>::packA_kern<_dt>( \
  775. bundle, param, matmul_param, matmul_algo, \
  776. ncb_index); \
  777. }; \
  778. auto kern_only_packA = [bundle, matmul_algo = m_matmul_algo, \
  779. matmul_param]( \
  780. const NCBKernParam& param, \
  781. const NCBKernIndex& \
  782. ncb_index) { \
  783. Im2colKerns<Pack_Mode::ONLY_PACKA>::packA_kern<_dt>( \
  784. bundle, param, matmul_param, matmul_algo, \
  785. ncb_index); \
  786. }; \
  787. COMPUTE_KERN(default, Pack_Mode::DEFAULT, _dt, _post_ctype, \
  788. _postprocess_mode); \
  789. COMPUTE_KERN(nopack, Pack_Mode::NO_PACK, _dt, _post_ctype, \
  790. _postprocess_mode); \
  791. COMPUTE_KERN(onlypackA, Pack_Mode::ONLY_PACKA, _dt, \
  792. _post_ctype, _postprocess_mode); \
  793. RETURN_KERNS(); \
  794. } \
  795. MIDOUT_END(); \
  796. return {}; \
  797. } \
  798. } while (0);
  799. cb(dt_float32, dt_float32, PostprocessMode::FLOAT, 0);
  800. #if !MEGDNN_DISABLE_FLOAT16
  801. cb(dt_float16, dt_float16, PostprocessMode::NO_PROCESS, 2);
  802. #endif
  803. #undef cb
  804. #undef COMPUTE_KERN
  805. #define COMPUTE_KERN(_name, _pack_mode, _src_ctype, _bias_ctype, _dst_ctype, \
  806. _i_bias_type, _i_dst_type, _postprocess_mode) \
  807. auto kern_compute_##_name = [bundle, bundle_thread, matmul_param, \
  808. matmul_algo = m_matmul_algo, \
  809. ohw_tile_size = m_ohw_tile_size, \
  810. oc_tile_size = m_oc_tile_size]( \
  811. const NCBKernParam& param, \
  812. const NCBKernIndex& ncb_index) { \
  813. Im2colKerns<_pack_mode>::kerns<_src_ctype, _bias_ctype, _dst_ctype, \
  814. DTypeTrait<_i_bias_type>::ctype, \
  815. DTypeTrait<_i_dst_type>::ctype, \
  816. _postprocess_mode>( \
  817. bundle, bundle_thread, param, matmul_param, matmul_algo, \
  818. ncb_index, ohw_tile_size, oc_tile_size); \
  819. };
  820. #define cb(_i_src_type, _i_bias_type, _i_dst_type, _src_ctype, _bias_ctype, \
  821. _dst_ctype, _postprocess_mode, _midout_tags) \
  822. do { \
  823. if (param.filter_type.enumv() == param.src_type.enumv() && \
  824. param.src_type.enumv() == DTypeTrait<_i_src_type>::enumv && \
  825. param.dst_type.enumv() == DTypeTrait<_i_dst_type>::enumv) { \
  826. MIDOUT_BEGIN(megdnn_fallback_im2col, 0, 1, _midout_tags) { \
  827. auto kern_padding = [bundle](const NCBKernParam& param, \
  828. const NCBKernIndex& ncb_index) { \
  829. copy_padding_kern<_src_ctype>(bundle, param, ncb_index); \
  830. }; \
  831. auto kern_default_packA = [bundle, \
  832. matmul_algo = m_matmul_algo, \
  833. matmul_param]( \
  834. const NCBKernParam& param, \
  835. const NCBKernIndex& \
  836. ncb_index) { \
  837. Im2colKerns<Pack_Mode::DEFAULT>::packA_kern<_src_ctype>( \
  838. bundle, param, matmul_param, matmul_algo, \
  839. ncb_index); \
  840. }; \
  841. auto kern_only_packA = \
  842. [bundle, matmul_algo = m_matmul_algo, matmul_param]( \
  843. const NCBKernParam& param, \
  844. const NCBKernIndex& ncb_index) { \
  845. Im2colKerns<Pack_Mode::ONLY_PACKA>::packA_kern< \
  846. _src_ctype>(bundle, param, matmul_param, \
  847. matmul_algo, ncb_index); \
  848. }; \
  849. COMPUTE_KERN(default, Pack_Mode::DEFAULT, _src_ctype, \
  850. _bias_ctype, _dst_ctype, _i_bias_type, \
  851. _i_dst_type, _postprocess_mode); \
  852. COMPUTE_KERN(nopack, Pack_Mode::NO_PACK, _src_ctype, \
  853. _bias_ctype, _dst_ctype, _i_bias_type, \
  854. _i_dst_type, _postprocess_mode); \
  855. COMPUTE_KERN(onlypackA, Pack_Mode::ONLY_PACKA, _src_ctype, \
  856. _bias_ctype, _dst_ctype, _i_bias_type, \
  857. _i_dst_type, _postprocess_mode); \
  858. RETURN_KERNS(); \
  859. } \
  860. MIDOUT_END(); \
  861. return {}; \
  862. } \
  863. } while (0);
  864. cb(dt_int8, dt_int32, dt_int32, dt_int8, dt_int32, dt_int32,
  865. PostprocessMode::NO_PROCESS, 3);
  866. cb(dt_int8, dt_int16, dt_int16, dt_int8, dt_int16, dt_int16,
  867. PostprocessMode::NO_PROCESS, 4);
  868. cb(dtype::QuantizedS8, dtype::QuantizedS32, dtype::QuantizedS32,
  869. dt_int8, dt_int32, dt_int32, PostprocessMode::NO_PROCESS, 7);
  870. cb(dtype::QuantizedS8, dtype::QuantizedS32, dtype::QuantizedS8, dt_int8,
  871. dt_int32, dt_int8, PostprocessMode::QUANTIZED, 8);
  872. #undef COMPUTE_KERN
  873. #undef RETURN_KERNS
  874. #undef cb
  875. megdnn_throw("unsupported data type on im2col matmul algo");
  876. }
  877. MIDOUT_END();
  878. return {};
  879. }
  880. bool ConvBiasImpl::AlgoIm2col::usable(
  881. ConvBiasImpl* opr, const NCBKernSizeParam& param,
  882. AlgoSelectionStrategy /*algo_selection_strategy*/) const {
  883. MIDOUT_BEGIN(megdnn_fallback_im2col, 0, 2) {
  884. //! make sure 8x8x16 and 8x8x32 biasmode is nobias and nonlineMode is
  885. //! identity otherwise return false mean that 8x8x32 and 8x8x16 not support
  886. //! PostProcess
  887. if (param.src_type.enumv() == param.filter_type.enumv() &&
  888. ((param.src_type.enumv() == DTypeEnum::Int8 &&
  889. (param.dst_type.enumv() == DTypeEnum::Int16 ||
  890. param.dst_type.enumv() == DTypeEnum::Int32)) ||
  891. ((param.src_type.enumv() == DTypeEnum::QuantizedS8 ||
  892. param.src_type.enumv() == DTypeEnum::Quantized8Asymm) &&
  893. param.dst_type.enumv() == DTypeEnum::QuantizedS32)) &&
  894. param.bias_mode != megdnn::BiasMode::NO_BIAS &&
  895. param.nonlineMode != megdnn::NonlineMode::IDENTITY) {
  896. return false;
  897. }
  898. fallback::MatrixMulImpl::KernSizeParam matmul_param =
  899. get_matmul_kern_param(param, m_ohw_tile_size, m_oc_tile_size);
  900. bool matmulusable = m_matmul_algo->usable(matmul_param);
  901. return matmulusable &&
  902. (opr->param().format == param::ConvBias::Format::NCHW) &&
  903. (param.filter_meta.spatial[0] == param.filter_meta.spatial[1] &&
  904. (param.filter_meta.spatial[0] <= 7)) &&
  905. (param.filter_meta.dilation[0] ==
  906. param.filter_meta.dilation[1] &&
  907. param.filter_meta.dilation[0] == 1) &&
  908. param.compute_mode == param::ConvBias::ComputeMode::DEFAULT;
  909. }
  910. MIDOUT_END();
  911. return false;
  912. }
  913. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台

Contributors (1)