You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

nn.h 97 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416241724182419242024212422242324242425242624272428242924302431243224332434243524362437243824392440244124422443244424452446244724482449245024512452245324542455245624572458245924602461246224632464246524662467246824692470247124722473247424752476247724782479248024812482248324842485248624872488248924902491249224932494249524962497249824992500250125022503250425052506
  1. #pragma once
  2. #include "megdnn/internal/opr_header_prologue.h"
  3. namespace megdnn {
  4. class SeparableConvBase : public OperatorBase {
  5. DEF_OPR_IMPL_CTOR(SeparableConvBase, OperatorBase);
  6. DEF_OPR_PARAM(SeparableConv);
  7. public:
  8. using Mode = Param::Mode;
  9. protected:
  10. void deduce_layout_fwd(
  11. const TensorLayout& src, const TensorLayout& filter_x,
  12. const TensorLayout& filter_y, TensorLayout& dst);
  13. void check_layout_fwd(
  14. const TensorLayout& src, const TensorLayout& filter_x,
  15. const TensorLayout& filter_y, const TensorLayout& dst);
  16. };
  17. class SeparableConvForward : public SeparableConvBase {
  18. DEF_OPR_IMPL(SeparableConvForward, SeparableConvBase, 3, 1);
  19. public:
  20. virtual void exec(
  21. _megdnn_tensor_in src, _megdnn_tensor_in filter_x,
  22. _megdnn_tensor_in filter_y, _megdnn_tensor_out dst,
  23. _megdnn_workspace workspace) = 0;
  24. void deduce_layout(
  25. const TensorLayout& src, const TensorLayout& filter_x,
  26. const TensorLayout& filter_y, TensorLayout& dst);
  27. virtual size_t get_workspace_in_bytes(
  28. const TensorLayout& src, const TensorLayout& filter_x,
  29. const TensorLayout& filter_y, const TensorLayout& dst) = 0;
  30. protected:
  31. void check_exec(
  32. const TensorLayout& src, const TensorLayout& filter_x,
  33. const TensorLayout& filter_y, const TensorLayout& dst,
  34. size_t workspace_in_bytes);
  35. };
  36. using SeparableConv = SeparableConvForward;
  37. namespace detail {
  38. struct PreprocessedFilter {
  39. //! user data; its lifetime should be bound to MegDNN Convolution
  40. //! operator
  41. void* algorithm_id;
  42. TensorNDArray tensors;
  43. };
  44. } // namespace detail
  45. /**
  46. * \brief base class for convolution operation
  47. *
  48. * This operator is supposed to perform convolution on arbitrary input
  49. * dimensions. The input/output format is N, C, dims..., and kernel format can
  50. * take two forms:
  51. * 1. OC, IC, dims..., for conventional dense convolution
  52. * 2. GROUP, OC_PER_GRP, IC_PER_GRP, dims... for sparse group convolution
  53. *
  54. * Currently, only 2D images are supported.
  55. */
  56. template <typename Parameter>
  57. class ConvolutionBase : public OperatorBase {
  58. DEF_OPR_IMPL_CTOR(ConvolutionBase, OperatorBase);
  59. using Param = Parameter;
  60. public:
  61. Param& param() { return m_param; }
  62. const Param& param() const { return m_param; }
  63. protected:
  64. Param m_param;
  65. public:
  66. static constexpr size_t MAX_SPATIAL_DIM = 2;
  67. using Mode = typename Param::Mode;
  68. struct CanonizedFilterMeta {
  69. DType dtype;
  70. typename Param::Format format;
  71. uint32_t
  72. //! whether filter should be flipped (i.e. is CONVOLUTION)
  73. should_flip,
  74. group, //!< number of groups
  75. icpg, //!< input channels per group
  76. ocpg, //!< output channels per group
  77. spatial_ndim, stride[MAX_SPATIAL_DIM], padding[MAX_SPATIAL_DIM],
  78. //! spatial dim
  79. spatial[MAX_SPATIAL_DIM], dilation[MAX_SPATIAL_DIM],
  80. //! spatial dim with dilation applied
  81. dilated_spatial[MAX_SPATIAL_DIM];
  82. //! T should be a ConvolutionBase<Z>::CanonizedFilterMeta
  83. template <typename T>
  84. void copy_from(const T& b) {
  85. dtype = b.dtype;
  86. format = b.format;
  87. should_flip = b.should_flip;
  88. group = b.group;
  89. icpg = b.icpg;
  90. ocpg = b.ocpg;
  91. spatial_ndim = b.spatial_ndim;
  92. memcpy(stride, b.stride, sizeof(stride));
  93. memcpy(padding, b.padding, sizeof(padding));
  94. memcpy(spatial, b.spatial, sizeof(spatial));
  95. memcpy(dilation, b.dilation, sizeof(dilation));
  96. memcpy(dilated_spatial, b.dilated_spatial, sizeof(dilated_spatial));
  97. }
  98. bool operator==(const CanonizedFilterMeta& b) const {
  99. bool flag = true;
  100. flag = flag && (format == b.format);
  101. flag = flag && (dtype == b.dtype);
  102. flag = flag && (should_flip == b.should_flip);
  103. flag = flag && (group == b.group);
  104. flag = flag && (icpg == b.icpg);
  105. flag = flag && (ocpg == b.ocpg);
  106. flag = flag && (spatial_ndim == b.spatial_ndim);
  107. if (flag) {
  108. for (uint32_t i = 0; i < spatial_ndim; ++i) {
  109. flag = flag && (stride[i] == b.stride[i]);
  110. flag = flag && (padding[i] == b.padding[i]);
  111. flag = flag && (spatial[i] == b.spatial[i]);
  112. flag = flag && (dilation[i] == b.dilation[i]);
  113. flag = flag && (dilated_spatial[i] == b.dilated_spatial[i]);
  114. }
  115. }
  116. return flag;
  117. }
  118. };
  119. using PreprocessedFilter = detail::PreprocessedFilter;
  120. protected:
  121. // Check or deduce output DType
  122. void check_or_deduce_dtype_fwd(DType src, DType filter, DType& dst) const;
  123. CanonizedFilterMeta deduce_layout_fwd(
  124. const TensorLayout& src, const TensorLayout& filter,
  125. TensorLayout& dst) const;
  126. CanonizedFilterMeta check_layout_fwd(
  127. const TensorLayout& src, const TensorLayout& filter,
  128. const TensorLayout& dst) const;
  129. CanonizedFilterMeta make_canonized_filter_meta(
  130. size_t src_ndim, const TensorLayout& filter) const;
  131. };
  132. class MaskPropagate : public OperatorBase {
  133. DEF_OPR_IMPL(MaskPropagate, OperatorBase, 1, 1);
  134. DEF_OPR_PARAM(MaskPropagate);
  135. public:
  136. virtual void exec(
  137. _megdnn_tensor_in src, _megdnn_tensor_out dst,
  138. _megdnn_workspace workspace) = 0;
  139. virtual size_t get_workspace_in_bytes(
  140. const TensorLayout& src, const TensorLayout& dst) = 0;
  141. void deduce_layout(const TensorLayout& src, TensorLayout& dst);
  142. };
  143. /**
  144. * \brief ConvolutionForward Operator with 0/1 Mask matrix
  145. */
  146. class MaskConvForward : public ConvolutionBase<param::Convolution> {
  147. DEF_OPR_IMPL(MaskConvForward, ConvolutionBase, 3, 1);
  148. public:
  149. virtual void exec(
  150. _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_in mask,
  151. _megdnn_tensor_out dst, _megdnn_workspace worksapce) = 0;
  152. virtual size_t get_workspace_in_bytes(
  153. const TensorLayout& src, const TensorLayout& filter,
  154. const TensorLayout& mask, const TensorLayout& dst) = 0;
  155. void deduce_dtype(DType src, DType filter, DType mask, DType& dst);
  156. void deduce_layout(
  157. const TensorLayout& src, const TensorLayout& filter,
  158. const TensorLayout& mask, TensorLayout& dst);
  159. protected:
  160. CanonizedFilterMeta check_exec(
  161. const TensorLayout& src, const TensorLayout& filter,
  162. const TensorLayout& mask, const TensorLayout& dst,
  163. size_t workspace_in_bytes);
  164. };
  165. using MaskConvolution = MaskConvForward;
  166. /**
  167. * \brief ConvolutionForward operator.
  168. */
  169. class ConvolutionForward : public ConvolutionBase<param::Convolution>,
  170. public detail::MultiAlgoOpr<ConvolutionForward, 3> {
  171. DEF_OPR_IMPL(ConvolutionForward, ConvolutionBase, 2, 1);
  172. public:
  173. /**
  174. * \param[in] src (n, ic, ih, iw)
  175. * \param[in] filter (oc, ic, fh, fw)
  176. * \param[in] preprocessed_filter if weight no preprocessed it will be
  177. * nullptr, else the preprocessed weights store in the tensors of
  178. * preprocessed_filter.
  179. * \param[in] workspace if weight no preprocessed
  180. * (preprocessed_filter == nullptr), The size of the workspace satisfies the
  181. * situation that weights is not processed, other wise the size of workspace
  182. * satisfies the situation that weights is preprocessed
  183. * \param[out] dst (n, oc, oh, ow)
  184. */
  185. virtual void exec(
  186. _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_out dst,
  187. const PreprocessedFilter* preprocessed_filter,
  188. _megdnn_workspace workspace) = 0;
  189. MGE_WIN_DECLSPEC_FUC void exec(
  190. _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_out dst,
  191. _megdnn_workspace workspace) {
  192. exec(src, filter, dst, nullptr, workspace);
  193. }
  194. /**
  195. * \brief execute weight preprocessing, read weights form filter and write
  196. * to preprocessed_filter after preprocessed.
  197. *
  198. * \praram[in] workspace the needed tmp workspace when exec_preprocess
  199. */
  200. virtual void exec_preprocess(
  201. const TensorLayout& src_layout, _megdnn_tensor_in filter,
  202. const TensorLayout& dst_layout, PreprocessedFilter* preprocessed_filter,
  203. _megdnn_workspace workspace) = 0;
  204. MGE_WIN_DECLSPEC_FUC void deduce_dtype(DType src, DType filter, DType& dst);
  205. MGE_WIN_DECLSPEC_FUC void deduce_layout(
  206. const TensorLayout& src, const TensorLayout& filter, TensorLayout& dst);
  207. /**
  208. * \brief query the workspace needed when executing the opr, if the weights
  209. * are preprocessed the preprocessed_filter will not be nullptr, else it
  210. * will be nullptr, the workspace size maybe different whether weights are
  211. * preprocessed
  212. *
  213. * \return the size of workspace needed when executing
  214. */
  215. virtual size_t get_workspace_in_bytes(
  216. const TensorLayout& src, const TensorLayout& filter,
  217. const TensorLayout& dst, const PreprocessedFilter* preprocessed_filter) = 0;
  218. /**
  219. * \brief deduce the preprocessed filter layouts according to the src,
  220. * filter and dst layout, the result may contain multi layouts when the
  221. * weights is not one
  222. *
  223. * \return SmallVector<TensorLayout> Derive the layouts of weight
  224. * preprocessing, return empty if preprocessing is not needed.
  225. */
  226. virtual SmallVector<TensorLayout> deduce_preprocessed_filter_layout(
  227. const TensorLayout& src, const TensorLayout& filter,
  228. const TensorLayout& dst) = 0;
  229. /**
  230. * \brief query the workspace needed when preprocessing the weights,
  231. * according to the return size, a _megdnn_workspace will be created and
  232. * passed through exec_preprocess
  233. *
  234. * \return the size of workspace needed when preprocessing
  235. */
  236. virtual size_t get_preprocess_workspace_in_bytes(
  237. const TensorLayout& src, const TensorLayout& filter,
  238. const TensorLayout& dst) = 0;
  239. static Algorithm::OprType get_opr_type() {
  240. return Algorithm::OprType::CONVOLUTION_FORWARD;
  241. }
  242. protected:
  243. CanonizedFilterMeta check_exec(
  244. const TensorLayout& src, const TensorLayout& filter,
  245. const TensorLayout& dst, size_t workspace_in_bytes,
  246. const PreprocessedFilter* preprocessed_filter);
  247. };
  248. using Convolution = ConvolutionForward;
  249. /**
  250. * \brief ConvolutionBackwardData operator.
  251. *
  252. * Calculating the gradient wrt. convolution input data.
  253. */
  254. class ConvolutionBackwardData
  255. : public ConvolutionBase<param::Convolution>,
  256. public detail::MultiAlgoOpr<ConvolutionBackwardData, 3> {
  257. DEF_OPR_IMPL(ConvolutionBackwardData, ConvolutionBase, 2, 1);
  258. public:
  259. /**
  260. * \param[in] filter (oc, ic, fh, fw)
  261. * \param[in] diff (n, oc, oh, ow)
  262. * \param[out] grad (n, ic, ih, iw)
  263. */
  264. virtual void exec(
  265. _megdnn_tensor_in filter, _megdnn_tensor_in diff, _megdnn_tensor_out grad,
  266. _megdnn_workspace workspace) = 0;
  267. virtual size_t get_workspace_in_bytes(
  268. const TensorLayout& filter, const TensorLayout& diff,
  269. const TensorLayout& grad) = 0;
  270. MGE_WIN_DECLSPEC_FUC void deduce_dtype(DType filter, DType diff, DType& grad);
  271. MGE_WIN_DECLSPEC_FUC void deduce_layout(
  272. const TensorLayout& filter, const TensorLayout& diff, TensorLayout& grad);
  273. static Algorithm::OprType get_opr_type() {
  274. return Algorithm::OprType::CONVOLUTION_BACKWARD_DATA;
  275. }
  276. protected:
  277. CanonizedFilterMeta check_exec(
  278. const TensorLayout& filter, const TensorLayout& diff,
  279. const TensorLayout& grad, size_t workspace_in_bytes);
  280. };
  281. /**
  282. * \brief ConvolutionBackwardFilter operator.
  283. *
  284. * Calculating the gradient wrt. convolution filter.
  285. */
  286. class ConvolutionBackwardFilter
  287. : public ConvolutionBase<param::Convolution>,
  288. public detail::MultiAlgoOpr<ConvolutionBackwardFilter, 3> {
  289. DEF_OPR_IMPL(ConvolutionBackwardFilter, ConvolutionBase, 2, 1);
  290. public:
  291. /**
  292. * \param[in] src (n, ic, ih, iw)
  293. * \param[in] diff (n, oc, oh, ow)
  294. * \param[out] grad (oc, ic, fh, fw)
  295. */
  296. virtual void exec(
  297. _megdnn_tensor_in src, _megdnn_tensor_in diff, _megdnn_tensor_out grad,
  298. _megdnn_workspace workspace) = 0;
  299. virtual size_t get_workspace_in_bytes(
  300. const TensorLayout& src, const TensorLayout& diff,
  301. const TensorLayout& grad) = 0;
  302. static Algorithm::OprType get_opr_type() {
  303. return Algorithm::OprType::CONVOLUTION_BACKWARD_FILTER;
  304. }
  305. protected:
  306. CanonizedFilterMeta check_exec(
  307. const TensorLayout& src, const TensorLayout& diff, const TensorLayout& grad,
  308. size_t workspace_in_bytes);
  309. };
  310. /**
  311. * \brief ConvolutionBias operator
  312. */
  313. class ConvBiasForward : public ConvolutionBase<param::ConvBias>,
  314. public detail::MultiAlgoOpr<ConvBiasForward, 5> {
  315. DEF_OPR_IMPL(ConvBiasForward, ConvolutionBase, 4, 1);
  316. public:
  317. /**
  318. * \param[in] src (n, ic, ih, iw) or (n, ih, iw, ic)
  319. * \param[in] filter (oc, ic, fh, fw) or (oc, fh, fw, ic) or (oc/4, fh, fw,
  320. * 4 * ic)
  321. * \param[in] bias (1, oc, 1, 1)
  322. * \param[in] z same as dst
  323. * \param[in] preprocessed_filter if weight no preprocessed it will be
  324. * nullptr, else the preprocessed weights store in the tensors of
  325. * preprocessed_filter.
  326. * \param[in] workspace if weight no preprocessed
  327. * (preprocessed_filter == nullptr), The size of the workspace satisfies the
  328. * situation that weights is not processed, other wise the size of workspace
  329. * satisfies the situation that weights is preprocessed
  330. * \param[out] dst (n, oc, oh, ow) or (n, oh, ow, oc)
  331. *
  332. * \note if the format is NCHW_WINOGRAD, the filter layout is (alphah,
  333. * alphaw, oc, ic)
  334. */
  335. virtual void exec(
  336. _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_in bias,
  337. _megdnn_tensor_in z, _megdnn_tensor_out dst,
  338. const PreprocessedFilter* preprocessed_filter,
  339. _megdnn_workspace workspace) = 0;
  340. MGE_WIN_DECLSPEC_FUC void exec(
  341. _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_in bias,
  342. _megdnn_tensor_in z, _megdnn_tensor_out dst, _megdnn_workspace workspace) {
  343. exec(src, filter, bias, z, dst, nullptr, workspace);
  344. }
  345. /**
  346. * \brief execute weight preprocessing, read weights form filter and bias,
  347. * write to preprocessed_filter after preprocessed.
  348. *
  349. * \praram[in] workspace the needed tmp workspace when exec_preprocess
  350. * running, the size is got by get_preprocess_workspace_in_bytes
  351. */
  352. virtual void exec_preprocess(
  353. const TensorLayout& src_layout, _megdnn_tensor_in filter,
  354. _megdnn_tensor_in bias, const TensorLayout& z_layout,
  355. const TensorLayout& dst_layout, PreprocessedFilter* preprocessed_filter,
  356. _megdnn_workspace workspace) = 0;
  357. MGE_WIN_DECLSPEC_FUC void deduce_dtype(
  358. DType src, DType filter, DType bias, DType z, DType& dst);
  359. MGE_WIN_DECLSPEC_FUC void deduce_layout(
  360. const TensorLayout& src, const TensorLayout& filter,
  361. const TensorLayout& bias, const TensorLayout& z, TensorLayout& dst);
  362. /**
  363. * \brief query the workspace needed when executing the opr, if the weights
  364. * are preprocessed the preprocessed_filter will not be nullptr, else it
  365. * will be nullptr, the workspace size maybe different whether weights are
  366. * preprocessed
  367. *
  368. * \return the size of workspace needed when executing
  369. */
  370. virtual size_t get_workspace_in_bytes(
  371. const TensorLayout& src, const TensorLayout& filter,
  372. const TensorLayout& bias, const TensorLayout& z, const TensorLayout& dst,
  373. const PreprocessedFilter* preprocessed_filter) = 0;
  374. /**
  375. * \brief query the workspace needed when pre-processing the weights,
  376. * according to the return size, a _megdnn_workspace will be created and
  377. * passed through exec_preprocess
  378. *
  379. * \return the size of workspace needed when pre-processing
  380. */
  381. virtual size_t get_preprocess_workspace_in_bytes(
  382. const TensorLayout& src, const TensorLayout& filter,
  383. const TensorLayout& bias, const TensorLayout& z,
  384. const TensorLayout& dst) = 0;
  385. /**
  386. * \brief deduce the pre-processed filter layouts according to the src,
  387. * filter and dst layout, which may contain multi layouts when the weights
  388. * is not one
  389. *
  390. * \return SmallVector<TensorLayout> Derive the layouts of weight
  391. * preprocessing, return empty if preprocessing is not needed.
  392. */
  393. virtual SmallVector<TensorLayout> deduce_preprocessed_filter_layout(
  394. const TensorLayout& src, const TensorLayout& filter,
  395. const TensorLayout& bias, const TensorLayout& z,
  396. const TensorLayout& dst) = 0;
  397. enum class BiasMode : uint32_t {
  398. NO_BIAS = 0, //!< no bias
  399. BROADCAST_CHANNEL_BIAS, //!< broadcast channel bias, [1, c, 1, 1]
  400. BIAS //!< [N, C, H, W]
  401. };
  402. //! param for winograd algos.
  403. struct WinogradParam {
  404. uint32_t channel_block_size;
  405. uint32_t output_block_size;
  406. uint32_t tile_size;
  407. uint32_t filter_size;
  408. bool operator==(const WinogradParam& rhs) const {
  409. return channel_block_size == rhs.channel_block_size &&
  410. output_block_size == rhs.output_block_size &&
  411. tile_size == rhs.tile_size && filter_size == rhs.filter_size;
  412. }
  413. std::string to_string() const;
  414. };
  415. static constexpr WinogradParam INVALID_WINOGRAD_PARAM = {0, 0, 0, 0};
  416. struct DirectParam {
  417. std::string to_string() const { return ""; }
  418. };
  419. struct MatmulParam {
  420. std::string to_string() const { return ""; }
  421. };
  422. struct DefaultParam {
  423. std::string to_string() const { return ""; }
  424. };
  425. //! get algo name, the format is ParamTrait<T>::category:base:p.to_string()
  426. //! \warning: base must not contain :.
  427. template <typename T>
  428. static std::string algo_name(
  429. const std::string& base, const T& p,
  430. param::ConvBias::Format format = param::ConvBias::Format::NCHW);
  431. /*!
  432. * \brief parse algo_name and get WinogradParam from algo name.
  433. *
  434. * \param algo name string
  435. * \return WinogradParam parsed from algo name, use pattern
  436. * winograd:base:m:tile_size.
  437. *
  438. * \warning: INVALID_WINOGRAD_PARAM returns if the algo_name is not matched.
  439. */
  440. static WinogradParam parse_winograd_name(const std::string& algo_name);
  441. /**
  442. * @brief find if there is nchw_nchwxx conv kernel optimized for argment,
  443. * nchw44 used for arm, nchw88 used for x86
  444. *
  445. * @param src_dtype conv feature map data type
  446. * @param filter_dtype conv filter or weight data type
  447. * @param dst_dtype output data type
  448. * @param fm filter meta param
  449. * @param bias_mode bias mode, no_bias or broadcast or bias
  450. * @param nonline_mode identity or relu or h_swish or sigmoid
  451. * @return true, found a kernel
  452. * @return false, can`t found any kernel
  453. */
  454. static bool is_nchw_nchwxx_optimized(
  455. const DTypeEnum src_dtype, const DTypeEnum filter_dtype,
  456. const DTypeEnum dst_dtype,
  457. const ConvolutionBase<param::Convolution>::CanonizedFilterMeta& fm,
  458. const ConvBiasForward::BiasMode bias_mode,
  459. const param::ConvBias::NonlineMode nonline_mode);
  460. static Algorithm::OprType get_opr_type() {
  461. return Algorithm::OprType::CONVBIAS_FORWARD;
  462. }
  463. protected:
  464. CanonizedFilterMeta check_exec(
  465. const TensorLayout& src, const TensorLayout& filter,
  466. const TensorLayout& bias, const TensorLayout& z, const TensorLayout& dst,
  467. size_t workspace_in_bytes, const PreprocessedFilter* preprocessed_filter);
  468. CanonizedFilterMeta check_exec_allow_noncontiguous(
  469. const TensorLayout& src, const TensorLayout& filter,
  470. const TensorLayout& bias, const TensorLayout& z, const TensorLayout& dst,
  471. size_t workspace_in_bytes, const PreprocessedFilter* preprocessed_filter);
  472. };
  473. using ConvBias = ConvBiasForward;
  474. /**
  475. * \brief RegionRestrictedConvolutionForward operator.
  476. */
  477. class RegionRestrictedConvolutionForward : public ConvolutionBase<param::Convolution> {
  478. DEF_OPR_IMPL(RegionRestrictedConvolutionForward, ConvolutionBase, 4, 1);
  479. public:
  480. /**
  481. * \param[in] src (n, ic, ih, iw) or (n, g*icpg, ih, iw)
  482. * \param[in] filter (oc, ic, fh, fw) or (g, ocpg, icpg, fh, fw)
  483. * \param[in] rin (n, ih, iw)
  484. * \param[in] rout (n, oh, ow)
  485. * \param[out] dst (n, oc, oh, ow) or (n, g*ocpg, oh, ow)
  486. */
  487. virtual void exec(
  488. _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_in rin,
  489. _megdnn_tensor_in rout, _megdnn_tensor_out dst,
  490. _megdnn_workspace workspace) = 0;
  491. void deduce_dtype(DType src, DType filter, DType rin, DType rout, DType& dst);
  492. MGE_WIN_DECLSPEC_FUC void deduce_layout(
  493. const TensorLayout& src, const TensorLayout& filter,
  494. const TensorLayout& rin, const TensorLayout& rout, TensorLayout& dst);
  495. /**
  496. * \brief query the workspace needed when executing the opr
  497. * \return the size of workspace needed when executing
  498. */
  499. virtual size_t get_workspace_in_bytes(
  500. const TensorLayout& src, const TensorLayout& filter,
  501. const TensorLayout& rin, const TensorLayout& rout,
  502. const TensorLayout& dst) = 0;
  503. static Algorithm::OprType get_opr_type() {
  504. return Algorithm::OprType::REGIONRESTRICTEDCONVOLUTION_FORWARD;
  505. }
  506. protected:
  507. CanonizedFilterMeta check_exec(
  508. const TensorLayout& src, const TensorLayout& filter,
  509. const TensorLayout& rin, const TensorLayout& rout, const TensorLayout& dst,
  510. size_t workspace_in_bytes);
  511. };
  512. using RegionRestrictedConvolution = RegionRestrictedConvolutionForward;
  513. /**
  514. * \brief RegionRestrictedConvolutionBackwardData operator.
  515. *
  516. * Calculating the gradient wrt. convolution input data.
  517. */
  518. class RegionRestrictedConvolutionBackwardData
  519. : public ConvolutionBase<param::Convolution> {
  520. DEF_OPR_IMPL(RegionRestrictedConvolutionBackwardData, ConvolutionBase, 4, 1);
  521. public:
  522. /**
  523. * \param[in] filter (oc, ic, fh, fw) or (g, ocpg, icpg, fh, fw)
  524. * \param[in] diff (n, oc, oh, ow) or (n, g*ocpg, oh, ow)
  525. * \param[in] rin (n, ih, iw)
  526. * \param[in] rout (n, oh, ow)
  527. * \param[out] grad (n, ic, ih, iw) or (n, g*icpg, ih, iw)
  528. */
  529. virtual void exec(
  530. _megdnn_tensor_in filter, _megdnn_tensor_in diff, _megdnn_tensor_in rin,
  531. _megdnn_tensor_in rout, _megdnn_tensor_out grad,
  532. _megdnn_workspace workspace) = 0;
  533. virtual size_t get_workspace_in_bytes(
  534. const TensorLayout& filter, const TensorLayout& diff,
  535. const TensorLayout& rin, const TensorLayout& rout,
  536. const TensorLayout& grad) = 0;
  537. MGE_WIN_DECLSPEC_FUC void deduce_dtype(
  538. DType filter, DType diff, DType rin, DType rout, DType& grad);
  539. MGE_WIN_DECLSPEC_FUC void deduce_layout(
  540. const TensorLayout& filter, const TensorLayout& diff,
  541. const TensorLayout& rin, const TensorLayout& rout, TensorLayout& grad);
  542. static Algorithm::OprType get_opr_type() {
  543. return Algorithm::OprType::REGIONRESTRICTEDCONVOLUTION_BACKWARD_DATA;
  544. }
  545. protected:
  546. CanonizedFilterMeta check_exec(
  547. const TensorLayout& filter, const TensorLayout& diff,
  548. const TensorLayout& rin, const TensorLayout& rout, const TensorLayout& grad,
  549. size_t workspace_in_bytes);
  550. };
  551. /**
  552. * \brief RegionRestrictedConvolutionBackwardFilter operator.
  553. *
  554. * Calculating the gradient wrt. convolution filter.
  555. */
  556. class RegionRestrictedConvolutionBackwardFilter
  557. : public ConvolutionBase<param::Convolution> {
  558. DEF_OPR_IMPL(RegionRestrictedConvolutionBackwardFilter, ConvolutionBase, 4, 1);
  559. public:
  560. /**
  561. * \param[in] src (n, ic, ih, iw) or (n, g*icpg, ih, iw)
  562. * \param[in] diff (n, oc, oh, ow) or (n, g*ocpg, oh, ow)
  563. * \param[in] rin (n, ih, iw)
  564. * \param[in] rout (n, oh, ow)
  565. * \param[out] grad (oc, ic, fh, fw) or (g, ocpg, icpg, fh, fw)
  566. */
  567. virtual void exec(
  568. _megdnn_tensor_in src, _megdnn_tensor_in diff, _megdnn_tensor_in rin,
  569. _megdnn_tensor_in rout, _megdnn_tensor_out grad,
  570. _megdnn_workspace workspace) = 0;
  571. virtual size_t get_workspace_in_bytes(
  572. const TensorLayout& src, const TensorLayout& diff, const TensorLayout& rin,
  573. const TensorLayout& rout, const TensorLayout& grad) = 0;
  574. static Algorithm::OprType get_opr_type() {
  575. return Algorithm::OprType::REGIONRESTRICTEDCONVOLUTION_BACKWARD_FILTER;
  576. }
  577. protected:
  578. CanonizedFilterMeta check_exec(
  579. const TensorLayout& src, const TensorLayout& diff, const TensorLayout& rin,
  580. const TensorLayout& rout, const TensorLayout& grad,
  581. size_t workspace_in_bytes);
  582. };
  583. /**
  584. * \brief base class for Conv - Nonline - Pooling
  585. */
  586. class ConvPoolingBase : public OperatorBase {
  587. DEF_OPR_IMPL_CTOR(ConvPoolingBase, OperatorBase);
  588. /**
  589. * \ Param::Method: Two methods to fetch the input data.
  590. * Default methods is WITH_TEXTURE_OBJ.
  591. * If you want to use WITH_SHARED_MEM mode,
  592. * please make sure that the size of
  593. * [ all of the fliter kernels + a channel
  594. * of input data + a channel of output data]
  595. * should be no large than 38KB.
  596. * And the pooling mode should not be "MAX".
  597. */
  598. DEF_OPR_PARAM(ConvPooling);
  599. protected:
  600. virtual void deduce_layout(
  601. const TensorLayout& src, const TensorLayout& filter,
  602. const TensorLayout& bias, TensorLayout& dst) = 0;
  603. virtual void check_layout(
  604. const TensorLayout& src, const TensorLayout& filter,
  605. const TensorLayout& bias, TensorLayout& dst,
  606. size_t workspace_limit_in_bytes) = 0;
  607. };
  608. class ConvPoolingForward : public ConvPoolingBase {
  609. DEF_OPR_IMPL(ConvPoolingForward, ConvPoolingBase, 2, 1);
  610. public:
  611. /**
  612. * \param[in] src input tensor
  613. * \param[out] dst output tensor
  614. */
  615. virtual void exec(
  616. const _megdnn_in TensorND src, const _megdnn_in TensorND filter,
  617. const _megdnn_in TensorND bias, _megdnn_out TensorND dst,
  618. _megdnn_out Workspace workspace) = 0;
  619. virtual void deduce_layout(
  620. const TensorLayout& src, const TensorLayout& filter,
  621. const TensorLayout& bias, TensorLayout& dst) = 0;
  622. virtual size_t get_workspace_in_bytes(
  623. const TensorLayout& src, const TensorLayout& filter,
  624. const TensorLayout& bias, const TensorLayout& dst) = 0;
  625. protected:
  626. virtual void check_layout(
  627. const TensorLayout& src, const TensorLayout& filter,
  628. const TensorLayout& bias, TensorLayout& dst,
  629. size_t workspace_limit_in_bytes) = 0;
  630. };
  631. using ConvPooling = ConvPoolingForward;
  632. class GroupLocalBase : public OperatorBase {
  633. DEF_OPR_IMPL_CTOR(GroupLocalBase, OperatorBase);
  634. DEF_OPR_PARAM(Convolution);
  635. public:
  636. using Mode = Param::Mode;
  637. protected:
  638. void deduce_layout_fwd(
  639. const TensorLayout& src, const TensorLayout& filter, TensorLayout& dst);
  640. void check_layout_fwd(
  641. const TensorLayout& src, const TensorLayout& filter,
  642. const TensorLayout& dst);
  643. };
  644. class GroupLocalForward : public GroupLocalBase {
  645. DEF_OPR_IMPL(GroupLocalForward, GroupLocalBase, 2, 1);
  646. public:
  647. /**
  648. * \param[in] src (N, IC, IH, IW)
  649. * \param[in] filter (G, OH, OW, IC/G, FH, FW, OC/G)
  650. * \param[out] dst (N, OC, OH, OW)
  651. **/
  652. virtual void exec(
  653. _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_out dst,
  654. _megdnn_workspace workspace) = 0;
  655. void deduce_layout(
  656. const TensorLayout& src, const TensorLayout& filter, TensorLayout& dst) {
  657. deduce_layout_fwd(src, filter, dst);
  658. }
  659. virtual size_t get_workspace_in_bytes(
  660. const TensorLayout& src, const TensorLayout& filter,
  661. const TensorLayout& dst) = 0;
  662. protected:
  663. void check_exec(
  664. const TensorLayout& src, const TensorLayout& filter,
  665. const TensorLayout& dst, size_t workspace_in_bytes);
  666. };
  667. using GroupLocal = GroupLocalForward;
  668. class GroupLocalBackwardData : public GroupLocalBase {
  669. DEF_OPR_IMPL(GroupLocalBackwardData, GroupLocalBase, 2, 1);
  670. public:
  671. virtual void exec(
  672. _megdnn_tensor_in filter, _megdnn_tensor_in diff, _megdnn_tensor_out grad,
  673. _megdnn_workspace workspace) = 0;
  674. virtual size_t get_workspace_in_bytes(
  675. const TensorLayout& filter, const TensorLayout& diff,
  676. const TensorLayout& grad) = 0;
  677. protected:
  678. void check_exec(
  679. const TensorLayout& filter, const TensorLayout& diff,
  680. const TensorLayout& grad, size_t workspace_in_bytes);
  681. };
  682. class GroupLocalBackwardFilter : public GroupLocalBase {
  683. DEF_OPR_IMPL(GroupLocalBackwardFilter, GroupLocalBase, 2, 1);
  684. public:
  685. virtual void exec(
  686. _megdnn_tensor_in src, _megdnn_tensor_in diff, _megdnn_tensor_out grad,
  687. _megdnn_workspace workspace) = 0;
  688. virtual size_t get_workspace_in_bytes(
  689. const TensorLayout& src, const TensorLayout& diff,
  690. const TensorLayout& grad) = 0;
  691. protected:
  692. void check_exec(
  693. const TensorLayout& filter, const TensorLayout& diff,
  694. const TensorLayout& grad, size_t workspace_in_bytes);
  695. };
  696. class Images2NeibsBase : public OperatorBase {
  697. DEF_OPR_IMPL_CTOR(Images2NeibsBase, OperatorBase);
  698. DEF_OPR_PARAM(Images2Neibs);
  699. protected:
  700. void deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst);
  701. void check_layout_fwd(const TensorLayout& filter, const TensorLayout& dst);
  702. };
  703. class Images2NeibsForward : public Images2NeibsBase {
  704. DEF_OPR_IMPL(Images2NeibsForward, Images2NeibsBase, 1, 1);
  705. public:
  706. /**
  707. * \param[in] src (N, C, IH, IW)
  708. * \param[out] dst (N, C, OH, OW, window_h, window_w)
  709. *
  710. * \see
  711. * http://deeplearning.net/software/theano/library/tensor/nnet/neighbours.html
  712. *
  713. * \f$ dst_{n, c, oh, ow, wh, ww} = src_{n, c, ih+wh, iw+fw}\f$,
  714. * where \f$ ih=-pad_h+oh*stride_h+(wh-1)*(dilation_h-1),
  715. * iw=-pad_w+ow*stride_w+(ww-1)*(dilation_w-1)\f$.
  716. */
  717. virtual void exec(
  718. _megdnn_tensor_in src, _megdnn_tensor_out dst,
  719. _megdnn_workspace workspace) = 0;
  720. virtual size_t get_workspace_in_bytes(
  721. const TensorLayout& src, const TensorLayout& dst) = 0;
  722. void deduce_layout(const TensorLayout& src, TensorLayout& dst);
  723. protected:
  724. void check_exec(
  725. const TensorLayout& src, const TensorLayout& dst,
  726. size_t workspace_in_bytes);
  727. };
  728. using Images2Neibs = Images2NeibsForward;
  729. class Images2NeibsBackward : public Images2NeibsBase {
  730. DEF_OPR_IMPL(Images2NeibsBackward, Images2NeibsBase, 1, 1);
  731. public:
  732. /**
  733. * \param[in] diff the backpropagated gradient wrt. dst
  734. * \param[out] grad the backpropagated gradient wrt. src
  735. */
  736. virtual void exec(
  737. _megdnn_tensor_in diff, _megdnn_tensor_out grad,
  738. _megdnn_workspace workspace) = 0;
  739. virtual size_t get_workspace_in_bytes(
  740. const TensorLayout& diff, const TensorLayout& grad) = 0;
  741. protected:
  742. void check_exec(
  743. const TensorLayout& diff, const TensorLayout& grad,
  744. size_t workspace_in_bytes);
  745. };
  746. class SlidingWindowTransposeBase : public OperatorBase {
  747. DEF_OPR_IMPL_CTOR(SlidingWindowTransposeBase, OperatorBase);
  748. DEF_OPR_PARAM(SlidingWindowTranspose);
  749. protected:
  750. void deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst);
  751. void check_layout_fwd(const TensorLayout& filter, const TensorLayout& dst);
  752. };
  753. class SlidingWindowTransposeForward : public SlidingWindowTransposeBase {
  754. DEF_OPR_IMPL(SlidingWindowTransposeForward, SlidingWindowTransposeBase, 1, 1);
  755. public:
  756. /**
  757. * \param[in] src (N, C, IH, IW, window_h, window_w)
  758. * \param[out] dst (N, C, OH, OW)
  759. */
  760. virtual void exec(
  761. _megdnn_tensor_in src, _megdnn_tensor_out dst,
  762. _megdnn_workspace workspace) = 0;
  763. virtual size_t get_workspace_in_bytes(
  764. const TensorLayout& src, const TensorLayout& dst) = 0;
  765. void deduce_layout(const TensorLayout& src, TensorLayout& dst);
  766. protected:
  767. void check_exec(
  768. const TensorLayout& src, const TensorLayout& dst,
  769. size_t workspace_in_bytes);
  770. };
  771. using SlidingWindowTranspose = SlidingWindowTransposeForward;
  772. class SlidingWindowTransposeBackward : public SlidingWindowTransposeBase {
  773. DEF_OPR_IMPL(SlidingWindowTransposeBackward, SlidingWindowTransposeBase, 1, 1);
  774. public:
  775. /**
  776. * \param[in] diff the backpropagated gradient wrt. dst
  777. * \param[out] grad the backpropagated gradient wrt. src
  778. */
  779. virtual void exec(
  780. _megdnn_tensor_in diff, _megdnn_tensor_out grad,
  781. _megdnn_workspace workspace) = 0;
  782. virtual size_t get_workspace_in_bytes(
  783. const TensorLayout& diff, const TensorLayout& grad) = 0;
  784. protected:
  785. void check_exec(
  786. const TensorLayout& diff, const TensorLayout& grad,
  787. size_t workspace_in_bytes);
  788. };
  789. /**
  790. * \brief base class for Pooling
  791. */
  792. class PoolingBase : public OperatorBase {
  793. DEF_OPR_IMPL_CTOR(PoolingBase, OperatorBase);
  794. DEF_OPR_PARAM(Pooling);
  795. public:
  796. using Mode = Param::Mode;
  797. protected:
  798. void deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst);
  799. void check_layout_fwd(const TensorLayout& src, const TensorLayout& dst);
  800. public:
  801. static void deduce_layout_impl(
  802. const TensorLayout& src, const Param& param, TensorLayout& dst);
  803. };
  804. class PoolingForward : public PoolingBase,
  805. public detail::MultiAlgoOpr<PoolingForward, 2> {
  806. DEF_OPR_IMPL(PoolingForward, PoolingBase, 1, 1);
  807. public:
  808. /**
  809. * \param[in] src input tensor
  810. * \param[out] dst output tensor
  811. */
  812. virtual void exec(
  813. _megdnn_tensor_in src, _megdnn_tensor_out dst,
  814. _megdnn_workspace workspace) = 0;
  815. MGE_WIN_DECLSPEC_FUC void deduce_layout(const TensorLayout& src, TensorLayout& dst);
  816. virtual size_t get_workspace_in_bytes(
  817. const TensorLayout& src, const TensorLayout& dst) = 0;
  818. static Algorithm::OprType get_opr_type() {
  819. return Algorithm::OprType::POOLING_FORWARD;
  820. }
  821. protected:
  822. void check_exec(
  823. const TensorLayout& src, const TensorLayout& dst,
  824. size_t workspace_in_bytes);
  825. };
  826. using Pooling = PoolingForward;
  827. class PoolingBackward : public PoolingBase,
  828. public detail::MultiAlgoOpr<PoolingBackward, 4> {
  829. DEF_OPR_IMPL(PoolingBackward, PoolingBase, 3, 1);
  830. public:
  831. /**
  832. * \param[in] src the `src' parameter in PoolingForward::exec
  833. * \param[in] dst the `dst' parameter in PoolingForward::exec
  834. * \param[in] diff the backpropagated gradient wrt. dst
  835. * \param[out] grad the backpropagated gradient wrt. src
  836. */
  837. virtual void exec(
  838. _megdnn_tensor_in src, _megdnn_tensor_in dst, _megdnn_tensor_in diff,
  839. _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
  840. virtual size_t get_workspace_in_bytes(
  841. const TensorLayout& src, const TensorLayout& dst, const TensorLayout& diff,
  842. const TensorLayout& grad) = 0;
  843. static Algorithm::OprType get_opr_type() {
  844. return Algorithm::OprType::POOLING_BACKWARD;
  845. }
  846. protected:
  847. void check_exec(
  848. const TensorLayout& src, const TensorLayout& dst, const TensorLayout& diff,
  849. const TensorLayout& grad, size_t workspace_in_bytes);
  850. };
  851. /**
  852. * \brief base class for AdaptivePooling
  853. */
  854. class AdaptivePoolingBase : public OperatorBase {
  855. DEF_OPR_IMPL_CTOR(AdaptivePoolingBase, OperatorBase);
  856. DEF_OPR_PARAM(AdaptivePooling);
  857. protected:
  858. param::Pooling deduce_pooling_param(
  859. const TensorLayout& src, const TensorLayout& dst);
  860. };
  861. class AdaptivePoolingForward : public AdaptivePoolingBase {
  862. DEF_OPR_IMPL(AdaptivePoolingForward, AdaptivePoolingBase, 1, 1);
  863. public:
  864. /**
  865. * \param[in] src input tensor
  866. * \param[out] dst output tensor
  867. */
  868. virtual void exec(
  869. _megdnn_tensor_in src, _megdnn_tensor_out dst,
  870. _megdnn_workspace workspace) = 0;
  871. virtual size_t get_workspace_in_bytes(
  872. const TensorLayout& src, const TensorLayout& dst) = 0;
  873. };
  874. using AdaptivePooling = AdaptivePoolingForward;
  875. class AdaptivePoolingBackward : public AdaptivePoolingBase {
  876. DEF_OPR_IMPL(AdaptivePoolingBackward, AdaptivePoolingBase, 3, 1);
  877. public:
  878. /**
  879. * \param[in] src the `src' parameter in AdaptivePoolingForward::exec
  880. * \param[in] dst the `dst' parameter in AdaptivePoolingForward::exec
  881. * \param[in] diff the backpropagated gradient wrt. dst
  882. * \param[out] grad the backpropagated gradient wrt. src
  883. */
  884. virtual void exec(
  885. _megdnn_tensor_in src, _megdnn_tensor_in dst, _megdnn_tensor_in diff,
  886. _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
  887. virtual size_t get_workspace_in_bytes(
  888. const TensorLayout& src, const TensorLayout& dst, const TensorLayout& diff,
  889. const TensorLayout& grad) = 0;
  890. };
  891. /**
  892. * \brief base class for Local
  893. */
  894. class LocalBase : public OperatorBase {
  895. DEF_OPR_IMPL_CTOR(LocalBase, OperatorBase);
  896. DEF_OPR_PARAM(Convolution);
  897. public:
  898. using Mode = Param::Mode;
  899. protected:
  900. void deduce_layout_fwd(
  901. const TensorLayout& src, const TensorLayout& filter, TensorLayout& dst);
  902. void check_layout_fwd(
  903. const TensorLayout& src, const TensorLayout& filter,
  904. const TensorLayout& dst);
  905. };
  906. class LocalForward : public LocalBase {
  907. DEF_OPR_IMPL(LocalForward, LocalBase, 2, 1);
  908. public:
  909. /**
  910. * \param[in] src (n, ic, ih, iw)
  911. * \param[in] filter (oh, ow, ic, fh, fw, oc)
  912. * \param[out] dst (n, oc, oh, ow)
  913. */
  914. virtual void exec(
  915. _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_out dst,
  916. _megdnn_workspace workspace) = 0;
  917. /**
  918. * \brief Deducing output tensor layouts from input tensor layouts.
  919. *
  920. * Be aware that the first and second dimension of `filter' are ignored
  921. * when deducing `dst' layout.
  922. */
  923. void deduce_layout(
  924. const TensorLayout& src, const TensorLayout& filter, TensorLayout& dst);
  925. virtual size_t get_workspace_in_bytes(
  926. const TensorLayout& src, const TensorLayout& filter,
  927. const TensorLayout& dst) = 0;
  928. protected:
  929. void check_exec(
  930. const TensorLayout& src, const TensorLayout& filter,
  931. const TensorLayout& dst, size_t workspace_in_bytes);
  932. };
  933. using Local = LocalForward;
  934. class LocalBackwardData : public LocalBase {
  935. DEF_OPR_IMPL(LocalBackwardData, LocalBase, 2, 1);
  936. public:
  937. /**
  938. * \param[in] filter (oh, ow, ic, fh, fw, oc)
  939. * \param[in] diff (n, oc, oh, ow)
  940. * \param[out] grad (n, ic, ih, iw)
  941. */
  942. virtual void exec(
  943. _megdnn_tensor_in filter, _megdnn_tensor_in diff, _megdnn_tensor_out grad,
  944. _megdnn_workspace workspace) = 0;
  945. virtual size_t get_workspace_in_bytes(
  946. const TensorLayout& filter, const TensorLayout& diff,
  947. const TensorLayout& grad) = 0;
  948. protected:
  949. void check_exec(
  950. const TensorLayout& filter, const TensorLayout& diff,
  951. const TensorLayout& grad, size_t workspace_in_bytes);
  952. };
  953. class LocalBackwardFilter : public LocalBase {
  954. DEF_OPR_IMPL(LocalBackwardFilter, LocalBase, 2, 1);
  955. public:
  956. /**
  957. * \param[in] src (n, ic, ih, iw)
  958. * \param[in] diff (n, oc, oh, ow)
  959. * \param[out] grad (oh, ow, ic, fh, fw, oc)
  960. */
  961. virtual void exec(
  962. _megdnn_tensor_in src, _megdnn_tensor_in diff, _megdnn_tensor_out grad,
  963. _megdnn_workspace workspace) = 0;
  964. virtual size_t get_workspace_in_bytes(
  965. const TensorLayout& src, const TensorLayout& diff,
  966. const TensorLayout& grad) = 0;
  967. protected:
  968. void check_exec(
  969. const TensorLayout& src, const TensorLayout& diff, const TensorLayout& grad,
  970. size_t workspace_in_bytes);
  971. };
  972. class BNBase : public OperatorBase {
  973. DEF_OPR_IMPL_CTOR(BNBase, OperatorBase);
  974. DEF_OPR_PARAM(BN);
  975. protected:
  976. void check_param();
  977. };
  978. class BNForward : public BNBase {
  979. DEF_OPR_IMPL(BNForward, BNBase, 6, 6);
  980. public:
  981. /**
  982. * \dst[i] = gemma
  983. * *(x[i]-estimatedMean[k])/sqrt(epsilon+estimatedVariance[k]) + beta \where
  984. * epsilon is a very small value to avoid a "divide by zero" error.
  985. * \param[in] src (n, c, h, w)
  986. * \param[out] dst (n, c, h, w)
  987. * \param[out] mean (see m_param.ParamDim) Global mean.
  988. * \param[out] variance (see m_param.ParamDim) Global variance.
  989. * \param[out] batch_mean (see m_param.ParamDim)
  990. * Optionally cached intermediate mean from forward pass
  991. * \param[out] batch_inv_variance (see m_param.ParamDim)
  992. * Optionally cached intermediate variance from forward pass
  993. * \param[out] reserve (see cudnnBatchNormalizationForwardTrainingEx)
  994. * src and dst must have the same shape.
  995. * src and dst must be contiguous.
  996. */
  997. virtual void exec(
  998. _megdnn_tensor_in src, _megdnn_tensor_in bn_scale,
  999. _megdnn_tensor_in bn_bias, _megdnn_tensor_inout mean,
  1000. _megdnn_tensor_inout variance, _megdnn_tensor_out batch_mean,
  1001. _megdnn_tensor_out batch_inv_variance, _megdnn_tensor_out reserve,
  1002. _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0;
  1003. void deduce_layout(
  1004. const TensorLayout& src, const TensorLayout& bn_scale,
  1005. const TensorLayout& bn_bias, TensorLayout& mean, TensorLayout& variance,
  1006. TensorLayout& batch_mean, TensorLayout& batch_inv_variance,
  1007. TensorLayout& reserve, TensorLayout& dst);
  1008. virtual size_t get_workspace_in_bytes(
  1009. const TensorLayout& src, const TensorLayout& bn_scale,
  1010. const TensorLayout& bn_bias, const TensorLayout& mean,
  1011. const TensorLayout& variance, const TensorLayout& batch_mean,
  1012. const TensorLayout& batch_inv_variance, const TensorLayout& reserve,
  1013. const TensorLayout& dst) = 0;
  1014. virtual size_t get_reserve_in_bytes(const TensorLayout& src) = 0;
  1015. protected:
  1016. void check_exec(
  1017. const TensorLayout& src, const TensorLayout& bn_scale,
  1018. const TensorLayout& bn_bias, const TensorLayout& mean,
  1019. const TensorLayout& variance, const TensorLayout& batch_mean,
  1020. const TensorLayout& batch_inv_variance, const TensorLayout& dst,
  1021. size_t workspace_in_bytes, size_t reserve_in_bytes = 0);
  1022. };
  1023. using BN = BNForward;
  1024. class BNBackward : public BNBase {
  1025. DEF_OPR_IMPL(BNBackward, BNBase, 6, 3);
  1026. public:
  1027. /**
  1028. * \param[in] input data of forwarding propagate.
  1029. * \param[in] dy the backpropagated gradient of y.
  1030. * \param[out] dx the backpropagated gradient of x.
  1031. * \param[out] d_bn_scale, the backpropagated gradient of bn_scale.
  1032. * \param[out] d_bn_bias, the backpropagated gradient of bn_bias.
  1033. * Optionally cached intermediate results from forward pass
  1034. * \param[in] saved_batch_mean mean of the input batch.
  1035. Calculated in the forwardpropagation.
  1036. * \param[in] saved_batch_variance of the input batch.
  1037. Calculated in the forwardpropagation.
  1038. * \param[in] reserve (see cudnnBatchNormalizationBackwardEx)
  1039. */
  1040. virtual void exec(
  1041. _megdnn_tensor_in x, _megdnn_tensor_in dy,
  1042. _megdnn_tensor_in saved_batch_mean, _megdnn_tensor_in saved_batch_variance,
  1043. _megdnn_tensor_in bn_scale, _megdnn_tensor_in reserve,
  1044. _megdnn_tensor_out d_bn_scale, _megdnn_tensor_out d_bn_bias,
  1045. _megdnn_tensor_out dx, _megdnn_workspace workspace) = 0;
  1046. virtual size_t get_workspace_in_bytes(
  1047. const TensorLayout& x, const TensorLayout& dy,
  1048. const TensorLayout& saved_batch_mean,
  1049. const TensorLayout& saved_batch_variance, const TensorLayout& bn_scale,
  1050. const TensorLayout& reserve, const TensorLayout& d_bn_scale,
  1051. const TensorLayout& d_bn_bias, const TensorLayout& dx) = 0;
  1052. virtual size_t get_reserve_in_bytes(const TensorLayout& src) = 0;
  1053. protected:
  1054. void check_exec(
  1055. const TensorLayout& x, const TensorLayout& dy,
  1056. const TensorLayout& saved_batch_mean,
  1057. const TensorLayout& saved_batch_variance, const TensorLayout& bn_scale,
  1058. const TensorLayout& d_bn_scale, const TensorLayout& d_bn_bias,
  1059. const TensorLayout& dx, size_t workspace_in_bytes,
  1060. size_t reserve_in_bytes = 0);
  1061. };
  1062. class LRNBase : public OperatorBase {
  1063. DEF_OPR_IMPL_CTOR(LRNBase, OperatorBase);
  1064. DEF_OPR_PARAM(LRN);
  1065. protected:
  1066. void check_param();
  1067. };
  1068. class LRNForward : public LRNBase {
  1069. DEF_OPR_IMPL(LRNForward, LRNBase, 1, 1);
  1070. public:
  1071. /**
  1072. * \see ImageNet Classification with Deep Convolutional Neural Networks
  1073. * \param[in] src (n, c, h, w)
  1074. * \param[out] dst (n, c, h, w)
  1075. *
  1076. * src and dst must have the same shape.
  1077. * src and dst must be contiguous.
  1078. */
  1079. virtual void exec(
  1080. _megdnn_tensor_in src, _megdnn_tensor_out dst,
  1081. _megdnn_workspace workspace) = 0;
  1082. void deduce_layout(const TensorLayout& src, TensorLayout& dst);
  1083. virtual size_t get_workspace_in_bytes(
  1084. const TensorLayout& src, const TensorLayout& dst) = 0;
  1085. protected:
  1086. void check_exec(
  1087. const TensorLayout& src, const TensorLayout& dst,
  1088. size_t workspace_in_bytes);
  1089. };
  1090. using LRN = LRNForward;
  1091. class LRNBackward : public LRNBase {
  1092. DEF_OPR_IMPL(LRNBackward, LRNBase, 3, 1);
  1093. public:
  1094. /**
  1095. * \param[in] src the `src' parameter in LRNForward::exec
  1096. * \param[in] dst the `dst' parameter in LRNForward::exec
  1097. * \param[in] diff the backpropagated gradient wrt. dst
  1098. * \param[out] grad the backpropagated gradient wrt. src
  1099. *
  1100. * All tensors should be contiguous and of the same shape.
  1101. */
  1102. virtual void exec(
  1103. _megdnn_tensor_in src, _megdnn_tensor_in dst, _megdnn_tensor_in diff,
  1104. _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
  1105. virtual size_t get_workspace_in_bytes(
  1106. const TensorLayout& src, const TensorLayout& dst, const TensorLayout& diff,
  1107. const TensorLayout& grad) = 0;
  1108. protected:
  1109. void check_exec(
  1110. const TensorLayout& src, const TensorLayout& dst, const TensorLayout& diff,
  1111. const TensorLayout& grad, size_t workspace_in_bytes);
  1112. };
  1113. class ROIPoolingBase : public OperatorBase {
  1114. DEF_OPR_IMPL_CTOR(ROIPoolingBase, OperatorBase);
  1115. DEF_OPR_PARAM(ROIPooling);
  1116. protected:
  1117. void check_layout_fwd(
  1118. const TensorLayout& src, const TensorLayout& rois, const TensorLayout& dst,
  1119. const TensorLayout& index);
  1120. };
  1121. class ROIPoolingForward : public ROIPoolingBase {
  1122. DEF_OPR_IMPL(ROIPoolingForward, ROIPoolingBase, 2, 2);
  1123. public:
  1124. /**
  1125. * \param[in] src (n, c, ih, iw)
  1126. * \param[in] rois (m, 5)
  1127. * \param[out] dst (m, c, oh, ow)
  1128. * \param[out] index (m, c, oh, ow) if mode is MAX, (0) if mode is AVERAGE
  1129. *
  1130. * The internal implementation is akin to
  1131. * https://github.com/rbgirshick/caffe-fast-rcnn .d
  1132. * Note that rois(, 0) denotes the input image index. We store it as
  1133. * a float, but it should be an integer instead.
  1134. *
  1135. * index is a temporary tensor to facilitate its backward operator.
  1136. * It is used to store argmax indicex in MAX mode, and it is not used
  1137. * in AVERAGE mode.
  1138. */
  1139. virtual void exec(
  1140. _megdnn_tensor_in src, _megdnn_tensor_in rois, _megdnn_tensor_out dst,
  1141. _megdnn_tensor_out index, _megdnn_workspace workspace) = 0;
  1142. virtual size_t get_workspace_in_bytes(
  1143. const TensorLayout& src, const TensorLayout& rois, const TensorLayout& dst,
  1144. const TensorLayout& index) = 0;
  1145. protected:
  1146. void check_exec(
  1147. const TensorLayout& src, const TensorLayout& rois, const TensorLayout& dst,
  1148. const TensorLayout& index, size_t workspace_in_bytes);
  1149. };
  1150. using ROIPooling = ROIPoolingForward;
  1151. class ROIPoolingBackward : public ROIPoolingBase {
  1152. DEF_OPR_IMPL(ROIPoolingBackward, ROIPoolingBase, 4, 1);
  1153. public:
  1154. /**
  1155. * \param[in] diff the backpropagated gradient wrt. dst
  1156. * \param[in] src the `src' parameter in ROIPoolingForward::exec
  1157. * \param[in] rois the `rois' parameter in ROIPoolingForward::exec
  1158. * \param[in] index the `index' parameter in ROIPoolingForward::exec
  1159. * \param[out] grad the backpropagated gradient wrt. src
  1160. */
  1161. virtual void exec(
  1162. _megdnn_tensor_in diff, _megdnn_tensor_in src, _megdnn_tensor_in rois,
  1163. _megdnn_tensor_in index, _megdnn_tensor_out grad,
  1164. _megdnn_workspace workspace) = 0;
  1165. virtual size_t get_workspace_in_bytes(
  1166. const TensorLayout& diff, const TensorLayout& src, const TensorLayout& rois,
  1167. const TensorLayout& index, const TensorLayout& grad) = 0;
  1168. protected:
  1169. void check_exec(
  1170. const TensorLayout& diff, const TensorLayout& src, const TensorLayout& rois,
  1171. const TensorLayout& index, const TensorLayout& grad,
  1172. size_t workspace_in_bytes);
  1173. };
  1174. class Convolution3DBase : public OperatorBase {
  1175. DEF_OPR_IMPL_CTOR(Convolution3DBase, OperatorBase);
  1176. DEF_OPR_PARAM(Convolution3D);
  1177. public:
  1178. static constexpr size_t MAX_SPATIAL_DIM = 3;
  1179. using Mode = Param::Mode;
  1180. struct CanonizedFilterMeta {
  1181. DTypeEnum dtype_enum;
  1182. Param::Format format;
  1183. uint32_t
  1184. //! whether filter should be flipped (i.e. is CONVOLUTION)
  1185. should_flip,
  1186. group, //!< number of groups
  1187. icpg, //!< input channels per group
  1188. ocpg, //!< output channels per group
  1189. spatial_ndim, stride[MAX_SPATIAL_DIM], padding[MAX_SPATIAL_DIM],
  1190. //! spatial dim
  1191. spatial[MAX_SPATIAL_DIM], dilation[MAX_SPATIAL_DIM],
  1192. //! spatial dim with dilation applied
  1193. dilated_spatial[MAX_SPATIAL_DIM];
  1194. } MEGDNN_PACKED;
  1195. protected:
  1196. CanonizedFilterMeta deduce_layout_fwd(
  1197. const TensorLayout& src, const TensorLayout& filter,
  1198. TensorLayout& dst) const;
  1199. CanonizedFilterMeta check_layout_fwd(
  1200. const TensorLayout& src, const TensorLayout& filter,
  1201. const TensorLayout& dst) const;
  1202. static CanonizedFilterMeta make_canonized_filter_meta_impl(
  1203. size_t src_ndim, const TensorLayout& filter, const Param& param);
  1204. CanonizedFilterMeta make_canonized_filter_meta(
  1205. size_t src_ndim, const TensorLayout& filter) const;
  1206. };
  1207. class Convolution3DForward : public Convolution3DBase,
  1208. public detail::MultiAlgoOpr<Convolution3DForward, 3> {
  1209. DEF_OPR_IMPL(Convolution3DForward, Convolution3DBase, 2, 1);
  1210. public:
  1211. /**
  1212. * \param[in] src (n, ic, id, ih, iw)
  1213. * \param[in] filter (oc, ic, fd, fh, fw)
  1214. * \param[out] dst (n, oc, od, oh, ow)
  1215. */
  1216. virtual void exec(
  1217. _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_out dst,
  1218. _megdnn_workspace workspace) = 0;
  1219. MGE_WIN_DECLSPEC_FUC void deduce_layout(
  1220. const TensorLayout& src, const TensorLayout& filter, TensorLayout& dst);
  1221. virtual size_t get_workspace_in_bytes(
  1222. const TensorLayout& src, const TensorLayout& filter,
  1223. const TensorLayout& dst) = 0;
  1224. static Algorithm::OprType get_opr_type() {
  1225. return Algorithm::OprType::CONVOLUTION3D_FORWARD;
  1226. }
  1227. protected:
  1228. CanonizedFilterMeta check_exec(
  1229. const TensorLayout& src, const TensorLayout& filter,
  1230. const TensorLayout& dst, size_t workspace_in_bytes);
  1231. };
  1232. using Convolution3D = Convolution3DForward;
  1233. class Convolution3DBackwardData
  1234. : public Convolution3DBase,
  1235. public detail::MultiAlgoOpr<Convolution3DBackwardData, 3> {
  1236. DEF_OPR_IMPL(Convolution3DBackwardData, Convolution3DBase, 2, 1);
  1237. public:
  1238. /**
  1239. * \param[in] filter (oc, ic, fd, fh, fw)
  1240. * \param[in] diff (n, oc, od, oh, ow)
  1241. * \param[out] grad (n, ic, id, ih, iw)
  1242. */
  1243. static void deduce_layout_impl(
  1244. const TensorLayout& filter, const TensorLayout& diff, const Param& param,
  1245. TensorLayout& grad);
  1246. virtual void exec(
  1247. _megdnn_tensor_in filter, _megdnn_tensor_in diff, _megdnn_tensor_out grad,
  1248. _megdnn_workspace workspace) = 0;
  1249. virtual size_t get_workspace_in_bytes(
  1250. const TensorLayout& filter, const TensorLayout& diff,
  1251. const TensorLayout& grad) = 0;
  1252. MGE_WIN_DECLSPEC_FUC void deduce_layout(
  1253. const TensorLayout& filter, const TensorLayout& diff, TensorLayout& grad);
  1254. static Algorithm::OprType get_opr_type() {
  1255. return Algorithm::OprType::CONVOLUTION3D_BACKWARD_DATA;
  1256. }
  1257. protected:
  1258. CanonizedFilterMeta check_exec(
  1259. const TensorLayout& filter, const TensorLayout& diff,
  1260. const TensorLayout& grad, size_t workspace_in_bytes);
  1261. };
  1262. class Convolution3DBackwardFilter
  1263. : public Convolution3DBase,
  1264. public detail::MultiAlgoOpr<Convolution3DBackwardFilter, 3> {
  1265. DEF_OPR_IMPL(Convolution3DBackwardFilter, Convolution3DBase, 2, 1);
  1266. public:
  1267. /**
  1268. * \param[in] src (n, ic, id, ih, iw)
  1269. * \param[in] diff (n, oc, od, oh, ow)
  1270. * \param[out] grad (oc, ic, fd, fh, fw)
  1271. */
  1272. virtual void exec(
  1273. _megdnn_tensor_in src, _megdnn_tensor_in diff, _megdnn_tensor_out grad,
  1274. _megdnn_workspace workspace) = 0;
  1275. virtual size_t get_workspace_in_bytes(
  1276. const TensorLayout& src, const TensorLayout& diff,
  1277. const TensorLayout& grad) = 0;
  1278. static Algorithm::OprType get_opr_type() {
  1279. return Algorithm::OprType::CONVOLUTION3D_BACKWARD_FILTER;
  1280. }
  1281. protected:
  1282. CanonizedFilterMeta check_exec(
  1283. const TensorLayout& src, const TensorLayout& diff, const TensorLayout& grad,
  1284. size_t workspace_in_bytes);
  1285. };
  1286. class LocalShareBase : public OperatorBase {
  1287. DEF_OPR_IMPL_CTOR(LocalShareBase, OperatorBase);
  1288. DEF_OPR_PARAM(LocalShare);
  1289. protected:
  1290. void deduce_layout_fwd(
  1291. const TensorLayout& src, const TensorLayout& filter, TensorLayout& dst);
  1292. void check_layout_fwd(
  1293. const TensorLayout& src, const TensorLayout& filter,
  1294. const TensorLayout& dst);
  1295. };
  1296. class LocalShareForward : public LocalShareBase,
  1297. public detail::MultiAlgoOpr<LocalShareForward, 3> {
  1298. DEF_OPR_IMPL(LocalShareForward, LocalShareBase, 2, 1);
  1299. public:
  1300. /**
  1301. * \param[in] src (N, IC, IH, IW)
  1302. * \param[in] filter (G, spatial_groups_h, spatial_groups_w, IC / G,
  1303. * FH, FW, OC / G)
  1304. * \param[out] dst (N, OC, OH, OW)
  1305. */
  1306. virtual void exec(
  1307. _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_out dst,
  1308. _megdnn_workspace workspace) = 0;
  1309. /**
  1310. * \brief deduce layout of the ouput tensor
  1311. */
  1312. void deduce_layout(
  1313. const TensorLayout& src, const TensorLayout& filter, TensorLayout& dst);
  1314. virtual size_t get_workspace_in_bytes(
  1315. const TensorLayout& src, const TensorLayout& filter,
  1316. const TensorLayout& dst) = 0;
  1317. static Algorithm::OprType get_opr_type() {
  1318. return Algorithm::OprType::LOCAL_SHARE_FORWARD;
  1319. }
  1320. protected:
  1321. void check_exec(
  1322. const TensorLayout& src, const TensorLayout& filter,
  1323. const TensorLayout& dst, size_t workspace_in_bytes);
  1324. };
  1325. using LocalShare = LocalShareForward;
  1326. class LocalShareBackwardData : public LocalShareBase,
  1327. public detail::MultiAlgoOpr<LocalShareBackwardData, 3> {
  1328. DEF_OPR_IMPL(LocalShareBackwardData, LocalShareBase, 2, 1);
  1329. public:
  1330. /**
  1331. * \param[in] filter (G, spatial_groups_h, spatial_groups_w, IC / G,
  1332. * FH, FW, OC / G)
  1333. * \param[in] diff (N, OC, OH, OW)
  1334. * \param[out] grad (N, IC, IH, IW)
  1335. */
  1336. virtual void exec(
  1337. _megdnn_tensor_in filter, _megdnn_tensor_in diff, _megdnn_tensor_out grad,
  1338. _megdnn_workspace workspace) = 0;
  1339. virtual size_t get_workspace_in_bytes(
  1340. const TensorLayout& filter, const TensorLayout& diff,
  1341. const TensorLayout& grad) = 0;
  1342. void deduce_layout(
  1343. const TensorLayout& filter, const TensorLayout& diff, TensorLayout& grad);
  1344. static Algorithm::OprType get_opr_type() {
  1345. return Algorithm::OprType::LOCAL_SHARE_BACKWARD_DATA;
  1346. }
  1347. protected:
  1348. void check_exec(
  1349. const TensorLayout& filter, const TensorLayout& diff,
  1350. const TensorLayout& grad, size_t workspace_in_bytes);
  1351. };
  1352. class LocalShareBackwardFilter
  1353. : public LocalShareBase,
  1354. public detail::MultiAlgoOpr<LocalShareBackwardFilter, 3> {
  1355. DEF_OPR_IMPL(LocalShareBackwardFilter, LocalShareBase, 2, 1);
  1356. public:
  1357. /**
  1358. * \param[in] src (N, IC, IH, IW)
  1359. * \param[in] diff (N, OC, OH, OW)
  1360. * \param[out] grad (G, spatial_groups_h, spatial_groups_w, IC / G,
  1361. * FH, FW, OC / G)
  1362. */
  1363. virtual void exec(
  1364. _megdnn_tensor_in src, _megdnn_tensor_in diff, _megdnn_tensor_out grad,
  1365. _megdnn_workspace workspace) = 0;
  1366. virtual size_t get_workspace_in_bytes(
  1367. const TensorLayout& src, const TensorLayout& diff,
  1368. const TensorLayout& grad) = 0;
  1369. static Algorithm::OprType get_opr_type() {
  1370. return Algorithm::OprType::LOCAL_SHARE_BACKWARD_FILTER;
  1371. }
  1372. protected:
  1373. void check_exec(
  1374. const TensorLayout& src, const TensorLayout& diff, const TensorLayout& grad,
  1375. size_t workspace_in_bytes);
  1376. };
  1377. class ROIAlignBase : public OperatorBase {
  1378. DEF_OPR_IMPL_CTOR(ROIAlignBase, OperatorBase);
  1379. DEF_OPR_PARAM(ROIAlign);
  1380. protected:
  1381. void deduce_layout_fwd(
  1382. const TensorLayout& src, const TensorLayout& rois, TensorLayout& dst,
  1383. TensorLayout& index);
  1384. void check_layout_fwd(
  1385. const TensorLayout& src, const TensorLayout& rois, const TensorLayout& dst,
  1386. const TensorLayout& index);
  1387. };
  1388. class ROIAlignForward : public ROIAlignBase {
  1389. DEF_OPR_IMPL(ROIAlignForward, ROIAlignBase, 2, 2);
  1390. public:
  1391. /**
  1392. * \param[in] src (n, c, ih, iw)
  1393. * \param[in] rois (m, 5)
  1394. * \param[out] dst (m, c, oh, ow)
  1395. * \param[out] index (m, c, oh, ow) if mode is MAX, (0) if mode is AVERAGE
  1396. *
  1397. * Note that rois(, 0) denotes the input image index. We store it as
  1398. * a float, but it should be an integer instead.
  1399. *
  1400. * index is a temporary tensor to facilitate its backward operator.
  1401. * It is used to store argmax indicex in MAX mode, and it is not used
  1402. * in AVERAGE mode.
  1403. */
  1404. virtual void exec(
  1405. _megdnn_tensor_in src, _megdnn_tensor_in rois, _megdnn_tensor_out dst,
  1406. _megdnn_tensor_out index, _megdnn_workspace workspace) = 0;
  1407. MGE_WIN_DECLSPEC_FUC void deduce_layout(
  1408. const TensorLayout& src, const TensorLayout& rois, TensorLayout& dst,
  1409. TensorLayout& index);
  1410. virtual size_t get_workspace_in_bytes(
  1411. const TensorLayout& src, const TensorLayout& rois, const TensorLayout& dst,
  1412. const TensorLayout& index) = 0;
  1413. protected:
  1414. void check_exec(
  1415. const TensorLayout& src, const TensorLayout& rois, const TensorLayout& dst,
  1416. const TensorLayout& index, size_t workspace_in_bytes);
  1417. };
  1418. using ROIAlign = ROIAlignForward;
  1419. class ROIAlignBackward : public ROIAlignBase {
  1420. DEF_OPR_IMPL(ROIAlignBackward, ROIAlignBase, 3, 1);
  1421. public:
  1422. /**
  1423. * \param[in] diff the backpropagated gradient wrt. dst
  1424. * \param[in] rois the `rois' parameter in ROIAlignForward::exec
  1425. * \param[in] index the `index' parameter in ROIAlignForward::exec
  1426. * \param[out] grad the backpropagated gradient wrt. src
  1427. */
  1428. virtual void exec(
  1429. _megdnn_tensor_in diff, _megdnn_tensor_in rois, _megdnn_tensor_in index,
  1430. _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
  1431. virtual size_t get_workspace_in_bytes(
  1432. const TensorLayout& diff, const TensorLayout& rois,
  1433. const TensorLayout& index, const TensorLayout& grad) = 0;
  1434. protected:
  1435. void check_exec(
  1436. const TensorLayout& diff, const TensorLayout& rois,
  1437. const TensorLayout& index, const TensorLayout& grad,
  1438. size_t workspace_in_bytes);
  1439. };
  1440. class DeformableConvBase : public OperatorBase {
  1441. DEF_OPR_IMPL_CTOR(DeformableConvBase, OperatorBase);
  1442. DEF_OPR_PARAM(Convolution);
  1443. public:
  1444. static constexpr size_t MAX_SPATIAL_DIM = 2;
  1445. struct CanonizedFilterMeta : Convolution::CanonizedFilterMeta {
  1446. uint32_t deformable_group;
  1447. };
  1448. protected:
  1449. CanonizedFilterMeta make_canonized_filter_meta(
  1450. size_t src_ndim, const TensorLayout& filter,
  1451. const TensorLayout& offset) const;
  1452. void deduce_layout_fwd(
  1453. const TensorLayout& im, const TensorLayout& filter,
  1454. const TensorLayout& mask, const TensorLayout& offset, TensorLayout& dst);
  1455. void check_layout_fwd(
  1456. const TensorLayout& src, const TensorLayout& filter,
  1457. const TensorLayout& mask, const TensorLayout& offset,
  1458. const TensorLayout& dst);
  1459. };
  1460. class DeformableConvForward : public DeformableConvBase,
  1461. public detail::MultiAlgoOpr<DeformableConvForward, 5> {
  1462. DEF_OPR_IMPL(DeformableConvForward, DeformableConvBase, 4, 1);
  1463. public:
  1464. /**
  1465. * \param[in] im (n, ic, ih, iw)
  1466. * \param[in] filter (oc, ic, fh, fw)
  1467. * \param[in] offset (dg, 2, fh, fw, oh, ow)
  1468. * \param[in] mask (dg, fh, fw, oh, ow)
  1469. * \param[out] dst (n, oc, oh, ow)
  1470. */
  1471. virtual void exec(
  1472. _megdnn_tensor_in im, _megdnn_tensor_in filter, _megdnn_tensor_in offset,
  1473. _megdnn_tensor_in mask, _megdnn_tensor_out dst,
  1474. _megdnn_workspace workspace) = 0;
  1475. void deduce_layout(
  1476. const TensorLayout& im, const TensorLayout& filter,
  1477. const TensorLayout& offset, const TensorLayout& mask, TensorLayout& dst);
  1478. virtual size_t get_workspace_in_bytes(
  1479. const TensorLayout& im, const TensorLayout& filter,
  1480. const TensorLayout& offset, const TensorLayout& mask,
  1481. const TensorLayout& dst) = 0;
  1482. static Algorithm::OprType get_opr_type() {
  1483. return Algorithm::OprType::DEFORMABLE_CONV_FORWARD;
  1484. }
  1485. protected:
  1486. CanonizedFilterMeta check_exec(
  1487. const TensorLayout& im, const TensorLayout& filter,
  1488. const TensorLayout& offset, const TensorLayout& mask,
  1489. const TensorLayout& dst, size_t workspace_in_bytes);
  1490. };
  1491. using DeformableConv = DeformableConvForward;
  1492. /**
  1493. * \brief DeformableConvBackwardFilter operator.
  1494. *
  1495. * Calculating the gradient wrt. convolution filter.
  1496. */
  1497. class DeformableConvBackwardFilter
  1498. : public DeformableConvBase,
  1499. public detail::MultiAlgoOpr<DeformableConvBackwardFilter, 5> {
  1500. DEF_OPR_IMPL(DeformableConvBackwardFilter, DeformableConvBase, 4, 1);
  1501. public:
  1502. /**
  1503. * \param[in] im (oc, ic, fh, fw)
  1504. * \param[in] offset (dg, 2, fh, fw, oh, ow)
  1505. * \param[in] mask (dg, fh, fw, oh, ow)
  1506. * \param[in] out_grad (n, oc, oh, ow)
  1507. * \param[out] filter_grad (oc, ic, ih, iw)
  1508. */
  1509. virtual void exec(
  1510. _megdnn_tensor_in im, _megdnn_tensor_in offset, _megdnn_tensor_in mask,
  1511. _megdnn_tensor_in out_grad, _megdnn_tensor_out filter_grad,
  1512. _megdnn_workspace workspace) = 0;
  1513. virtual size_t get_workspace_in_bytes(
  1514. const TensorLayout& im, const TensorLayout& offset,
  1515. const TensorLayout& mask, const TensorLayout& out_grad,
  1516. const TensorLayout& filter_grad) = 0;
  1517. void deduce_layout(
  1518. const TensorLayout& im, const TensorLayout& offset,
  1519. const TensorLayout& mask, const TensorLayout& out_grad,
  1520. TensorLayout& filter_grad);
  1521. static Algorithm::OprType get_opr_type() {
  1522. return Algorithm::OprType::DEFORMABLE_CONV_BACKWARD_FILTER;
  1523. }
  1524. protected:
  1525. CanonizedFilterMeta check_exec(
  1526. const TensorLayout& im, const TensorLayout& offset,
  1527. const TensorLayout& mask, const TensorLayout& out_grad,
  1528. const TensorLayout& filter_grad, size_t workspace_in_bytes);
  1529. };
  1530. /**
  1531. * \brief DeformableConvBackwardData operator.
  1532. *
  1533. * Calculating the gradient wrt. convolution input data, offset and mask.
  1534. */
  1535. class DeformableConvBackwardData
  1536. : public DeformableConvBase,
  1537. public detail::MultiAlgoOpr<DeformableConvBackwardData, 8> {
  1538. DEF_OPR_IMPL(DeformableConvBackwardData, DeformableConvBase, 5, 3);
  1539. public:
  1540. /**
  1541. * \param[in] im (oc, ic, fh, fw)
  1542. * \param[in] filter (oc, ic, fh, fw)
  1543. * \param[in] offset (dg, 2, fh, fw, oh, ow)
  1544. * \param[in] mask (dg, fh, fw, oh, ow)
  1545. * \param[in] out_grad (n, oc, oh, ow)
  1546. * \param[out] im_grad (n, ic, ih, iw)
  1547. * \param[out] offset_grad (dg, 2, fh, fw, oh, ow)
  1548. * \param[out] mask_grad (dg, fh, fw, oh, ow)
  1549. */
  1550. virtual void exec(
  1551. _megdnn_tensor_in im, _megdnn_tensor_in filter, _megdnn_tensor_in offset,
  1552. _megdnn_tensor_in mask, _megdnn_tensor_in out_grad,
  1553. _megdnn_tensor_out im_grad, _megdnn_tensor_out offset_grad,
  1554. _megdnn_tensor_out mask_grad, _megdnn_workspace workspace) = 0;
  1555. virtual size_t get_workspace_in_bytes(
  1556. const TensorLayout& im, const TensorLayout& filter,
  1557. const TensorLayout& offset, const TensorLayout& mask,
  1558. const TensorLayout& out_grad, const TensorLayout& im_grad,
  1559. const TensorLayout& offset_grad, const TensorLayout& mask_grad) = 0;
  1560. void deduce_layout(
  1561. const TensorLayout& im, const TensorLayout& filter,
  1562. const TensorLayout& offset, const TensorLayout& mask,
  1563. const TensorLayout& out_grad, TensorLayout& im_grad,
  1564. TensorLayout& offset_grad, TensorLayout& mask_grad);
  1565. static Algorithm::OprType get_opr_type() {
  1566. return Algorithm::OprType::DEFORMABLE_CONV_BACKWARD_DATA;
  1567. }
  1568. protected:
  1569. CanonizedFilterMeta check_exec(
  1570. const TensorLayout& im, const TensorLayout& filter,
  1571. const TensorLayout& offset, const TensorLayout& mask,
  1572. const TensorLayout& out_grad, const TensorLayout& im_grad,
  1573. const TensorLayout& offset_grad, const TensorLayout& mask_grad,
  1574. size_t workspace_in_bytes);
  1575. };
  1576. class DeformablePSROIPoolingBase : public OperatorBase {
  1577. DEF_OPR_IMPL_CTOR(DeformablePSROIPoolingBase, OperatorBase);
  1578. DEF_OPR_PARAM(DeformablePSROIPooling);
  1579. protected:
  1580. void deduce_layout_fwd(
  1581. const TensorLayout& data, const TensorLayout& trans,
  1582. const TensorLayout& rois, TensorLayout& out_data, TensorLayout& out_count);
  1583. void check_layout_fwd(
  1584. const TensorLayout& data, const TensorLayout& trans,
  1585. const TensorLayout& rois, const TensorLayout& out_data,
  1586. const TensorLayout& out_count, size_t workspace_in_bytes);
  1587. };
  1588. class DeformablePSROIPoolingForward : public DeformablePSROIPoolingBase {
  1589. DEF_OPR_IMPL(DeformablePSROIPoolingForward, DeformablePSROIPoolingBase, 3, 2);
  1590. public:
  1591. /**
  1592. * \param[in] data (oc, ic, ih, iw)
  1593. * \param[in] rois (xx, xx, xx, xx)
  1594. * \param[in] trans (oc, ic, fh, fw)
  1595. * \param[out] out_data ( n, ic, ih, iw)
  1596. * \param[out] out_count ( n, ic, ih, iw)
  1597. */
  1598. virtual size_t get_workspace_in_bytes(
  1599. const TensorLayout& data, const TensorLayout& rois,
  1600. const TensorLayout& trans, const TensorLayout& out_data,
  1601. const TensorLayout& out_count) = 0;
  1602. virtual void exec(
  1603. _megdnn_tensor_in data, _megdnn_tensor_in rois, _megdnn_tensor_in trans,
  1604. _megdnn_tensor_out out_data, _megdnn_tensor_out out_count,
  1605. _megdnn_workspace workspace) = 0;
  1606. void deduce_layout(
  1607. const TensorLayout& data, const TensorLayout& rois,
  1608. const TensorLayout& trans, TensorLayout& out_data, TensorLayout& out_count);
  1609. void check_exec(
  1610. const TensorLayout& data, const TensorLayout& rois,
  1611. const TensorLayout& trans, const TensorLayout& out_data,
  1612. const TensorLayout& out_count, size_t workspace_in_bytes);
  1613. };
  1614. using DeformablePSROIPooling = DeformablePSROIPoolingForward;
  1615. class DeformablePSROIPoolingBackward : public DeformablePSROIPoolingBase {
  1616. DEF_OPR_IMPL(DeformablePSROIPoolingBackward, DeformablePSROIPoolingBase, 5, 2);
  1617. public:
  1618. /**
  1619. * \param[in] data (oc, ic, ih, iw)
  1620. * \param[in] rois (xx, xx, xx, xx)
  1621. * \param[in] trans (oc, ic, fh, fw)
  1622. * \param[in] out_diff (xx, xx, xx, xx)
  1623. * \param[in] out_count (xx, xx, xx, xx)
  1624. * \param[out] data_diff ( n, ic, ih, iw)
  1625. * \param[out] trans_diff ( n, ic, ih, iw)
  1626. */
  1627. virtual void exec(
  1628. _megdnn_tensor_in data, _megdnn_tensor_in rois, _megdnn_tensor_in trans,
  1629. _megdnn_tensor_in out_diff, _megdnn_tensor_in out_count,
  1630. _megdnn_tensor_out data_diff, _megdnn_tensor_out trans_diff,
  1631. _megdnn_workspace workspace) = 0;
  1632. virtual size_t get_workspace_in_bytes(
  1633. const TensorLayout& data, const TensorLayout& rois,
  1634. const TensorLayout& trans, const TensorLayout& out_diff,
  1635. const TensorLayout& out_count, const TensorLayout& data_diff,
  1636. const TensorLayout& trans_diff) = 0;
  1637. void check_exec(
  1638. const TensorLayout& data, const TensorLayout& rois,
  1639. const TensorLayout& trans, const TensorLayout& out_diff,
  1640. const TensorLayout& out_count, const TensorLayout& data_diff,
  1641. const TensorLayout& trans_diff, size_t workspace_in_bytes);
  1642. };
  1643. class BatchConvBiasForward : public ConvolutionBase<param::BatchConvBias>,
  1644. public detail::MultiAlgoOpr<BatchConvBiasForward, 5> {
  1645. DEF_OPR_IMPL(BatchConvBiasForward, ConvolutionBase, 4, 1);
  1646. public:
  1647. virtual void exec(
  1648. _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_in bias,
  1649. _megdnn_tensor_in z, _megdnn_tensor_out dst,
  1650. _megdnn_workspace workspace) = 0;
  1651. void deduce_dtype(DType src, DType filter, DType bias, DType z, DType& dst);
  1652. void deduce_layout(
  1653. const TensorLayout& src, const TensorLayout& filter,
  1654. const TensorLayout& bias, const TensorLayout& z, TensorLayout& dst);
  1655. virtual size_t get_workspace_in_bytes(
  1656. const TensorLayout& src, const TensorLayout& filter,
  1657. const TensorLayout& bias, const TensorLayout& z,
  1658. const TensorLayout& dst) = 0;
  1659. static Algorithm::OprType get_opr_type() {
  1660. return Algorithm::OprType::BATCH_CONV_FORWARD;
  1661. }
  1662. protected:
  1663. CanonizedFilterMeta check_exec(
  1664. const TensorLayout& src, const TensorLayout& filter,
  1665. const TensorLayout& bias, const TensorLayout& z, const TensorLayout& dst,
  1666. size_t workspace_in_bytes);
  1667. };
  1668. using BatchConvBias = BatchConvBiasForward;
  1669. class FakeQuantBase : public OperatorBase {
  1670. DEF_OPR_IMPL_CTOR(FakeQuantBase, OperatorBase);
  1671. DEF_OPR_PARAM(FakeQuant);
  1672. protected:
  1673. void deduce_layout_fwd(const TensorLayout& input, TensorLayout& output);
  1674. void check_layout_fwd(
  1675. const TensorLayout& input, const TensorLayout& scale,
  1676. const TensorLayout& zero_point, const TensorLayout& output);
  1677. };
  1678. class FakeQuantForward : public FakeQuantBase {
  1679. DEF_OPR_IMPL(FakeQuantForward, FakeQuantBase, 3, 1);
  1680. public:
  1681. virtual void exec(
  1682. _megdnn_tensor_in input, _megdnn_tensor_in scale,
  1683. _megdnn_tensor_in zero_point, _megdnn_tensor_out output,
  1684. _megdnn_workspace workspace) = 0;
  1685. void deduce_layout(
  1686. const TensorLayout& input, const TensorLayout& scale,
  1687. const TensorLayout& zero_point, TensorLayout& output);
  1688. virtual size_t get_workspace_in_bytes(
  1689. const TensorLayout& input, const TensorLayout& scale,
  1690. const TensorLayout& zero_point, const TensorLayout& output) = 0;
  1691. protected:
  1692. void check_exec(
  1693. const TensorLayout& input, const TensorLayout& scale,
  1694. const TensorLayout& zero_point, const TensorLayout& output,
  1695. size_t workspace_in_bytes);
  1696. };
  1697. using FakeQuant = FakeQuantForward;
  1698. class FakeQuantBackward : public FakeQuantBase {
  1699. DEF_OPR_IMPL(FakeQuantBackward, FakeQuantBase, 4, 1);
  1700. public:
  1701. virtual void exec(
  1702. _megdnn_tensor_in diff, _megdnn_tensor_in input, _megdnn_tensor_in scale,
  1703. _megdnn_tensor_in zero_point, _megdnn_tensor_out grad,
  1704. _megdnn_workspace workspace) = 0;
  1705. virtual size_t get_workspace_in_bytes(
  1706. const TensorLayout& diff, const TensorLayout& input,
  1707. const TensorLayout& scale, const TensorLayout& zero_point,
  1708. const TensorLayout& grad) = 0;
  1709. protected:
  1710. void check_exec(
  1711. const TensorLayout& diff, const TensorLayout& input,
  1712. const TensorLayout& scale, const TensorLayout& zero_point,
  1713. const TensorLayout& grad, size_t workspace_in_bytes);
  1714. };
  1715. class TQTBase : public OperatorBase {
  1716. DEF_OPR_IMPL_CTOR(TQTBase, OperatorBase);
  1717. DEF_OPR_PARAM(TQT);
  1718. protected:
  1719. void deduce_layout_fwd(const TensorLayout& input, TensorLayout& output);
  1720. void check_layout_fwd(
  1721. const TensorLayout& input, const TensorLayout& scale,
  1722. const TensorLayout& output);
  1723. };
  1724. class TQTForward : public TQTBase {
  1725. DEF_OPR_IMPL(TQTForward, TQTBase, 2, 1);
  1726. public:
  1727. virtual void exec(
  1728. _megdnn_tensor_in input, _megdnn_tensor_in scale, _megdnn_tensor_out output,
  1729. _megdnn_workspace workspace) = 0;
  1730. void deduce_layout(
  1731. const TensorLayout& input, const TensorLayout& scale, TensorLayout& output);
  1732. virtual size_t get_workspace_in_bytes(
  1733. const TensorLayout& input, const TensorLayout& scale,
  1734. const TensorLayout& output) = 0;
  1735. protected:
  1736. void check_exec(
  1737. const TensorLayout& input, const TensorLayout& scale,
  1738. const TensorLayout& output, size_t workspace_in_bytes);
  1739. };
  1740. using TQT = TQTForward;
  1741. class TQTBackward : public TQTBase {
  1742. DEF_OPR_IMPL(TQTBackward, TQTBase, 3, 2);
  1743. public:
  1744. virtual void exec(
  1745. _megdnn_tensor_in diff, _megdnn_tensor_in input, _megdnn_tensor_in scale,
  1746. _megdnn_tensor_out grad_x, _megdnn_tensor_out grad_s,
  1747. _megdnn_workspace workspace) = 0;
  1748. virtual size_t get_workspace_in_bytes(
  1749. const TensorLayout& diff, const TensorLayout& input,
  1750. const TensorLayout& scale, const TensorLayout& grad_x,
  1751. const TensorLayout& grad_s) = 0;
  1752. protected:
  1753. void check_exec(
  1754. const TensorLayout& diff, const TensorLayout& input,
  1755. const TensorLayout& scale, const TensorLayout& grad_x,
  1756. const TensorLayout& grad_s, size_t workspace_in_bytes);
  1757. };
  1758. class LSQBase : public OperatorBase {
  1759. DEF_OPR_IMPL_CTOR(LSQBase, OperatorBase);
  1760. DEF_OPR_PARAM(LSQ);
  1761. protected:
  1762. void deduce_layout_fwd(const TensorLayout& input, TensorLayout& output);
  1763. void check_layout_fwd(
  1764. const TensorLayout& input, const TensorLayout& scale,
  1765. const TensorLayout& zero_point, const TensorLayout& grad_scale,
  1766. const TensorLayout& output);
  1767. };
  1768. class LSQForward : public LSQBase {
  1769. DEF_OPR_IMPL(LSQForward, LSQBase, 4, 1);
  1770. public:
  1771. virtual void exec(
  1772. _megdnn_tensor_in input, _megdnn_tensor_in scale,
  1773. _megdnn_tensor_in zero_point, _megdnn_tensor_in grad_scale,
  1774. _megdnn_tensor_out output, _megdnn_workspace workspace) = 0;
  1775. void deduce_layout(
  1776. const TensorLayout& input, const TensorLayout& scale,
  1777. const TensorLayout& zero_point, const TensorLayout& grad_scale,
  1778. TensorLayout& output);
  1779. virtual size_t get_workspace_in_bytes(
  1780. const TensorLayout& input, const TensorLayout& scale,
  1781. const TensorLayout& zero_point, const TensorLayout& grad_scale,
  1782. const TensorLayout& output) = 0;
  1783. protected:
  1784. void check_exec(
  1785. const TensorLayout& input, const TensorLayout& scale,
  1786. const TensorLayout& zero_point, const TensorLayout& grad_scale,
  1787. const TensorLayout& output, size_t workspace_in_bytes);
  1788. };
  1789. using LSQ = LSQForward;
  1790. class LSQBackward : public LSQBase {
  1791. DEF_OPR_IMPL(LSQBackward, LSQBase, 5, 2);
  1792. public:
  1793. virtual void exec(
  1794. _megdnn_tensor_in diff, _megdnn_tensor_in input, _megdnn_tensor_in scale,
  1795. _megdnn_tensor_in zero_point, _megdnn_tensor_in grad_scale,
  1796. _megdnn_tensor_out grad_x, _megdnn_tensor_out grad_s,
  1797. _megdnn_workspace workspace) = 0;
  1798. virtual size_t get_workspace_in_bytes(
  1799. const TensorLayout& diff, const TensorLayout& input,
  1800. const TensorLayout& scale, const TensorLayout& zero_point,
  1801. const TensorLayout& grad_scale, const TensorLayout& grad_x,
  1802. const TensorLayout& grad_s) = 0;
  1803. protected:
  1804. void check_exec(
  1805. const TensorLayout& diff, const TensorLayout& input,
  1806. const TensorLayout& scale, const TensorLayout& zero_point,
  1807. const TensorLayout& grad_scale, const TensorLayout& grad_x,
  1808. const TensorLayout& grad_s, size_t workspace_in_bytes);
  1809. };
  1810. class LayerNormBase : public OperatorBase {
  1811. DEF_OPR_IMPL_CTOR(LayerNormBase, OperatorBase);
  1812. DEF_OPR_PARAM(LayerNorm);
  1813. public:
  1814. MGE_WIN_DECLSPEC_FUC static void deduce_layout_fwd_impl(
  1815. const TensorLayout& data, const Param& p, TensorLayout& dst,
  1816. TensorLayout& mean, TensorLayout& rstd);
  1817. protected:
  1818. void deduce_layout_fwd(
  1819. const TensorLayout& data, const TensorLayout& weight,
  1820. const TensorLayout& bias, TensorLayout& dst, TensorLayout& mean,
  1821. TensorLayout& rstd);
  1822. void check_layout_fwd(
  1823. const TensorLayout& data, const TensorLayout& weight,
  1824. const TensorLayout& bias, const TensorLayout& dst, const TensorLayout& mean,
  1825. const TensorLayout& rstd);
  1826. };
  1827. class LayerNormForward : public LayerNormBase {
  1828. DEF_OPR_IMPL(LayerNormForward, LayerNormBase, 3, 3);
  1829. public:
  1830. virtual void exec(
  1831. _megdnn_tensor_in data, _megdnn_tensor_in weight, _megdnn_tensor_in bias,
  1832. _megdnn_tensor_out dst, _megdnn_tensor_out mean, _megdnn_tensor_out rstd,
  1833. _megdnn_workspace workspace) = 0;
  1834. MGE_WIN_DECLSPEC_FUC void deduce_layout(
  1835. const TensorLayout& data, const TensorLayout& weight,
  1836. const TensorLayout& bias, TensorLayout& dst, TensorLayout& mean,
  1837. TensorLayout& rstd);
  1838. virtual size_t get_workspace_in_bytes(
  1839. const TensorLayout& data, const TensorLayout& weight,
  1840. const TensorLayout& bias, const TensorLayout& dst, const TensorLayout& mean,
  1841. const TensorLayout& rstd) = 0;
  1842. protected:
  1843. void check_exec(
  1844. const TensorLayout& data, const TensorLayout& weight,
  1845. const TensorLayout& bias, const TensorLayout& dst, const TensorLayout& mean,
  1846. const TensorLayout& rstd, size_t workspace_in_bytes);
  1847. };
  1848. using LayerNorm = LayerNormForward;
  1849. class LayerNormBackward : public LayerNormBase {
  1850. DEF_OPR_IMPL(LayerNormBackward, LayerNormBase, 5, 3);
  1851. public:
  1852. virtual void exec(
  1853. _megdnn_tensor_in diff, _megdnn_tensor_in data, _megdnn_tensor_in weight,
  1854. _megdnn_tensor_in mean, _megdnn_tensor_in rstd, _megdnn_tensor_out ddata,
  1855. _megdnn_tensor_out dweight, _megdnn_tensor_out dbias,
  1856. _megdnn_workspace workspace) = 0;
  1857. void deduce_layout(
  1858. const TensorLayout& diff, const TensorLayout& data,
  1859. const TensorLayout& weight, const TensorLayout& mean,
  1860. const TensorLayout& rstd, TensorLayout& ddata, TensorLayout& dweight,
  1861. TensorLayout& dbias);
  1862. virtual size_t get_workspace_in_bytes(
  1863. const TensorLayout& diff, const TensorLayout& data,
  1864. const TensorLayout& weight, const TensorLayout& mean,
  1865. const TensorLayout& rstd, const TensorLayout& ddata,
  1866. const TensorLayout& dweight, const TensorLayout& dbias) = 0;
  1867. protected:
  1868. void check_exec(
  1869. const TensorLayout& diff, const TensorLayout& data,
  1870. const TensorLayout& weight, const TensorLayout& mean,
  1871. const TensorLayout& rstd, const TensorLayout& ddata,
  1872. const TensorLayout& dweight, const TensorLayout& dbias,
  1873. size_t workspace_in_bytes);
  1874. };
  1875. class DropoutBase : public OperatorBase {
  1876. DEF_OPR_IMPL_CTOR(DropoutBase, OperatorBase);
  1877. DEF_OPR_PARAM(Dropout);
  1878. };
  1879. class DropoutForward : public DropoutBase {
  1880. DEF_OPR_IMPL(DropoutForward, DropoutBase, 1, 2);
  1881. public:
  1882. void deduce_layout(const TensorLayout& inp, TensorLayout& oup, TensorLayout& mask);
  1883. virtual void exec(
  1884. _megdnn_tensor_in inp, _megdnn_tensor_out oup, _megdnn_tensor_out mask,
  1885. _megdnn_workspace workspace) = 0;
  1886. virtual size_t get_workspace_in_bytes(
  1887. const TensorLayout& inp, const TensorLayout& oup,
  1888. const TensorLayout& mask) = 0;
  1889. virtual size_t get_mask_size_in_bytes(const TensorLayout& inp) = 0;
  1890. protected:
  1891. void check_exec(
  1892. const TensorLayout& inp, const TensorLayout& oup, const TensorLayout& mask,
  1893. size_t workspace_in_bytes);
  1894. };
  1895. using Dropout = DropoutForward;
  1896. class DropoutBackward : public DropoutBase {
  1897. DEF_OPR_IMPL(DropoutBackward, DropoutBase, 2, 1);
  1898. public:
  1899. void deduce_layout(
  1900. const TensorLayout& doup, const TensorLayout& mask, TensorLayout& dinp);
  1901. virtual void exec(
  1902. _megdnn_tensor_in doup, _megdnn_tensor_in mask, _megdnn_tensor_out dinp,
  1903. _megdnn_workspace workspace) = 0;
  1904. virtual size_t get_workspace_in_bytes(
  1905. const TensorLayout& doup, const TensorLayout& mask,
  1906. const TensorLayout& dinp) = 0;
  1907. protected:
  1908. void check_exec(
  1909. const TensorLayout& doup, const TensorLayout& mask,
  1910. const TensorLayout& dinp, size_t workspace_in_bytes);
  1911. };
  1912. class SoftmaxBase : public OperatorBase {
  1913. DEF_OPR_IMPL_CTOR(SoftmaxBase, OperatorBase);
  1914. DEF_OPR_PARAM(Softmax);
  1915. protected:
  1916. void deduce_layout_fwd(const TensorLayout& input, TensorLayout& output);
  1917. void check_layout_fwd(const TensorLayout& input, const TensorLayout& output);
  1918. };
  1919. class SoftmaxForward : public SoftmaxBase {
  1920. DEF_OPR_IMPL(SoftmaxForward, SoftmaxBase, 1, 1);
  1921. public:
  1922. /**
  1923. * \param[in] input input tensor
  1924. * \param[out] output output tensor
  1925. */
  1926. virtual void exec(
  1927. _megdnn_tensor_in input, _megdnn_tensor_out output,
  1928. _megdnn_workspace workspace) = 0;
  1929. void deduce_layout(const TensorLayout& input, TensorLayout& output);
  1930. virtual size_t get_workspace_in_bytes(
  1931. const TensorLayout& input, const TensorLayout& output) = 0;
  1932. protected:
  1933. void check_exec(
  1934. const TensorLayout& input, const TensorLayout& output,
  1935. size_t workspace_in_bytes);
  1936. };
  1937. using Softmax = SoftmaxForward;
  1938. class SoftmaxBackward : public SoftmaxBase {
  1939. DEF_OPR_IMPL(SoftmaxBackward, SoftmaxBase, 2, 1);
  1940. public:
  1941. virtual void exec(
  1942. _megdnn_tensor_in input, _megdnn_tensor_in diff, _megdnn_tensor_out grad_x,
  1943. _megdnn_workspace workspace) = 0;
  1944. virtual size_t get_workspace_in_bytes(
  1945. const TensorLayout& input, const TensorLayout& diff,
  1946. const TensorLayout& grad_x) = 0;
  1947. protected:
  1948. void check_exec(
  1949. const TensorLayout& input, const TensorLayout& diff,
  1950. const TensorLayout& grad_x, size_t workspace_in_bytes);
  1951. };
  1952. class RNNCellForward : public OperatorBase {
  1953. DEF_OPR_PARAM(RNNCell);
  1954. DEF_OPR_IMPL(RNNCellForward, OperatorBase, 6, 1);
  1955. public:
  1956. virtual void exec(
  1957. _megdnn_tensor_in input, _megdnn_tensor_in weight_ih,
  1958. _megdnn_tensor_in bias_ih, _megdnn_tensor_in hx,
  1959. _megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_hh,
  1960. _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0;
  1961. void deduce_layout(
  1962. const TensorLayout& input, const TensorLayout& weight_ih,
  1963. const TensorLayout& bias_ih, const TensorLayout& hx,
  1964. const TensorLayout& weight_hh, const TensorLayout& bias_hh,
  1965. TensorLayout& dst);
  1966. virtual size_t get_workspace_in_bytes(
  1967. const TensorLayout& input, const TensorLayout& weight_ih,
  1968. const TensorLayout& bias_ih, const TensorLayout& hx,
  1969. const TensorLayout& weight_hh, const TensorLayout& bias_hh,
  1970. const TensorLayout& dst) = 0;
  1971. protected:
  1972. void check_exec(
  1973. const TensorLayout& input, const TensorLayout& weight_ih,
  1974. const TensorLayout& bias_ih, const TensorLayout& hx,
  1975. const TensorLayout& weight_hh, const TensorLayout& bias_hh,
  1976. const TensorLayout& dst, size_t workspace_in_bytes);
  1977. };
  1978. using RNNCell = RNNCellForward;
  1979. class LSTMCellForward : public OperatorBase {
  1980. // DEF_OPR_PARAM(LSTMCell);
  1981. DEF_OPR_PARAM(Empty);
  1982. DEF_OPR_IMPL(LSTMCellForward, OperatorBase, 7, 3);
  1983. public:
  1984. virtual void exec(
  1985. _megdnn_tensor_in input, _megdnn_tensor_in weight_ih,
  1986. _megdnn_tensor_in bias_ih, _megdnn_tensor_in hx,
  1987. _megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_hh,
  1988. _megdnn_tensor_in cx, _megdnn_tensor_out h_new, _megdnn_tensor_out c_new,
  1989. _megdnn_tensor_out gates, _megdnn_workspace workspace) = 0;
  1990. void deduce_layout(
  1991. const TensorLayout& input, const TensorLayout& weight_ih,
  1992. const TensorLayout& bias_ih, const TensorLayout& hx,
  1993. const TensorLayout& weight_hh, const TensorLayout& bias_hh,
  1994. const TensorLayout& cx, TensorLayout& h_new, TensorLayout& c_new,
  1995. TensorLayout& gates);
  1996. virtual size_t get_workspace_in_bytes(
  1997. const TensorLayout& input, const TensorLayout& weight_ih,
  1998. const TensorLayout& bias_ih, const TensorLayout& hx,
  1999. const TensorLayout& weight_hh, const TensorLayout& bias_hh,
  2000. const TensorLayout& cx, const TensorLayout& h_new,
  2001. const TensorLayout& c_new, const TensorLayout& gates) = 0;
  2002. protected:
  2003. void check_exec(
  2004. const TensorLayout& input, const TensorLayout& weight_ih,
  2005. const TensorLayout& bias_ih, const TensorLayout& hx,
  2006. const TensorLayout& weight_hh, const TensorLayout& bias_hh,
  2007. const TensorLayout& cx, const TensorLayout& h_new,
  2008. const TensorLayout& c_new, const TensorLayout& gates,
  2009. size_t workspace_in_bytes);
  2010. };
  2011. using LSTMCell = LSTMCellForward;
  2012. class RNNForward : public OperatorBase {
  2013. DEF_OPR_PARAM(RNN);
  2014. DEF_OPR_IMPL(RNNForward, OperatorBase, 3, 3);
  2015. public:
  2016. virtual void exec(
  2017. _megdnn_tensor_in input, _megdnn_tensor_in hx,
  2018. _megdnn_tensor_in flatten_weights, _megdnn_tensor_out output,
  2019. _megdnn_tensor_out hy, _megdnn_tensor_out reserve_space,
  2020. _megdnn_workspace workspace) = 0;
  2021. void deduce_layout(
  2022. const TensorLayout& input, const TensorLayout& hx,
  2023. const TensorLayout& flatten_weights, TensorLayout& output, TensorLayout& hy,
  2024. TensorLayout& reserve_space);
  2025. virtual size_t get_workspace_in_bytes(
  2026. const TensorLayout& input, const TensorLayout& hx,
  2027. const TensorLayout& flatten_weights, const TensorLayout& output,
  2028. const TensorLayout& hy, const TensorLayout& reserve_space) = 0;
  2029. virtual size_t get_reserve_size_in_bytes(const TensorLayout& input) = 0;
  2030. protected:
  2031. void check_exec(
  2032. const TensorLayout& input, const TensorLayout& hx,
  2033. const TensorLayout& flatten_weights, const TensorLayout& output,
  2034. const TensorLayout& hy, const TensorLayout& reserve_space,
  2035. size_t workspace_in_bytes);
  2036. };
  2037. using RNN = RNNForward;
  2038. class RNNBackward : public OperatorBase {
  2039. DEF_OPR_PARAM(RNN);
  2040. DEF_OPR_IMPL(RNNBackward, OperatorBase, 7, 3);
  2041. public:
  2042. virtual void exec(
  2043. _megdnn_tensor_in x, _megdnn_tensor_in y, _megdnn_tensor_in hx,
  2044. _megdnn_tensor_in dy, _megdnn_tensor_in dhy,
  2045. _megdnn_tensor_in flatten_weights, _megdnn_tensor_in reserve_space,
  2046. _megdnn_tensor_out dx, _megdnn_tensor_out dhx, _megdnn_tensor_out dw,
  2047. _megdnn_workspace workspace) = 0;
  2048. void deduce_layout(
  2049. const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx,
  2050. const TensorLayout& dy, const TensorLayout& dhy,
  2051. const TensorLayout& flatten_weights, const TensorLayout& reserve_space,
  2052. TensorLayout& dx, TensorLayout& dhx, TensorLayout& dw);
  2053. virtual size_t get_workspace_in_bytes(
  2054. const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx,
  2055. const TensorLayout& dy, const TensorLayout& dhy,
  2056. const TensorLayout& flatten_weights, const TensorLayout& reserve_space,
  2057. const TensorLayout& dx, const TensorLayout& dhx,
  2058. const TensorLayout& dw) = 0;
  2059. protected:
  2060. void check_exec(
  2061. const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx,
  2062. const TensorLayout& dy, const TensorLayout& dhy,
  2063. const TensorLayout& flatten_weights, const TensorLayout& reserve_space,
  2064. const TensorLayout& dx, const TensorLayout& dhx, const TensorLayout& dw,
  2065. size_t workspace_in_bytes);
  2066. };
  2067. class LSTMForward : public OperatorBase {
  2068. DEF_OPR_PARAM(LSTM);
  2069. DEF_OPR_IMPL(LSTMForward, OperatorBase, 4, 4);
  2070. public:
  2071. virtual void exec(
  2072. _megdnn_tensor_in input, _megdnn_tensor_in hx, _megdnn_tensor_in cx,
  2073. _megdnn_tensor_in flatten_weights, _megdnn_tensor_out output,
  2074. _megdnn_tensor_out hy, _megdnn_tensor_out cy,
  2075. _megdnn_tensor_out reserve_space, _megdnn_workspace workspace) = 0;
  2076. void deduce_layout(
  2077. const TensorLayout& input, const TensorLayout& hx, const TensorLayout& cx,
  2078. const TensorLayout& flatten_weights, TensorLayout& output, TensorLayout& hy,
  2079. TensorLayout& cy, TensorLayout& reserve_space);
  2080. virtual size_t get_workspace_in_bytes(
  2081. const TensorLayout& input, const TensorLayout& hx, const TensorLayout& cx,
  2082. const TensorLayout& flatten_weights, const TensorLayout& output,
  2083. const TensorLayout& hy, const TensorLayout& cy,
  2084. const TensorLayout& reserve_space) = 0;
  2085. virtual size_t get_reserve_size_in_bytes(const TensorLayout& input) = 0;
  2086. protected:
  2087. void check_exec(
  2088. const TensorLayout& input, const TensorLayout& hx, const TensorLayout& cx,
  2089. const TensorLayout& flatten_weights, const TensorLayout& output,
  2090. const TensorLayout& hy, const TensorLayout& cy,
  2091. const TensorLayout& reserve_space, size_t workspace_in_bytes);
  2092. };
  2093. using LSTM = LSTMForward;
  2094. class LSTMBackward : public OperatorBase {
  2095. DEF_OPR_PARAM(LSTM);
  2096. DEF_OPR_IMPL(LSTMBackward, OperatorBase, 9, 4);
  2097. public:
  2098. virtual void exec(
  2099. _megdnn_tensor_in x, _megdnn_tensor_in y, _megdnn_tensor_in hx,
  2100. _megdnn_tensor_in cx, _megdnn_tensor_in dy, _megdnn_tensor_in dhy,
  2101. _megdnn_tensor_in dcy, _megdnn_tensor_in flatten_weights,
  2102. _megdnn_tensor_in reserve_space, _megdnn_tensor_out dx,
  2103. _megdnn_tensor_out dhx, _megdnn_tensor_out dcx, _megdnn_tensor_out dw,
  2104. _megdnn_workspace workspace) = 0;
  2105. void deduce_layout(
  2106. const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx,
  2107. const TensorLayout& cx, const TensorLayout& dy, const TensorLayout& dhy,
  2108. const TensorLayout& dcy, const TensorLayout& flatten_weights,
  2109. const TensorLayout& reserve_space, TensorLayout& dx, TensorLayout& dhx,
  2110. TensorLayout& dcx, TensorLayout& dw);
  2111. virtual size_t get_workspace_in_bytes(
  2112. const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx,
  2113. const TensorLayout& cx, const TensorLayout& dy, const TensorLayout& dhy,
  2114. const TensorLayout& dcy, const TensorLayout& flatten_weights,
  2115. const TensorLayout& reserve_space, const TensorLayout& dx,
  2116. const TensorLayout& dhx, const TensorLayout& dcx,
  2117. const TensorLayout& dw) = 0;
  2118. protected:
  2119. void check_exec(
  2120. const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx,
  2121. const TensorLayout& cx, const TensorLayout& dy, const TensorLayout& dhy,
  2122. const TensorLayout& dcy, const TensorLayout& flatten_weights,
  2123. const TensorLayout& reserve_space, const TensorLayout& dx,
  2124. const TensorLayout& dhx, const TensorLayout& dcx, const TensorLayout& dw,
  2125. size_t workspace_in_bytes);
  2126. };
  2127. class GroupNormBase : public OperatorBase {
  2128. DEF_OPR_IMPL_CTOR(GroupNormBase, OperatorBase);
  2129. DEF_OPR_PARAM(GroupNorm);
  2130. protected:
  2131. void deduce_layout_fwd(
  2132. const TensorLayout& data, const TensorLayout& weight,
  2133. const TensorLayout& bias, TensorLayout& dst, TensorLayout& mean,
  2134. TensorLayout& rstd);
  2135. void check_layout_fwd(
  2136. const TensorLayout& data, const TensorLayout& weight,
  2137. const TensorLayout& bias, const TensorLayout& dst, const TensorLayout& mean,
  2138. const TensorLayout& rstd);
  2139. };
  2140. class GroupNormForward : public GroupNormBase {
  2141. DEF_OPR_IMPL(GroupNormForward, GroupNormBase, 3, 3);
  2142. public:
  2143. virtual void exec(
  2144. _megdnn_tensor_in data, _megdnn_tensor_in weight, _megdnn_tensor_in bias,
  2145. _megdnn_tensor_out dst, _megdnn_tensor_out mean, _megdnn_tensor_out rstd,
  2146. _megdnn_workspace workspace) = 0;
  2147. MGE_WIN_DECLSPEC_FUC void deduce_layout(
  2148. const TensorLayout& data, const TensorLayout& weight,
  2149. const TensorLayout& bias, TensorLayout& dst, TensorLayout& mean,
  2150. TensorLayout& rstd);
  2151. virtual size_t get_workspace_in_bytes(
  2152. const TensorLayout& data, const TensorLayout& weight,
  2153. const TensorLayout& bias, const TensorLayout& dst, const TensorLayout& mean,
  2154. const TensorLayout& rstd) = 0;
  2155. protected:
  2156. void check_exec(
  2157. const TensorLayout& data, const TensorLayout& weight,
  2158. const TensorLayout& bias, const TensorLayout& dst, const TensorLayout& mean,
  2159. const TensorLayout& rstd, size_t workspace_in_bytes);
  2160. };
  2161. using GroupNorm = GroupNormForward;
  2162. class GroupNormBackward : public GroupNormBase {
  2163. DEF_OPR_IMPL(GroupNormBackward, GroupNormBase, 5, 3);
  2164. public:
  2165. virtual void exec(
  2166. _megdnn_tensor_in diff, _megdnn_tensor_in data, _megdnn_tensor_in weight,
  2167. _megdnn_tensor_in mean, _megdnn_tensor_in rstd, _megdnn_tensor_out ddata,
  2168. _megdnn_tensor_out dweight, _megdnn_tensor_out dbias,
  2169. _megdnn_workspace workspace) = 0;
  2170. void deduce_layout(
  2171. const TensorLayout& diff, const TensorLayout& data,
  2172. const TensorLayout& weight, const TensorLayout& mean,
  2173. const TensorLayout& rstd, TensorLayout& ddata, TensorLayout& dweight,
  2174. TensorLayout& dbias);
  2175. virtual size_t get_workspace_in_bytes(
  2176. const TensorLayout& diff, const TensorLayout& data,
  2177. const TensorLayout& weight, const TensorLayout& mean,
  2178. const TensorLayout& rstd, const TensorLayout& ddata,
  2179. const TensorLayout& dweight, const TensorLayout& dbias) = 0;
  2180. protected:
  2181. void check_exec(
  2182. const TensorLayout& diff, const TensorLayout& data,
  2183. const TensorLayout& weight, const TensorLayout& mean,
  2184. const TensorLayout& rstd, const TensorLayout& ddata,
  2185. const TensorLayout& dweight, const TensorLayout& dbias,
  2186. size_t workspace_in_bytes);
  2187. };
  2188. } // namespace megdnn
  2189. #include "megdnn/internal/opr_header_epilogue.h"
  2190. // vim: syntax=cpp.doxygen