You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

nn.h 82 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134
  1. /**
  2. * \file dnn/include/megdnn/oprs/nn.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #pragma once
  13. #include "megdnn/internal/opr_header_prologue.h"
  14. namespace megdnn {
  15. class SeparableConvBase : public OperatorBase {
  16. DEF_OPR_IMPL_CTOR(SeparableConvBase, OperatorBase);
  17. DEF_OPR_PARAM(SeparableConv);
  18. public:
  19. using Mode = Param::Mode;
  20. protected:
  21. void deduce_layout_fwd(
  22. const TensorLayout& src, const TensorLayout& filter_x,
  23. const TensorLayout& filter_y, TensorLayout& dst);
  24. void check_layout_fwd(
  25. const TensorLayout& src, const TensorLayout& filter_x,
  26. const TensorLayout& filter_y, const TensorLayout& dst);
  27. };
  28. class SeparableConvForward : public SeparableConvBase {
  29. DEF_OPR_IMPL(SeparableConvForward, SeparableConvBase, 3, 1);
  30. public:
  31. virtual void exec(
  32. _megdnn_tensor_in src, _megdnn_tensor_in filter_x,
  33. _megdnn_tensor_in filter_y, _megdnn_tensor_out dst,
  34. _megdnn_workspace workspace) = 0;
  35. void deduce_layout(
  36. const TensorLayout& src, const TensorLayout& filter_x,
  37. const TensorLayout& filter_y, TensorLayout& dst);
  38. virtual size_t get_workspace_in_bytes(
  39. const TensorLayout& src, const TensorLayout& filter_x,
  40. const TensorLayout& filter_y, const TensorLayout& dst) = 0;
  41. protected:
  42. void check_exec(
  43. const TensorLayout& src, const TensorLayout& filter_x,
  44. const TensorLayout& filter_y, const TensorLayout& dst,
  45. size_t workspace_in_bytes);
  46. };
  47. using SeparableConv = SeparableConvForward;
  48. namespace detail {
  49. struct PreprocessedFilter {
  50. //! user data; its lifetime should be bound to MegDNN Convolution
  51. //! operator
  52. void* algorithm_id;
  53. TensorNDArray tensors;
  54. };
  55. } // namespace detail
  56. /**
  57. * \brief base class for convolution operation
  58. *
  59. * This operator is supposed to perform convolution on arbitrary input
  60. * dimensions. The input/output format is N, C, dims..., and kernel format can
  61. * take two forms:
  62. * 1. OC, IC, dims..., for conventional dense convolution
  63. * 2. GROUP, OC_PER_GRP, IC_PER_GRP, dims... for sparse group convolution
  64. *
  65. * Currently, only 2D images are supported.
  66. */
  67. template <typename Parameter>
  68. class ConvolutionBase : public OperatorBase {
  69. DEF_OPR_IMPL_CTOR(ConvolutionBase, OperatorBase);
  70. using Param = Parameter;
  71. public:
  72. Param& param() { return m_param; }
  73. const Param& param() const { return m_param; }
  74. protected:
  75. Param m_param;
  76. public:
  77. static constexpr size_t MAX_SPATIAL_DIM = 2;
  78. using Mode = typename Param::Mode;
  79. struct CanonizedFilterMeta {
  80. DType dtype;
  81. typename Param::Format format;
  82. uint32_t
  83. //! whether filter should be flipped (i.e. is CONVOLUTION)
  84. should_flip,
  85. group, //!< number of groups
  86. icpg, //!< input channels per group
  87. ocpg, //!< output channels per group
  88. spatial_ndim, stride[MAX_SPATIAL_DIM], padding[MAX_SPATIAL_DIM],
  89. //! spatial dim
  90. spatial[MAX_SPATIAL_DIM], dilation[MAX_SPATIAL_DIM],
  91. //! spatial dim with dilation applied
  92. dilated_spatial[MAX_SPATIAL_DIM];
  93. //! T should be a ConvolutionBase<Z>::CanonizedFilterMeta
  94. template <typename T>
  95. void copy_from(const T& b) {
  96. dtype = b.dtype;
  97. format = b.format;
  98. should_flip = b.should_flip;
  99. group = b.group;
  100. icpg = b.icpg;
  101. ocpg = b.ocpg;
  102. spatial_ndim = b.spatial_ndim;
  103. memcpy(stride, b.stride, sizeof(stride));
  104. memcpy(padding, b.padding, sizeof(padding));
  105. memcpy(spatial, b.spatial, sizeof(spatial));
  106. memcpy(dilation, b.dilation, sizeof(dilation));
  107. memcpy(dilated_spatial, b.dilated_spatial, sizeof(dilated_spatial));
  108. }
  109. bool operator==(const CanonizedFilterMeta& b) const {
  110. bool flag = true;
  111. flag = flag && (format == b.format);
  112. flag = flag && (dtype == b.dtype);
  113. flag = flag && (should_flip == b.should_flip);
  114. flag = flag && (group == b.group);
  115. flag = flag && (icpg == b.icpg);
  116. flag = flag && (ocpg == b.ocpg);
  117. flag = flag && (spatial_ndim == b.spatial_ndim);
  118. if (flag) {
  119. for (uint32_t i = 0; i < spatial_ndim; ++i) {
  120. flag = flag && (stride[i] == b.stride[i]);
  121. flag = flag && (padding[i] == b.padding[i]);
  122. flag = flag && (spatial[i] == b.spatial[i]);
  123. flag = flag && (dilation[i] == b.dilation[i]);
  124. flag = flag && (dilated_spatial[i] == b.dilated_spatial[i]);
  125. }
  126. }
  127. return flag;
  128. }
  129. };
  130. using PreprocessedFilter = detail::PreprocessedFilter;
  131. protected:
  132. // Check or deduce output DType
  133. void check_or_deduce_dtype_fwd(DType src, DType filter, DType& dst) const;
  134. CanonizedFilterMeta deduce_layout_fwd(
  135. const TensorLayout& src, const TensorLayout& filter,
  136. TensorLayout& dst) const;
  137. CanonizedFilterMeta check_layout_fwd(
  138. const TensorLayout& src, const TensorLayout& filter,
  139. const TensorLayout& dst) const;
  140. CanonizedFilterMeta make_canonized_filter_meta(
  141. size_t src_ndim, const TensorLayout& filter) const;
  142. };
  143. class MaskPropagate : public OperatorBase {
  144. DEF_OPR_IMPL(MaskPropagate, OperatorBase, 1, 1);
  145. DEF_OPR_PARAM(MaskPropagate);
  146. public:
  147. virtual void exec(
  148. _megdnn_tensor_in src, _megdnn_tensor_out dst,
  149. _megdnn_workspace workspace) = 0;
  150. virtual size_t get_workspace_in_bytes(
  151. const TensorLayout& src, const TensorLayout& dst) = 0;
  152. void deduce_layout(const TensorLayout& src, TensorLayout& dst);
  153. };
  154. /**
  155. * \brief ConvolutionForward Operator with 0/1 Mask matrix
  156. */
  157. class MaskConvForward : public ConvolutionBase<param::Convolution> {
  158. DEF_OPR_IMPL(MaskConvForward, ConvolutionBase, 3, 1);
  159. public:
  160. virtual void exec(
  161. _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_in mask,
  162. _megdnn_tensor_out dst, _megdnn_workspace worksapce) = 0;
  163. virtual size_t get_workspace_in_bytes(
  164. const TensorLayout& src, const TensorLayout& filter,
  165. const TensorLayout& mask, const TensorLayout& dst) = 0;
  166. void deduce_dtype(DType src, DType filter, DType mask, DType& dst);
  167. void deduce_layout(
  168. const TensorLayout& src, const TensorLayout& filter,
  169. const TensorLayout& mask, TensorLayout& dst);
  170. protected:
  171. CanonizedFilterMeta check_exec(
  172. const TensorLayout& src, const TensorLayout& filter,
  173. const TensorLayout& mask, const TensorLayout& dst,
  174. size_t workspace_in_bytes);
  175. };
  176. using MaskConvolution = MaskConvForward;
  177. /**
  178. * \brief ConvolutionForward operator.
  179. */
  180. class ConvolutionForward : public ConvolutionBase<param::Convolution>,
  181. public detail::MultiAlgoOpr<ConvolutionForward, 3> {
  182. DEF_OPR_IMPL(ConvolutionForward, ConvolutionBase, 2, 1);
  183. public:
  184. /**
  185. * \param[in] src (n, ic, ih, iw)
  186. * \param[in] filter (oc, ic, fh, fw)
  187. * \param[in] preprocessed_filter if weight no preprocessed it will be
  188. * nullptr, else the preprocessed weights store in the tensors of
  189. * preprocessed_filter.
  190. * \param[in] workspace if weight no preprocessed
  191. * (preprocessed_filter == nullptr), The size of the workspace satisfies the
  192. * situation that weights is not processed, other wise the size of workspace
  193. * satisfies the situation that weights is preprocessed
  194. * \param[out] dst (n, oc, oh, ow)
  195. */
  196. virtual void exec(
  197. _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_out dst,
  198. const PreprocessedFilter* preprocessed_filter,
  199. _megdnn_workspace workspace) = 0;
  200. /**
  201. * \brief execute weight preprocessing, read weights form filter and write
  202. * to preprocessed_filter after preprocessed.
  203. *
  204. * \praram[in] workspace the needed tmp workspace when exec_preprocess
  205. */
  206. virtual void exec_preprocess(
  207. const TensorLayout& src_layout, _megdnn_tensor_in filter,
  208. const TensorLayout& dst_layout, PreprocessedFilter* preprocessed_filter,
  209. _megdnn_workspace workspace) = 0;
  210. void deduce_dtype(DType src, DType filter, DType& dst);
  211. void deduce_layout(
  212. const TensorLayout& src, const TensorLayout& filter, TensorLayout& dst);
  213. /**
  214. * \brief query the workspace needed when executing the opr, if the weights
  215. * are preprocessed the preprocessed_filter will not be nullptr, else it
  216. * will be nullptr, the workspace size maybe different whether weights are
  217. * preprocessed
  218. *
  219. * \return the size of workspace needed when executing
  220. */
  221. virtual size_t get_workspace_in_bytes(
  222. const TensorLayout& src, const TensorLayout& filter,
  223. const TensorLayout& dst, const PreprocessedFilter* preprocessed_filter) = 0;
  224. /**
  225. * \brief deduce the preprocessed filter layouts according to the src,
  226. * filter and dst layout, the result may contain multi layouts when the
  227. * weights is not one
  228. *
  229. * \return SmallVector<TensorLayout> Derive the layouts of weight
  230. * preprocessing, return empty if preprocessing is not needed.
  231. */
  232. virtual SmallVector<TensorLayout> deduce_preprocessed_filter_layout(
  233. const TensorLayout& src, const TensorLayout& filter,
  234. const TensorLayout& dst) = 0;
  235. /**
  236. * \brief query the workspace needed when preprocessing the weights,
  237. * according to the return size, a _megdnn_workspace will be created and
  238. * passed through exec_preprocess
  239. *
  240. * \return the size of workspace needed when preprocessing
  241. */
  242. virtual size_t get_preprocess_workspace_in_bytes(
  243. const TensorLayout& src, const TensorLayout& filter,
  244. const TensorLayout& dst) = 0;
  245. static Algorithm::OprType get_opr_type() {
  246. return Algorithm::OprType::CONVOLUTION_FORWARD;
  247. }
  248. protected:
  249. CanonizedFilterMeta check_exec(
  250. const TensorLayout& src, const TensorLayout& filter,
  251. const TensorLayout& dst, size_t workspace_in_bytes,
  252. const PreprocessedFilter* preprocessed_filter);
  253. };
  254. using Convolution = ConvolutionForward;
  255. /**
  256. * \brief ConvolutionBackwardData operator.
  257. *
  258. * Calculating the gradient wrt. convolution input data.
  259. */
  260. class ConvolutionBackwardData
  261. : public ConvolutionBase<param::Convolution>,
  262. public detail::MultiAlgoOpr<ConvolutionBackwardData, 3> {
  263. DEF_OPR_IMPL(ConvolutionBackwardData, ConvolutionBase, 2, 1);
  264. public:
  265. /**
  266. * \param[in] filter (oc, ic, fh, fw)
  267. * \param[in] diff (n, oc, oh, ow)
  268. * \param[out] grad (n, ic, ih, iw)
  269. */
  270. virtual void exec(
  271. _megdnn_tensor_in filter, _megdnn_tensor_in diff, _megdnn_tensor_out grad,
  272. _megdnn_workspace workspace) = 0;
  273. virtual size_t get_workspace_in_bytes(
  274. const TensorLayout& filter, const TensorLayout& diff,
  275. const TensorLayout& grad) = 0;
  276. void deduce_dtype(DType filter, DType diff, DType& grad);
  277. void deduce_layout(
  278. const TensorLayout& filter, const TensorLayout& diff, TensorLayout& grad);
  279. static Algorithm::OprType get_opr_type() {
  280. return Algorithm::OprType::CONVOLUTION_BACKWARD_DATA;
  281. }
  282. protected:
  283. CanonizedFilterMeta check_exec(
  284. const TensorLayout& filter, const TensorLayout& diff,
  285. const TensorLayout& grad, size_t workspace_in_bytes);
  286. };
  287. /**
  288. * \brief ConvolutionBackwardFilter operator.
  289. *
  290. * Calculating the gradient wrt. convolution filter.
  291. */
  292. class ConvolutionBackwardFilter
  293. : public ConvolutionBase<param::Convolution>,
  294. public detail::MultiAlgoOpr<ConvolutionBackwardFilter, 3> {
  295. DEF_OPR_IMPL(ConvolutionBackwardFilter, ConvolutionBase, 2, 1);
  296. public:
  297. /**
  298. * \param[in] src (n, ic, ih, iw)
  299. * \param[in] diff (n, oc, oh, ow)
  300. * \param[out] grad (oc, ic, fh, fw)
  301. */
  302. virtual void exec(
  303. _megdnn_tensor_in src, _megdnn_tensor_in diff, _megdnn_tensor_out grad,
  304. _megdnn_workspace workspace) = 0;
  305. virtual size_t get_workspace_in_bytes(
  306. const TensorLayout& src, const TensorLayout& diff,
  307. const TensorLayout& grad) = 0;
  308. static Algorithm::OprType get_opr_type() {
  309. return Algorithm::OprType::CONVOLUTION_BACKWARD_FILTER;
  310. }
  311. protected:
  312. CanonizedFilterMeta check_exec(
  313. const TensorLayout& src, const TensorLayout& diff, const TensorLayout& grad,
  314. size_t workspace_in_bytes);
  315. };
  316. /**
  317. * \brief ConvolutionBias operator
  318. */
  319. class ConvBiasForward : public ConvolutionBase<param::ConvBias>,
  320. public detail::MultiAlgoOpr<ConvBiasForward, 5> {
  321. DEF_OPR_IMPL(ConvBiasForward, ConvolutionBase, 4, 1);
  322. public:
  323. /**
  324. * \param[in] src (n, ic, ih, iw) or (n, ih, iw, ic)
  325. * \param[in] filter (oc, ic, fh, fw) or (oc, fh, fw, ic) or (oc/4, fh, fw,
  326. * 4 * ic)
  327. * \param[in] bias (1, oc, 1, 1)
  328. * \param[in] z same as dst
  329. * \param[in] preprocessed_filter if weight no preprocessed it will be
  330. * nullptr, else the preprocessed weights store in the tensors of
  331. * preprocessed_filter.
  332. * \param[in] workspace if weight no preprocessed
  333. * (preprocessed_filter == nullptr), The size of the workspace satisfies the
  334. * situation that weights is not processed, other wise the size of workspace
  335. * satisfies the situation that weights is preprocessed
  336. * \param[out] dst (n, oc, oh, ow) or (n, oh, ow, oc)
  337. *
  338. * \note if the format is NCHW_WINOGRAD, the filter layout is (alphah,
  339. * alphaw, oc, ic)
  340. */
  341. virtual void exec(
  342. _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_in bias,
  343. _megdnn_tensor_in z, _megdnn_tensor_out dst,
  344. const PreprocessedFilter* preprocessed_filter,
  345. _megdnn_workspace workspace) = 0;
  346. /**
  347. * \brief execute weight preprocessing, read weights form filter and bias,
  348. * write to preprocessed_filter after preprocessed.
  349. *
  350. * \praram[in] workspace the needed tmp workspace when exec_preprocess
  351. * running, the size is got by get_preprocess_workspace_in_bytes
  352. */
  353. virtual void exec_preprocess(
  354. const TensorLayout& src_layout, _megdnn_tensor_in filter,
  355. _megdnn_tensor_in bias, const TensorLayout& z_layout,
  356. const TensorLayout& dst_layout, PreprocessedFilter* preprocessed_filter,
  357. _megdnn_workspace workspace) = 0;
  358. void deduce_dtype(DType src, DType filter, DType bias, DType z, DType& dst);
  359. void deduce_layout(
  360. const TensorLayout& src, const TensorLayout& filter,
  361. const TensorLayout& bias, const TensorLayout& z, TensorLayout& dst);
  362. /**
  363. * \brief query the workspace needed when executing the opr, if the weights
  364. * are preprocessed the preprocessed_filter will not be nullptr, else it
  365. * will be nullptr, the workspace size maybe different whether weights are
  366. * preprocessed
  367. *
  368. * \return the size of workspace needed when executing
  369. */
  370. virtual size_t get_workspace_in_bytes(
  371. const TensorLayout& src, const TensorLayout& filter,
  372. const TensorLayout& bias, const TensorLayout& z, const TensorLayout& dst,
  373. const PreprocessedFilter* preprocessed_filter) = 0;
  374. /**
  375. * \brief query the workspace needed when pre-processing the weights,
  376. * according to the return size, a _megdnn_workspace will be created and
  377. * passed through exec_preprocess
  378. *
  379. * \return the size of workspace needed when pre-processing
  380. */
  381. virtual size_t get_preprocess_workspace_in_bytes(
  382. const TensorLayout& src, const TensorLayout& filter,
  383. const TensorLayout& bias, const TensorLayout& z,
  384. const TensorLayout& dst) = 0;
  385. /**
  386. * \brief deduce the pre-processed filter layouts according to the src,
  387. * filter and dst layout, which may contain multi layouts when the weights
  388. * is not one
  389. *
  390. * \return SmallVector<TensorLayout> Derive the layouts of weight
  391. * preprocessing, return empty if preprocessing is not needed.
  392. */
  393. virtual SmallVector<TensorLayout> deduce_preprocessed_filter_layout(
  394. const TensorLayout& src, const TensorLayout& filter,
  395. const TensorLayout& bias, const TensorLayout& z,
  396. const TensorLayout& dst) = 0;
  397. enum class BiasMode : uint32_t {
  398. NO_BIAS = 0, //!< no bias
  399. BROADCAST_CHANNEL_BIAS, //!< broadcast channel bias, [1, c, 1, 1]
  400. BIAS //!< [N, C, H, W]
  401. };
  402. //! param for winograd algos.
  403. struct WinogradParam {
  404. uint32_t channel_block_size;
  405. uint32_t output_block_size;
  406. uint32_t tile_size;
  407. bool operator==(const WinogradParam& rhs) const {
  408. return channel_block_size == rhs.channel_block_size &&
  409. output_block_size == rhs.output_block_size &&
  410. tile_size == rhs.tile_size;
  411. }
  412. std::string to_string() const;
  413. };
  414. static constexpr WinogradParam INVALID_WINOGRAD_PARAM = {0, 0, 0};
  415. struct DirectParam {
  416. std::string to_string() const { return ""; }
  417. };
  418. struct MatmulParam {
  419. std::string to_string() const { return ""; }
  420. };
  421. struct DefaultParam {
  422. std::string to_string() const { return ""; }
  423. };
  424. //! get algo name, the format is ParamTrait<T>::category:base:p.to_string()
  425. //! \warning: base must not contain :.
  426. template <typename T>
  427. static std::string algo_name(
  428. const std::string& base, const T& p,
  429. param::ConvBias::Format format = param::ConvBias::Format::NCHW);
  430. /*!
  431. * \brief parse algo_name and get WinogradParam from algo name.
  432. *
  433. * \param algo name string
  434. * \return WinogradParam parsed from algo name, use pattern
  435. * winograd:base:m:tile_size.
  436. *
  437. * \warning: INVALID_WINOGRAD_PARAM returns if the algo_name is not matched.
  438. */
  439. static WinogradParam parse_winograd_name(const std::string& algo_name);
  440. /**
  441. * @brief find if there is nchw_nchwxx conv kernel optimized for argment,
  442. * nchw44 used for arm, nchw88 used for x86
  443. *
  444. * @param src_dtype conv feature map data type
  445. * @param filter_dtype conv filter or weight data type
  446. * @param dst_dtype output data type
  447. * @param fm filter meta param
  448. * @param bias_mode bias mode, no_bias or broadcast or bias
  449. * @param nonline_mode identity or relu or h_swish or sigmoid
  450. * @return true, found a kernel
  451. * @return false, can`t found any kernel
  452. */
  453. static bool is_nchw_nchwxx_optimized(
  454. const DTypeEnum src_dtype, const DTypeEnum filter_dtype,
  455. const DTypeEnum dst_dtype,
  456. const ConvolutionBase<param::Convolution>::CanonizedFilterMeta& fm,
  457. const ConvBiasForward::BiasMode bias_mode,
  458. const param::ConvBias::NonlineMode nonline_mode);
  459. static Algorithm::OprType get_opr_type() {
  460. return Algorithm::OprType::CONVBIAS_FORWARD;
  461. }
  462. protected:
  463. CanonizedFilterMeta check_exec(
  464. const TensorLayout& src, const TensorLayout& filter,
  465. const TensorLayout& bias, const TensorLayout& z, const TensorLayout& dst,
  466. size_t workspace_in_bytes, const PreprocessedFilter* preprocessed_filter);
  467. CanonizedFilterMeta check_exec_allow_noncontiguous(
  468. const TensorLayout& src, const TensorLayout& filter,
  469. const TensorLayout& bias, const TensorLayout& z, const TensorLayout& dst,
  470. size_t workspace_in_bytes, const PreprocessedFilter* preprocessed_filter);
  471. };
  472. using ConvBias = ConvBiasForward;
  473. /**
  474. * \brief base class for Conv - Nonline - Pooling
  475. */
  476. class ConvPoolingBase : public OperatorBase {
  477. DEF_OPR_IMPL_CTOR(ConvPoolingBase, OperatorBase);
  478. /**
  479. * \ Param::Method: Two methods to fetch the input data.
  480. * Default methods is WITH_TEXTURE_OBJ.
  481. * If you want to use WITH_SHARED_MEM mode,
  482. * please make sure that the size of
  483. * [ all of the fliter kernels + a channel
  484. * of input data + a channel of output data]
  485. * should be no large than 38KB.
  486. * And the pooling mode should not be "MAX".
  487. */
  488. DEF_OPR_PARAM(ConvPooling);
  489. protected:
  490. virtual void deduce_layout(
  491. const TensorLayout& src, const TensorLayout& filter,
  492. const TensorLayout& bias, TensorLayout& dst) = 0;
  493. virtual void check_layout(
  494. const TensorLayout& src, const TensorLayout& filter,
  495. const TensorLayout& bias, TensorLayout& dst,
  496. size_t workspace_limit_in_bytes) = 0;
  497. };
  498. class ConvPoolingForward : public ConvPoolingBase {
  499. DEF_OPR_IMPL(ConvPoolingForward, ConvPoolingBase, 2, 1);
  500. public:
  501. /**
  502. * \param[in] src input tensor
  503. * \param[out] dst output tensor
  504. */
  505. virtual void exec(
  506. const _megdnn_in TensorND src, const _megdnn_in TensorND filter,
  507. const _megdnn_in TensorND bias, _megdnn_out TensorND dst,
  508. _megdnn_out Workspace workspace) = 0;
  509. virtual void deduce_layout(
  510. const TensorLayout& src, const TensorLayout& filter,
  511. const TensorLayout& bias, TensorLayout& dst) = 0;
  512. virtual size_t get_workspace_in_bytes(
  513. const TensorLayout& src, const TensorLayout& filter,
  514. const TensorLayout& bias, const TensorLayout& dst) = 0;
  515. protected:
  516. virtual void check_layout(
  517. const TensorLayout& src, const TensorLayout& filter,
  518. const TensorLayout& bias, TensorLayout& dst,
  519. size_t workspace_limit_in_bytes) = 0;
  520. };
  521. using ConvPooling = ConvPoolingForward;
  522. class GroupLocalBase : public OperatorBase {
  523. DEF_OPR_IMPL_CTOR(GroupLocalBase, OperatorBase);
  524. DEF_OPR_PARAM(Convolution);
  525. public:
  526. using Mode = Param::Mode;
  527. protected:
  528. void deduce_layout_fwd(
  529. const TensorLayout& src, const TensorLayout& filter, TensorLayout& dst);
  530. void check_layout_fwd(
  531. const TensorLayout& src, const TensorLayout& filter,
  532. const TensorLayout& dst);
  533. };
  534. class GroupLocalForward : public GroupLocalBase {
  535. DEF_OPR_IMPL(GroupLocalForward, GroupLocalBase, 2, 1);
  536. public:
  537. /**
  538. * \param[in] src (N, IC, IH, IW)
  539. * \param[in] filter (G, OH, OW, IC/G, FH, FW, OC/G)
  540. * \param[out] dst (N, OC, OH, OW)
  541. **/
  542. virtual void exec(
  543. _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_out dst,
  544. _megdnn_workspace workspace) = 0;
  545. void deduce_layout(
  546. const TensorLayout& src, const TensorLayout& filter, TensorLayout& dst) {
  547. deduce_layout_fwd(src, filter, dst);
  548. }
  549. virtual size_t get_workspace_in_bytes(
  550. const TensorLayout& src, const TensorLayout& filter,
  551. const TensorLayout& dst) = 0;
  552. protected:
  553. void check_exec(
  554. const TensorLayout& src, const TensorLayout& filter,
  555. const TensorLayout& dst, size_t workspace_in_bytes);
  556. };
  557. using GroupLocal = GroupLocalForward;
  558. class GroupLocalBackwardData : public GroupLocalBase {
  559. DEF_OPR_IMPL(GroupLocalBackwardData, GroupLocalBase, 2, 1);
  560. public:
  561. virtual void exec(
  562. _megdnn_tensor_in filter, _megdnn_tensor_in diff, _megdnn_tensor_out grad,
  563. _megdnn_workspace workspace) = 0;
  564. virtual size_t get_workspace_in_bytes(
  565. const TensorLayout& filter, const TensorLayout& diff,
  566. const TensorLayout& grad) = 0;
  567. protected:
  568. void check_exec(
  569. const TensorLayout& filter, const TensorLayout& diff,
  570. const TensorLayout& grad, size_t workspace_in_bytes);
  571. };
  572. class GroupLocalBackwardFilter : public GroupLocalBase {
  573. DEF_OPR_IMPL(GroupLocalBackwardFilter, GroupLocalBase, 2, 1);
  574. public:
  575. virtual void exec(
  576. _megdnn_tensor_in src, _megdnn_tensor_in diff, _megdnn_tensor_out grad,
  577. _megdnn_workspace workspace) = 0;
  578. virtual size_t get_workspace_in_bytes(
  579. const TensorLayout& src, const TensorLayout& diff,
  580. const TensorLayout& grad) = 0;
  581. protected:
  582. void check_exec(
  583. const TensorLayout& filter, const TensorLayout& diff,
  584. const TensorLayout& grad, size_t workspace_in_bytes);
  585. };
  586. class Images2NeibsBase : public OperatorBase {
  587. DEF_OPR_IMPL_CTOR(Images2NeibsBase, OperatorBase);
  588. DEF_OPR_PARAM(Images2Neibs);
  589. protected:
  590. void deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst);
  591. void check_layout_fwd(const TensorLayout& filter, const TensorLayout& dst);
  592. };
  593. class Images2NeibsForward : public Images2NeibsBase {
  594. DEF_OPR_IMPL(Images2NeibsForward, Images2NeibsBase, 1, 1);
  595. public:
  596. /**
  597. * \param[in] src (N, C, IH, IW)
  598. * \param[out] dst (N, C, OH, OW, window_h, window_w)
  599. *
  600. * \see
  601. * http://deeplearning.net/software/theano/library/tensor/nnet/neighbours.html
  602. *
  603. * \f$ dst_{n, c, oh, ow, wh, ww} = src_{n, c, ih+wh, iw+fw}\f$,
  604. * where \f$ ih=-pad_h+oh*stride_h+(wh-1)*(dilation_h-1),
  605. * iw=-pad_w+ow*stride_w+(ww-1)*(dilation_w-1)\f$.
  606. */
  607. virtual void exec(
  608. _megdnn_tensor_in src, _megdnn_tensor_out dst,
  609. _megdnn_workspace workspace) = 0;
  610. virtual size_t get_workspace_in_bytes(
  611. const TensorLayout& src, const TensorLayout& dst) = 0;
  612. void deduce_layout(const TensorLayout& src, TensorLayout& dst);
  613. protected:
  614. void check_exec(
  615. const TensorLayout& src, const TensorLayout& dst,
  616. size_t workspace_in_bytes);
  617. };
  618. using Images2Neibs = Images2NeibsForward;
  619. class Images2NeibsBackward : public Images2NeibsBase {
  620. DEF_OPR_IMPL(Images2NeibsBackward, Images2NeibsBase, 1, 1);
  621. public:
  622. /**
  623. * \param[in] diff the backpropagated gradient wrt. dst
  624. * \param[out] grad the backpropagated gradient wrt. src
  625. */
  626. virtual void exec(
  627. _megdnn_tensor_in diff, _megdnn_tensor_out grad,
  628. _megdnn_workspace workspace) = 0;
  629. virtual size_t get_workspace_in_bytes(
  630. const TensorLayout& diff, const TensorLayout& grad) = 0;
  631. protected:
  632. void check_exec(
  633. const TensorLayout& diff, const TensorLayout& grad,
  634. size_t workspace_in_bytes);
  635. };
  636. class SlidingWindowTransposeBase : public OperatorBase {
  637. DEF_OPR_IMPL_CTOR(SlidingWindowTransposeBase, OperatorBase);
  638. DEF_OPR_PARAM(SlidingWindowTranspose);
  639. protected:
  640. void deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst);
  641. void check_layout_fwd(const TensorLayout& filter, const TensorLayout& dst);
  642. };
  643. class SlidingWindowTransposeForward : public SlidingWindowTransposeBase {
  644. DEF_OPR_IMPL(SlidingWindowTransposeForward, SlidingWindowTransposeBase, 1, 1);
  645. public:
  646. /**
  647. * \param[in] src (N, C, IH, IW, window_h, window_w)
  648. * \param[out] dst (N, C, OH, OW)
  649. */
  650. virtual void exec(
  651. _megdnn_tensor_in src, _megdnn_tensor_out dst,
  652. _megdnn_workspace workspace) = 0;
  653. virtual size_t get_workspace_in_bytes(
  654. const TensorLayout& src, const TensorLayout& dst) = 0;
  655. void deduce_layout(const TensorLayout& src, TensorLayout& dst);
  656. protected:
  657. void check_exec(
  658. const TensorLayout& src, const TensorLayout& dst,
  659. size_t workspace_in_bytes);
  660. };
  661. using SlidingWindowTranspose = SlidingWindowTransposeForward;
  662. class SlidingWindowTransposeBackward : public SlidingWindowTransposeBase {
  663. DEF_OPR_IMPL(SlidingWindowTransposeBackward, SlidingWindowTransposeBase, 1, 1);
  664. public:
  665. /**
  666. * \param[in] diff the backpropagated gradient wrt. dst
  667. * \param[out] grad the backpropagated gradient wrt. src
  668. */
  669. virtual void exec(
  670. _megdnn_tensor_in diff, _megdnn_tensor_out grad,
  671. _megdnn_workspace workspace) = 0;
  672. virtual size_t get_workspace_in_bytes(
  673. const TensorLayout& diff, const TensorLayout& grad) = 0;
  674. protected:
  675. void check_exec(
  676. const TensorLayout& diff, const TensorLayout& grad,
  677. size_t workspace_in_bytes);
  678. };
  679. /**
  680. * \brief base class for Pooling
  681. */
  682. class PoolingBase : public OperatorBase {
  683. DEF_OPR_IMPL_CTOR(PoolingBase, OperatorBase);
  684. DEF_OPR_PARAM(Pooling);
  685. public:
  686. using Mode = Param::Mode;
  687. protected:
  688. void deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst);
  689. void check_layout_fwd(const TensorLayout& src, const TensorLayout& dst);
  690. };
  691. class PoolingForward : public PoolingBase,
  692. public detail::MultiAlgoOpr<PoolingForward, 2> {
  693. DEF_OPR_IMPL(PoolingForward, PoolingBase, 1, 1);
  694. public:
  695. /**
  696. * \param[in] src input tensor
  697. * \param[out] dst output tensor
  698. */
  699. virtual void exec(
  700. _megdnn_tensor_in src, _megdnn_tensor_out dst,
  701. _megdnn_workspace workspace) = 0;
  702. void deduce_layout(const TensorLayout& src, TensorLayout& dst);
  703. virtual size_t get_workspace_in_bytes(
  704. const TensorLayout& src, const TensorLayout& dst) = 0;
  705. static Algorithm::OprType get_opr_type() {
  706. return Algorithm::OprType::POOLING_FORWARD;
  707. }
  708. protected:
  709. void check_exec(
  710. const TensorLayout& src, const TensorLayout& dst,
  711. size_t workspace_in_bytes);
  712. };
  713. using Pooling = PoolingForward;
  714. class PoolingBackward : public PoolingBase,
  715. public detail::MultiAlgoOpr<PoolingBackward, 4> {
  716. DEF_OPR_IMPL(PoolingBackward, PoolingBase, 3, 1);
  717. public:
  718. /**
  719. * \param[in] src the `src' parameter in PoolingForward::exec
  720. * \param[in] dst the `dst' parameter in PoolingForward::exec
  721. * \param[in] diff the backpropagated gradient wrt. dst
  722. * \param[out] grad the backpropagated gradient wrt. src
  723. */
  724. virtual void exec(
  725. _megdnn_tensor_in src, _megdnn_tensor_in dst, _megdnn_tensor_in diff,
  726. _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
  727. virtual size_t get_workspace_in_bytes(
  728. const TensorLayout& src, const TensorLayout& dst, const TensorLayout& diff,
  729. const TensorLayout& grad) = 0;
  730. static Algorithm::OprType get_opr_type() {
  731. return Algorithm::OprType::POOLING_BACKWARD;
  732. }
  733. protected:
  734. void check_exec(
  735. const TensorLayout& src, const TensorLayout& dst, const TensorLayout& diff,
  736. const TensorLayout& grad, size_t workspace_in_bytes);
  737. };
  738. /**
  739. * \brief base class for AdaptivePooling
  740. */
  741. class AdaptivePoolingBase : public OperatorBase {
  742. DEF_OPR_IMPL_CTOR(AdaptivePoolingBase, OperatorBase);
  743. DEF_OPR_PARAM(AdaptivePooling);
  744. protected:
  745. param::Pooling deduce_pooling_param(
  746. const TensorLayout& src, const TensorLayout& dst);
  747. };
  748. class AdaptivePoolingForward : public AdaptivePoolingBase {
  749. DEF_OPR_IMPL(AdaptivePoolingForward, AdaptivePoolingBase, 1, 1);
  750. public:
  751. /**
  752. * \param[in] src input tensor
  753. * \param[out] dst output tensor
  754. */
  755. virtual void exec(
  756. _megdnn_tensor_in src, _megdnn_tensor_out dst,
  757. _megdnn_workspace workspace) = 0;
  758. virtual size_t get_workspace_in_bytes(
  759. const TensorLayout& src, const TensorLayout& dst) = 0;
  760. };
  761. using AdaptivePooling = AdaptivePoolingForward;
  762. class AdaptivePoolingBackward : public AdaptivePoolingBase {
  763. DEF_OPR_IMPL(AdaptivePoolingBackward, AdaptivePoolingBase, 3, 1);
  764. public:
  765. /**
  766. * \param[in] src the `src' parameter in AdaptivePoolingForward::exec
  767. * \param[in] dst the `dst' parameter in AdaptivePoolingForward::exec
  768. * \param[in] diff the backpropagated gradient wrt. dst
  769. * \param[out] grad the backpropagated gradient wrt. src
  770. */
  771. virtual void exec(
  772. _megdnn_tensor_in src, _megdnn_tensor_in dst, _megdnn_tensor_in diff,
  773. _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
  774. virtual size_t get_workspace_in_bytes(
  775. const TensorLayout& src, const TensorLayout& dst, const TensorLayout& diff,
  776. const TensorLayout& grad) = 0;
  777. };
  778. /**
  779. * \brief base class for Local
  780. */
  781. class LocalBase : public OperatorBase {
  782. DEF_OPR_IMPL_CTOR(LocalBase, OperatorBase);
  783. DEF_OPR_PARAM(Convolution);
  784. public:
  785. using Mode = Param::Mode;
  786. protected:
  787. void deduce_layout_fwd(
  788. const TensorLayout& src, const TensorLayout& filter, TensorLayout& dst);
  789. void check_layout_fwd(
  790. const TensorLayout& src, const TensorLayout& filter,
  791. const TensorLayout& dst);
  792. };
  793. class LocalForward : public LocalBase {
  794. DEF_OPR_IMPL(LocalForward, LocalBase, 2, 1);
  795. public:
  796. /**
  797. * \param[in] src (n, ic, ih, iw)
  798. * \param[in] filter (oh, ow, ic, fh, fw, oc)
  799. * \param[out] dst (n, oc, oh, ow)
  800. */
  801. virtual void exec(
  802. _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_out dst,
  803. _megdnn_workspace workspace) = 0;
  804. /**
  805. * \brief Deducing output tensor layouts from input tensor layouts.
  806. *
  807. * Be aware that the first and second dimension of `filter' are ignored
  808. * when deducing `dst' layout.
  809. */
  810. void deduce_layout(
  811. const TensorLayout& src, const TensorLayout& filter, TensorLayout& dst);
  812. virtual size_t get_workspace_in_bytes(
  813. const TensorLayout& src, const TensorLayout& filter,
  814. const TensorLayout& dst) = 0;
  815. protected:
  816. void check_exec(
  817. const TensorLayout& src, const TensorLayout& filter,
  818. const TensorLayout& dst, size_t workspace_in_bytes);
  819. };
  820. using Local = LocalForward;
  821. class LocalBackwardData : public LocalBase {
  822. DEF_OPR_IMPL(LocalBackwardData, LocalBase, 2, 1);
  823. public:
  824. /**
  825. * \param[in] filter (oh, ow, ic, fh, fw, oc)
  826. * \param[in] diff (n, oc, oh, ow)
  827. * \param[out] grad (n, ic, ih, iw)
  828. */
  829. virtual void exec(
  830. _megdnn_tensor_in filter, _megdnn_tensor_in diff, _megdnn_tensor_out grad,
  831. _megdnn_workspace workspace) = 0;
  832. virtual size_t get_workspace_in_bytes(
  833. const TensorLayout& filter, const TensorLayout& diff,
  834. const TensorLayout& grad) = 0;
  835. protected:
  836. void check_exec(
  837. const TensorLayout& filter, const TensorLayout& diff,
  838. const TensorLayout& grad, size_t workspace_in_bytes);
  839. };
  840. class LocalBackwardFilter : public LocalBase {
  841. DEF_OPR_IMPL(LocalBackwardFilter, LocalBase, 2, 1);
  842. public:
  843. /**
  844. * \param[in] src (n, ic, ih, iw)
  845. * \param[in] diff (n, oc, oh, ow)
  846. * \param[out] grad (oh, ow, ic, fh, fw, oc)
  847. */
  848. virtual void exec(
  849. _megdnn_tensor_in src, _megdnn_tensor_in diff, _megdnn_tensor_out grad,
  850. _megdnn_workspace workspace) = 0;
  851. virtual size_t get_workspace_in_bytes(
  852. const TensorLayout& src, const TensorLayout& diff,
  853. const TensorLayout& grad) = 0;
  854. protected:
  855. void check_exec(
  856. const TensorLayout& src, const TensorLayout& diff, const TensorLayout& grad,
  857. size_t workspace_in_bytes);
  858. };
  859. class BNBase : public OperatorBase {
  860. DEF_OPR_IMPL_CTOR(BNBase, OperatorBase);
  861. DEF_OPR_PARAM(BN);
  862. protected:
  863. void check_param();
  864. };
  865. class BNForward : public BNBase {
  866. DEF_OPR_IMPL(BNForward, BNBase, 6, 6);
  867. public:
  868. /**
  869. * \dst[i] = gemma
  870. * *(x[i]-estimatedMean[k])/sqrt(epsilon+estimatedVariance[k]) + beta \where
  871. * epsilon is a very small value to avoid a "divide by zero" error.
  872. * \param[in] src (n, c, h, w)
  873. * \param[out] dst (n, c, h, w)
  874. * \param[out] mean (see m_param.ParamDim) Global mean.
  875. * \param[out] variance (see m_param.ParamDim) Global variance.
  876. * \param[out] batch_mean (see m_param.ParamDim)
  877. * Optionally cached intermediate mean from forward pass
  878. * \param[out] batch_inv_variance (see m_param.ParamDim)
  879. * Optionally cached intermediate variance from forward pass
  880. * \param[out] reserve (see cudnnBatchNormalizationForwardTrainingEx)
  881. * src and dst must have the same shape.
  882. * src and dst must be contiguous.
  883. */
  884. virtual void exec(
  885. _megdnn_tensor_in src, _megdnn_tensor_in bn_scale,
  886. _megdnn_tensor_in bn_bias, _megdnn_tensor_inout mean,
  887. _megdnn_tensor_inout variance, _megdnn_tensor_out batch_mean,
  888. _megdnn_tensor_out batch_inv_variance, _megdnn_tensor_out reserve,
  889. _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0;
  890. void deduce_layout(
  891. const TensorLayout& src, const TensorLayout& bn_scale,
  892. const TensorLayout& bn_bias, TensorLayout& mean, TensorLayout& variance,
  893. TensorLayout& batch_mean, TensorLayout& batch_inv_variance,
  894. TensorLayout& reserve, TensorLayout& dst);
  895. virtual size_t get_workspace_in_bytes(
  896. const TensorLayout& src, const TensorLayout& bn_scale,
  897. const TensorLayout& bn_bias, const TensorLayout& mean,
  898. const TensorLayout& variance, const TensorLayout& batch_mean,
  899. const TensorLayout& batch_inv_variance, const TensorLayout& reserve,
  900. const TensorLayout& dst) = 0;
  901. virtual size_t get_reserve_in_bytes(const TensorLayout& src) = 0;
  902. protected:
  903. void check_exec(
  904. const TensorLayout& src, const TensorLayout& bn_scale,
  905. const TensorLayout& bn_bias, const TensorLayout& mean,
  906. const TensorLayout& variance, const TensorLayout& batch_mean,
  907. const TensorLayout& batch_inv_variance, const TensorLayout& dst,
  908. size_t workspace_in_bytes, size_t reserve_in_bytes = 0);
  909. };
  910. using BN = BNForward;
  911. class BNBackward : public BNBase {
  912. DEF_OPR_IMPL(BNBackward, BNBase, 6, 3);
  913. public:
  914. /**
  915. * \param[in] input data of forwarding propagate.
  916. * \param[in] dy the backpropagated gradient of y.
  917. * \param[out] dx the backpropagated gradient of x.
  918. * \param[out] d_bn_scale, the backpropagated gradient of bn_scale.
  919. * \param[out] d_bn_bias, the backpropagated gradient of bn_bias.
  920. * Optionally cached intermediate results from forward pass
  921. * \param[in] saved_batch_mean mean of the input batch.
  922. Calculated in the forwardpropagation.
  923. * \param[in] saved_batch_variance of the input batch.
  924. Calculated in the forwardpropagation.
  925. * \param[in] reserve (see cudnnBatchNormalizationBackwardEx)
  926. */
  927. virtual void exec(
  928. _megdnn_tensor_in x, _megdnn_tensor_in dy,
  929. _megdnn_tensor_in saved_batch_mean, _megdnn_tensor_in saved_batch_variance,
  930. _megdnn_tensor_in bn_scale, _megdnn_tensor_in reserve,
  931. _megdnn_tensor_out d_bn_scale, _megdnn_tensor_out d_bn_bias,
  932. _megdnn_tensor_out dx, _megdnn_workspace workspace) = 0;
  933. virtual size_t get_workspace_in_bytes(
  934. const TensorLayout& x, const TensorLayout& dy,
  935. const TensorLayout& saved_batch_mean,
  936. const TensorLayout& saved_batch_variance, const TensorLayout& bn_scale,
  937. const TensorLayout& reserve, const TensorLayout& d_bn_scale,
  938. const TensorLayout& d_bn_bias, const TensorLayout& dx) = 0;
  939. virtual size_t get_reserve_in_bytes(const TensorLayout& src) = 0;
  940. protected:
  941. void check_exec(
  942. const TensorLayout& x, const TensorLayout& dy,
  943. const TensorLayout& saved_batch_mean,
  944. const TensorLayout& saved_batch_variance, const TensorLayout& bn_scale,
  945. const TensorLayout& d_bn_scale, const TensorLayout& d_bn_bias,
  946. const TensorLayout& dx, size_t workspace_in_bytes,
  947. size_t reserve_in_bytes = 0);
  948. };
  949. class LRNBase : public OperatorBase {
  950. DEF_OPR_IMPL_CTOR(LRNBase, OperatorBase);
  951. DEF_OPR_PARAM(LRN);
  952. protected:
  953. void check_param();
  954. };
  955. class LRNForward : public LRNBase {
  956. DEF_OPR_IMPL(LRNForward, LRNBase, 1, 1);
  957. public:
  958. /**
  959. * \see ImageNet Classification with Deep Convolutional Neural Networks
  960. * \param[in] src (n, c, h, w)
  961. * \param[out] dst (n, c, h, w)
  962. *
  963. * src and dst must have the same shape.
  964. * src and dst must be contiguous.
  965. */
  966. virtual void exec(
  967. _megdnn_tensor_in src, _megdnn_tensor_out dst,
  968. _megdnn_workspace workspace) = 0;
  969. void deduce_layout(const TensorLayout& src, TensorLayout& dst);
  970. virtual size_t get_workspace_in_bytes(
  971. const TensorLayout& src, const TensorLayout& dst) = 0;
  972. protected:
  973. void check_exec(
  974. const TensorLayout& src, const TensorLayout& dst,
  975. size_t workspace_in_bytes);
  976. };
  977. using LRN = LRNForward;
  978. class LRNBackward : public LRNBase {
  979. DEF_OPR_IMPL(LRNBackward, LRNBase, 3, 1);
  980. public:
  981. /**
  982. * \param[in] src the `src' parameter in LRNForward::exec
  983. * \param[in] dst the `dst' parameter in LRNForward::exec
  984. * \param[in] diff the backpropagated gradient wrt. dst
  985. * \param[out] grad the backpropagated gradient wrt. src
  986. *
  987. * All tensors should be contiguous and of the same shape.
  988. */
  989. virtual void exec(
  990. _megdnn_tensor_in src, _megdnn_tensor_in dst, _megdnn_tensor_in diff,
  991. _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
  992. virtual size_t get_workspace_in_bytes(
  993. const TensorLayout& src, const TensorLayout& dst, const TensorLayout& diff,
  994. const TensorLayout& grad) = 0;
  995. protected:
  996. void check_exec(
  997. const TensorLayout& src, const TensorLayout& dst, const TensorLayout& diff,
  998. const TensorLayout& grad, size_t workspace_in_bytes);
  999. };
  1000. class ROIPoolingBase : public OperatorBase {
  1001. DEF_OPR_IMPL_CTOR(ROIPoolingBase, OperatorBase);
  1002. DEF_OPR_PARAM(ROIPooling);
  1003. protected:
  1004. void check_layout_fwd(
  1005. const TensorLayout& src, const TensorLayout& rois, const TensorLayout& dst,
  1006. const TensorLayout& index);
  1007. };
  1008. class ROIPoolingForward : public ROIPoolingBase {
  1009. DEF_OPR_IMPL(ROIPoolingForward, ROIPoolingBase, 2, 2);
  1010. public:
  1011. /**
  1012. * \param[in] src (n, c, ih, iw)
  1013. * \param[in] rois (m, 5)
  1014. * \param[out] dst (m, c, oh, ow)
  1015. * \param[out] index (m, c, oh, ow) if mode is MAX, (0) if mode is AVERAGE
  1016. *
  1017. * The internal implementation is akin to
  1018. * https://github.com/rbgirshick/caffe-fast-rcnn .d
  1019. * Note that rois(, 0) denotes the input image index. We store it as
  1020. * a float, but it should be an integer instead.
  1021. *
  1022. * index is a temporary tensor to facilitate its backward operator.
  1023. * It is used to store argmax indicex in MAX mode, and it is not used
  1024. * in AVERAGE mode.
  1025. */
  1026. virtual void exec(
  1027. _megdnn_tensor_in src, _megdnn_tensor_in rois, _megdnn_tensor_out dst,
  1028. _megdnn_tensor_out index, _megdnn_workspace workspace) = 0;
  1029. virtual size_t get_workspace_in_bytes(
  1030. const TensorLayout& src, const TensorLayout& rois, const TensorLayout& dst,
  1031. const TensorLayout& index) = 0;
  1032. protected:
  1033. void check_exec(
  1034. const TensorLayout& src, const TensorLayout& rois, const TensorLayout& dst,
  1035. const TensorLayout& index, size_t workspace_in_bytes);
  1036. };
  1037. using ROIPooling = ROIPoolingForward;
  1038. class ROIPoolingBackward : public ROIPoolingBase {
  1039. DEF_OPR_IMPL(ROIPoolingBackward, ROIPoolingBase, 4, 1);
  1040. public:
  1041. /**
  1042. * \param[in] diff the backpropagated gradient wrt. dst
  1043. * \param[in] src the `src' parameter in ROIPoolingForward::exec
  1044. * \param[in] rois the `rois' parameter in ROIPoolingForward::exec
  1045. * \param[in] index the `index' parameter in ROIPoolingForward::exec
  1046. * \param[out] grad the backpropagated gradient wrt. src
  1047. */
  1048. virtual void exec(
  1049. _megdnn_tensor_in diff, _megdnn_tensor_in src, _megdnn_tensor_in rois,
  1050. _megdnn_tensor_in index, _megdnn_tensor_out grad,
  1051. _megdnn_workspace workspace) = 0;
  1052. virtual size_t get_workspace_in_bytes(
  1053. const TensorLayout& diff, const TensorLayout& src, const TensorLayout& rois,
  1054. const TensorLayout& index, const TensorLayout& grad) = 0;
  1055. protected:
  1056. void check_exec(
  1057. const TensorLayout& diff, const TensorLayout& src, const TensorLayout& rois,
  1058. const TensorLayout& index, const TensorLayout& grad,
  1059. size_t workspace_in_bytes);
  1060. };
  1061. class Convolution3DBase : public OperatorBase {
  1062. DEF_OPR_IMPL_CTOR(Convolution3DBase, OperatorBase);
  1063. DEF_OPR_PARAM(Convolution3D);
  1064. public:
  1065. static constexpr size_t MAX_SPATIAL_DIM = 3;
  1066. using Mode = Param::Mode;
  1067. struct CanonizedFilterMeta {
  1068. DTypeEnum dtype_enum;
  1069. Param::Format format;
  1070. uint32_t
  1071. //! whether filter should be flipped (i.e. is CONVOLUTION)
  1072. should_flip,
  1073. group, //!< number of groups
  1074. icpg, //!< input channels per group
  1075. ocpg, //!< output channels per group
  1076. spatial_ndim, stride[MAX_SPATIAL_DIM], padding[MAX_SPATIAL_DIM],
  1077. //! spatial dim
  1078. spatial[MAX_SPATIAL_DIM], dilation[MAX_SPATIAL_DIM],
  1079. //! spatial dim with dilation applied
  1080. dilated_spatial[MAX_SPATIAL_DIM];
  1081. } MEGDNN_PACKED;
  1082. protected:
  1083. CanonizedFilterMeta deduce_layout_fwd(
  1084. const TensorLayout& src, const TensorLayout& filter,
  1085. TensorLayout& dst) const;
  1086. CanonizedFilterMeta check_layout_fwd(
  1087. const TensorLayout& src, const TensorLayout& filter,
  1088. const TensorLayout& dst) const;
  1089. CanonizedFilterMeta make_canonized_filter_meta(
  1090. size_t src_ndim, const TensorLayout& filter) const;
  1091. };
  1092. class Convolution3DForward : public Convolution3DBase,
  1093. public detail::MultiAlgoOpr<Convolution3DForward, 3> {
  1094. DEF_OPR_IMPL(Convolution3DForward, Convolution3DBase, 2, 1);
  1095. public:
  1096. /**
  1097. * \param[in] src (n, ic, id, ih, iw)
  1098. * \param[in] filter (oc, ic, fd, fh, fw)
  1099. * \param[out] dst (n, oc, od, oh, ow)
  1100. */
  1101. virtual void exec(
  1102. _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_out dst,
  1103. _megdnn_workspace workspace) = 0;
  1104. void deduce_layout(
  1105. const TensorLayout& src, const TensorLayout& filter, TensorLayout& dst);
  1106. virtual size_t get_workspace_in_bytes(
  1107. const TensorLayout& src, const TensorLayout& filter,
  1108. const TensorLayout& dst) = 0;
  1109. static Algorithm::OprType get_opr_type() {
  1110. return Algorithm::OprType::CONVOLUTION3D_FORWARD;
  1111. }
  1112. protected:
  1113. CanonizedFilterMeta check_exec(
  1114. const TensorLayout& src, const TensorLayout& filter,
  1115. const TensorLayout& dst, size_t workspace_in_bytes);
  1116. };
  1117. using Convolution3D = Convolution3DForward;
  1118. class Convolution3DBackwardData
  1119. : public Convolution3DBase,
  1120. public detail::MultiAlgoOpr<Convolution3DBackwardData, 3> {
  1121. DEF_OPR_IMPL(Convolution3DBackwardData, Convolution3DBase, 2, 1);
  1122. public:
  1123. /**
  1124. * \param[in] filter (oc, ic, fd, fh, fw)
  1125. * \param[in] diff (n, oc, od, oh, ow)
  1126. * \param[out] grad (n, ic, id, ih, iw)
  1127. */
  1128. virtual void exec(
  1129. _megdnn_tensor_in filter, _megdnn_tensor_in diff, _megdnn_tensor_out grad,
  1130. _megdnn_workspace workspace) = 0;
  1131. virtual size_t get_workspace_in_bytes(
  1132. const TensorLayout& filter, const TensorLayout& diff,
  1133. const TensorLayout& grad) = 0;
  1134. void deduce_layout(
  1135. const TensorLayout& filter, const TensorLayout& diff, TensorLayout& grad);
  1136. static Algorithm::OprType get_opr_type() {
  1137. return Algorithm::OprType::CONVOLUTION3D_BACKWARD_DATA;
  1138. }
  1139. protected:
  1140. CanonizedFilterMeta check_exec(
  1141. const TensorLayout& filter, const TensorLayout& diff,
  1142. const TensorLayout& grad, size_t workspace_in_bytes);
  1143. };
  1144. class Convolution3DBackwardFilter
  1145. : public Convolution3DBase,
  1146. public detail::MultiAlgoOpr<Convolution3DBackwardFilter, 3> {
  1147. DEF_OPR_IMPL(Convolution3DBackwardFilter, Convolution3DBase, 2, 1);
  1148. public:
  1149. /**
  1150. * \param[in] src (n, ic, id, ih, iw)
  1151. * \param[in] diff (n, oc, od, oh, ow)
  1152. * \param[out] grad (oc, ic, fd, fh, fw)
  1153. */
  1154. virtual void exec(
  1155. _megdnn_tensor_in src, _megdnn_tensor_in diff, _megdnn_tensor_out grad,
  1156. _megdnn_workspace workspace) = 0;
  1157. virtual size_t get_workspace_in_bytes(
  1158. const TensorLayout& src, const TensorLayout& diff,
  1159. const TensorLayout& grad) = 0;
  1160. static Algorithm::OprType get_opr_type() {
  1161. return Algorithm::OprType::CONVOLUTION3D_BACKWARD_FILTER;
  1162. }
  1163. protected:
  1164. CanonizedFilterMeta check_exec(
  1165. const TensorLayout& src, const TensorLayout& diff, const TensorLayout& grad,
  1166. size_t workspace_in_bytes);
  1167. };
  1168. class LocalShareBase : public OperatorBase {
  1169. DEF_OPR_IMPL_CTOR(LocalShareBase, OperatorBase);
  1170. DEF_OPR_PARAM(LocalShare);
  1171. protected:
  1172. void deduce_layout_fwd(
  1173. const TensorLayout& src, const TensorLayout& filter, TensorLayout& dst);
  1174. void check_layout_fwd(
  1175. const TensorLayout& src, const TensorLayout& filter,
  1176. const TensorLayout& dst);
  1177. };
  1178. class LocalShareForward : public LocalShareBase,
  1179. public detail::MultiAlgoOpr<LocalShareForward, 3> {
  1180. DEF_OPR_IMPL(LocalShareForward, LocalShareBase, 2, 1);
  1181. public:
  1182. /**
  1183. * \param[in] src (N, IC, IH, IW)
  1184. * \param[in] filter (G, spatial_groups_h, spatial_groups_w, IC / G,
  1185. * FH, FW, OC / G)
  1186. * \param[out] dst (N, OC, OH, OW)
  1187. */
  1188. virtual void exec(
  1189. _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_out dst,
  1190. _megdnn_workspace workspace) = 0;
  1191. /**
  1192. * \brief deduce layout of the ouput tensor
  1193. */
  1194. void deduce_layout(
  1195. const TensorLayout& src, const TensorLayout& filter, TensorLayout& dst);
  1196. virtual size_t get_workspace_in_bytes(
  1197. const TensorLayout& src, const TensorLayout& filter,
  1198. const TensorLayout& dst) = 0;
  1199. static Algorithm::OprType get_opr_type() {
  1200. return Algorithm::OprType::LOCAL_SHARE_FORWARD;
  1201. }
  1202. protected:
  1203. void check_exec(
  1204. const TensorLayout& src, const TensorLayout& filter,
  1205. const TensorLayout& dst, size_t workspace_in_bytes);
  1206. };
  1207. using LocalShare = LocalShareForward;
  1208. class LocalShareBackwardData : public LocalShareBase,
  1209. public detail::MultiAlgoOpr<LocalShareBackwardData, 3> {
  1210. DEF_OPR_IMPL(LocalShareBackwardData, LocalShareBase, 2, 1);
  1211. public:
  1212. /**
  1213. * \param[in] filter (G, spatial_groups_h, spatial_groups_w, IC / G,
  1214. * FH, FW, OC / G)
  1215. * \param[in] diff (N, OC, OH, OW)
  1216. * \param[out] grad (N, IC, IH, IW)
  1217. */
  1218. virtual void exec(
  1219. _megdnn_tensor_in filter, _megdnn_tensor_in diff, _megdnn_tensor_out grad,
  1220. _megdnn_workspace workspace) = 0;
  1221. virtual size_t get_workspace_in_bytes(
  1222. const TensorLayout& filter, const TensorLayout& diff,
  1223. const TensorLayout& grad) = 0;
  1224. void deduce_layout(
  1225. const TensorLayout& filter, const TensorLayout& diff, TensorLayout& grad);
  1226. static Algorithm::OprType get_opr_type() {
  1227. return Algorithm::OprType::LOCAL_SHARE_BACKWARD_DATA;
  1228. }
  1229. protected:
  1230. void check_exec(
  1231. const TensorLayout& filter, const TensorLayout& diff,
  1232. const TensorLayout& grad, size_t workspace_in_bytes);
  1233. };
  1234. class LocalShareBackwardFilter
  1235. : public LocalShareBase,
  1236. public detail::MultiAlgoOpr<LocalShareBackwardFilter, 3> {
  1237. DEF_OPR_IMPL(LocalShareBackwardFilter, LocalShareBase, 2, 1);
  1238. public:
  1239. /**
  1240. * \param[in] src (N, IC, IH, IW)
  1241. * \param[in] diff (N, OC, OH, OW)
  1242. * \param[out] grad (G, spatial_groups_h, spatial_groups_w, IC / G,
  1243. * FH, FW, OC / G)
  1244. */
  1245. virtual void exec(
  1246. _megdnn_tensor_in src, _megdnn_tensor_in diff, _megdnn_tensor_out grad,
  1247. _megdnn_workspace workspace) = 0;
  1248. virtual size_t get_workspace_in_bytes(
  1249. const TensorLayout& src, const TensorLayout& diff,
  1250. const TensorLayout& grad) = 0;
  1251. static Algorithm::OprType get_opr_type() {
  1252. return Algorithm::OprType::LOCAL_SHARE_BACKWARD_FILTER;
  1253. }
  1254. protected:
  1255. void check_exec(
  1256. const TensorLayout& src, const TensorLayout& diff, const TensorLayout& grad,
  1257. size_t workspace_in_bytes);
  1258. };
  1259. class ROIAlignBase : public OperatorBase {
  1260. DEF_OPR_IMPL_CTOR(ROIAlignBase, OperatorBase);
  1261. DEF_OPR_PARAM(ROIAlign);
  1262. protected:
  1263. void deduce_layout_fwd(
  1264. const TensorLayout& src, const TensorLayout& rois, TensorLayout& dst,
  1265. TensorLayout& index);
  1266. void check_layout_fwd(
  1267. const TensorLayout& src, const TensorLayout& rois, const TensorLayout& dst,
  1268. const TensorLayout& index);
  1269. };
  1270. class ROIAlignForward : public ROIAlignBase {
  1271. DEF_OPR_IMPL(ROIAlignForward, ROIAlignBase, 2, 2);
  1272. public:
  1273. /**
  1274. * \param[in] src (n, c, ih, iw)
  1275. * \param[in] rois (m, 5)
  1276. * \param[out] dst (m, c, oh, ow)
  1277. * \param[out] index (m, c, oh, ow) if mode is MAX, (0) if mode is AVERAGE
  1278. *
  1279. * Note that rois(, 0) denotes the input image index. We store it as
  1280. * a float, but it should be an integer instead.
  1281. *
  1282. * index is a temporary tensor to facilitate its backward operator.
  1283. * It is used to store argmax indicex in MAX mode, and it is not used
  1284. * in AVERAGE mode.
  1285. */
  1286. virtual void exec(
  1287. _megdnn_tensor_in src, _megdnn_tensor_in rois, _megdnn_tensor_out dst,
  1288. _megdnn_tensor_out index, _megdnn_workspace workspace) = 0;
  1289. void deduce_layout(
  1290. const TensorLayout& src, const TensorLayout& rois, TensorLayout& dst,
  1291. TensorLayout& index);
  1292. virtual size_t get_workspace_in_bytes(
  1293. const TensorLayout& src, const TensorLayout& rois, const TensorLayout& dst,
  1294. const TensorLayout& index) = 0;
  1295. protected:
  1296. void check_exec(
  1297. const TensorLayout& src, const TensorLayout& rois, const TensorLayout& dst,
  1298. const TensorLayout& index, size_t workspace_in_bytes);
  1299. };
  1300. using ROIAlign = ROIAlignForward;
  1301. class ROIAlignBackward : public ROIAlignBase {
  1302. DEF_OPR_IMPL(ROIAlignBackward, ROIAlignBase, 3, 1);
  1303. public:
  1304. /**
  1305. * \param[in] diff the backpropagated gradient wrt. dst
  1306. * \param[in] rois the `rois' parameter in ROIAlignForward::exec
  1307. * \param[in] index the `index' parameter in ROIAlignForward::exec
  1308. * \param[out] grad the backpropagated gradient wrt. src
  1309. */
  1310. virtual void exec(
  1311. _megdnn_tensor_in diff, _megdnn_tensor_in rois, _megdnn_tensor_in index,
  1312. _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
  1313. virtual size_t get_workspace_in_bytes(
  1314. const TensorLayout& diff, const TensorLayout& rois,
  1315. const TensorLayout& index, const TensorLayout& grad) = 0;
  1316. protected:
  1317. void check_exec(
  1318. const TensorLayout& diff, const TensorLayout& rois,
  1319. const TensorLayout& index, const TensorLayout& grad,
  1320. size_t workspace_in_bytes);
  1321. };
  1322. class DeformableConvBase : public OperatorBase {
  1323. DEF_OPR_IMPL_CTOR(DeformableConvBase, OperatorBase);
  1324. DEF_OPR_PARAM(Convolution);
  1325. public:
  1326. static constexpr size_t MAX_SPATIAL_DIM = 2;
  1327. struct CanonizedFilterMeta : Convolution::CanonizedFilterMeta {
  1328. uint32_t deformable_group;
  1329. };
  1330. protected:
  1331. CanonizedFilterMeta make_canonized_filter_meta(
  1332. size_t src_ndim, const TensorLayout& filter,
  1333. const TensorLayout& offset) const;
  1334. void deduce_layout_fwd(
  1335. const TensorLayout& im, const TensorLayout& filter,
  1336. const TensorLayout& mask, const TensorLayout& offset, TensorLayout& dst);
  1337. void check_layout_fwd(
  1338. const TensorLayout& src, const TensorLayout& filter,
  1339. const TensorLayout& mask, const TensorLayout& offset,
  1340. const TensorLayout& dst);
  1341. };
  1342. class DeformableConvForward : public DeformableConvBase,
  1343. public detail::MultiAlgoOpr<DeformableConvForward, 5> {
  1344. DEF_OPR_IMPL(DeformableConvForward, DeformableConvBase, 4, 1);
  1345. public:
  1346. /**
  1347. * \param[in] im (n, ic, ih, iw)
  1348. * \param[in] filter (oc, ic, fh, fw)
  1349. * \param[in] offset (dg, 2, fh, fw, oh, ow)
  1350. * \param[in] mask (dg, fh, fw, oh, ow)
  1351. * \param[out] dst (n, oc, oh, ow)
  1352. */
  1353. virtual void exec(
  1354. _megdnn_tensor_in im, _megdnn_tensor_in filter, _megdnn_tensor_in offset,
  1355. _megdnn_tensor_in mask, _megdnn_tensor_out dst,
  1356. _megdnn_workspace workspace) = 0;
  1357. void deduce_layout(
  1358. const TensorLayout& im, const TensorLayout& filter,
  1359. const TensorLayout& offset, const TensorLayout& mask, TensorLayout& dst);
  1360. virtual size_t get_workspace_in_bytes(
  1361. const TensorLayout& im, const TensorLayout& filter,
  1362. const TensorLayout& offset, const TensorLayout& mask,
  1363. const TensorLayout& dst) = 0;
  1364. static Algorithm::OprType get_opr_type() {
  1365. return Algorithm::OprType::DEFORMABLE_CONV_FORWARD;
  1366. }
  1367. protected:
  1368. CanonizedFilterMeta check_exec(
  1369. const TensorLayout& im, const TensorLayout& filter,
  1370. const TensorLayout& offset, const TensorLayout& mask,
  1371. const TensorLayout& dst, size_t workspace_in_bytes);
  1372. };
  1373. using DeformableConv = DeformableConvForward;
  1374. /**
  1375. * \brief DeformableConvBackwardFilter operator.
  1376. *
  1377. * Calculating the gradient wrt. convolution filter.
  1378. */
  1379. class DeformableConvBackwardFilter
  1380. : public DeformableConvBase,
  1381. public detail::MultiAlgoOpr<DeformableConvBackwardFilter, 5> {
  1382. DEF_OPR_IMPL(DeformableConvBackwardFilter, DeformableConvBase, 4, 1);
  1383. public:
  1384. /**
  1385. * \param[in] im (oc, ic, fh, fw)
  1386. * \param[in] offset (dg, 2, fh, fw, oh, ow)
  1387. * \param[in] mask (dg, fh, fw, oh, ow)
  1388. * \param[in] out_grad (n, oc, oh, ow)
  1389. * \param[out] filter_grad (oc, ic, ih, iw)
  1390. */
  1391. virtual void exec(
  1392. _megdnn_tensor_in im, _megdnn_tensor_in offset, _megdnn_tensor_in mask,
  1393. _megdnn_tensor_in out_grad, _megdnn_tensor_out filter_grad,
  1394. _megdnn_workspace workspace) = 0;
  1395. virtual size_t get_workspace_in_bytes(
  1396. const TensorLayout& im, const TensorLayout& offset,
  1397. const TensorLayout& mask, const TensorLayout& out_grad,
  1398. const TensorLayout& filter_grad) = 0;
  1399. void deduce_layout(
  1400. const TensorLayout& im, const TensorLayout& offset,
  1401. const TensorLayout& mask, const TensorLayout& out_grad,
  1402. TensorLayout& filter_grad);
  1403. static Algorithm::OprType get_opr_type() {
  1404. return Algorithm::OprType::DEFORMABLE_CONV_BACKWARD_FILTER;
  1405. }
  1406. protected:
  1407. CanonizedFilterMeta check_exec(
  1408. const TensorLayout& im, const TensorLayout& offset,
  1409. const TensorLayout& mask, const TensorLayout& out_grad,
  1410. const TensorLayout& filter_grad, size_t workspace_in_bytes);
  1411. };
  1412. /**
  1413. * \brief DeformableConvBackwardData operator.
  1414. *
  1415. * Calculating the gradient wrt. convolution input data, offset and mask.
  1416. */
  1417. class DeformableConvBackwardData
  1418. : public DeformableConvBase,
  1419. public detail::MultiAlgoOpr<DeformableConvBackwardData, 8> {
  1420. DEF_OPR_IMPL(DeformableConvBackwardData, DeformableConvBase, 5, 3);
  1421. public:
  1422. /**
  1423. * \param[in] im (oc, ic, fh, fw)
  1424. * \param[in] filter (oc, ic, fh, fw)
  1425. * \param[in] offset (dg, 2, fh, fw, oh, ow)
  1426. * \param[in] mask (dg, fh, fw, oh, ow)
  1427. * \param[in] out_grad (n, oc, oh, ow)
  1428. * \param[out] im_grad (n, ic, ih, iw)
  1429. * \param[out] offset_grad (dg, 2, fh, fw, oh, ow)
  1430. * \param[out] mask_grad (dg, fh, fw, oh, ow)
  1431. */
  1432. virtual void exec(
  1433. _megdnn_tensor_in im, _megdnn_tensor_in filter, _megdnn_tensor_in offset,
  1434. _megdnn_tensor_in mask, _megdnn_tensor_in out_grad,
  1435. _megdnn_tensor_out im_grad, _megdnn_tensor_out offset_grad,
  1436. _megdnn_tensor_out mask_grad, _megdnn_workspace workspace) = 0;
  1437. virtual size_t get_workspace_in_bytes(
  1438. const TensorLayout& im, const TensorLayout& filter,
  1439. const TensorLayout& offset, const TensorLayout& mask,
  1440. const TensorLayout& out_grad, const TensorLayout& im_grad,
  1441. const TensorLayout& offset_grad, const TensorLayout& mask_grad) = 0;
  1442. void deduce_layout(
  1443. const TensorLayout& im, const TensorLayout& filter,
  1444. const TensorLayout& offset, const TensorLayout& mask,
  1445. const TensorLayout& out_grad, TensorLayout& im_grad,
  1446. TensorLayout& offset_grad, TensorLayout& mask_grad);
  1447. static Algorithm::OprType get_opr_type() {
  1448. return Algorithm::OprType::DEFORMABLE_CONV_BACKWARD_DATA;
  1449. }
  1450. protected:
  1451. CanonizedFilterMeta check_exec(
  1452. const TensorLayout& im, const TensorLayout& filter,
  1453. const TensorLayout& offset, const TensorLayout& mask,
  1454. const TensorLayout& out_grad, const TensorLayout& im_grad,
  1455. const TensorLayout& offset_grad, const TensorLayout& mask_grad,
  1456. size_t workspace_in_bytes);
  1457. };
  1458. class DeformablePSROIPoolingBase : public OperatorBase {
  1459. DEF_OPR_IMPL_CTOR(DeformablePSROIPoolingBase, OperatorBase);
  1460. DEF_OPR_PARAM(DeformablePSROIPooling);
  1461. protected:
  1462. void deduce_layout_fwd(
  1463. const TensorLayout& data, const TensorLayout& trans,
  1464. const TensorLayout& rois, TensorLayout& out_data, TensorLayout& out_count);
  1465. void check_layout_fwd(
  1466. const TensorLayout& data, const TensorLayout& trans,
  1467. const TensorLayout& rois, const TensorLayout& out_data,
  1468. const TensorLayout& out_count, size_t workspace_in_bytes);
  1469. };
  1470. class DeformablePSROIPoolingForward : public DeformablePSROIPoolingBase {
  1471. DEF_OPR_IMPL(DeformablePSROIPoolingForward, DeformablePSROIPoolingBase, 3, 2);
  1472. public:
  1473. /**
  1474. * \param[in] data (oc, ic, ih, iw)
  1475. * \param[in] rois (xx, xx, xx, xx)
  1476. * \param[in] trans (oc, ic, fh, fw)
  1477. * \param[out] out_data ( n, ic, ih, iw)
  1478. * \param[out] out_count ( n, ic, ih, iw)
  1479. */
  1480. virtual size_t get_workspace_in_bytes(
  1481. const TensorLayout& data, const TensorLayout& rois,
  1482. const TensorLayout& trans, const TensorLayout& out_data,
  1483. const TensorLayout& out_count) = 0;
  1484. virtual void exec(
  1485. _megdnn_tensor_in data, _megdnn_tensor_in rois, _megdnn_tensor_in trans,
  1486. _megdnn_tensor_out out_data, _megdnn_tensor_out out_count,
  1487. _megdnn_workspace workspace) = 0;
  1488. void deduce_layout(
  1489. const TensorLayout& data, const TensorLayout& rois,
  1490. const TensorLayout& trans, TensorLayout& out_data, TensorLayout& out_count);
  1491. void check_exec(
  1492. const TensorLayout& data, const TensorLayout& rois,
  1493. const TensorLayout& trans, const TensorLayout& out_data,
  1494. const TensorLayout& out_count, size_t workspace_in_bytes);
  1495. };
  1496. using DeformablePSROIPooling = DeformablePSROIPoolingForward;
  1497. class DeformablePSROIPoolingBackward : public DeformablePSROIPoolingBase {
  1498. DEF_OPR_IMPL(DeformablePSROIPoolingBackward, DeformablePSROIPoolingBase, 5, 2);
  1499. public:
  1500. /**
  1501. * \param[in] data (oc, ic, ih, iw)
  1502. * \param[in] rois (xx, xx, xx, xx)
  1503. * \param[in] trans (oc, ic, fh, fw)
  1504. * \param[in] out_diff (xx, xx, xx, xx)
  1505. * \param[in] out_count (xx, xx, xx, xx)
  1506. * \param[out] data_diff ( n, ic, ih, iw)
  1507. * \param[out] trans_diff ( n, ic, ih, iw)
  1508. */
  1509. virtual void exec(
  1510. _megdnn_tensor_in data, _megdnn_tensor_in rois, _megdnn_tensor_in trans,
  1511. _megdnn_tensor_in out_diff, _megdnn_tensor_in out_count,
  1512. _megdnn_tensor_out data_diff, _megdnn_tensor_out trans_diff,
  1513. _megdnn_workspace workspace) = 0;
  1514. virtual size_t get_workspace_in_bytes(
  1515. const TensorLayout& data, const TensorLayout& rois,
  1516. const TensorLayout& trans, const TensorLayout& out_diff,
  1517. const TensorLayout& out_count, const TensorLayout& data_diff,
  1518. const TensorLayout& trans_diff) = 0;
  1519. void check_exec(
  1520. const TensorLayout& data, const TensorLayout& rois,
  1521. const TensorLayout& trans, const TensorLayout& out_diff,
  1522. const TensorLayout& out_count, const TensorLayout& data_diff,
  1523. const TensorLayout& trans_diff, size_t workspace_in_bytes);
  1524. };
  1525. class BatchConvBiasForward : public ConvolutionBase<param::BatchConvBias>,
  1526. public detail::MultiAlgoOpr<BatchConvBiasForward, 5> {
  1527. DEF_OPR_IMPL(BatchConvBiasForward, ConvolutionBase, 4, 1);
  1528. public:
  1529. virtual void exec(
  1530. _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_in bias,
  1531. _megdnn_tensor_in z, _megdnn_tensor_out dst,
  1532. _megdnn_workspace workspace) = 0;
  1533. void deduce_dtype(DType src, DType filter, DType bias, DType z, DType& dst);
  1534. void deduce_layout(
  1535. const TensorLayout& src, const TensorLayout& filter,
  1536. const TensorLayout& bias, const TensorLayout& z, TensorLayout& dst);
  1537. virtual size_t get_workspace_in_bytes(
  1538. const TensorLayout& src, const TensorLayout& filter,
  1539. const TensorLayout& bias, const TensorLayout& z,
  1540. const TensorLayout& dst) = 0;
  1541. static Algorithm::OprType get_opr_type() {
  1542. return Algorithm::OprType::BATCH_CONV_FORWARD;
  1543. }
  1544. protected:
  1545. CanonizedFilterMeta check_exec(
  1546. const TensorLayout& src, const TensorLayout& filter,
  1547. const TensorLayout& bias, const TensorLayout& z, const TensorLayout& dst,
  1548. size_t workspace_in_bytes);
  1549. };
  1550. using BatchConvBias = BatchConvBiasForward;
  1551. class FakeQuantBase : public OperatorBase {
  1552. DEF_OPR_IMPL_CTOR(FakeQuantBase, OperatorBase);
  1553. DEF_OPR_PARAM(FakeQuant);
  1554. protected:
  1555. void deduce_layout_fwd(const TensorLayout& input, TensorLayout& output);
  1556. void check_layout_fwd(
  1557. const TensorLayout& input, const TensorLayout& scale,
  1558. const TensorLayout& zero_point, const TensorLayout& output);
  1559. };
  1560. class FakeQuantForward : public FakeQuantBase {
  1561. DEF_OPR_IMPL(FakeQuantForward, FakeQuantBase, 3, 1);
  1562. public:
  1563. virtual void exec(
  1564. _megdnn_tensor_in input, _megdnn_tensor_in scale,
  1565. _megdnn_tensor_in zero_point, _megdnn_tensor_out output,
  1566. _megdnn_workspace workspace) = 0;
  1567. void deduce_layout(
  1568. const TensorLayout& input, const TensorLayout& scale,
  1569. const TensorLayout& zero_point, TensorLayout& output);
  1570. virtual size_t get_workspace_in_bytes(
  1571. const TensorLayout& input, const TensorLayout& scale,
  1572. const TensorLayout& zero_point, const TensorLayout& output) = 0;
  1573. protected:
  1574. void check_exec(
  1575. const TensorLayout& input, const TensorLayout& scale,
  1576. const TensorLayout& zero_point, const TensorLayout& output,
  1577. size_t workspace_in_bytes);
  1578. };
  1579. using FakeQuant = FakeQuantForward;
  1580. class FakeQuantBackward : public FakeQuantBase {
  1581. DEF_OPR_IMPL(FakeQuantBackward, FakeQuantBase, 4, 1);
  1582. public:
  1583. virtual void exec(
  1584. _megdnn_tensor_in diff, _megdnn_tensor_in input, _megdnn_tensor_in scale,
  1585. _megdnn_tensor_in zero_point, _megdnn_tensor_out grad,
  1586. _megdnn_workspace workspace) = 0;
  1587. virtual size_t get_workspace_in_bytes(
  1588. const TensorLayout& diff, const TensorLayout& input,
  1589. const TensorLayout& scale, const TensorLayout& zero_point,
  1590. const TensorLayout& grad) = 0;
  1591. protected:
  1592. void check_exec(
  1593. const TensorLayout& diff, const TensorLayout& input,
  1594. const TensorLayout& scale, const TensorLayout& zero_point,
  1595. const TensorLayout& grad, size_t workspace_in_bytes);
  1596. };
  1597. class TQTBase : public OperatorBase {
  1598. DEF_OPR_IMPL_CTOR(TQTBase, OperatorBase);
  1599. DEF_OPR_PARAM(TQT);
  1600. protected:
  1601. void deduce_layout_fwd(const TensorLayout& input, TensorLayout& output);
  1602. void check_layout_fwd(
  1603. const TensorLayout& input, const TensorLayout& scale,
  1604. const TensorLayout& output);
  1605. };
  1606. class TQTForward : public TQTBase {
  1607. DEF_OPR_IMPL(TQTForward, TQTBase, 2, 1);
  1608. public:
  1609. virtual void exec(
  1610. _megdnn_tensor_in input, _megdnn_tensor_in scale, _megdnn_tensor_out output,
  1611. _megdnn_workspace workspace) = 0;
  1612. void deduce_layout(
  1613. const TensorLayout& input, const TensorLayout& scale, TensorLayout& output);
  1614. virtual size_t get_workspace_in_bytes(
  1615. const TensorLayout& input, const TensorLayout& scale,
  1616. const TensorLayout& output) = 0;
  1617. protected:
  1618. void check_exec(
  1619. const TensorLayout& input, const TensorLayout& scale,
  1620. const TensorLayout& output, size_t workspace_in_bytes);
  1621. };
  1622. using TQT = TQTForward;
  1623. class TQTBackward : public TQTBase {
  1624. DEF_OPR_IMPL(TQTBackward, TQTBase, 3, 2);
  1625. public:
  1626. virtual void exec(
  1627. _megdnn_tensor_in diff, _megdnn_tensor_in input, _megdnn_tensor_in scale,
  1628. _megdnn_tensor_out grad_x, _megdnn_tensor_out grad_s,
  1629. _megdnn_workspace workspace) = 0;
  1630. virtual size_t get_workspace_in_bytes(
  1631. const TensorLayout& diff, const TensorLayout& input,
  1632. const TensorLayout& scale, const TensorLayout& grad_x,
  1633. const TensorLayout& grad_s) = 0;
  1634. protected:
  1635. void check_exec(
  1636. const TensorLayout& diff, const TensorLayout& input,
  1637. const TensorLayout& scale, const TensorLayout& grad_x,
  1638. const TensorLayout& grad_s, size_t workspace_in_bytes);
  1639. };
  1640. class LSQBase : public OperatorBase {
  1641. DEF_OPR_IMPL_CTOR(LSQBase, OperatorBase);
  1642. DEF_OPR_PARAM(LSQ);
  1643. protected:
  1644. void deduce_layout_fwd(const TensorLayout& input, TensorLayout& output);
  1645. void check_layout_fwd(
  1646. const TensorLayout& input, const TensorLayout& scale,
  1647. const TensorLayout& zero_point, const TensorLayout& grad_scale,
  1648. const TensorLayout& output);
  1649. };
  1650. class LSQForward : public LSQBase {
  1651. DEF_OPR_IMPL(LSQForward, LSQBase, 4, 1);
  1652. public:
  1653. virtual void exec(
  1654. _megdnn_tensor_in input, _megdnn_tensor_in scale,
  1655. _megdnn_tensor_in zero_point, _megdnn_tensor_in grad_scale,
  1656. _megdnn_tensor_out output, _megdnn_workspace workspace) = 0;
  1657. void deduce_layout(
  1658. const TensorLayout& input, const TensorLayout& scale,
  1659. const TensorLayout& zero_point, const TensorLayout& grad_scale,
  1660. TensorLayout& output);
  1661. virtual size_t get_workspace_in_bytes(
  1662. const TensorLayout& input, const TensorLayout& scale,
  1663. const TensorLayout& zero_point, const TensorLayout& grad_scale,
  1664. const TensorLayout& output) = 0;
  1665. protected:
  1666. void check_exec(
  1667. const TensorLayout& input, const TensorLayout& scale,
  1668. const TensorLayout& zero_point, const TensorLayout& grad_scale,
  1669. const TensorLayout& output, size_t workspace_in_bytes);
  1670. };
  1671. using LSQ = LSQForward;
  1672. class LSQBackward : public LSQBase {
  1673. DEF_OPR_IMPL(LSQBackward, LSQBase, 5, 2);
  1674. public:
  1675. virtual void exec(
  1676. _megdnn_tensor_in diff, _megdnn_tensor_in input, _megdnn_tensor_in scale,
  1677. _megdnn_tensor_in zero_point, _megdnn_tensor_in grad_scale,
  1678. _megdnn_tensor_out grad_x, _megdnn_tensor_out grad_s,
  1679. _megdnn_workspace workspace) = 0;
  1680. virtual size_t get_workspace_in_bytes(
  1681. const TensorLayout& diff, const TensorLayout& input,
  1682. const TensorLayout& scale, const TensorLayout& zero_point,
  1683. const TensorLayout& grad_scale, const TensorLayout& grad_x,
  1684. const TensorLayout& grad_s) = 0;
  1685. protected:
  1686. void check_exec(
  1687. const TensorLayout& diff, const TensorLayout& input,
  1688. const TensorLayout& scale, const TensorLayout& zero_point,
  1689. const TensorLayout& grad_scale, const TensorLayout& grad_x,
  1690. const TensorLayout& grad_s, size_t workspace_in_bytes);
  1691. };
  1692. class RNNCellForward : public OperatorBase {
  1693. DEF_OPR_PARAM(RNNCell);
  1694. DEF_OPR_IMPL(RNNCellForward, OperatorBase, 6, 1);
  1695. public:
  1696. virtual void exec(
  1697. _megdnn_tensor_in input, _megdnn_tensor_in weight_ih,
  1698. _megdnn_tensor_in bias_ih, _megdnn_tensor_in hx,
  1699. _megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_hh,
  1700. _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0;
  1701. static void deduce_layout(
  1702. const TensorLayout& input, const TensorLayout& weight_ih,
  1703. const TensorLayout& bias_ih, const TensorLayout& hx,
  1704. const TensorLayout& weight_hh, const TensorLayout& bias_hh,
  1705. TensorLayout& dst);
  1706. virtual size_t get_workspace_in_bytes(
  1707. const TensorLayout& input, const TensorLayout& weight_ih,
  1708. const TensorLayout& bias_ih, const TensorLayout& hx,
  1709. const TensorLayout& weight_hh, const TensorLayout& bias_hh,
  1710. const TensorLayout& dst) = 0;
  1711. protected:
  1712. void check_exec(
  1713. const TensorLayout& input, const TensorLayout& weight_ih,
  1714. const TensorLayout& bias_ih, const TensorLayout& hx,
  1715. const TensorLayout& weight_hh, const TensorLayout& bias_hh,
  1716. const TensorLayout& dst, size_t workspace_in_bytes);
  1717. };
  1718. using RNNCell = RNNCellForward;
  1719. class LSTMCellForward : public OperatorBase {
  1720. // DEF_OPR_PARAM(LSTMCell);
  1721. DEF_OPR_PARAM(Empty);
  1722. DEF_OPR_IMPL(LSTMCellForward, OperatorBase, 7, 3);
  1723. public:
  1724. virtual void exec(
  1725. _megdnn_tensor_in input, _megdnn_tensor_in weight_ih,
  1726. _megdnn_tensor_in bias_ih, _megdnn_tensor_in hx,
  1727. _megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_hh,
  1728. _megdnn_tensor_in cx, _megdnn_tensor_out h_new, _megdnn_tensor_out c_new,
  1729. _megdnn_tensor_out gates, _megdnn_workspace workspace) = 0;
  1730. void deduce_layout(
  1731. const TensorLayout& input, const TensorLayout& weight_ih,
  1732. const TensorLayout& bias_ih, const TensorLayout& hx,
  1733. const TensorLayout& weight_hh, const TensorLayout& bias_hh,
  1734. const TensorLayout& cx, TensorLayout& h_new, TensorLayout& c_new,
  1735. TensorLayout& gates);
  1736. virtual size_t get_workspace_in_bytes(
  1737. const TensorLayout& input, const TensorLayout& weight_ih,
  1738. const TensorLayout& bias_ih, const TensorLayout& hx,
  1739. const TensorLayout& weight_hh, const TensorLayout& bias_hh,
  1740. const TensorLayout& cx, const TensorLayout& h_new,
  1741. const TensorLayout& c_new, const TensorLayout& gates) = 0;
  1742. protected:
  1743. void check_exec(
  1744. const TensorLayout& input, const TensorLayout& weight_ih,
  1745. const TensorLayout& bias_ih, const TensorLayout& hx,
  1746. const TensorLayout& weight_hh, const TensorLayout& bias_hh,
  1747. const TensorLayout& cx, const TensorLayout& h_new,
  1748. const TensorLayout& c_new, const TensorLayout& gates,
  1749. size_t workspace_in_bytes);
  1750. };
  1751. using LSTMCell = LSTMCellForward;
  1752. class RNNForward : public OperatorBase {
  1753. DEF_OPR_PARAM(RNN);
  1754. DEF_OPR_IMPL(RNNForward, OperatorBase, 3, 3);
  1755. public:
  1756. virtual void exec(
  1757. _megdnn_tensor_in input, _megdnn_tensor_in hx,
  1758. _megdnn_tensor_in flatten_weights, _megdnn_tensor_out output,
  1759. _megdnn_tensor_out hy, _megdnn_tensor_out reserve_space,
  1760. _megdnn_workspace workspace) = 0;
  1761. void deduce_layout(
  1762. const TensorLayout& input, const TensorLayout& hx,
  1763. const TensorLayout& flatten_weights, TensorLayout& output, TensorLayout& hy,
  1764. TensorLayout& reserve_space);
  1765. virtual size_t get_workspace_in_bytes(
  1766. const TensorLayout& input, const TensorLayout& hx,
  1767. const TensorLayout& flatten_weights, const TensorLayout& output,
  1768. const TensorLayout& hy, const TensorLayout& reserve_space) = 0;
  1769. virtual size_t get_reserve_size_in_bytes(const TensorLayout& input) = 0;
  1770. protected:
  1771. void check_exec(
  1772. const TensorLayout& input, const TensorLayout& hx,
  1773. const TensorLayout& flatten_weights, const TensorLayout& output,
  1774. const TensorLayout& hy, const TensorLayout& reserve_space,
  1775. size_t workspace_in_bytes);
  1776. };
  1777. using RNN = RNNForward;
  1778. class RNNBackward : public OperatorBase {
  1779. DEF_OPR_PARAM(RNN);
  1780. DEF_OPR_IMPL(RNNBackward, OperatorBase, 7, 3);
  1781. public:
  1782. virtual void exec(
  1783. _megdnn_tensor_in x, _megdnn_tensor_in y, _megdnn_tensor_in hx,
  1784. _megdnn_tensor_in dy, _megdnn_tensor_in dhy,
  1785. _megdnn_tensor_in flatten_weights, _megdnn_tensor_in reserve_space,
  1786. _megdnn_tensor_out dx, _megdnn_tensor_out dhx, _megdnn_tensor_out dw,
  1787. _megdnn_workspace workspace) = 0;
  1788. void deduce_layout(
  1789. const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx,
  1790. const TensorLayout& dy, const TensorLayout& dhy,
  1791. const TensorLayout& flatten_weights, const TensorLayout& reserve_space,
  1792. TensorLayout& dx, TensorLayout& dhx, TensorLayout& dw);
  1793. virtual size_t get_workspace_in_bytes(
  1794. const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx,
  1795. const TensorLayout& dy, const TensorLayout& dhy,
  1796. const TensorLayout& flatten_weights, const TensorLayout& reserve_space,
  1797. const TensorLayout& dx, const TensorLayout& dhx,
  1798. const TensorLayout& dw) = 0;
  1799. protected:
  1800. void check_exec(
  1801. const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx,
  1802. const TensorLayout& dy, const TensorLayout& dhy,
  1803. const TensorLayout& flatten_weights, const TensorLayout& reserve_space,
  1804. const TensorLayout& dx, const TensorLayout& dhx, const TensorLayout& dw,
  1805. size_t workspace_in_bytes);
  1806. };
  1807. class LSTMForward : public OperatorBase {
  1808. DEF_OPR_PARAM(LSTM);
  1809. DEF_OPR_IMPL(LSTMForward, OperatorBase, 4, 4);
  1810. public:
  1811. virtual void exec(
  1812. _megdnn_tensor_in input, _megdnn_tensor_in hx, _megdnn_tensor_in cx,
  1813. _megdnn_tensor_in flatten_weights, _megdnn_tensor_out output,
  1814. _megdnn_tensor_out hy, _megdnn_tensor_out cy,
  1815. _megdnn_tensor_out reserve_space, _megdnn_workspace workspace) = 0;
  1816. void deduce_layout(
  1817. const TensorLayout& input, const TensorLayout& hx, const TensorLayout& cx,
  1818. const TensorLayout& flatten_weights, TensorLayout& output, TensorLayout& hy,
  1819. TensorLayout& cy, TensorLayout& reserve_space);
  1820. virtual size_t get_workspace_in_bytes(
  1821. const TensorLayout& input, const TensorLayout& hx, const TensorLayout& cx,
  1822. const TensorLayout& flatten_weights, const TensorLayout& output,
  1823. const TensorLayout& hy, const TensorLayout& cy,
  1824. const TensorLayout& reserve_space) = 0;
  1825. virtual size_t get_reserve_size_in_bytes(const TensorLayout& input) = 0;
  1826. protected:
  1827. void check_exec(
  1828. const TensorLayout& input, const TensorLayout& hx, const TensorLayout& cx,
  1829. const TensorLayout& flatten_weights, const TensorLayout& output,
  1830. const TensorLayout& hy, const TensorLayout& cy,
  1831. const TensorLayout& reserve_space, size_t workspace_in_bytes);
  1832. };
  1833. using LSTM = LSTMForward;
  1834. class LSTMBackward : public OperatorBase {
  1835. DEF_OPR_PARAM(LSTM);
  1836. DEF_OPR_IMPL(LSTMBackward, OperatorBase, 9, 4);
  1837. public:
  1838. virtual void exec(
  1839. _megdnn_tensor_in x, _megdnn_tensor_in y, _megdnn_tensor_in hx,
  1840. _megdnn_tensor_in cx, _megdnn_tensor_in dy, _megdnn_tensor_in dhy,
  1841. _megdnn_tensor_in dcy, _megdnn_tensor_in flatten_weights,
  1842. _megdnn_tensor_in reserve_space, _megdnn_tensor_out dx,
  1843. _megdnn_tensor_out dhx, _megdnn_tensor_out dcx, _megdnn_tensor_out dw,
  1844. _megdnn_workspace workspace) = 0;
  1845. void deduce_layout(
  1846. const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx,
  1847. const TensorLayout& cx, const TensorLayout& dy, const TensorLayout& dhy,
  1848. const TensorLayout& dcy, const TensorLayout& flatten_weights,
  1849. const TensorLayout& reserve_space, TensorLayout& dx, TensorLayout& dhx,
  1850. TensorLayout& dcx, TensorLayout& dw);
  1851. virtual size_t get_workspace_in_bytes(
  1852. const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx,
  1853. const TensorLayout& cx, const TensorLayout& dy, const TensorLayout& dhy,
  1854. const TensorLayout& dcy, const TensorLayout& flatten_weights,
  1855. const TensorLayout& reserve_space, const TensorLayout& dx,
  1856. const TensorLayout& dhx, const TensorLayout& dcx,
  1857. const TensorLayout& dw) = 0;
  1858. protected:
  1859. void check_exec(
  1860. const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx,
  1861. const TensorLayout& cx, const TensorLayout& dy, const TensorLayout& dhy,
  1862. const TensorLayout& dcy, const TensorLayout& flatten_weights,
  1863. const TensorLayout& reserve_space, const TensorLayout& dx,
  1864. const TensorLayout& dhx, const TensorLayout& dcx, const TensorLayout& dw,
  1865. size_t workspace_in_bytes);
  1866. };
  1867. } // namespace megdnn
  1868. #include "megdnn/internal/opr_header_epilogue.h"
  1869. // vim: syntax=cpp.doxygen