You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

nn.h 70 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690
  1. /**
  2. * \file dnn/include/megdnn/oprs/nn.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #pragma once
  13. #include "megdnn/internal/opr_header_prologue.h"
  14. namespace megdnn {
  15. class SeparableConvBase : public OperatorBase {
  16. DEF_OPR_IMPL_CTOR(SeparableConvBase, OperatorBase);
  17. DEF_OPR_PARAM(SeparableConv);
  18. public:
  19. using Mode = Param::Mode;
  20. protected:
  21. void deduce_layout_fwd(const TensorLayout& src,
  22. const TensorLayout& filter_x,
  23. const TensorLayout& filter_y, TensorLayout& dst);
  24. void check_layout_fwd(const TensorLayout& src, const TensorLayout& filter_x,
  25. const TensorLayout& filter_y,
  26. const TensorLayout& dst);
  27. };
  28. class SeparableConvForward : public SeparableConvBase {
  29. DEF_OPR_IMPL(SeparableConvForward, SeparableConvBase, 3, 1);
  30. public:
  31. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter_x,
  32. _megdnn_tensor_in filter_y, _megdnn_tensor_out dst,
  33. _megdnn_workspace workspace) = 0;
  34. void deduce_layout(const TensorLayout& src, const TensorLayout& filter_x,
  35. const TensorLayout& filter_y, TensorLayout& dst);
  36. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  37. const TensorLayout& filter_x,
  38. const TensorLayout& filter_y,
  39. const TensorLayout& dst) = 0;
  40. protected:
  41. void check_exec(const TensorLayout& src, const TensorLayout& filter_x,
  42. const TensorLayout& filter_y, const TensorLayout& dst,
  43. size_t workspace_in_bytes);
  44. };
  45. using SeparableConv = SeparableConvForward;
  46. namespace detail {
  47. struct PreprocessedFilter {
  48. //! user data; its lifetime should be bound to MegDNN Convolution
  49. //! operator
  50. void* algorithm_id;
  51. TensorNDArray tensors;
  52. };
  53. } // namespace detail
  54. /**
  55. * \brief base class for convolution operation
  56. *
  57. * This operator is supposed to perform convolution on arbitrary input
  58. * dimensions. The input/output format is N, C, dims..., and kernel format can
  59. * take two forms:
  60. * 1. OC, IC, dims..., for conventional dense convolution
  61. * 2. GROUP, OC_PER_GRP, IC_PER_GRP, dims... for sparse group convolution
  62. *
  63. * Currently, only 2D images are supported.
  64. */
  65. template <typename Parameter>
  66. class ConvolutionBase : public OperatorBase {
  67. DEF_OPR_IMPL_CTOR(ConvolutionBase, OperatorBase);
  68. using Param = Parameter;
  69. public:
  70. Param& param() { return m_param; }
  71. const Param& param() const { return m_param; }
  72. protected:
  73. Param m_param;
  74. public:
  75. static constexpr size_t MAX_SPATIAL_DIM = 2;
  76. using Mode = typename Param::Mode;
  77. struct CanonizedFilterMeta {
  78. DType dtype;
  79. typename Param::Format format;
  80. uint32_t
  81. //! whether filter should be flipped (i.e. is CONVOLUTION)
  82. should_flip,
  83. group, //!< number of groups
  84. icpg, //!< input channels per group
  85. ocpg, //!< output channels per group
  86. spatial_ndim, stride[MAX_SPATIAL_DIM], padding[MAX_SPATIAL_DIM],
  87. //! spatial dim
  88. spatial[MAX_SPATIAL_DIM], dilation[MAX_SPATIAL_DIM],
  89. //! spatial dim with dilation applied
  90. dilated_spatial[MAX_SPATIAL_DIM];
  91. //! T should be a ConvolutionBase<Z>::CanonizedFilterMeta
  92. template <typename T>
  93. void copy_from(const T& b) {
  94. dtype = b.dtype;
  95. format = b.format;
  96. should_flip = b.should_flip;
  97. group = b.group;
  98. icpg = b.icpg;
  99. ocpg = b.ocpg;
  100. spatial_ndim = b.spatial_ndim;
  101. memcpy(stride, b.stride, sizeof(stride));
  102. memcpy(padding, b.padding, sizeof(padding));
  103. memcpy(spatial, b.spatial, sizeof(spatial));
  104. memcpy(dilation, b.dilation, sizeof(dilation));
  105. memcpy(dilated_spatial, b.dilated_spatial, sizeof(dilated_spatial));
  106. }
  107. bool operator==(const CanonizedFilterMeta& b) const {
  108. bool flag = true;
  109. flag = flag && (format == b.format);
  110. flag = flag && (dtype == b.dtype);
  111. flag = flag && (should_flip == b.should_flip);
  112. flag = flag && (group == b.group);
  113. flag = flag && (icpg == b.icpg);
  114. flag = flag && (ocpg == b.ocpg);
  115. flag = flag && (spatial_ndim == b.spatial_ndim);
  116. if (flag) {
  117. for (uint32_t i = 0; i < spatial_ndim; ++i) {
  118. flag = flag && (stride[i] == b.stride[i]);
  119. flag = flag && (padding[i] == b.padding[i]);
  120. flag = flag && (spatial[i] == b.spatial[i]);
  121. flag = flag && (dilation[i] == b.dilation[i]);
  122. flag = flag && (dilated_spatial[i] == b.dilated_spatial[i]);
  123. }
  124. }
  125. return flag;
  126. }
  127. };
  128. using PreprocessedFilter = detail::PreprocessedFilter;
  129. protected:
  130. // Check or deduce output DType
  131. void check_or_deduce_dtype_fwd(DType src, DType filter, DType& dst) const;
  132. CanonizedFilterMeta deduce_layout_fwd(const TensorLayout& src,
  133. const TensorLayout& filter,
  134. TensorLayout& dst) const;
  135. CanonizedFilterMeta check_layout_fwd(const TensorLayout& src,
  136. const TensorLayout& filter,
  137. const TensorLayout& dst) const;
  138. CanonizedFilterMeta make_canonized_filter_meta(
  139. size_t src_ndim, const TensorLayout& filter) const;
  140. };
  141. class MaskPropagate : public OperatorBase {
  142. DEF_OPR_IMPL(MaskPropagate, OperatorBase, 1, 1);
  143. DEF_OPR_PARAM(MaskPropagate);
  144. public:
  145. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
  146. _megdnn_workspace workspace) = 0;
  147. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  148. const TensorLayout& dst) = 0;
  149. void deduce_layout(const TensorLayout& src, TensorLayout& dst);
  150. };
  151. /**
  152. * \brief ConvolutionForward Operator with 0/1 Mask matrix
  153. */
  154. class MaskConvForward : public ConvolutionBase<param::Convolution> {
  155. DEF_OPR_IMPL(MaskConvForward, ConvolutionBase, 3, 1);
  156. public:
  157. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter,
  158. _megdnn_tensor_in mask, _megdnn_tensor_out dst,
  159. _megdnn_workspace worksapce) = 0;
  160. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  161. const TensorLayout& filter,
  162. const TensorLayout& mask,
  163. const TensorLayout& dst) = 0;
  164. void deduce_dtype(DType src, DType filter, DType mask, DType& dst);
  165. void deduce_layout(const TensorLayout& src, const TensorLayout& filter,
  166. const TensorLayout& mask, TensorLayout& dst);
  167. protected:
  168. CanonizedFilterMeta check_exec(const TensorLayout& src,
  169. const TensorLayout& filter,
  170. const TensorLayout& mask,
  171. const TensorLayout& dst,
  172. size_t workspace_in_bytes);
  173. };
  174. using MaskConvolution = MaskConvForward;
  175. /**
  176. * \brief ConvolutionForward operator.
  177. */
  178. class ConvolutionForward : public ConvolutionBase<param::Convolution>,
  179. public detail::MultiAlgoOpr<ConvolutionForward, 3> {
  180. DEF_OPR_IMPL(ConvolutionForward, ConvolutionBase, 2, 1);
  181. public:
  182. /**
  183. * \param[in] src (n, ic, ih, iw)
  184. * \param[in] filter (oc, ic, fh, fw)
  185. * \param[in] preprocessed_filter if weight no preprocessed it will be
  186. * nullptr, else the preprocessed weights store in the tensors of
  187. * preprocessed_filter.
  188. * \param[in] workspace if weight no preprocessed
  189. * (preprocessed_filter == nullptr), The size of the workspace satisfies the
  190. * situation that weights is not processed, other wise the size of workspace
  191. * satisfies the situation that weights is preprocessed
  192. * \param[out] dst (n, oc, oh, ow)
  193. */
  194. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter,
  195. _megdnn_tensor_out dst,
  196. const PreprocessedFilter* preprocessed_filter,
  197. _megdnn_workspace workspace) = 0;
  198. /**
  199. * \brief execute weight preprocessing, read weights form filter and write to
  200. * preprocessed_filter after preprocessed.
  201. *
  202. * \praram[in] workspace the needed tmp workspace when exec_preprocess
  203. */
  204. virtual void exec_preprocess(const TensorLayout& src_layout,
  205. _megdnn_tensor_in filter,
  206. const TensorLayout& dst_layout,
  207. PreprocessedFilter* preprocessed_filter,
  208. _megdnn_workspace workspace) = 0;
  209. void deduce_dtype(DType src, DType filter, DType& dst);
  210. void deduce_layout(const TensorLayout& src, const TensorLayout& filter,
  211. TensorLayout& dst);
  212. /**
  213. * \brief query the workspace needed when executing the opr, if the weights
  214. * are preprocessed the preprocessed_filter will not be nullptr, else it
  215. * will be nullptr, the workspace size maybe different whether weights are
  216. * preprocessed
  217. *
  218. * \return the size of workspace needed when executing
  219. */
  220. virtual size_t get_workspace_in_bytes(
  221. const TensorLayout& src, const TensorLayout& filter,
  222. const TensorLayout& dst,
  223. const PreprocessedFilter* preprocessed_filter) = 0;
  224. /**
  225. * \brief deduce the preprocessed filter layouts according to the src,
  226. * filter and dst layout, the result may contain multi layouts when the
  227. * weights is not one
  228. *
  229. * \return SmallVector<TensorLayout> Derive the layouts of weight
  230. * preprocessing, return empty if preprocessing is not needed.
  231. */
  232. virtual SmallVector<TensorLayout> deduce_preprocessed_filter_layout(
  233. const TensorLayout& src, const TensorLayout& filter,
  234. const TensorLayout& dst) = 0;
  235. /**
  236. * \brief query the workspace needed when preprocessing the weights,
  237. * according to the return size, a _megdnn_workspace will be created and
  238. * passed through exec_preprocess
  239. *
  240. * \return the size of workspace needed when preprocessing
  241. */
  242. virtual size_t get_preprocess_workspace_in_bytes(
  243. const TensorLayout& src, const TensorLayout& filter,
  244. const TensorLayout& dst) = 0;
  245. protected:
  246. CanonizedFilterMeta check_exec(
  247. const TensorLayout& src, const TensorLayout& filter,
  248. const TensorLayout& dst, size_t workspace_in_bytes,
  249. const PreprocessedFilter* preprocessed_filter);
  250. };
  251. using Convolution = ConvolutionForward;
  252. /**
  253. * \brief ConvolutionBackwardData operator.
  254. *
  255. * Calculating the gradient wrt. convolution input data.
  256. */
  257. class ConvolutionBackwardData
  258. : public ConvolutionBase<param::Convolution>,
  259. public detail::MultiAlgoOpr<ConvolutionBackwardData, 3> {
  260. DEF_OPR_IMPL(ConvolutionBackwardData, ConvolutionBase, 2, 1);
  261. public:
  262. /**
  263. * \param[in] filter (oc, ic, fh, fw)
  264. * \param[in] diff (n, oc, oh, ow)
  265. * \param[out] grad (n, ic, ih, iw)
  266. */
  267. virtual void exec(_megdnn_tensor_in filter, _megdnn_tensor_in diff,
  268. _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
  269. virtual size_t get_workspace_in_bytes(const TensorLayout& filter,
  270. const TensorLayout& diff,
  271. const TensorLayout& grad) = 0;
  272. void deduce_dtype(DType filter, DType diff, DType& grad);
  273. void deduce_layout(const TensorLayout& filter, const TensorLayout& diff,
  274. TensorLayout& grad);
  275. protected:
  276. CanonizedFilterMeta check_exec(const TensorLayout& filter,
  277. const TensorLayout& diff,
  278. const TensorLayout& grad,
  279. size_t workspace_in_bytes);
  280. };
  281. /**
  282. * \brief ConvolutionBackwardFilter operator.
  283. *
  284. * Calculating the gradient wrt. convolution filter.
  285. */
  286. class ConvolutionBackwardFilter
  287. : public ConvolutionBase<param::Convolution>,
  288. public detail::MultiAlgoOpr<ConvolutionBackwardFilter, 3> {
  289. DEF_OPR_IMPL(ConvolutionBackwardFilter, ConvolutionBase, 2, 1);
  290. public:
  291. /**
  292. * \param[in] src (n, ic, ih, iw)
  293. * \param[in] diff (n, oc, oh, ow)
  294. * \param[out] grad (oc, ic, fh, fw)
  295. */
  296. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in diff,
  297. _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
  298. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  299. const TensorLayout& diff,
  300. const TensorLayout& grad) = 0;
  301. protected:
  302. CanonizedFilterMeta check_exec(const TensorLayout& src,
  303. const TensorLayout& diff,
  304. const TensorLayout& grad,
  305. size_t workspace_in_bytes);
  306. };
  307. /**
  308. * \brief ConvolutionBias operator
  309. */
  310. class ConvBiasForward : public ConvolutionBase<param::ConvBias>,
  311. public detail::MultiAlgoOpr<ConvBiasForward, 5> {
  312. DEF_OPR_IMPL(ConvBiasForward, ConvolutionBase, 4, 1);
  313. public:
  314. /**
  315. * \param[in] src (n, ic, ih, iw) or (n, ih, iw, ic)
  316. * \param[in] filter (oc, ic, fh, fw) or (oc, fh, fw, ic) or (oc/4, fh, fw,
  317. * 4 * ic)
  318. * \param[in] bias (1, oc, 1, 1)
  319. * \param[in] z same as dst
  320. * \param[in] preprocessed_filter if weight no preprocessed it will be
  321. * nullptr, else the preprocessed weights store in the tensors of
  322. * preprocessed_filter.
  323. * \param[in] workspace if weight no preprocessed
  324. * (preprocessed_filter == nullptr), The size of the workspace satisfies the
  325. * situation that weights is not processed, other wise the size of workspace
  326. * satisfies the situation that weights is preprocessed
  327. * \param[out] dst (n, oc, oh, ow) or (n, oh, ow, oc)
  328. *
  329. * \note if the format is NCHW_WINOGRAD, the filter layout is (alphah,
  330. * alphaw, oc, ic)
  331. */
  332. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter,
  333. _megdnn_tensor_in bias, _megdnn_tensor_in z,
  334. _megdnn_tensor_out dst,
  335. const PreprocessedFilter* preprocessed_filter,
  336. _megdnn_workspace workspace) = 0;
  337. /**
  338. * \brief execute weight preprocessing, read weights form filter and bias,
  339. * write to preprocessed_filter after preprocessed.
  340. *
  341. * \praram[in] workspace the needed tmp workspace when exec_preprocess
  342. * running, the size is got by get_preprocess_workspace_in_bytes
  343. */
  344. virtual void exec_preprocess(const TensorLayout& src_layout,
  345. _megdnn_tensor_in filter,
  346. _megdnn_tensor_in bias,
  347. const TensorLayout& z_layout,
  348. const TensorLayout& dst_layout,
  349. PreprocessedFilter* preprocessed_filter,
  350. _megdnn_workspace workspace) = 0;
  351. void deduce_dtype(DType src, DType filter, DType bias, DType z, DType& dst);
  352. void deduce_layout(const TensorLayout& src, const TensorLayout& filter,
  353. const TensorLayout& bias, const TensorLayout& z,
  354. TensorLayout& dst);
  355. /**
  356. * \brief query the workspace needed when executing the opr, if the weights
  357. * are preprocessed the preprocessed_filter will not be nullptr, else it
  358. * will be nullptr, the workspace size maybe different whether weights are
  359. * preprocessed
  360. *
  361. * \return the size of workspace needed when executing
  362. */
  363. virtual size_t get_workspace_in_bytes(
  364. const TensorLayout& src, const TensorLayout& filter,
  365. const TensorLayout& bias, const TensorLayout& z,
  366. const TensorLayout& dst,
  367. const PreprocessedFilter* preprocessed_filter) = 0;
  368. /**
  369. * \brief query the workspace needed when pre-processing the weights,
  370. * according to the return size, a _megdnn_workspace will be created and
  371. * passed through exec_preprocess
  372. *
  373. * \return the size of workspace needed when pre-processing
  374. */
  375. virtual size_t get_preprocess_workspace_in_bytes(
  376. const TensorLayout& src, const TensorLayout& filter,
  377. const TensorLayout& bias, const TensorLayout& z,
  378. const TensorLayout& dst) = 0;
  379. /**
  380. * \brief deduce the pre-processed filter layouts according to the src,
  381. * filter and dst layout, which may contain multi layouts when the weights
  382. * is not one
  383. *
  384. * \return SmallVector<TensorLayout> Derive the layouts of weight
  385. * preprocessing, return empty if preprocessing is not needed.
  386. */
  387. virtual SmallVector<TensorLayout> deduce_preprocessed_filter_layout(
  388. const TensorLayout& src, const TensorLayout& filter,
  389. const TensorLayout& bias, const TensorLayout& z,
  390. const TensorLayout& dst) = 0;
  391. enum class BiasMode : uint32_t {
  392. NO_BIAS = 0, //!< no bias
  393. BROADCAST_CHANNEL_BIAS, //!< broadcast channel bias, [1, c, 1, 1]
  394. BIAS //!< [N, C, H, W]
  395. };
  396. //! param for winograd algos.
  397. struct WinogradParam {
  398. uint32_t channel_block_size;
  399. uint32_t output_block_size;
  400. uint32_t tile_size;
  401. bool operator==(const WinogradParam& rhs) const {
  402. return channel_block_size == rhs.channel_block_size &&
  403. output_block_size == rhs.output_block_size &&
  404. tile_size == rhs.tile_size;
  405. }
  406. std::string to_string() const;
  407. };
  408. static constexpr WinogradParam INVALID_WINOGRAD_PARAM = {0, 0, 0};
  409. struct DirectParam {
  410. std::string to_string() const { return ""; }
  411. };
  412. struct MatmulParam {
  413. std::string to_string() const { return ""; }
  414. };
  415. struct DefaultParam {
  416. std::string to_string() const { return ""; }
  417. };
  418. //! get algo name, the format is ParamTrait<T>::category:base:p.to_string()
  419. //! \warning: base must not contain :.
  420. template <typename T>
  421. static std::string algo_name(
  422. const std::string& base, const T& p,
  423. param::ConvBias::Format format = param::ConvBias::Format::NCHW);
  424. /*!
  425. * \brief parse algo_name and get WinogradParam from algo name.
  426. *
  427. * \param algo name string
  428. * \return WinogradParam parsed from algo name, use pattern
  429. * winograd:base:m:tile_size.
  430. *
  431. * \warning: INVALID_WINOGRAD_PARAM returns if the algo_name is not matched.
  432. */
  433. static WinogradParam parse_winograd_name(const std::string& algo_name);
  434. /**
  435. * @brief find if there is nchw_nchwxx conv kernel optimized for argment,
  436. * nchw44 used for arm, nchw88 used for x86
  437. *
  438. * @param src_dtype conv feature map data type
  439. * @param filter_dtype conv filter or weight data type
  440. * @param dst_dtype output data type
  441. * @param fm filter meta param
  442. * @param bias_mode bias mode, no_bias or broadcast or bias
  443. * @param nonline_mode identity or relu or h_swish or sigmoid
  444. * @return true, found a kernel
  445. * @return false, can`t found any kernel
  446. */
  447. static bool is_nchw_nchwxx_optimized(
  448. const DTypeEnum src_dtype, const DTypeEnum filter_dtype,
  449. const DTypeEnum dst_dtype,
  450. const ConvolutionBase<param::Convolution>::CanonizedFilterMeta& fm,
  451. const ConvBiasForward::BiasMode bias_mode,
  452. const param::ConvBias::NonlineMode nonline_mode);
  453. protected:
  454. CanonizedFilterMeta check_exec(
  455. const TensorLayout& src, const TensorLayout& filter,
  456. const TensorLayout& bias, const TensorLayout& z,
  457. const TensorLayout& dst, size_t workspace_in_bytes,
  458. const PreprocessedFilter* preprocessed_filter);
  459. };
  460. using ConvBias = ConvBiasForward;
  461. /**
  462. * \brief base class for Conv - Nonline - Pooling
  463. */
  464. class ConvPoolingBase : public OperatorBase {
  465. DEF_OPR_IMPL_CTOR(ConvPoolingBase, OperatorBase);
  466. /**
  467. * \ Param::Method: Two methods to fetch the input data.
  468. * Default methods is WITH_TEXTURE_OBJ.
  469. * If you want to use WITH_SHARED_MEM mode,
  470. * please make sure that the size of
  471. * [ all of the fliter kernels + a channel
  472. * of input data + a channel of output data]
  473. * should be no large than 38KB.
  474. * And the pooling mode should not be "MAX".
  475. */
  476. DEF_OPR_PARAM(ConvPooling);
  477. protected:
  478. virtual void deduce_layout(const TensorLayout& src,
  479. const TensorLayout& filter,
  480. const TensorLayout& bias, TensorLayout& dst) = 0;
  481. virtual void check_layout(const TensorLayout& src,
  482. const TensorLayout& filter,
  483. const TensorLayout& bias, TensorLayout& dst,
  484. size_t workspace_limit_in_bytes) = 0;
  485. };
  486. class ConvPoolingForward : public ConvPoolingBase {
  487. DEF_OPR_IMPL(ConvPoolingForward, ConvPoolingBase, 2, 1);
  488. public:
  489. /**
  490. * \param[in] src input tensor
  491. * \param[out] dst output tensor
  492. */
  493. virtual void exec(const _megdnn_in TensorND src,
  494. const _megdnn_in TensorND filter,
  495. const _megdnn_in TensorND bias, _megdnn_out TensorND dst,
  496. _megdnn_out Workspace workspace) = 0;
  497. virtual void deduce_layout(const TensorLayout& src,
  498. const TensorLayout& filter,
  499. const TensorLayout& bias, TensorLayout& dst) = 0;
  500. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  501. const TensorLayout& filter,
  502. const TensorLayout& bias,
  503. const TensorLayout& dst) = 0;
  504. protected:
  505. virtual void check_layout(const TensorLayout& src,
  506. const TensorLayout& filter,
  507. const TensorLayout& bias, TensorLayout& dst,
  508. size_t workspace_limit_in_bytes) = 0;
  509. };
  510. using ConvPooling = ConvPoolingForward;
  511. class GroupLocalBase : public OperatorBase {
  512. DEF_OPR_IMPL_CTOR(GroupLocalBase, OperatorBase);
  513. DEF_OPR_PARAM(Convolution);
  514. public:
  515. using Mode = Param::Mode;
  516. protected:
  517. void deduce_layout_fwd(const TensorLayout& src, const TensorLayout& filter,
  518. TensorLayout& dst);
  519. void check_layout_fwd(const TensorLayout& src, const TensorLayout& filter,
  520. const TensorLayout& dst);
  521. };
  522. class GroupLocalForward : public GroupLocalBase {
  523. DEF_OPR_IMPL(GroupLocalForward, GroupLocalBase, 2, 1);
  524. public:
  525. /**
  526. * \param[in] src (N, IC, IH, IW)
  527. * \param[in] filter (G, OH, OW, IC/G, FH, FW, OC/G)
  528. * \param[out] dst (N, OC, OH, OW)
  529. **/
  530. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter,
  531. _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0;
  532. void deduce_layout(const TensorLayout& src, const TensorLayout& filter,
  533. TensorLayout& dst) {
  534. deduce_layout_fwd(src, filter, dst);
  535. }
  536. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  537. const TensorLayout& filter,
  538. const TensorLayout& dst) = 0;
  539. protected:
  540. void check_exec(const TensorLayout& src, const TensorLayout& filter,
  541. const TensorLayout& dst, size_t workspace_in_bytes);
  542. };
  543. using GroupLocal = GroupLocalForward;
  544. class GroupLocalBackwardData : public GroupLocalBase {
  545. DEF_OPR_IMPL(GroupLocalBackwardData, GroupLocalBase, 2, 1);
  546. public:
  547. virtual void exec(_megdnn_tensor_in filter, _megdnn_tensor_in diff,
  548. _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
  549. virtual size_t get_workspace_in_bytes(const TensorLayout& filter,
  550. const TensorLayout& diff,
  551. const TensorLayout& grad) = 0;
  552. protected:
  553. void check_exec(const TensorLayout& filter, const TensorLayout& diff,
  554. const TensorLayout& grad, size_t workspace_in_bytes);
  555. };
  556. class GroupLocalBackwardFilter : public GroupLocalBase {
  557. DEF_OPR_IMPL(GroupLocalBackwardFilter, GroupLocalBase, 2, 1);
  558. public:
  559. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in diff,
  560. _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
  561. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  562. const TensorLayout& diff,
  563. const TensorLayout& grad) = 0;
  564. protected:
  565. void check_exec(const TensorLayout& filter, const TensorLayout& diff,
  566. const TensorLayout& grad, size_t workspace_in_bytes);
  567. };
  568. class Images2NeibsBase : public OperatorBase {
  569. DEF_OPR_IMPL_CTOR(Images2NeibsBase, OperatorBase);
  570. DEF_OPR_PARAM(Images2Neibs);
  571. protected:
  572. void deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst);
  573. void check_layout_fwd(const TensorLayout& filter, const TensorLayout& dst);
  574. };
  575. class Images2NeibsForward : public Images2NeibsBase {
  576. DEF_OPR_IMPL(Images2NeibsForward, Images2NeibsBase, 1, 1);
  577. public:
  578. /**
  579. * \param[in] src (N, C, IH, IW)
  580. * \param[out] dst (N, C, OH, OW, window_h, window_w)
  581. *
  582. * \see
  583. * http://deeplearning.net/software/theano/library/tensor/nnet/neighbours.html
  584. *
  585. * \f$ dst_{n, c, oh, ow, wh, ww} = src_{n, c, ih+wh, iw+fw}\f$,
  586. * where \f$ ih=-pad_h+oh*stride_h, iw=-pad_w+ow*stride_w\f$.
  587. */
  588. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
  589. _megdnn_workspace workspace) = 0;
  590. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  591. const TensorLayout& dst) = 0;
  592. void deduce_layout(const TensorLayout& src, TensorLayout& dst);
  593. protected:
  594. void check_exec(const TensorLayout& src, const TensorLayout& dst,
  595. size_t workspace_in_bytes);
  596. };
  597. using Images2Neibs = Images2NeibsForward;
  598. class Images2NeibsBackward : public Images2NeibsBase {
  599. DEF_OPR_IMPL(Images2NeibsBackward, Images2NeibsBase, 1, 1);
  600. public:
  601. /**
  602. * \param[in] diff the backpropagated gradient wrt. dst
  603. * \param[out] grad the backpropagated gradient wrt. src
  604. */
  605. virtual void exec(_megdnn_tensor_in diff, _megdnn_tensor_out grad,
  606. _megdnn_workspace workspace) = 0;
  607. virtual size_t get_workspace_in_bytes(const TensorLayout& diff,
  608. const TensorLayout& grad) = 0;
  609. protected:
  610. void check_exec(const TensorLayout& diff, const TensorLayout& grad,
  611. size_t workspace_in_bytes);
  612. };
  613. /**
  614. * \brief base class for Pooling
  615. */
  616. class PoolingBase : public OperatorBase {
  617. DEF_OPR_IMPL_CTOR(PoolingBase, OperatorBase);
  618. DEF_OPR_PARAM(Pooling);
  619. public:
  620. using Mode = Param::Mode;
  621. protected:
  622. void deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst);
  623. void check_layout_fwd(const TensorLayout& src, const TensorLayout& dst);
  624. };
  625. class PoolingForward : public PoolingBase {
  626. DEF_OPR_IMPL(PoolingForward, PoolingBase, 1, 1);
  627. public:
  628. /**
  629. * \param[in] src input tensor
  630. * \param[out] dst output tensor
  631. */
  632. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
  633. _megdnn_workspace workspace) = 0;
  634. void deduce_layout(const TensorLayout& src, TensorLayout& dst);
  635. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  636. const TensorLayout& dst) = 0;
  637. protected:
  638. void check_exec(const TensorLayout& src, const TensorLayout& dst,
  639. size_t workspace_in_bytes);
  640. };
  641. using Pooling = PoolingForward;
  642. class PoolingBackward : public PoolingBase {
  643. DEF_OPR_IMPL(PoolingBackward, PoolingBase, 3, 1);
  644. public:
  645. /**
  646. * \param[in] src the `src' parameter in PoolingForward::exec
  647. * \param[in] dst the `dst' parameter in PoolingForward::exec
  648. * \param[in] diff the backpropagated gradient wrt. dst
  649. * \param[out] grad the backpropagated gradient wrt. src
  650. */
  651. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in dst,
  652. _megdnn_tensor_in diff, _megdnn_tensor_out grad,
  653. _megdnn_workspace workspace) = 0;
  654. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  655. const TensorLayout& dst,
  656. const TensorLayout& diff,
  657. const TensorLayout& grad) = 0;
  658. protected:
  659. void check_exec(const TensorLayout& src, const TensorLayout& dst,
  660. const TensorLayout& diff, const TensorLayout& grad,
  661. size_t workspace_in_bytes);
  662. };
  663. /**
  664. * \brief base class for AdaptivePooling
  665. */
  666. class AdaptivePoolingBase : public OperatorBase {
  667. DEF_OPR_IMPL_CTOR(AdaptivePoolingBase, OperatorBase);
  668. DEF_OPR_PARAM(AdaptivePooling);
  669. protected:
  670. param::Pooling deduce_pooling_param(const TensorLayout& src,
  671. const TensorLayout& dst);
  672. };
  673. class AdaptivePoolingForward : public AdaptivePoolingBase {
  674. DEF_OPR_IMPL(AdaptivePoolingForward, AdaptivePoolingBase, 1, 1);
  675. public:
  676. /**
  677. * \param[in] src input tensor
  678. * \param[out] dst output tensor
  679. */
  680. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
  681. _megdnn_workspace workspace) = 0;
  682. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  683. const TensorLayout& dst) = 0;
  684. };
  685. using AdaptivePooling = AdaptivePoolingForward;
  686. class AdaptivePoolingBackward : public AdaptivePoolingBase {
  687. DEF_OPR_IMPL(AdaptivePoolingBackward, AdaptivePoolingBase, 3, 1);
  688. public:
  689. /**
  690. * \param[in] src the `src' parameter in AdaptivePoolingForward::exec
  691. * \param[in] dst the `dst' parameter in AdaptivePoolingForward::exec
  692. * \param[in] diff the backpropagated gradient wrt. dst
  693. * \param[out] grad the backpropagated gradient wrt. src
  694. */
  695. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in dst,
  696. _megdnn_tensor_in diff, _megdnn_tensor_out grad,
  697. _megdnn_workspace workspace) = 0;
  698. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  699. const TensorLayout& dst,
  700. const TensorLayout& diff,
  701. const TensorLayout& grad) = 0;
  702. };
  703. /**
  704. * \brief base class for Local
  705. */
  706. class LocalBase : public OperatorBase {
  707. DEF_OPR_IMPL_CTOR(LocalBase, OperatorBase);
  708. DEF_OPR_PARAM(Convolution);
  709. public:
  710. using Mode = Param::Mode;
  711. protected:
  712. void deduce_layout_fwd(const TensorLayout& src, const TensorLayout& filter,
  713. TensorLayout& dst);
  714. void check_layout_fwd(const TensorLayout& src, const TensorLayout& filter,
  715. const TensorLayout& dst);
  716. };
  717. class LocalForward : public LocalBase {
  718. DEF_OPR_IMPL(LocalForward, LocalBase, 2, 1);
  719. public:
  720. /**
  721. * \param[in] src (n, ic, ih, iw)
  722. * \param[in] filter (oh, ow, ic, fh, fw, oc)
  723. * \param[out] dst (n, oc, oh, ow)
  724. */
  725. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter,
  726. _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0;
  727. /**
  728. * \brief Deducing output tensor layouts from input tensor layouts.
  729. *
  730. * Be aware that the first and second dimension of `filter' are ignored
  731. * when deducing `dst' layout.
  732. */
  733. void deduce_layout(const TensorLayout& src, const TensorLayout& filter,
  734. TensorLayout& dst);
  735. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  736. const TensorLayout& filter,
  737. const TensorLayout& dst) = 0;
  738. protected:
  739. void check_exec(const TensorLayout& src, const TensorLayout& filter,
  740. const TensorLayout& dst, size_t workspace_in_bytes);
  741. };
  742. using Local = LocalForward;
  743. class LocalBackwardData : public LocalBase {
  744. DEF_OPR_IMPL(LocalBackwardData, LocalBase, 2, 1);
  745. public:
  746. /**
  747. * \param[in] filter (oh, ow, ic, fh, fw, oc)
  748. * \param[in] diff (n, oc, oh, ow)
  749. * \param[out] grad (n, ic, ih, iw)
  750. */
  751. virtual void exec(_megdnn_tensor_in filter, _megdnn_tensor_in diff,
  752. _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
  753. virtual size_t get_workspace_in_bytes(const TensorLayout& filter,
  754. const TensorLayout& diff,
  755. const TensorLayout& grad) = 0;
  756. protected:
  757. void check_exec(const TensorLayout& filter, const TensorLayout& diff,
  758. const TensorLayout& grad, size_t workspace_in_bytes);
  759. };
  760. class LocalBackwardFilter : public LocalBase {
  761. DEF_OPR_IMPL(LocalBackwardFilter, LocalBase, 2, 1);
  762. public:
  763. /**
  764. * \param[in] src (n, ic, ih, iw)
  765. * \param[in] diff (n, oc, oh, ow)
  766. * \param[out] grad (oh, ow, ic, fh, fw, oc)
  767. */
  768. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in diff,
  769. _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
  770. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  771. const TensorLayout& diff,
  772. const TensorLayout& grad) = 0;
  773. protected:
  774. void check_exec(const TensorLayout& src, const TensorLayout& diff,
  775. const TensorLayout& grad, size_t workspace_in_bytes);
  776. };
  777. class BNBase : public OperatorBase {
  778. DEF_OPR_IMPL_CTOR(BNBase, OperatorBase);
  779. DEF_OPR_PARAM(BN);
  780. protected:
  781. void check_param();
  782. };
  783. class BNForward : public BNBase {
  784. DEF_OPR_IMPL(BNForward, BNBase, 6, 5);
  785. public:
  786. /**
  787. * \dst[i] = gemma
  788. * *(x[i]-estimatedMean[k])/sqrt(epsilon+estimatedVariance[k]) + beta \where
  789. * epsilon is a very small value to avoid a "divide by zero" error.
  790. * \param[in] src (n, c, h, w)
  791. * \param[out] dst (n, c, h, w)
  792. * \param[out] mean (see m_param.ParamDim) Global mean.
  793. * \param[out] variance (see m_param.ParamDim) Global variance.
  794. * \Param[out] batch_mean (see m_param.ParamDim)
  795. * Optionally cached intermediate mean from forward pass
  796. * \Param[out] batch_inv_variance (see m_param.ParamDim)
  797. * Optionally cached intermediate variance from forward pass
  798. * src and dst must have the same shape.
  799. * src and dst must be contiguous.
  800. */
  801. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in bn_scale,
  802. _megdnn_tensor_in bn_bias, _megdnn_tensor_inout mean,
  803. _megdnn_tensor_inout variance,
  804. _megdnn_tensor_out batch_mean,
  805. _megdnn_tensor_out batch_inv_variance,
  806. _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0;
  807. void deduce_layout(const TensorLayout& src, TensorLayout& bn_scale,
  808. TensorLayout& bn_bias, TensorLayout& mean,
  809. TensorLayout& variance, TensorLayout& batch_mean,
  810. TensorLayout& batch_inv_variance, TensorLayout& dst);
  811. virtual size_t get_workspace_in_bytes(
  812. const TensorLayout& src, const TensorLayout& bn_scale,
  813. const TensorLayout& bn_bias, const TensorLayout& mean,
  814. const TensorLayout& variance, const TensorLayout& batch_mean,
  815. const TensorLayout& batch_inv_variance,
  816. const TensorLayout& dst) = 0;
  817. protected:
  818. void check_exec(const TensorLayout& src, const TensorLayout& bn_scale,
  819. const TensorLayout& bn_bias, const TensorLayout& mean,
  820. const TensorLayout& variance,
  821. const TensorLayout& batch_mean,
  822. const TensorLayout& batch_inv_variance,
  823. const TensorLayout& dst, size_t workspace_in_bytes);
  824. };
  825. using BN = BNForward;
  826. class BNBackward : public BNBase {
  827. DEF_OPR_IMPL(BNBackward, BNBase, 5, 3);
  828. public:
  829. /**
  830. * \param[in] input data of forwarding propagate.
  831. * \param[in] dy the backpropagated gradient of y.
  832. * \param[out] dx the backpropagated gradient of x.
  833. * \param[out] d_bn_scale, the backpropagated gradient of bn_scale.
  834. * \param[out] d_bn_bias, the backpropagated gradient of bn_bias.
  835. * Optionally cached intermediate results from forward pass
  836. * \param[in] saved_batch_mean mean of the input batch.
  837. Calculated in the forwardpropagation.
  838. * \param[in] saved_batch_variance of the input batch.
  839. Calculated in the forwardpropagation.
  840. */
  841. virtual void exec(_megdnn_tensor_in x, _megdnn_tensor_in dy,
  842. _megdnn_tensor_in saved_batch_mean,
  843. _megdnn_tensor_in saved_batch_variance,
  844. _megdnn_tensor_in bn_scale, _megdnn_tensor_out d_bn_scale,
  845. _megdnn_tensor_out d_bn_bias, _megdnn_tensor_out dx,
  846. _megdnn_workspace workspace) = 0;
  847. virtual size_t get_workspace_in_bytes(
  848. const TensorLayout& x, const TensorLayout& dy,
  849. const TensorLayout& saved_batch_mean,
  850. const TensorLayout& saved_batch_variance,
  851. const TensorLayout& bn_scale, const TensorLayout& d_bn_scale,
  852. const TensorLayout& d_bn_bias, const TensorLayout& dx) = 0;
  853. protected:
  854. void check_exec(const TensorLayout& x, const TensorLayout& dy,
  855. const TensorLayout& saved_batch_mean,
  856. const TensorLayout& saved_batch_variance,
  857. const TensorLayout& bn_scale,
  858. const TensorLayout& d_bn_scale,
  859. const TensorLayout& d_bn_bias, const TensorLayout& dx,
  860. size_t workspace_in_bytes);
  861. };
  862. class LRNBase : public OperatorBase {
  863. DEF_OPR_IMPL_CTOR(LRNBase, OperatorBase);
  864. DEF_OPR_PARAM(LRN);
  865. protected:
  866. void check_param();
  867. };
  868. class LRNForward : public LRNBase {
  869. DEF_OPR_IMPL(LRNForward, LRNBase, 1, 1);
  870. public:
  871. /**
  872. * \see ImageNet Classification with Deep Convolutional Neural Networks
  873. * \param[in] src (n, c, h, w)
  874. * \param[out] dst (n, c, h, w)
  875. *
  876. * src and dst must have the same shape.
  877. * src and dst must be contiguous.
  878. */
  879. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
  880. _megdnn_workspace workspace) = 0;
  881. void deduce_layout(const TensorLayout& src, TensorLayout& dst);
  882. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  883. const TensorLayout& dst) = 0;
  884. protected:
  885. void check_exec(const TensorLayout& src, const TensorLayout& dst,
  886. size_t workspace_in_bytes);
  887. };
  888. using LRN = LRNForward;
  889. class LRNBackward : public LRNBase {
  890. DEF_OPR_IMPL(LRNBackward, LRNBase, 3, 1);
  891. public:
  892. /**
  893. * \param[in] src the `src' parameter in LRNForward::exec
  894. * \param[in] dst the `dst' parameter in LRNForward::exec
  895. * \param[in] diff the backpropagated gradient wrt. dst
  896. * \param[out] grad the backpropagated gradient wrt. src
  897. *
  898. * All tensors should be contiguous and of the same shape.
  899. */
  900. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in dst,
  901. _megdnn_tensor_in diff, _megdnn_tensor_out grad,
  902. _megdnn_workspace workspace) = 0;
  903. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  904. const TensorLayout& dst,
  905. const TensorLayout& diff,
  906. const TensorLayout& grad) = 0;
  907. protected:
  908. void check_exec(const TensorLayout& src, const TensorLayout& dst,
  909. const TensorLayout& diff, const TensorLayout& grad,
  910. size_t workspace_in_bytes);
  911. };
  912. class ROIPoolingBase : public OperatorBase {
  913. DEF_OPR_IMPL_CTOR(ROIPoolingBase, OperatorBase);
  914. DEF_OPR_PARAM(ROIPooling);
  915. protected:
  916. void check_layout_fwd(const TensorLayout& src, const TensorLayout& rois,
  917. const TensorLayout& dst, const TensorLayout& index);
  918. };
  919. class ROIPoolingForward : public ROIPoolingBase {
  920. DEF_OPR_IMPL(ROIPoolingForward, ROIPoolingBase, 2, 2);
  921. public:
  922. /**
  923. * \param[in] src (n, c, ih, iw)
  924. * \param[in] rois (m, 5)
  925. * \param[out] dst (m, c, oh, ow)
  926. * \param[out] index (m, c, oh, ow) if mode is MAX, (0) if mode is AVERAGE
  927. *
  928. * The internal implementation is akin to
  929. * https://github.com/rbgirshick/caffe-fast-rcnn .d
  930. * Note that rois(, 0) denotes the input image index. We store it as
  931. * a float, but it should be an integer instead.
  932. *
  933. * index is a temporary tensor to facilitate its backward operator.
  934. * It is used to store argmax indicex in MAX mode, and it is not used
  935. * in AVERAGE mode.
  936. */
  937. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in rois,
  938. _megdnn_tensor_out dst, _megdnn_tensor_out index,
  939. _megdnn_workspace workspace) = 0;
  940. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  941. const TensorLayout& rois,
  942. const TensorLayout& dst,
  943. const TensorLayout& index) = 0;
  944. protected:
  945. void check_exec(const TensorLayout& src, const TensorLayout& rois,
  946. const TensorLayout& dst, const TensorLayout& index,
  947. size_t workspace_in_bytes);
  948. };
  949. using ROIPooling = ROIPoolingForward;
  950. class ROIPoolingBackward : public ROIPoolingBase {
  951. DEF_OPR_IMPL(ROIPoolingBackward, ROIPoolingBase, 4, 1);
  952. public:
  953. /**
  954. * \param[in] diff the backpropagated gradient wrt. dst
  955. * \param[in] src the `src' parameter in ROIPoolingForward::exec
  956. * \param[in] rois the `rois' parameter in ROIPoolingForward::exec
  957. * \param[in] index the `index' parameter in ROIPoolingForward::exec
  958. * \param[out] grad the backpropagated gradient wrt. src
  959. */
  960. virtual void exec(_megdnn_tensor_in diff, _megdnn_tensor_in src,
  961. _megdnn_tensor_in rois, _megdnn_tensor_in index,
  962. _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
  963. virtual size_t get_workspace_in_bytes(const TensorLayout& diff,
  964. const TensorLayout& src,
  965. const TensorLayout& rois,
  966. const TensorLayout& index,
  967. const TensorLayout& grad) = 0;
  968. protected:
  969. void check_exec(const TensorLayout& diff, const TensorLayout& src,
  970. const TensorLayout& rois, const TensorLayout& index,
  971. const TensorLayout& grad, size_t workspace_in_bytes);
  972. };
  973. class Convolution3DBase : public OperatorBase {
  974. DEF_OPR_IMPL_CTOR(Convolution3DBase, OperatorBase);
  975. DEF_OPR_PARAM(Convolution3D);
  976. public:
  977. static constexpr size_t MAX_SPATIAL_DIM = 3;
  978. using Mode = Param::Mode;
  979. struct CanonizedFilterMeta {
  980. DTypeEnum dtype_enum;
  981. Param::Format format;
  982. uint32_t
  983. //! whether filter should be flipped (i.e. is CONVOLUTION)
  984. should_flip,
  985. group, //!< number of groups
  986. icpg, //!< input channels per group
  987. ocpg, //!< output channels per group
  988. spatial_ndim, stride[MAX_SPATIAL_DIM], padding[MAX_SPATIAL_DIM],
  989. //! spatial dim
  990. spatial[MAX_SPATIAL_DIM], dilation[MAX_SPATIAL_DIM],
  991. //! spatial dim with dilation applied
  992. dilated_spatial[MAX_SPATIAL_DIM];
  993. } MEGDNN_PACKED;
  994. protected:
  995. CanonizedFilterMeta deduce_layout_fwd(const TensorLayout& src,
  996. const TensorLayout& filter,
  997. TensorLayout& dst) const;
  998. CanonizedFilterMeta check_layout_fwd(const TensorLayout& src,
  999. const TensorLayout& filter,
  1000. const TensorLayout& dst) const;
  1001. CanonizedFilterMeta make_canonized_filter_meta(
  1002. size_t src_ndim, const TensorLayout& filter) const;
  1003. };
  1004. class Convolution3DForward
  1005. : public Convolution3DBase,
  1006. public detail::MultiAlgoOpr<Convolution3DForward, 3> {
  1007. DEF_OPR_IMPL(Convolution3DForward, Convolution3DBase, 2, 1);
  1008. public:
  1009. /**
  1010. * \param[in] src (n, ic, id, ih, iw)
  1011. * \param[in] filter (oc, ic, fd, fh, fw)
  1012. * \param[out] dst (n, oc, od, oh, ow)
  1013. */
  1014. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter,
  1015. _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0;
  1016. void deduce_layout(const TensorLayout& src, const TensorLayout& filter,
  1017. TensorLayout& dst);
  1018. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  1019. const TensorLayout& filter,
  1020. const TensorLayout& dst) = 0;
  1021. protected:
  1022. CanonizedFilterMeta check_exec(const TensorLayout& src,
  1023. const TensorLayout& filter,
  1024. const TensorLayout& dst,
  1025. size_t workspace_in_bytes);
  1026. };
  1027. using Convolution3D = Convolution3DForward;
  1028. class Convolution3DBackwardData
  1029. : public Convolution3DBase,
  1030. public detail::MultiAlgoOpr<Convolution3DBackwardData, 3> {
  1031. DEF_OPR_IMPL(Convolution3DBackwardData, Convolution3DBase, 2, 1);
  1032. public:
  1033. /**
  1034. * \param[in] filter (oc, ic, fd, fh, fw)
  1035. * \param[in] diff (n, oc, od, oh, ow)
  1036. * \param[out] grad (n, ic, id, ih, iw)
  1037. */
  1038. virtual void exec(_megdnn_tensor_in filter, _megdnn_tensor_in diff,
  1039. _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
  1040. virtual size_t get_workspace_in_bytes(const TensorLayout& filter,
  1041. const TensorLayout& diff,
  1042. const TensorLayout& grad) = 0;
  1043. void deduce_layout(const TensorLayout& filter, const TensorLayout& diff,
  1044. TensorLayout& grad);
  1045. protected:
  1046. CanonizedFilterMeta check_exec(const TensorLayout& filter,
  1047. const TensorLayout& diff,
  1048. const TensorLayout& grad,
  1049. size_t workspace_in_bytes);
  1050. };
  1051. class Convolution3DBackwardFilter
  1052. : public Convolution3DBase,
  1053. public detail::MultiAlgoOpr<Convolution3DBackwardFilter, 3> {
  1054. DEF_OPR_IMPL(Convolution3DBackwardFilter, Convolution3DBase, 2, 1);
  1055. public:
  1056. /**
  1057. * \param[in] src (n, ic, id, ih, iw)
  1058. * \param[in] diff (n, oc, od, oh, ow)
  1059. * \param[out] grad (oc, ic, fd, fh, fw)
  1060. */
  1061. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in diff,
  1062. _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
  1063. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  1064. const TensorLayout& diff,
  1065. const TensorLayout& grad) = 0;
  1066. protected:
  1067. CanonizedFilterMeta check_exec(const TensorLayout& src,
  1068. const TensorLayout& diff,
  1069. const TensorLayout& grad,
  1070. size_t workspace_in_bytes);
  1071. };
  1072. class LocalShareBase : public OperatorBase {
  1073. DEF_OPR_IMPL_CTOR(LocalShareBase, OperatorBase);
  1074. DEF_OPR_PARAM(LocalShare);
  1075. protected:
  1076. void deduce_layout_fwd(const TensorLayout& src, const TensorLayout& filter,
  1077. TensorLayout& dst);
  1078. void check_layout_fwd(const TensorLayout& src, const TensorLayout& filter,
  1079. const TensorLayout& dst);
  1080. };
  1081. class LocalShareForward : public LocalShareBase,
  1082. public detail::MultiAlgoOpr<LocalShareForward, 3> {
  1083. DEF_OPR_IMPL(LocalShareForward, LocalShareBase, 2, 1);
  1084. public:
  1085. /**
  1086. * \param[in] src (N, IC, IH, IW)
  1087. * \param[in] filter (G, spatial_groups_h, spatial_groups_w, IC / G,
  1088. * FH, FW, OC / G)
  1089. * \param[out] dst (N, OC, OH, OW)
  1090. */
  1091. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter,
  1092. _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0;
  1093. /**
  1094. * \brief deduce layout of the ouput tensor
  1095. */
  1096. void deduce_layout(const TensorLayout& src, const TensorLayout& filter,
  1097. TensorLayout& dst);
  1098. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  1099. const TensorLayout& filter,
  1100. const TensorLayout& dst) = 0;
  1101. protected:
  1102. void check_exec(const TensorLayout& src, const TensorLayout& filter,
  1103. const TensorLayout& dst, size_t workspace_in_bytes);
  1104. };
  1105. using LocalShare = LocalShareForward;
  1106. class LocalShareBackwardData
  1107. : public LocalShareBase,
  1108. public detail::MultiAlgoOpr<LocalShareBackwardData, 3> {
  1109. DEF_OPR_IMPL(LocalShareBackwardData, LocalShareBase, 2, 1);
  1110. public:
  1111. /**
  1112. * \param[in] filter (G, spatial_groups_h, spatial_groups_w, IC / G,
  1113. * FH, FW, OC / G)
  1114. * \param[in] diff (N, OC, OH, OW)
  1115. * \param[out] grad (N, IC, IH, IW)
  1116. */
  1117. virtual void exec(_megdnn_tensor_in filter, _megdnn_tensor_in diff,
  1118. _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
  1119. virtual size_t get_workspace_in_bytes(const TensorLayout& filter,
  1120. const TensorLayout& diff,
  1121. const TensorLayout& grad) = 0;
  1122. void deduce_layout(const TensorLayout& filter, const TensorLayout& diff,
  1123. TensorLayout& grad);
  1124. protected:
  1125. void check_exec(const TensorLayout& filter, const TensorLayout& diff,
  1126. const TensorLayout& grad, size_t workspace_in_bytes);
  1127. };
  1128. class LocalShareBackwardFilter
  1129. : public LocalShareBase,
  1130. public detail::MultiAlgoOpr<LocalShareBackwardFilter, 3> {
  1131. DEF_OPR_IMPL(LocalShareBackwardFilter, LocalShareBase, 2, 1);
  1132. public:
  1133. /**
  1134. * \param[in] src (N, IC, IH, IW)
  1135. * \param[in] diff (N, OC, OH, OW)
  1136. * \param[out] grad (G, spatial_groups_h, spatial_groups_w, IC / G,
  1137. * FH, FW, OC / G)
  1138. */
  1139. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in diff,
  1140. _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
  1141. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  1142. const TensorLayout& diff,
  1143. const TensorLayout& grad) = 0;
  1144. protected:
  1145. void check_exec(const TensorLayout& src, const TensorLayout& diff,
  1146. const TensorLayout& grad, size_t workspace_in_bytes);
  1147. };
  1148. class ROIAlignBase : public OperatorBase {
  1149. DEF_OPR_IMPL_CTOR(ROIAlignBase, OperatorBase);
  1150. DEF_OPR_PARAM(ROIAlign);
  1151. protected:
  1152. void deduce_layout_fwd(const TensorLayout& src, const TensorLayout& rois,
  1153. TensorLayout& dst, TensorLayout& index);
  1154. void check_layout_fwd(const TensorLayout& src, const TensorLayout& rois,
  1155. const TensorLayout& dst, const TensorLayout& index);
  1156. };
  1157. class ROIAlignForward : public ROIAlignBase {
  1158. DEF_OPR_IMPL(ROIAlignForward, ROIAlignBase, 2, 2);
  1159. public:
  1160. /**
  1161. * \param[in] src (n, c, ih, iw)
  1162. * \param[in] rois (m, 5)
  1163. * \param[out] dst (m, c, oh, ow)
  1164. * \param[out] index (m, c, oh, ow) if mode is MAX, (0) if mode is AVERAGE
  1165. *
  1166. * Note that rois(, 0) denotes the input image index. We store it as
  1167. * a float, but it should be an integer instead.
  1168. *
  1169. * index is a temporary tensor to facilitate its backward operator.
  1170. * It is used to store argmax indicex in MAX mode, and it is not used
  1171. * in AVERAGE mode.
  1172. */
  1173. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in rois,
  1174. _megdnn_tensor_out dst, _megdnn_tensor_out index,
  1175. _megdnn_workspace workspace) = 0;
  1176. void deduce_layout(const TensorLayout& src, const TensorLayout& rois,
  1177. TensorLayout& dst, TensorLayout& index);
  1178. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  1179. const TensorLayout& rois,
  1180. const TensorLayout& dst,
  1181. const TensorLayout& index) = 0;
  1182. protected:
  1183. void check_exec(const TensorLayout& src, const TensorLayout& rois,
  1184. const TensorLayout& dst, const TensorLayout& index,
  1185. size_t workspace_in_bytes);
  1186. };
  1187. using ROIAlign = ROIAlignForward;
  1188. class ROIAlignBackward : public ROIAlignBase {
  1189. DEF_OPR_IMPL(ROIAlignBackward, ROIAlignBase, 3, 1);
  1190. public:
  1191. /**
  1192. * \param[in] diff the backpropagated gradient wrt. dst
  1193. * \param[in] rois the `rois' parameter in ROIAlignForward::exec
  1194. * \param[in] index the `index' parameter in ROIAlignForward::exec
  1195. * \param[out] grad the backpropagated gradient wrt. src
  1196. */
  1197. virtual void exec(_megdnn_tensor_in diff, _megdnn_tensor_in rois,
  1198. _megdnn_tensor_in index, _megdnn_tensor_out grad,
  1199. _megdnn_workspace workspace) = 0;
  1200. virtual size_t get_workspace_in_bytes(const TensorLayout& diff,
  1201. const TensorLayout& rois,
  1202. const TensorLayout& index,
  1203. const TensorLayout& grad) = 0;
  1204. protected:
  1205. void check_exec(const TensorLayout& diff, const TensorLayout& rois,
  1206. const TensorLayout& index, const TensorLayout& grad,
  1207. size_t workspace_in_bytes);
  1208. };
  1209. class DeformableConvBase : public OperatorBase {
  1210. DEF_OPR_IMPL_CTOR(DeformableConvBase, OperatorBase);
  1211. DEF_OPR_PARAM(Convolution);
  1212. public:
  1213. static constexpr size_t MAX_SPATIAL_DIM = 2;
  1214. struct CanonizedFilterMeta : Convolution::CanonizedFilterMeta {
  1215. uint32_t deformable_group;
  1216. };
  1217. protected:
  1218. CanonizedFilterMeta make_canonized_filter_meta(
  1219. size_t src_ndim, const TensorLayout& filter,
  1220. const TensorLayout& offset) const;
  1221. void deduce_layout_fwd(const TensorLayout& im, const TensorLayout& filter,
  1222. const TensorLayout& mask, const TensorLayout& offset,
  1223. TensorLayout& dst);
  1224. void check_layout_fwd(const TensorLayout& src, const TensorLayout& filter,
  1225. const TensorLayout& mask, const TensorLayout& offset,
  1226. const TensorLayout& dst);
  1227. };
  1228. class DeformableConvForward
  1229. : public DeformableConvBase,
  1230. public detail::MultiAlgoOpr<DeformableConvForward, 5> {
  1231. DEF_OPR_IMPL(DeformableConvForward, DeformableConvBase, 4, 1);
  1232. public:
  1233. /**
  1234. * \param[in] im (n, ic, ih, iw)
  1235. * \param[in] filter (oc, ic, fh, fw)
  1236. * \param[in] offset (dg, 2, fh, fw, oh, ow)
  1237. * \param[in] mask (dg, fh, fw, oh, ow)
  1238. * \param[out] dst (n, oc, oh, ow)
  1239. */
  1240. virtual void exec(_megdnn_tensor_in im, _megdnn_tensor_in filter,
  1241. _megdnn_tensor_in offset, _megdnn_tensor_in mask,
  1242. _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0;
  1243. void deduce_layout(const TensorLayout& im, const TensorLayout& filter,
  1244. const TensorLayout& offset, const TensorLayout& mask,
  1245. TensorLayout& dst);
  1246. virtual size_t get_workspace_in_bytes(const TensorLayout& im,
  1247. const TensorLayout& filter,
  1248. const TensorLayout& offset,
  1249. const TensorLayout& mask,
  1250. const TensorLayout& dst) = 0;
  1251. protected:
  1252. CanonizedFilterMeta check_exec(const TensorLayout& im,
  1253. const TensorLayout& filter,
  1254. const TensorLayout& offset,
  1255. const TensorLayout& mask,
  1256. const TensorLayout& dst,
  1257. size_t workspace_in_bytes);
  1258. };
  1259. using DeformableConv = DeformableConvForward;
  1260. /**
  1261. * \brief DeformableConvBackwardFilter operator.
  1262. *
  1263. * Calculating the gradient wrt. convolution filter.
  1264. */
  1265. class DeformableConvBackwardFilter
  1266. : public DeformableConvBase,
  1267. public detail::MultiAlgoOpr<DeformableConvBackwardFilter, 5> {
  1268. DEF_OPR_IMPL(DeformableConvBackwardFilter, DeformableConvBase, 4, 1);
  1269. public:
  1270. /**
  1271. * \param[in] im (oc, ic, fh, fw)
  1272. * \param[in] offset (dg, 2, fh, fw, oh, ow)
  1273. * \param[in] mask (dg, fh, fw, oh, ow)
  1274. * \param[in] out_grad (n, oc, oh, ow)
  1275. * \param[out] filter_grad (oc, ic, ih, iw)
  1276. */
  1277. virtual void exec(_megdnn_tensor_in im, _megdnn_tensor_in offset,
  1278. _megdnn_tensor_in mask, _megdnn_tensor_in out_grad,
  1279. _megdnn_tensor_out filter_grad,
  1280. _megdnn_workspace workspace) = 0;
  1281. virtual size_t get_workspace_in_bytes(const TensorLayout& im,
  1282. const TensorLayout& offset,
  1283. const TensorLayout& mask,
  1284. const TensorLayout& out_grad,
  1285. const TensorLayout& filter_grad) = 0;
  1286. void deduce_layout(const TensorLayout& im, const TensorLayout& offset,
  1287. const TensorLayout& mask, const TensorLayout& out_grad,
  1288. TensorLayout& filter_grad);
  1289. protected:
  1290. CanonizedFilterMeta check_exec(const TensorLayout& im,
  1291. const TensorLayout& offset,
  1292. const TensorLayout& mask,
  1293. const TensorLayout& out_grad,
  1294. const TensorLayout& filter_grad,
  1295. size_t workspace_in_bytes);
  1296. };
  1297. /**
  1298. * \brief DeformableConvBackwardData operator.
  1299. *
  1300. * Calculating the gradient wrt. convolution input data, offset and mask.
  1301. */
  1302. class DeformableConvBackwardData
  1303. : public DeformableConvBase,
  1304. public detail::MultiAlgoOpr<DeformableConvBackwardData, 8> {
  1305. DEF_OPR_IMPL(DeformableConvBackwardData, DeformableConvBase, 5, 3);
  1306. public:
  1307. /**
  1308. * \param[in] im (oc, ic, fh, fw)
  1309. * \param[in] filter (oc, ic, fh, fw)
  1310. * \param[in] offset (dg, 2, fh, fw, oh, ow)
  1311. * \param[in] mask (dg, fh, fw, oh, ow)
  1312. * \param[in] out_grad (n, oc, oh, ow)
  1313. * \param[out] im_grad (n, ic, ih, iw)
  1314. * \param[out] offset_grad (dg, 2, fh, fw, oh, ow)
  1315. * \param[out] mask_grad (dg, fh, fw, oh, ow)
  1316. */
  1317. virtual void exec(_megdnn_tensor_in im, _megdnn_tensor_in filter,
  1318. _megdnn_tensor_in offset, _megdnn_tensor_in mask,
  1319. _megdnn_tensor_in out_grad, _megdnn_tensor_out im_grad,
  1320. _megdnn_tensor_out offset_grad,
  1321. _megdnn_tensor_out mask_grad,
  1322. _megdnn_workspace workspace) = 0;
  1323. virtual size_t get_workspace_in_bytes(
  1324. const TensorLayout& im, const TensorLayout& filter,
  1325. const TensorLayout& offset, const TensorLayout& mask,
  1326. const TensorLayout& out_grad, const TensorLayout& im_grad,
  1327. const TensorLayout& offset_grad, const TensorLayout& mask_grad) = 0;
  1328. void deduce_layout(const TensorLayout& im, const TensorLayout& filter,
  1329. const TensorLayout& offset, const TensorLayout& mask,
  1330. const TensorLayout& out_grad, TensorLayout& im_grad,
  1331. TensorLayout& offset_grad, TensorLayout& mask_grad);
  1332. protected:
  1333. CanonizedFilterMeta check_exec(
  1334. const TensorLayout& im, const TensorLayout& filter,
  1335. const TensorLayout& offset, const TensorLayout& mask,
  1336. const TensorLayout& out_grad, const TensorLayout& im_grad,
  1337. const TensorLayout& offset_grad, const TensorLayout& mask_grad,
  1338. size_t workspace_in_bytes);
  1339. };
  1340. class DeformablePSROIPoolingBase : public OperatorBase {
  1341. DEF_OPR_IMPL_CTOR(DeformablePSROIPoolingBase, OperatorBase);
  1342. DEF_OPR_PARAM(DeformablePSROIPooling);
  1343. protected:
  1344. void deduce_layout_fwd(const TensorLayout& data, const TensorLayout& trans,
  1345. const TensorLayout& rois, TensorLayout& out_data,
  1346. TensorLayout& out_count);
  1347. void check_layout_fwd(const TensorLayout& data, const TensorLayout& trans,
  1348. const TensorLayout& rois,
  1349. const TensorLayout& out_data,
  1350. const TensorLayout& out_count,
  1351. size_t workspace_in_bytes);
  1352. };
  1353. class DeformablePSROIPoolingForward : public DeformablePSROIPoolingBase {
  1354. DEF_OPR_IMPL(DeformablePSROIPoolingForward, DeformablePSROIPoolingBase, 3,
  1355. 2);
  1356. public:
  1357. /**
  1358. * \param[in] data (oc, ic, ih, iw)
  1359. * \param[in] rois (xx, xx, xx, xx)
  1360. * \param[in] trans (oc, ic, fh, fw)
  1361. * \param[out] out_data ( n, ic, ih, iw)
  1362. * \param[out] out_count ( n, ic, ih, iw)
  1363. */
  1364. virtual size_t get_workspace_in_bytes(const TensorLayout& data,
  1365. const TensorLayout& rois,
  1366. const TensorLayout& trans,
  1367. const TensorLayout& out_data,
  1368. const TensorLayout& out_count) = 0;
  1369. virtual void exec(_megdnn_tensor_in data, _megdnn_tensor_in rois,
  1370. _megdnn_tensor_in trans, _megdnn_tensor_out out_data,
  1371. _megdnn_tensor_out out_count,
  1372. _megdnn_workspace workspace) = 0;
  1373. void deduce_layout(const TensorLayout& data, const TensorLayout& rois,
  1374. const TensorLayout& trans, TensorLayout& out_data,
  1375. TensorLayout& out_count);
  1376. void check_exec(const TensorLayout& data, const TensorLayout& rois,
  1377. const TensorLayout& trans, const TensorLayout& out_data,
  1378. const TensorLayout& out_count, size_t workspace_in_bytes);
  1379. };
  1380. using DeformablePSROIPooling = DeformablePSROIPoolingForward;
  1381. class DeformablePSROIPoolingBackward : public DeformablePSROIPoolingBase {
  1382. DEF_OPR_IMPL(DeformablePSROIPoolingBackward, DeformablePSROIPoolingBase, 5,
  1383. 2);
  1384. public:
  1385. /**
  1386. * \param[in] data (oc, ic, ih, iw)
  1387. * \param[in] rois (xx, xx, xx, xx)
  1388. * \param[in] trans (oc, ic, fh, fw)
  1389. * \param[in] out_diff (xx, xx, xx, xx)
  1390. * \param[in] out_count (xx, xx, xx, xx)
  1391. * \param[out] data_diff ( n, ic, ih, iw)
  1392. * \param[out] trans_diff ( n, ic, ih, iw)
  1393. */
  1394. virtual void exec(_megdnn_tensor_in data, _megdnn_tensor_in rois,
  1395. _megdnn_tensor_in trans, _megdnn_tensor_in out_diff,
  1396. _megdnn_tensor_in out_count, _megdnn_tensor_out data_diff,
  1397. _megdnn_tensor_out trans_diff,
  1398. _megdnn_workspace workspace) = 0;
  1399. virtual size_t get_workspace_in_bytes(const TensorLayout& data,
  1400. const TensorLayout& rois,
  1401. const TensorLayout& trans,
  1402. const TensorLayout& out_diff,
  1403. const TensorLayout& out_count,
  1404. const TensorLayout& data_diff,
  1405. const TensorLayout& trans_diff) = 0;
  1406. void check_exec(const TensorLayout& data, const TensorLayout& rois,
  1407. const TensorLayout& trans, const TensorLayout& out_diff,
  1408. const TensorLayout& out_count,
  1409. const TensorLayout& data_diff,
  1410. const TensorLayout& trans_diff, size_t workspace_in_bytes);
  1411. };
  1412. class BatchConvBiasForward
  1413. : public ConvolutionBase<param::BatchConvBias>,
  1414. public detail::MultiAlgoOpr<BatchConvBiasForward, 5> {
  1415. DEF_OPR_IMPL(BatchConvBiasForward, ConvolutionBase, 4, 1);
  1416. public:
  1417. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter,
  1418. _megdnn_tensor_in bias, _megdnn_tensor_in z,
  1419. _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0;
  1420. void deduce_dtype(DType src, DType filter, DType bias, DType z, DType& dst);
  1421. void deduce_layout(const TensorLayout& src, const TensorLayout& filter,
  1422. const TensorLayout& bias, const TensorLayout& z,
  1423. TensorLayout& dst);
  1424. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  1425. const TensorLayout& filter,
  1426. const TensorLayout& bias,
  1427. const TensorLayout& z,
  1428. const TensorLayout& dst) = 0;
  1429. protected:
  1430. CanonizedFilterMeta check_exec(const TensorLayout& src,
  1431. const TensorLayout& filter,
  1432. const TensorLayout& bias,
  1433. const TensorLayout& z,
  1434. const TensorLayout& dst,
  1435. size_t workspace_in_bytes);
  1436. };
  1437. using BatchConvBias = BatchConvBiasForward;
  1438. class FakeQuantBase : public OperatorBase {
  1439. DEF_OPR_IMPL_CTOR(FakeQuantBase, OperatorBase);
  1440. DEF_OPR_PARAM(FakeQuant);
  1441. protected:
  1442. void deduce_layout_fwd(const TensorLayout& input, TensorLayout& output);
  1443. void check_layout_fwd(const TensorLayout& input, const TensorLayout& scale,
  1444. const TensorLayout& zero_point,
  1445. const TensorLayout& output);
  1446. };
  1447. class FakeQuantForward : public FakeQuantBase {
  1448. DEF_OPR_IMPL(FakeQuantForward, FakeQuantBase, 3, 1);
  1449. public:
  1450. virtual void exec(_megdnn_tensor_in input, _megdnn_tensor_in scale,
  1451. _megdnn_tensor_in zero_point, _megdnn_tensor_out output,
  1452. _megdnn_workspace workspace) = 0;
  1453. void deduce_layout(const TensorLayout& input, const TensorLayout& scale,
  1454. const TensorLayout& zero_point, TensorLayout& output);
  1455. virtual size_t get_workspace_in_bytes(const TensorLayout& input,
  1456. const TensorLayout& scale,
  1457. const TensorLayout& zero_point,
  1458. const TensorLayout& output) = 0;
  1459. protected:
  1460. void check_exec(const TensorLayout& input, const TensorLayout& scale,
  1461. const TensorLayout& zero_point, const TensorLayout& output,
  1462. size_t workspace_in_bytes);
  1463. };
  1464. using FakeQuant = FakeQuantForward;
  1465. class FakeQuantBackward : public FakeQuantBase {
  1466. DEF_OPR_IMPL(FakeQuantBackward, FakeQuantBase, 4, 1);
  1467. public:
  1468. virtual void exec(_megdnn_tensor_in diff, _megdnn_tensor_in input,
  1469. _megdnn_tensor_in scale, _megdnn_tensor_in zero_point,
  1470. _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
  1471. virtual size_t get_workspace_in_bytes(const TensorLayout& diff,
  1472. const TensorLayout& input,
  1473. const TensorLayout& scale,
  1474. const TensorLayout& zero_point,
  1475. const TensorLayout& grad) = 0;
  1476. protected:
  1477. void check_exec(const TensorLayout& diff, const TensorLayout& input,
  1478. const TensorLayout& scale, const TensorLayout& zero_point,
  1479. const TensorLayout& grad, size_t workspace_in_bytes);
  1480. };
  1481. } // namespace megdnn
  1482. #include "megdnn/internal/opr_header_epilogue.h"
  1483. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台