You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

winograd.h 39 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773
  1. /**
  2. * \file dnn/src/fallback/conv_bias/winograd/winograd.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #pragma once
  13. #include <cstddef>
  14. #include "include/megdnn/basic_types.h"
  15. #include "include/megdnn/dtype.h"
  16. #include "include/megdnn/thin/small_vector.h"
  17. #include "src/fallback/conv_bias/opr_impl.h"
  18. #include "src/fallback/matrix_mul/opr_impl.h"
  19. #include "midout.h"
  20. MIDOUT_DECL(megdnn_fallback_conv_bias_winograd_common)
  21. namespace megdnn {
  22. namespace winograd {
  23. /**
  24. * \brief Winograd convolution
  25. *
  26. * The algo is refer to https://arxiv.org/abs/1509.09308.
  27. *
  28. * Format: DEFAULT
  29. * filter: (OC, IC, FH, FW) -> (alpha, alpha, IC, OC)
  30. * src: (N, C, H, W) -> (N, NR_TILES, alpha, alpha, TILE_SIZE, IC)
  31. *
  32. * We will perform gemm on:
  33. * (TILE_SIZE, IC) x (IC, OC) -> (TILE_SIZE, OC)
  34. *
  35. * Format: MK4
  36. * filter: (OC, IC, FH, FW) -> (alpha, alpha, OCB, ICB, IC_BLOCK_SIZE,
  37. * OC_BLOCK_SIZE)
  38. * src: (N, C, H, W) -> (N, NR_TILES, alpha, alpha, ICB, TILE_SIZE,
  39. * IC_BLOCK_SIZE)
  40. *
  41. * We will perform gemm on:
  42. * (OCB, ICB, IC_BLOCK_SIZE, OC_BLOCK_SIZE) x (ICB, TILE_SIZE, IC_BLOCK_SIZE)
  43. * = (OCB, TILE_SIZE, OC_BLOCK_SIZE)
  44. */
  45. //! The default oc size of one thread in multi-threads mode
  46. constexpr static size_t UNIT_OC_SIZE_DEFAULT = 1024;
  47. template <typename Strategy,
  48. param::MatrixMul::Format format = param::MatrixMul::Format::DEFAULT>
  49. class ConvBias {
  50. using output_compute_type = typename Strategy::output_compute_type;
  51. using input_filter_compute_type =
  52. typename Strategy::input_filter_compute_type;
  53. using stype = typename Strategy::stype;
  54. using dst_type = typename Strategy::dst_type;
  55. using NCBKernSizeParam = fallback::ConvBiasImpl::NCBKernSizeParam;
  56. using NCBKernParam = fallback::ConvBiasImpl::NCBKernParam;
  57. using NCBKernIndex = fallback::ConvBiasImpl::NCBKernIndex;
  58. using NCBKern = fallback::ConvBiasImpl::NCBKern;
  59. static_assert(
  60. format == param::MatrixMul::Format::DEFAULT ||
  61. (format == param::MatrixMul::Format::MK4 &&
  62. Strategy::IC_BLOCK_SIZE == 4 &&
  63. Strategy::OC_BLOCK_SIZE == 4) ||
  64. (format == param::MatrixMul::Format::MK8 &&
  65. Strategy::IC_BLOCK_SIZE == 8 &&
  66. Strategy::OC_BLOCK_SIZE == 8),
  67. "format should be default, mk4 and mk8, if mk4 IC_BLOCK_SIZE and "
  68. "OC_BLOCK_SIZE should be 4, if mk8 IC_BLOCK_SIZE and "
  69. "OC_BLOCK_SIZE should be 8");
  70. Strategy m_strategy;
  71. size_t m_unit_tile_size;
  72. //! m_unit_oc_size is must be times of Strategy::OC_BLOCK_SIZE
  73. size_t m_unit_oc_size;
  74. WorkspaceBundle get_wbundle(
  75. const NCBKernSizeParam& param,
  76. fallback::MatrixMulImpl::AlgoBase* matmul_algo) const {
  77. size_t OC = param.filter_meta.ocpg;
  78. size_t IC = param.filter_meta.icpg;
  79. size_t GROUP = param.filter_meta.group;
  80. size_t nr_threads = param.nr_threads;
  81. size_t filter_transform_buf_size = 0;
  82. //! filter : (alpha, alpha, IC, OC) or (OCB, ICB, IC_BLOCK_SIZE,
  83. //! OC_BLOCK_SIZE)
  84. if (param.preprocessed_filter == nullptr &&
  85. param.filter_meta.format !=
  86. param::ConvBias::Format::NCHW_WINOGRAD &&
  87. param.filter_meta.format !=
  88. param::ConvBias::Format::NCHW88_WINOGRAD &&
  89. param.filter_meta.format !=
  90. param::ConvBias::Format::NCHW44_WINOGRAD) {
  91. filter_transform_buf_size = Strategy::ALPHA * Strategy::ALPHA * OC *
  92. IC * sizeof(input_filter_compute_type);
  93. }
  94. size_t winograd_comput_size =
  95. get_wbundle_compute(param, matmul_algo).total_size_in_bytes() *
  96. nr_threads;
  97. if (param.filter_meta.format == param::ConvBias::Format::NCHW ||
  98. param.filter_meta.format == param::ConvBias::Format::NCHW88 ||
  99. param.filter_meta.format == param::ConvBias::Format::NCHW44) {
  100. return WorkspaceBundle(
  101. nullptr,
  102. {winograd_comput_size, filter_transform_buf_size * GROUP});
  103. } else {
  104. megdnn_assert(param.filter_meta.format ==
  105. param::ConvBias::Format::NCHW_WINOGRAD ||
  106. param.filter_meta.format ==
  107. param::ConvBias::Format::NCHW88_WINOGRAD ||
  108. param.filter_meta.format ==
  109. param::ConvBias::Format::NCHW44_WINOGRAD);
  110. return WorkspaceBundle(nullptr, {winograd_comput_size});
  111. }
  112. }
  113. WorkspaceBundle get_wbundle_compute(
  114. const NCBKernSizeParam& param,
  115. fallback::MatrixMulImpl::AlgoBase* matmul_algo) const {
  116. size_t OC = param.filter_meta.ocpg;
  117. size_t IC = param.filter_meta.icpg;
  118. size_t oc_size = std::min(OC, m_unit_oc_size);
  119. //! input : (alpha, alpha, unit_tile_size, IC) or (alpha, alpha,
  120. //! ICB, unit_tile_size, IC_BLOCK_SIZE)
  121. size_t input_transform_buf_size = Strategy::ALPHA * Strategy::ALPHA *
  122. IC * m_unit_tile_size *
  123. sizeof(input_filter_compute_type);
  124. //! output : (alpha, alpha, unit_tile_size, OC) or
  125. //! (alpha, alpha, OCB, unit_tile_size, OC_BLOCK_SIZE)
  126. size_t output_transform_buf_size = Strategy::ALPHA * Strategy::ALPHA *
  127. oc_size * m_unit_tile_size *
  128. sizeof(output_compute_type);
  129. //! use for inner temporary usage
  130. size_t transform_mid_buf_size =
  131. 2 * Strategy::ALPHA * Strategy::ALPHA *
  132. sizeof(output_compute_type) *
  133. std::max(Strategy::IC_BLOCK_SIZE, Strategy::OC_BLOCK_SIZE);
  134. size_t matmul_workspace_size = matmul_algo->get_workspace(
  135. get_matmul_kern_param(param, m_unit_oc_size));
  136. //! compute workspace is independent and separated as far as possible
  137. //! in case of false cache line sharing
  138. return WorkspaceBundle(
  139. nullptr, {input_transform_buf_size, output_transform_buf_size,
  140. transform_mid_buf_size, matmul_workspace_size});
  141. }
  142. WorkspaceBundle get_preprocess_wbundle(
  143. const NCBKernSizeParam& param) const {
  144. //! use for inner temporary usage
  145. size_t transform_mid_buf_size =
  146. 2 * Strategy::ALPHA * Strategy::ALPHA *
  147. sizeof(output_compute_type) *
  148. std::max(Strategy::IC_BLOCK_SIZE, Strategy::OC_BLOCK_SIZE);
  149. size_t nr_threads = param.nr_threads;
  150. SmallVector<size_t> space_vec(nr_threads, transform_mid_buf_size);
  151. return WorkspaceBundle{nullptr, space_vec};
  152. }
  153. public:
  154. //! Get the m_unit_oc_size, according to the nr_threads and
  155. //! output_featuremap_size. When single thread the m_unit_oc_size is set
  156. //! 2048 heuristicly, When multi-threads, the m_unit_oc_size is set
  157. //! according to nr_threads and out_featuremap_size
  158. ConvBias(const Strategy& strategy, size_t unit_tile_size,
  159. const NCBKernSizeParam& param)
  160. : m_strategy{strategy}, m_unit_tile_size{unit_tile_size} {
  161. size_t nr_threads = param.nr_threads;
  162. size_t OC = param.filter_meta.ocpg;
  163. size_t OH = param.osz[0];
  164. size_t OW = param.osz[1];
  165. if (nr_threads > 1) {
  166. size_t units_h = div_ceil<size_t>(OH, Strategy::OUTPUT_BLOCK_SIZE);
  167. size_t units_w = div_ceil<size_t>(OW, Strategy::OUTPUT_BLOCK_SIZE);
  168. size_t nr_units = units_h * units_w;
  169. size_t nr_parallism_unit =
  170. div_ceil<size_t>(nr_units, unit_tile_size);
  171. if (nr_parallism_unit < nr_threads) {
  172. m_unit_oc_size = div_ceil<size_t>(OC, nr_threads);
  173. if (format == param::MatrixMul::Format::MK8) {
  174. m_unit_oc_size = round_up<size_t>(m_unit_oc_size, 8);
  175. } else {
  176. m_unit_oc_size = round_up<size_t>(m_unit_oc_size, 4);
  177. }
  178. } else {
  179. m_unit_oc_size = UNIT_OC_SIZE_DEFAULT;
  180. }
  181. } else {
  182. m_unit_oc_size = UNIT_OC_SIZE_DEFAULT;
  183. }
  184. }
  185. ConvBias(const Strategy& strategy, size_t unit_tile_size)
  186. : m_strategy{strategy}, m_unit_tile_size{unit_tile_size} {
  187. m_unit_oc_size = UNIT_OC_SIZE_DEFAULT;
  188. }
  189. size_t get_workspace_size(
  190. const NCBKernSizeParam& param,
  191. fallback::MatrixMulImpl::AlgoBase* matmul_algo) const {
  192. return get_wbundle(param, matmul_algo).total_size_in_bytes();
  193. }
  194. size_t get_preprocess_workspace_size(
  195. const NCBKernSizeParam& param,
  196. fallback::MatrixMulImpl::AlgoBase*) const {
  197. return get_preprocess_wbundle(param).total_size_in_bytes();
  198. }
  199. SmallVector<TensorLayout> deduce_preprocessed_filter_layout(
  200. const NCBKernSizeParam& param, fallback::MatrixMulImpl::AlgoBase*) {
  201. size_t OC = param.filter_meta.ocpg;
  202. size_t IC = param.filter_meta.icpg;
  203. size_t GROUP = param.filter_meta.group;
  204. SmallVector<TensorLayout> preprocessed_layouts;
  205. DType dtype = m_strategy.filter_dtype;
  206. if (dtype.category() == DTypeCategory::QUANTIZED) {
  207. if (format == param::MatrixMul::Format::MK4) {
  208. dtype = dtype::Float32();
  209. } else if (format == param::MatrixMul::Format::MK8) {
  210. dtype = dtype::Int16();
  211. }
  212. }
  213. if (format == param::MatrixMul::Format::DEFAULT) {
  214. preprocessed_layouts.push_back(
  215. {{GROUP, Strategy::ALPHA, Strategy::ALPHA, OC, IC}, dtype});
  216. } else if (format == param::MatrixMul::Format::MK4) {
  217. preprocessed_layouts.push_back(
  218. {{GROUP, Strategy::ALPHA, Strategy::ALPHA, OC / 4, IC / 4,
  219. 4, 4},
  220. dtype});
  221. } else {
  222. megdnn_assert(format == param::MatrixMul::Format::MK8);
  223. preprocessed_layouts.push_back(
  224. {{GROUP, Strategy::ALPHA, Strategy::ALPHA, OC / 8, IC / 8,
  225. 8, 8},
  226. dtype});
  227. }
  228. return preprocessed_layouts;
  229. }
  230. //! Used by winograd_filter_preprocess opr
  231. void filter_process(const stype* filter_ptr,
  232. input_filter_compute_type* filter_transform_buf,
  233. void* transform_mid_buf, size_t OC, size_t IC) {
  234. m_strategy.filter(
  235. filter_ptr, filter_transform_buf,
  236. static_cast<input_filter_compute_type*>(transform_mid_buf), OC,
  237. IC, 0, OC);
  238. }
  239. static void filter_process(Strategy strategy,
  240. const WorkspaceBundle& bundle_top,
  241. const WorkspaceBundle& bundle_compute,
  242. const NCBKernParam& kern_param,
  243. const NCBKernIndex& ncb_index) {
  244. size_t compute_workspace_size_per_thread =
  245. bundle_compute.total_size_in_bytes();
  246. size_t thread_id = ncb_index.thread_id;
  247. size_t oc_id = ncb_index.ndrange_id[2];
  248. size_t group_id = ncb_index.ndrange_id[0];
  249. size_t OC = kern_param.filter_meta.ocpg;
  250. size_t IC = kern_param.filter_meta.icpg;
  251. size_t filter_group_size = Strategy::ALPHA * Strategy::ALPHA * OC * IC *
  252. sizeof(input_filter_compute_type);
  253. //! Filter trans dst ptr
  254. input_filter_compute_type* filter_transform_buf =
  255. reinterpret_cast<input_filter_compute_type*>(
  256. reinterpret_cast<uintptr_t>(bundle_top.get(1)) +
  257. group_id * filter_group_size);
  258. //! Filter trans src ptr
  259. input_filter_compute_type* transform_mid_buf =
  260. reinterpret_cast<input_filter_compute_type*>(
  261. reinterpret_cast<uintptr_t>(bundle_compute.get(2)) +
  262. compute_workspace_size_per_thread * thread_id);
  263. const stype* filter_ptr = kern_param.filter<stype>(group_id);
  264. size_t oc_start = oc_id, oc_end = oc_id + 1;
  265. if (kern_param.filter_meta.format == param::ConvBias::Format::NCHW88) {
  266. oc_start = 8 * oc_id;
  267. oc_end = oc_start + 8;
  268. } else if (kern_param.filter_meta.format ==
  269. param::ConvBias::Format::NCHW44) {
  270. oc_start = 4 * oc_id;
  271. oc_end = oc_start + 4;
  272. }
  273. strategy.filter(filter_ptr, filter_transform_buf, transform_mid_buf, OC,
  274. IC, oc_start, oc_end);
  275. }
  276. static void filter_preprocess(Strategy strategy,
  277. const WorkspaceBundle& bundle,
  278. const TensorND& preprocessed_tensor,
  279. const NCBKernParam& kern_param,
  280. const NCBKernIndex& ncb_index) {
  281. size_t thread_id = ncb_index.thread_id;
  282. size_t oc_id = ncb_index.ndrange_id[1];
  283. size_t group_id = ncb_index.ndrange_id[0];
  284. size_t OC = kern_param.filter_meta.ocpg;
  285. size_t IC = kern_param.filter_meta.icpg;
  286. size_t filter_group_size = Strategy::ALPHA * Strategy::ALPHA * OC * IC *
  287. sizeof(input_filter_compute_type);
  288. //! Filter trans dst ptr
  289. input_filter_compute_type* filter_transform_buf =
  290. reinterpret_cast<input_filter_compute_type*>(
  291. reinterpret_cast<uintptr_t>(
  292. preprocessed_tensor.raw_ptr) +
  293. group_id * filter_group_size);
  294. //! Filter trans src ptr
  295. input_filter_compute_type* transform_mid_buf =
  296. reinterpret_cast<input_filter_compute_type*>(
  297. reinterpret_cast<uintptr_t>(bundle.get(thread_id)));
  298. const stype* filter_ptr = kern_param.filter<stype>(group_id);
  299. size_t oc_start, oc_end;
  300. if (kern_param.filter_meta.format == param::ConvBias::Format::NCHW88) {
  301. oc_start = 8 * oc_id;
  302. oc_end = oc_start + 8;
  303. } else if (kern_param.filter_meta.format ==
  304. param::ConvBias::Format::NCHW44) {
  305. oc_start = 4 * oc_id;
  306. oc_end = oc_start + 4;
  307. } else {
  308. oc_start = oc_id;
  309. oc_end = oc_id + 1;
  310. }
  311. strategy.filter(filter_ptr, filter_transform_buf, transform_mid_buf, OC,
  312. IC, oc_start, oc_end);
  313. }
  314. static void winograd_compute(
  315. Strategy strategy, const WorkspaceBundle& bundle_top,
  316. const WorkspaceBundle& bundle_compute,
  317. fallback::MatrixMulImpl::AlgoBase* matmul_algo,
  318. fallback::MatrixMulImpl::KernParam matmul_param,
  319. size_t unit_tile_size, size_t unit_oc_size,
  320. const NCBKernParam& ncb_param, const NCBKernIndex& ncb_index) {
  321. size_t OC = ncb_param.filter_meta.ocpg;
  322. size_t IC = ncb_param.filter_meta.icpg;
  323. size_t IH = ncb_param.isz[0];
  324. size_t IW = ncb_param.isz[1];
  325. size_t OH = ncb_param.osz[0];
  326. size_t OW = ncb_param.osz[1];
  327. size_t PH = ncb_param.filter_meta.padding[0];
  328. size_t PW = ncb_param.filter_meta.padding[1];
  329. size_t filter_group_size = Strategy::ALPHA * Strategy::ALPHA * OC * IC *
  330. sizeof(input_filter_compute_type);
  331. size_t compute_workspace_size_per_thread =
  332. bundle_compute.total_size_in_bytes();
  333. size_t units_h = div_ceil<size_t>(OH, Strategy::OUTPUT_BLOCK_SIZE);
  334. size_t units_w = div_ceil<size_t>(OW, Strategy::OUTPUT_BLOCK_SIZE);
  335. size_t nr_units = units_h * units_w;
  336. size_t oc_block_id = ncb_index.ndrange_id[3];
  337. size_t tile_id = ncb_index.ndrange_id[2];
  338. size_t batch_id = ncb_index.ndrange_id[1];
  339. size_t group_id = ncb_index.ndrange_id[0];
  340. size_t thread_id = ncb_index.thread_id;
  341. const stype* src_ptr = ncb_param.src<stype>(batch_id, group_id);
  342. dst_type* dst_ptr = ncb_param.dst<dst_type>(batch_id, group_id);
  343. const output_compute_type* bias_ptr =
  344. static_cast<const output_compute_type*>(
  345. ncb_param.bias<output_compute_type>(batch_id,
  346. group_id));
  347. input_filter_compute_type* input_transform_buf =
  348. reinterpret_cast<input_filter_compute_type*>(
  349. reinterpret_cast<uintptr_t>(bundle_compute.get(0)) +
  350. compute_workspace_size_per_thread * thread_id);
  351. output_compute_type* output_transform_buf =
  352. reinterpret_cast<output_compute_type*>(
  353. reinterpret_cast<uintptr_t>(bundle_compute.get(1)) +
  354. compute_workspace_size_per_thread * thread_id);
  355. input_filter_compute_type* transform_mid_buf =
  356. reinterpret_cast<input_filter_compute_type*>(
  357. reinterpret_cast<uintptr_t>(bundle_compute.get(2)) +
  358. compute_workspace_size_per_thread * thread_id);
  359. //! NCHW88_WINOGRAD and NCHW_WINOGRAD is the same offset
  360. const input_filter_compute_type* filter_transform_buf = nullptr;
  361. if (nullptr != ncb_param.preprocessed_filter) {
  362. auto preprocess_raw_ptr =
  363. ncb_param.preprocessed_filter->tensors[0].raw_ptr;
  364. filter_transform_buf = reinterpret_cast<input_filter_compute_type*>(
  365. reinterpret_cast<uintptr_t>(preprocess_raw_ptr) +
  366. group_id * filter_group_size);
  367. } else {
  368. filter_transform_buf =
  369. static_cast<const input_filter_compute_type*>(
  370. ncb_param.filter<input_filter_compute_type>(
  371. group_id));
  372. if (ncb_param.filter_meta.format == param::ConvBias::Format::NCHW ||
  373. ncb_param.filter_meta.format ==
  374. param::ConvBias::Format::NCHW88 ||
  375. ncb_param.filter_meta.format ==
  376. param::ConvBias::Format::NCHW44) {
  377. filter_transform_buf =
  378. reinterpret_cast<input_filter_compute_type*>(
  379. reinterpret_cast<uintptr_t>(bundle_top.get(1)) +
  380. group_id * filter_group_size);
  381. }
  382. }
  383. //! prepare matmul param
  384. matmul_param.workspace_ptr = reinterpret_cast<void*>(
  385. reinterpret_cast<uintptr_t>(bundle_compute.get(3)) +
  386. compute_workspace_size_per_thread * thread_id);
  387. matmul_param.workspace_size = bundle_compute.get_size(3);
  388. fallback::MatrixMulImpl::kern_t matmul_kern =
  389. matmul_algo->get_kern(matmul_param);
  390. size_t unit_start_idx = tile_id * unit_tile_size;
  391. size_t nr_tiles_in_unit =
  392. std::min(nr_units - unit_start_idx, unit_tile_size);
  393. size_t oc_start_idx = oc_block_id * unit_oc_size;
  394. size_t nr_oc_in_unit = std::min(OC - oc_start_idx, unit_oc_size);
  395. megdnn_assert(nr_oc_in_unit % Strategy::OC_BLOCK_SIZE == 0,
  396. "The winograd remain oc is not times of OC_BLOCK_SIZE");
  397. if (format == param::MatrixMul::Format::MK4 ||
  398. format == param::MatrixMul::Format::MK8) {
  399. #if !MEGDNN_X86
  400. nr_tiles_in_unit = round_up<size_t>(nr_tiles_in_unit, 4);
  401. #endif
  402. megdnn_assert(nr_tiles_in_unit <= unit_tile_size,
  403. "nr_tiles_in_unit: %zu TILE_SIZE:%zu",
  404. nr_tiles_in_unit, unit_tile_size);
  405. }
  406. //! BTdB
  407. strategy.input(src_ptr, input_transform_buf, transform_mid_buf,
  408. IH, IW, IC, PH, PW, unit_start_idx, nr_tiles_in_unit);
  409. rep(i, Strategy::ALPHA) rep(j, Strategy::ALPHA) {
  410. if (format == param::MatrixMul::Format::DEFAULT) {
  411. matmul_param.A_ptr =
  412. input_transform_buf +
  413. (i * Strategy::ALPHA + j) * nr_tiles_in_unit * IC;
  414. matmul_param.B_ptr = filter_transform_buf +
  415. (i * Strategy::ALPHA + j) * OC * IC +
  416. oc_start_idx;
  417. matmul_param.C_ptr = output_transform_buf +
  418. (i * Strategy::ALPHA + j) *
  419. nr_tiles_in_unit * nr_oc_in_unit;
  420. matmul_param.M = nr_tiles_in_unit;
  421. matmul_param.N = nr_oc_in_unit;
  422. matmul_param.LDB = OC;
  423. matmul_param.LDC = nr_oc_in_unit;
  424. } else {
  425. matmul_param.A_ptr = filter_transform_buf +
  426. (i * Strategy::ALPHA + j) * OC * IC +
  427. oc_start_idx * IC;
  428. matmul_param.B_ptr =
  429. input_transform_buf +
  430. (i * Strategy::ALPHA + j) * nr_tiles_in_unit * IC;
  431. matmul_param.C_ptr = output_transform_buf +
  432. (i * Strategy::ALPHA + j) *
  433. nr_tiles_in_unit * nr_oc_in_unit;
  434. matmul_param.N = nr_tiles_in_unit;
  435. matmul_param.M = nr_oc_in_unit;
  436. matmul_param.LDB = matmul_param.N * Strategy::IC_BLOCK_SIZE;
  437. matmul_param.LDC = matmul_param.N * Strategy::IC_BLOCK_SIZE;
  438. }
  439. matmul_kern(matmul_param);
  440. }
  441. //! Y = ATmA
  442. size_t oc_end_idx = oc_start_idx + nr_oc_in_unit;
  443. strategy.output(
  444. output_transform_buf, bias_ptr, dst_ptr,
  445. reinterpret_cast<output_compute_type*>(transform_mid_buf),
  446. ncb_param.bias_mode, ncb_param.nonlineMode, OH, OW,
  447. oc_start_idx, oc_end_idx, unit_start_idx, nr_tiles_in_unit);
  448. };
  449. SmallVector<NCBKern> get_preprocess_kerns(
  450. const NCBKernSizeParam& param, fallback::MatrixMulImpl::AlgoBase*) {
  451. megdnn_assert(
  452. param.filter_meta.format == param::ConvBias::Format::NCHW ||
  453. param.filter_meta.format == param::ConvBias::Format::NCHW88 ||
  454. param.filter_meta.format == param::ConvBias::Format::NCHW44);
  455. megdnn_assert(param.preprocessed_filter &&
  456. param.preprocessed_filter->tensors.size() > 0);
  457. size_t OC = param.filter_meta.ocpg;
  458. size_t GROUP = param.filter_meta.group;
  459. const TensorND& preprocessed_dst =
  460. param.preprocessed_filter->tensors[0];
  461. WorkspaceBundle bundle = get_preprocess_wbundle(param);
  462. Strategy strategy = m_strategy;
  463. SmallVector<NCBKern> kerns;
  464. auto filter_process_kern =
  465. [strategy, bundle, &preprocessed_dst, this](
  466. const NCBKernParam& ncb_param,
  467. const NCBKernIndex& ncb_index) mutable {
  468. MEGDNN_MARK_USED_VAR(this);
  469. MIDOUT_BEGIN(megdnn_fallback_conv_bias_winograd_common,
  470. midout_iv("filter_preprocess"_hash)) {
  471. bundle.set(ncb_param.workspace_ptr);
  472. filter_preprocess(strategy, bundle, preprocessed_dst,
  473. ncb_param, ncb_index);
  474. }
  475. MIDOUT_END();
  476. };
  477. size_t oc_parallelism = OC;
  478. if (param.filter_meta.format == param::ConvBias::Format::NCHW88) {
  479. megdnn_assert(OC % 8 == 0);
  480. oc_parallelism = OC / 8;
  481. } else if (param.filter_meta.format ==
  482. param::ConvBias::Format::NCHW44) {
  483. megdnn_assert(OC % 4 == 0);
  484. oc_parallelism = OC / 4;
  485. }
  486. kerns.push_back({filter_process_kern, {GROUP, oc_parallelism}});
  487. return kerns;
  488. }
  489. SmallVector<NCBKern> get_kerns(
  490. const NCBKernSizeParam& param,
  491. fallback::MatrixMulImpl::AlgoBase* matmul_algo) {
  492. size_t N = param.n;
  493. size_t OC = param.filter_meta.ocpg;
  494. size_t OH = param.osz[0];
  495. size_t OW = param.osz[1];
  496. size_t GROUP = param.filter_meta.group;
  497. WorkspaceBundle bundle_top = get_wbundle(param, matmul_algo);
  498. WorkspaceBundle bundle_compute =
  499. get_wbundle_compute(param, matmul_algo);
  500. fallback::MatrixMulImpl::KernParam matmul_param;
  501. static_cast<fallback::MatrixMulImpl::KernSizeParam&>(matmul_param) =
  502. get_matmul_kern_param(param, m_unit_oc_size);
  503. size_t unit_tile_size = m_unit_tile_size;
  504. size_t unit_oc_size = m_unit_oc_size;
  505. size_t units_h = div_ceil<size_t>(OH, Strategy::OUTPUT_BLOCK_SIZE);
  506. size_t units_w = div_ceil<size_t>(OW, Strategy::OUTPUT_BLOCK_SIZE);
  507. size_t nr_units = units_h * units_w;
  508. size_t nr_hw_tiles = div_ceil<size_t>(nr_units, m_unit_tile_size);
  509. size_t nr_oc_tiles = div_ceil<size_t>(OC, m_unit_oc_size);
  510. //! The filter should process ahead
  511. megdnn_assert(
  512. param.filter_meta.stride[0] == 1 &&
  513. param.filter_meta.stride[1] == 1 &&
  514. (param.filter_meta.format == param::ConvBias::Format::NCHW ||
  515. param.filter_meta.format == param::ConvBias::Format::NCHW88 ||
  516. param.filter_meta.format == param::ConvBias::Format::NCHW44 ||
  517. param.filter_meta.format ==
  518. param::ConvBias::Format::NCHW_WINOGRAD ||
  519. param.filter_meta.format ==
  520. param::ConvBias::Format::NCHW88_WINOGRAD ||
  521. param.filter_meta.format ==
  522. param::ConvBias::Format::NCHW44_WINOGRAD));
  523. SmallVector<NCBKern> kerns;
  524. if (param.preprocessed_filter == nullptr &&
  525. (param.filter_meta.format == param::ConvBias::Format::NCHW ||
  526. param.filter_meta.format == param::ConvBias::Format::NCHW88 ||
  527. param.filter_meta.format == param::ConvBias::Format::NCHW44)) {
  528. auto filter_process_kern =
  529. [strategy = m_strategy, bundle_top, bundle_compute, this](
  530. const NCBKernParam& ncb_param,
  531. const NCBKernIndex& ncb_index) mutable {
  532. MEGDNN_MARK_USED_VAR(this);
  533. MIDOUT_BEGIN(megdnn_fallback_conv_bias_winograd_common,
  534. midout_iv("filter_process"_hash)) {
  535. bundle_top.set(ncb_param.workspace_ptr);
  536. bundle_compute.set(bundle_top.get(0));
  537. filter_process(strategy, bundle_top, bundle_compute,
  538. ncb_param, std::move(ncb_index));
  539. }
  540. MIDOUT_END();
  541. };
  542. size_t oc_parallelism = OC;
  543. if (param.filter_meta.format == param::ConvBias::Format::NCHW88) {
  544. megdnn_assert(OC % 8 == 0);
  545. oc_parallelism = OC / 8;
  546. } else if (param.filter_meta.format ==
  547. param::ConvBias::Format::NCHW44) {
  548. megdnn_assert(OC % 4 == 0);
  549. oc_parallelism = OC / 4;
  550. }
  551. kerns.push_back({filter_process_kern, {GROUP, 1, oc_parallelism}});
  552. }
  553. auto winograd_compute_kern =
  554. [strategy = m_strategy, bundle_top, bundle_compute, matmul_algo,
  555. matmul_param, unit_tile_size, unit_oc_size,
  556. this](const NCBKernParam& ncb_param,
  557. const NCBKernIndex& ncb_index) mutable {
  558. MEGDNN_MARK_USED_VAR(this);
  559. MIDOUT_BEGIN(megdnn_fallback_conv_bias_winograd_common,
  560. midout_iv("winograd_compute"_hash)) {
  561. bundle_top.set(ncb_param.workspace_ptr);
  562. bundle_compute.set(bundle_top.get(0));
  563. winograd_compute(strategy, bundle_top, bundle_compute,
  564. matmul_algo, matmul_param,
  565. unit_tile_size, unit_oc_size,
  566. ncb_param, std::move(ncb_index));
  567. }
  568. MIDOUT_END();
  569. };
  570. kerns.push_back(
  571. {winograd_compute_kern, {GROUP, N, nr_hw_tiles, nr_oc_tiles}});
  572. return kerns;
  573. }
  574. fallback::MatrixMulImpl::KernSizeParam get_matmul_kern_param(
  575. const NCBKernSizeParam& param, size_t nr_oc_in_unit = 0) const {
  576. size_t M = 0;
  577. size_t N = 0;
  578. size_t K = 0;
  579. size_t LDA = 0, LDB = 0, LDC = 0;
  580. if (nr_oc_in_unit == 0) {
  581. nr_oc_in_unit = param.filter_meta.ocpg;
  582. }
  583. if (format == param::MatrixMul::Format::DEFAULT) {
  584. M = m_unit_tile_size;
  585. N = nr_oc_in_unit;
  586. K = param.filter_meta.icpg;
  587. LDA = K;
  588. LDB = N;
  589. LDC = N;
  590. } else {
  591. M = nr_oc_in_unit;
  592. N = m_unit_tile_size;
  593. K = param.filter_meta.icpg;
  594. megdnn_assert(K % Strategy::IC_BLOCK_SIZE == 0, "invalid K: %zu",
  595. K);
  596. LDA = K / Strategy::IC_BLOCK_SIZE * Strategy::OC_BLOCK_SIZE *
  597. Strategy::IC_BLOCK_SIZE;
  598. LDB = N * Strategy::IC_BLOCK_SIZE;
  599. LDC = N * Strategy::IC_BLOCK_SIZE;
  600. }
  601. return {DType::from_enum(DTypeTrait<input_filter_compute_type>::enumv),
  602. DType::from_enum(DTypeTrait<input_filter_compute_type>::enumv),
  603. DType::from_enum(DTypeTrait<output_compute_type>::enumv),
  604. M,
  605. N,
  606. K,
  607. LDA,
  608. LDB,
  609. LDC,
  610. false,
  611. false,
  612. param::MatrixMul::ComputeMode::DEFAULT,
  613. format};
  614. }
  615. };
  616. } // namespace winograd
  617. } // namespace megdnn
  618. #define MEGDNN_REG_WINOGRAD_STRATEGY( \
  619. _stype, _dtype, _input_filter_ctype, _ctype, _output_block_size, \
  620. _kernel_size, _ic_block_size, _oc_block_size, _strategy_cls_name) \
  621. class _strategy_cls_name { \
  622. public: \
  623. using stype = _stype; \
  624. using dst_type = _dtype; \
  625. using output_compute_type = _ctype; \
  626. using input_filter_compute_type = _input_filter_ctype; \
  627. /** \
  628. * kernel size of convolution, same as \c r \
  629. * output block size, same as \c m \
  630. */ \
  631. constexpr static size_t KERNEL_SIZE = _kernel_size; \
  632. constexpr static size_t OUTPUT_BLOCK_SIZE = _output_block_size; \
  633. constexpr static size_t IC_BLOCK_SIZE = _ic_block_size; \
  634. constexpr static size_t OC_BLOCK_SIZE = _oc_block_size; \
  635. constexpr static size_t ALPHA = KERNEL_SIZE + OUTPUT_BLOCK_SIZE - 1; \
  636. /** \
  637. * process \c UNIT_TILE_SIZE small matrix mul once, total tiles is \
  638. * N * DIV_UP(OH, OUTPUT_BLOCK_SIZE) * DIV_UP(OW, OUTPUT_BLOCK_SIZE) \
  639. */ \
  640. const DType src_dtype; \
  641. const DType filter_dtype; \
  642. const DType dst_dtype; \
  643. _strategy_cls_name(DType src_dtype, DType filter_dtype, \
  644. DType dst_dtype); \
  645. void filter(const stype* filter, \
  646. input_filter_compute_type* filter_transform_buf, \
  647. input_filter_compute_type* transform_mid_buf, size_t OC, \
  648. size_t IC, size_t oc_start, size_t oc_end); \
  649. void input(const stype* input, \
  650. input_filter_compute_type* input_transform_buf, \
  651. input_filter_compute_type* transform_mid_buf, \
  652. size_t IH, size_t IW, size_t IC, size_t PH, size_t PW, \
  653. size_t unit_start_idx, size_t nr_tiles_in_unit); \
  654. void output(const output_compute_type* output_transform_buf, \
  655. const output_compute_type* bias, dst_type* output, \
  656. output_compute_type* transform_mid_buf, BiasMode bmode, \
  657. NonlineMode nonline_mode, size_t OH, size_t OW, \
  658. size_t oc_start, size_t oc_end, size_t unit_start_idx, \
  659. size_t nr_tiles_in_unit); \
  660. };
  661. #define MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(_strategy_cls_name) \
  662. constexpr size_t _strategy_cls_name::KERNEL_SIZE; \
  663. constexpr size_t _strategy_cls_name::OUTPUT_BLOCK_SIZE; \
  664. constexpr size_t _strategy_cls_name::ALPHA; \
  665. constexpr size_t _strategy_cls_name::IC_BLOCK_SIZE; \
  666. constexpr size_t _strategy_cls_name::OC_BLOCK_SIZE; \
  667. _strategy_cls_name::_strategy_cls_name( \
  668. DType src_dtype, DType filter_dtype, DType dst_dtype) \
  669. : src_dtype(src_dtype), \
  670. filter_dtype(filter_dtype), \
  671. dst_dtype(dst_dtype) {}
  672. #define MEGDNN_WINOGRADS_ALGO_FUN_DEFINE(_class, _fun, _strategy, \
  673. _midout_flag, _matmul_format) \
  674. MEGDNN_MARK_USED_VAR(param); \
  675. MIDOUT_BEGIN(_midout_flag, midout_iv(#_class #_fun##_hash)) { \
  676. _strategy strategy(param.src_type, param.filter_type, param.dst_type); \
  677. return megdnn::winograd::ConvBias<_strategy, _matmul_format>( \
  678. strategy, m_tile_size, param) \
  679. ._fun(param, m_matmul_algo); \
  680. } \
  681. MIDOUT_END();
  682. #define MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL(_class, _strategy, _midout_flag, \
  683. _matmul_format) \
  684. size_t ConvBiasImpl::_class::get_workspace(const NCBKernSizeParam& param) \
  685. const { \
  686. MEGDNN_WINOGRADS_ALGO_FUN_DEFINE(_class, get_workspace_size, \
  687. _strategy, _midout_flag, \
  688. _matmul_format); \
  689. return 0; \
  690. } \
  691. size_t ConvBiasImpl::_class::get_preprocess_workspace( \
  692. const NCBKernSizeParam& param) const { \
  693. MEGDNN_WINOGRADS_ALGO_FUN_DEFINE( \
  694. _class, get_preprocess_workspace_size, _strategy, \
  695. _midout_flag, _matmul_format); \
  696. return 0; \
  697. } \
  698. SmallVector<TensorLayout> \
  699. ConvBiasImpl::_class::deduce_preprocessed_filter_layout( \
  700. const NCBKernSizeParam& param) const { \
  701. MEGDNN_WINOGRADS_ALGO_FUN_DEFINE( \
  702. _class, deduce_preprocessed_filter_layout, _strategy, \
  703. _midout_flag, _matmul_format); \
  704. return {}; \
  705. } \
  706. SmallVector<ConvBiasImpl::NCBKern> \
  707. ConvBiasImpl::_class::dispatch_preprocess_kerns( \
  708. const NCBKernSizeParam& param) const { \
  709. MEGDNN_WINOGRADS_ALGO_FUN_DEFINE(_class, get_preprocess_kerns, \
  710. _strategy, _midout_flag, \
  711. _matmul_format); \
  712. return {}; \
  713. } \
  714. SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::_class::dispatch_kerns( \
  715. const NCBKernSizeParam& param) const { \
  716. MEGDNN_WINOGRADS_ALGO_FUN_DEFINE(_class, get_kerns, _strategy, \
  717. _midout_flag, _matmul_format); \
  718. return {}; \
  719. }
  720. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台