You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

nn.h 89 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303
  1. #pragma once
  2. #include "megdnn/internal/opr_header_prologue.h"
  3. namespace megdnn {
  4. class SeparableConvBase : public OperatorBase {
  5. DEF_OPR_IMPL_CTOR(SeparableConvBase, OperatorBase);
  6. DEF_OPR_PARAM(SeparableConv);
  7. public:
  8. using Mode = Param::Mode;
  9. protected:
  10. void deduce_layout_fwd(
  11. const TensorLayout& src, const TensorLayout& filter_x,
  12. const TensorLayout& filter_y, TensorLayout& dst);
  13. void check_layout_fwd(
  14. const TensorLayout& src, const TensorLayout& filter_x,
  15. const TensorLayout& filter_y, const TensorLayout& dst);
  16. };
  17. class SeparableConvForward : public SeparableConvBase {
  18. DEF_OPR_IMPL(SeparableConvForward, SeparableConvBase, 3, 1);
  19. public:
  20. virtual void exec(
  21. _megdnn_tensor_in src, _megdnn_tensor_in filter_x,
  22. _megdnn_tensor_in filter_y, _megdnn_tensor_out dst,
  23. _megdnn_workspace workspace) = 0;
  24. void deduce_layout(
  25. const TensorLayout& src, const TensorLayout& filter_x,
  26. const TensorLayout& filter_y, TensorLayout& dst);
  27. virtual size_t get_workspace_in_bytes(
  28. const TensorLayout& src, const TensorLayout& filter_x,
  29. const TensorLayout& filter_y, const TensorLayout& dst) = 0;
  30. protected:
  31. void check_exec(
  32. const TensorLayout& src, const TensorLayout& filter_x,
  33. const TensorLayout& filter_y, const TensorLayout& dst,
  34. size_t workspace_in_bytes);
  35. };
  36. using SeparableConv = SeparableConvForward;
  37. namespace detail {
  38. struct PreprocessedFilter {
  39. //! user data; its lifetime should be bound to MegDNN Convolution
  40. //! operator
  41. void* algorithm_id;
  42. TensorNDArray tensors;
  43. };
  44. } // namespace detail
  45. /**
  46. * \brief base class for convolution operation
  47. *
  48. * This operator is supposed to perform convolution on arbitrary input
  49. * dimensions. The input/output format is N, C, dims..., and kernel format can
  50. * take two forms:
  51. * 1. OC, IC, dims..., for conventional dense convolution
  52. * 2. GROUP, OC_PER_GRP, IC_PER_GRP, dims... for sparse group convolution
  53. *
  54. * Currently, only 2D images are supported.
  55. */
  56. template <typename Parameter>
  57. class ConvolutionBase : public OperatorBase {
  58. DEF_OPR_IMPL_CTOR(ConvolutionBase, OperatorBase);
  59. using Param = Parameter;
  60. public:
  61. Param& param() { return m_param; }
  62. const Param& param() const { return m_param; }
  63. protected:
  64. Param m_param;
  65. public:
  66. static constexpr size_t MAX_SPATIAL_DIM = 2;
  67. using Mode = typename Param::Mode;
  68. struct CanonizedFilterMeta {
  69. DType dtype;
  70. typename Param::Format format;
  71. uint32_t
  72. //! whether filter should be flipped (i.e. is CONVOLUTION)
  73. should_flip,
  74. group, //!< number of groups
  75. icpg, //!< input channels per group
  76. ocpg, //!< output channels per group
  77. spatial_ndim, stride[MAX_SPATIAL_DIM], padding[MAX_SPATIAL_DIM],
  78. //! spatial dim
  79. spatial[MAX_SPATIAL_DIM], dilation[MAX_SPATIAL_DIM],
  80. //! spatial dim with dilation applied
  81. dilated_spatial[MAX_SPATIAL_DIM];
  82. //! T should be a ConvolutionBase<Z>::CanonizedFilterMeta
  83. template <typename T>
  84. void copy_from(const T& b) {
  85. dtype = b.dtype;
  86. format = b.format;
  87. should_flip = b.should_flip;
  88. group = b.group;
  89. icpg = b.icpg;
  90. ocpg = b.ocpg;
  91. spatial_ndim = b.spatial_ndim;
  92. memcpy(stride, b.stride, sizeof(stride));
  93. memcpy(padding, b.padding, sizeof(padding));
  94. memcpy(spatial, b.spatial, sizeof(spatial));
  95. memcpy(dilation, b.dilation, sizeof(dilation));
  96. memcpy(dilated_spatial, b.dilated_spatial, sizeof(dilated_spatial));
  97. }
  98. bool operator==(const CanonizedFilterMeta& b) const {
  99. bool flag = true;
  100. flag = flag && (format == b.format);
  101. flag = flag && (dtype == b.dtype);
  102. flag = flag && (should_flip == b.should_flip);
  103. flag = flag && (group == b.group);
  104. flag = flag && (icpg == b.icpg);
  105. flag = flag && (ocpg == b.ocpg);
  106. flag = flag && (spatial_ndim == b.spatial_ndim);
  107. if (flag) {
  108. for (uint32_t i = 0; i < spatial_ndim; ++i) {
  109. flag = flag && (stride[i] == b.stride[i]);
  110. flag = flag && (padding[i] == b.padding[i]);
  111. flag = flag && (spatial[i] == b.spatial[i]);
  112. flag = flag && (dilation[i] == b.dilation[i]);
  113. flag = flag && (dilated_spatial[i] == b.dilated_spatial[i]);
  114. }
  115. }
  116. return flag;
  117. }
  118. };
  119. using PreprocessedFilter = detail::PreprocessedFilter;
  120. protected:
  121. // Check or deduce output DType
  122. void check_or_deduce_dtype_fwd(DType src, DType filter, DType& dst) const;
  123. CanonizedFilterMeta deduce_layout_fwd(
  124. const TensorLayout& src, const TensorLayout& filter,
  125. TensorLayout& dst) const;
  126. CanonizedFilterMeta check_layout_fwd(
  127. const TensorLayout& src, const TensorLayout& filter,
  128. const TensorLayout& dst) const;
  129. CanonizedFilterMeta make_canonized_filter_meta(
  130. size_t src_ndim, const TensorLayout& filter) const;
  131. };
  132. class MaskPropagate : public OperatorBase {
  133. DEF_OPR_IMPL(MaskPropagate, OperatorBase, 1, 1);
  134. DEF_OPR_PARAM(MaskPropagate);
  135. public:
  136. virtual void exec(
  137. _megdnn_tensor_in src, _megdnn_tensor_out dst,
  138. _megdnn_workspace workspace) = 0;
  139. virtual size_t get_workspace_in_bytes(
  140. const TensorLayout& src, const TensorLayout& dst) = 0;
  141. void deduce_layout(const TensorLayout& src, TensorLayout& dst);
  142. };
  143. /**
  144. * \brief ConvolutionForward Operator with 0/1 Mask matrix
  145. */
  146. class MaskConvForward : public ConvolutionBase<param::Convolution> {
  147. DEF_OPR_IMPL(MaskConvForward, ConvolutionBase, 3, 1);
  148. public:
  149. virtual void exec(
  150. _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_in mask,
  151. _megdnn_tensor_out dst, _megdnn_workspace worksapce) = 0;
  152. virtual size_t get_workspace_in_bytes(
  153. const TensorLayout& src, const TensorLayout& filter,
  154. const TensorLayout& mask, const TensorLayout& dst) = 0;
  155. void deduce_dtype(DType src, DType filter, DType mask, DType& dst);
  156. void deduce_layout(
  157. const TensorLayout& src, const TensorLayout& filter,
  158. const TensorLayout& mask, TensorLayout& dst);
  159. protected:
  160. CanonizedFilterMeta check_exec(
  161. const TensorLayout& src, const TensorLayout& filter,
  162. const TensorLayout& mask, const TensorLayout& dst,
  163. size_t workspace_in_bytes);
  164. };
  165. using MaskConvolution = MaskConvForward;
  166. /**
  167. * \brief ConvolutionForward operator.
  168. */
  169. class ConvolutionForward : public ConvolutionBase<param::Convolution>,
  170. public detail::MultiAlgoOpr<ConvolutionForward, 3> {
  171. DEF_OPR_IMPL(ConvolutionForward, ConvolutionBase, 2, 1);
  172. public:
  173. /**
  174. * \param[in] src (n, ic, ih, iw)
  175. * \param[in] filter (oc, ic, fh, fw)
  176. * \param[in] preprocessed_filter if weight no preprocessed it will be
  177. * nullptr, else the preprocessed weights store in the tensors of
  178. * preprocessed_filter.
  179. * \param[in] workspace if weight no preprocessed
  180. * (preprocessed_filter == nullptr), The size of the workspace satisfies the
  181. * situation that weights is not processed, other wise the size of workspace
  182. * satisfies the situation that weights is preprocessed
  183. * \param[out] dst (n, oc, oh, ow)
  184. */
  185. virtual void exec(
  186. _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_out dst,
  187. const PreprocessedFilter* preprocessed_filter,
  188. _megdnn_workspace workspace) = 0;
  189. /**
  190. * \brief execute weight preprocessing, read weights form filter and write
  191. * to preprocessed_filter after preprocessed.
  192. *
  193. * \praram[in] workspace the needed tmp workspace when exec_preprocess
  194. */
  195. virtual void exec_preprocess(
  196. const TensorLayout& src_layout, _megdnn_tensor_in filter,
  197. const TensorLayout& dst_layout, PreprocessedFilter* preprocessed_filter,
  198. _megdnn_workspace workspace) = 0;
  199. MGE_WIN_DECLSPEC_FUC void deduce_dtype(DType src, DType filter, DType& dst);
  200. MGE_WIN_DECLSPEC_FUC void deduce_layout(
  201. const TensorLayout& src, const TensorLayout& filter, TensorLayout& dst);
  202. /**
  203. * \brief query the workspace needed when executing the opr, if the weights
  204. * are preprocessed the preprocessed_filter will not be nullptr, else it
  205. * will be nullptr, the workspace size maybe different whether weights are
  206. * preprocessed
  207. *
  208. * \return the size of workspace needed when executing
  209. */
  210. virtual size_t get_workspace_in_bytes(
  211. const TensorLayout& src, const TensorLayout& filter,
  212. const TensorLayout& dst, const PreprocessedFilter* preprocessed_filter) = 0;
  213. /**
  214. * \brief deduce the preprocessed filter layouts according to the src,
  215. * filter and dst layout, the result may contain multi layouts when the
  216. * weights is not one
  217. *
  218. * \return SmallVector<TensorLayout> Derive the layouts of weight
  219. * preprocessing, return empty if preprocessing is not needed.
  220. */
  221. virtual SmallVector<TensorLayout> deduce_preprocessed_filter_layout(
  222. const TensorLayout& src, const TensorLayout& filter,
  223. const TensorLayout& dst) = 0;
  224. /**
  225. * \brief query the workspace needed when preprocessing the weights,
  226. * according to the return size, a _megdnn_workspace will be created and
  227. * passed through exec_preprocess
  228. *
  229. * \return the size of workspace needed when preprocessing
  230. */
  231. virtual size_t get_preprocess_workspace_in_bytes(
  232. const TensorLayout& src, const TensorLayout& filter,
  233. const TensorLayout& dst) = 0;
  234. static Algorithm::OprType get_opr_type() {
  235. return Algorithm::OprType::CONVOLUTION_FORWARD;
  236. }
  237. protected:
  238. CanonizedFilterMeta check_exec(
  239. const TensorLayout& src, const TensorLayout& filter,
  240. const TensorLayout& dst, size_t workspace_in_bytes,
  241. const PreprocessedFilter* preprocessed_filter);
  242. };
  243. using Convolution = ConvolutionForward;
  244. /**
  245. * \brief ConvolutionBackwardData operator.
  246. *
  247. * Calculating the gradient wrt. convolution input data.
  248. */
  249. class ConvolutionBackwardData
  250. : public ConvolutionBase<param::Convolution>,
  251. public detail::MultiAlgoOpr<ConvolutionBackwardData, 3> {
  252. DEF_OPR_IMPL(ConvolutionBackwardData, ConvolutionBase, 2, 1);
  253. public:
  254. /**
  255. * \param[in] filter (oc, ic, fh, fw)
  256. * \param[in] diff (n, oc, oh, ow)
  257. * \param[out] grad (n, ic, ih, iw)
  258. */
  259. virtual void exec(
  260. _megdnn_tensor_in filter, _megdnn_tensor_in diff, _megdnn_tensor_out grad,
  261. _megdnn_workspace workspace) = 0;
  262. virtual size_t get_workspace_in_bytes(
  263. const TensorLayout& filter, const TensorLayout& diff,
  264. const TensorLayout& grad) = 0;
  265. MGE_WIN_DECLSPEC_FUC void deduce_dtype(DType filter, DType diff, DType& grad);
  266. MGE_WIN_DECLSPEC_FUC void deduce_layout(
  267. const TensorLayout& filter, const TensorLayout& diff, TensorLayout& grad);
  268. static Algorithm::OprType get_opr_type() {
  269. return Algorithm::OprType::CONVOLUTION_BACKWARD_DATA;
  270. }
  271. protected:
  272. CanonizedFilterMeta check_exec(
  273. const TensorLayout& filter, const TensorLayout& diff,
  274. const TensorLayout& grad, size_t workspace_in_bytes);
  275. };
  276. /**
  277. * \brief ConvolutionBackwardFilter operator.
  278. *
  279. * Calculating the gradient wrt. convolution filter.
  280. */
  281. class ConvolutionBackwardFilter
  282. : public ConvolutionBase<param::Convolution>,
  283. public detail::MultiAlgoOpr<ConvolutionBackwardFilter, 3> {
  284. DEF_OPR_IMPL(ConvolutionBackwardFilter, ConvolutionBase, 2, 1);
  285. public:
  286. /**
  287. * \param[in] src (n, ic, ih, iw)
  288. * \param[in] diff (n, oc, oh, ow)
  289. * \param[out] grad (oc, ic, fh, fw)
  290. */
  291. virtual void exec(
  292. _megdnn_tensor_in src, _megdnn_tensor_in diff, _megdnn_tensor_out grad,
  293. _megdnn_workspace workspace) = 0;
  294. virtual size_t get_workspace_in_bytes(
  295. const TensorLayout& src, const TensorLayout& diff,
  296. const TensorLayout& grad) = 0;
  297. static Algorithm::OprType get_opr_type() {
  298. return Algorithm::OprType::CONVOLUTION_BACKWARD_FILTER;
  299. }
  300. protected:
  301. CanonizedFilterMeta check_exec(
  302. const TensorLayout& src, const TensorLayout& diff, const TensorLayout& grad,
  303. size_t workspace_in_bytes);
  304. };
  305. /**
  306. * \brief ConvolutionBias operator
  307. */
  308. class ConvBiasForward : public ConvolutionBase<param::ConvBias>,
  309. public detail::MultiAlgoOpr<ConvBiasForward, 5> {
  310. DEF_OPR_IMPL(ConvBiasForward, ConvolutionBase, 4, 1);
  311. public:
  312. /**
  313. * \param[in] src (n, ic, ih, iw) or (n, ih, iw, ic)
  314. * \param[in] filter (oc, ic, fh, fw) or (oc, fh, fw, ic) or (oc/4, fh, fw,
  315. * 4 * ic)
  316. * \param[in] bias (1, oc, 1, 1)
  317. * \param[in] z same as dst
  318. * \param[in] preprocessed_filter if weight no preprocessed it will be
  319. * nullptr, else the preprocessed weights store in the tensors of
  320. * preprocessed_filter.
  321. * \param[in] workspace if weight no preprocessed
  322. * (preprocessed_filter == nullptr), The size of the workspace satisfies the
  323. * situation that weights is not processed, other wise the size of workspace
  324. * satisfies the situation that weights is preprocessed
  325. * \param[out] dst (n, oc, oh, ow) or (n, oh, ow, oc)
  326. *
  327. * \note if the format is NCHW_WINOGRAD, the filter layout is (alphah,
  328. * alphaw, oc, ic)
  329. */
  330. virtual void exec(
  331. _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_in bias,
  332. _megdnn_tensor_in z, _megdnn_tensor_out dst,
  333. const PreprocessedFilter* preprocessed_filter,
  334. _megdnn_workspace workspace) = 0;
  335. MGE_WIN_DECLSPEC_FUC void exec(
  336. _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_in bias,
  337. _megdnn_tensor_in z, _megdnn_tensor_out dst, _megdnn_workspace workspace) {
  338. exec(src, filter, bias, z, dst, nullptr, workspace);
  339. }
  340. /**
  341. * \brief execute weight preprocessing, read weights form filter and bias,
  342. * write to preprocessed_filter after preprocessed.
  343. *
  344. * \praram[in] workspace the needed tmp workspace when exec_preprocess
  345. * running, the size is got by get_preprocess_workspace_in_bytes
  346. */
  347. virtual void exec_preprocess(
  348. const TensorLayout& src_layout, _megdnn_tensor_in filter,
  349. _megdnn_tensor_in bias, const TensorLayout& z_layout,
  350. const TensorLayout& dst_layout, PreprocessedFilter* preprocessed_filter,
  351. _megdnn_workspace workspace) = 0;
  352. MGE_WIN_DECLSPEC_FUC void deduce_dtype(
  353. DType src, DType filter, DType bias, DType z, DType& dst);
  354. MGE_WIN_DECLSPEC_FUC void deduce_layout(
  355. const TensorLayout& src, const TensorLayout& filter,
  356. const TensorLayout& bias, const TensorLayout& z, TensorLayout& dst);
  357. /**
  358. * \brief query the workspace needed when executing the opr, if the weights
  359. * are preprocessed the preprocessed_filter will not be nullptr, else it
  360. * will be nullptr, the workspace size maybe different whether weights are
  361. * preprocessed
  362. *
  363. * \return the size of workspace needed when executing
  364. */
  365. virtual size_t get_workspace_in_bytes(
  366. const TensorLayout& src, const TensorLayout& filter,
  367. const TensorLayout& bias, const TensorLayout& z, const TensorLayout& dst,
  368. const PreprocessedFilter* preprocessed_filter) = 0;
  369. /**
  370. * \brief query the workspace needed when pre-processing the weights,
  371. * according to the return size, a _megdnn_workspace will be created and
  372. * passed through exec_preprocess
  373. *
  374. * \return the size of workspace needed when pre-processing
  375. */
  376. virtual size_t get_preprocess_workspace_in_bytes(
  377. const TensorLayout& src, const TensorLayout& filter,
  378. const TensorLayout& bias, const TensorLayout& z,
  379. const TensorLayout& dst) = 0;
  380. /**
  381. * \brief deduce the pre-processed filter layouts according to the src,
  382. * filter and dst layout, which may contain multi layouts when the weights
  383. * is not one
  384. *
  385. * \return SmallVector<TensorLayout> Derive the layouts of weight
  386. * preprocessing, return empty if preprocessing is not needed.
  387. */
  388. virtual SmallVector<TensorLayout> deduce_preprocessed_filter_layout(
  389. const TensorLayout& src, const TensorLayout& filter,
  390. const TensorLayout& bias, const TensorLayout& z,
  391. const TensorLayout& dst) = 0;
  392. enum class BiasMode : uint32_t {
  393. NO_BIAS = 0, //!< no bias
  394. BROADCAST_CHANNEL_BIAS, //!< broadcast channel bias, [1, c, 1, 1]
  395. BIAS //!< [N, C, H, W]
  396. };
  397. //! param for winograd algos.
  398. struct WinogradParam {
  399. uint32_t channel_block_size;
  400. uint32_t output_block_size;
  401. uint32_t tile_size;
  402. bool operator==(const WinogradParam& rhs) const {
  403. return channel_block_size == rhs.channel_block_size &&
  404. output_block_size == rhs.output_block_size &&
  405. tile_size == rhs.tile_size;
  406. }
  407. std::string to_string() const;
  408. };
  409. static constexpr WinogradParam INVALID_WINOGRAD_PARAM = {0, 0, 0};
  410. struct DirectParam {
  411. std::string to_string() const { return ""; }
  412. };
  413. struct MatmulParam {
  414. std::string to_string() const { return ""; }
  415. };
  416. struct DefaultParam {
  417. std::string to_string() const { return ""; }
  418. };
  419. //! get algo name, the format is ParamTrait<T>::category:base:p.to_string()
  420. //! \warning: base must not contain :.
  421. template <typename T>
  422. static std::string algo_name(
  423. const std::string& base, const T& p,
  424. param::ConvBias::Format format = param::ConvBias::Format::NCHW);
  425. /*!
  426. * \brief parse algo_name and get WinogradParam from algo name.
  427. *
  428. * \param algo name string
  429. * \return WinogradParam parsed from algo name, use pattern
  430. * winograd:base:m:tile_size.
  431. *
  432. * \warning: INVALID_WINOGRAD_PARAM returns if the algo_name is not matched.
  433. */
  434. static WinogradParam parse_winograd_name(const std::string& algo_name);
  435. /**
  436. * @brief find if there is nchw_nchwxx conv kernel optimized for argment,
  437. * nchw44 used for arm, nchw88 used for x86
  438. *
  439. * @param src_dtype conv feature map data type
  440. * @param filter_dtype conv filter or weight data type
  441. * @param dst_dtype output data type
  442. * @param fm filter meta param
  443. * @param bias_mode bias mode, no_bias or broadcast or bias
  444. * @param nonline_mode identity or relu or h_swish or sigmoid
  445. * @return true, found a kernel
  446. * @return false, can`t found any kernel
  447. */
  448. static bool is_nchw_nchwxx_optimized(
  449. const DTypeEnum src_dtype, const DTypeEnum filter_dtype,
  450. const DTypeEnum dst_dtype,
  451. const ConvolutionBase<param::Convolution>::CanonizedFilterMeta& fm,
  452. const ConvBiasForward::BiasMode bias_mode,
  453. const param::ConvBias::NonlineMode nonline_mode);
  454. static Algorithm::OprType get_opr_type() {
  455. return Algorithm::OprType::CONVBIAS_FORWARD;
  456. }
  457. protected:
  458. CanonizedFilterMeta check_exec(
  459. const TensorLayout& src, const TensorLayout& filter,
  460. const TensorLayout& bias, const TensorLayout& z, const TensorLayout& dst,
  461. size_t workspace_in_bytes, const PreprocessedFilter* preprocessed_filter);
  462. CanonizedFilterMeta check_exec_allow_noncontiguous(
  463. const TensorLayout& src, const TensorLayout& filter,
  464. const TensorLayout& bias, const TensorLayout& z, const TensorLayout& dst,
  465. size_t workspace_in_bytes, const PreprocessedFilter* preprocessed_filter);
  466. };
  467. using ConvBias = ConvBiasForward;
  468. /**
  469. * \brief base class for Conv - Nonline - Pooling
  470. */
  471. class ConvPoolingBase : public OperatorBase {
  472. DEF_OPR_IMPL_CTOR(ConvPoolingBase, OperatorBase);
  473. /**
  474. * \ Param::Method: Two methods to fetch the input data.
  475. * Default methods is WITH_TEXTURE_OBJ.
  476. * If you want to use WITH_SHARED_MEM mode,
  477. * please make sure that the size of
  478. * [ all of the fliter kernels + a channel
  479. * of input data + a channel of output data]
  480. * should be no large than 38KB.
  481. * And the pooling mode should not be "MAX".
  482. */
  483. DEF_OPR_PARAM(ConvPooling);
  484. protected:
  485. virtual void deduce_layout(
  486. const TensorLayout& src, const TensorLayout& filter,
  487. const TensorLayout& bias, TensorLayout& dst) = 0;
  488. virtual void check_layout(
  489. const TensorLayout& src, const TensorLayout& filter,
  490. const TensorLayout& bias, TensorLayout& dst,
  491. size_t workspace_limit_in_bytes) = 0;
  492. };
  493. class ConvPoolingForward : public ConvPoolingBase {
  494. DEF_OPR_IMPL(ConvPoolingForward, ConvPoolingBase, 2, 1);
  495. public:
  496. /**
  497. * \param[in] src input tensor
  498. * \param[out] dst output tensor
  499. */
  500. virtual void exec(
  501. const _megdnn_in TensorND src, const _megdnn_in TensorND filter,
  502. const _megdnn_in TensorND bias, _megdnn_out TensorND dst,
  503. _megdnn_out Workspace workspace) = 0;
  504. virtual void deduce_layout(
  505. const TensorLayout& src, const TensorLayout& filter,
  506. const TensorLayout& bias, TensorLayout& dst) = 0;
  507. virtual size_t get_workspace_in_bytes(
  508. const TensorLayout& src, const TensorLayout& filter,
  509. const TensorLayout& bias, const TensorLayout& dst) = 0;
  510. protected:
  511. virtual void check_layout(
  512. const TensorLayout& src, const TensorLayout& filter,
  513. const TensorLayout& bias, TensorLayout& dst,
  514. size_t workspace_limit_in_bytes) = 0;
  515. };
  516. using ConvPooling = ConvPoolingForward;
  517. class GroupLocalBase : public OperatorBase {
  518. DEF_OPR_IMPL_CTOR(GroupLocalBase, OperatorBase);
  519. DEF_OPR_PARAM(Convolution);
  520. public:
  521. using Mode = Param::Mode;
  522. protected:
  523. void deduce_layout_fwd(
  524. const TensorLayout& src, const TensorLayout& filter, TensorLayout& dst);
  525. void check_layout_fwd(
  526. const TensorLayout& src, const TensorLayout& filter,
  527. const TensorLayout& dst);
  528. };
  529. class GroupLocalForward : public GroupLocalBase {
  530. DEF_OPR_IMPL(GroupLocalForward, GroupLocalBase, 2, 1);
  531. public:
  532. /**
  533. * \param[in] src (N, IC, IH, IW)
  534. * \param[in] filter (G, OH, OW, IC/G, FH, FW, OC/G)
  535. * \param[out] dst (N, OC, OH, OW)
  536. **/
  537. virtual void exec(
  538. _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_out dst,
  539. _megdnn_workspace workspace) = 0;
  540. void deduce_layout(
  541. const TensorLayout& src, const TensorLayout& filter, TensorLayout& dst) {
  542. deduce_layout_fwd(src, filter, dst);
  543. }
  544. virtual size_t get_workspace_in_bytes(
  545. const TensorLayout& src, const TensorLayout& filter,
  546. const TensorLayout& dst) = 0;
  547. protected:
  548. void check_exec(
  549. const TensorLayout& src, const TensorLayout& filter,
  550. const TensorLayout& dst, size_t workspace_in_bytes);
  551. };
  552. using GroupLocal = GroupLocalForward;
  553. class GroupLocalBackwardData : public GroupLocalBase {
  554. DEF_OPR_IMPL(GroupLocalBackwardData, GroupLocalBase, 2, 1);
  555. public:
  556. virtual void exec(
  557. _megdnn_tensor_in filter, _megdnn_tensor_in diff, _megdnn_tensor_out grad,
  558. _megdnn_workspace workspace) = 0;
  559. virtual size_t get_workspace_in_bytes(
  560. const TensorLayout& filter, const TensorLayout& diff,
  561. const TensorLayout& grad) = 0;
  562. protected:
  563. void check_exec(
  564. const TensorLayout& filter, const TensorLayout& diff,
  565. const TensorLayout& grad, size_t workspace_in_bytes);
  566. };
  567. class GroupLocalBackwardFilter : public GroupLocalBase {
  568. DEF_OPR_IMPL(GroupLocalBackwardFilter, GroupLocalBase, 2, 1);
  569. public:
  570. virtual void exec(
  571. _megdnn_tensor_in src, _megdnn_tensor_in diff, _megdnn_tensor_out grad,
  572. _megdnn_workspace workspace) = 0;
  573. virtual size_t get_workspace_in_bytes(
  574. const TensorLayout& src, const TensorLayout& diff,
  575. const TensorLayout& grad) = 0;
  576. protected:
  577. void check_exec(
  578. const TensorLayout& filter, const TensorLayout& diff,
  579. const TensorLayout& grad, size_t workspace_in_bytes);
  580. };
  581. class Images2NeibsBase : public OperatorBase {
  582. DEF_OPR_IMPL_CTOR(Images2NeibsBase, OperatorBase);
  583. DEF_OPR_PARAM(Images2Neibs);
  584. protected:
  585. void deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst);
  586. void check_layout_fwd(const TensorLayout& filter, const TensorLayout& dst);
  587. };
  588. class Images2NeibsForward : public Images2NeibsBase {
  589. DEF_OPR_IMPL(Images2NeibsForward, Images2NeibsBase, 1, 1);
  590. public:
  591. /**
  592. * \param[in] src (N, C, IH, IW)
  593. * \param[out] dst (N, C, OH, OW, window_h, window_w)
  594. *
  595. * \see
  596. * http://deeplearning.net/software/theano/library/tensor/nnet/neighbours.html
  597. *
  598. * \f$ dst_{n, c, oh, ow, wh, ww} = src_{n, c, ih+wh, iw+fw}\f$,
  599. * where \f$ ih=-pad_h+oh*stride_h+(wh-1)*(dilation_h-1),
  600. * iw=-pad_w+ow*stride_w+(ww-1)*(dilation_w-1)\f$.
  601. */
  602. virtual void exec(
  603. _megdnn_tensor_in src, _megdnn_tensor_out dst,
  604. _megdnn_workspace workspace) = 0;
  605. virtual size_t get_workspace_in_bytes(
  606. const TensorLayout& src, const TensorLayout& dst) = 0;
  607. void deduce_layout(const TensorLayout& src, TensorLayout& dst);
  608. protected:
  609. void check_exec(
  610. const TensorLayout& src, const TensorLayout& dst,
  611. size_t workspace_in_bytes);
  612. };
  613. using Images2Neibs = Images2NeibsForward;
  614. class Images2NeibsBackward : public Images2NeibsBase {
  615. DEF_OPR_IMPL(Images2NeibsBackward, Images2NeibsBase, 1, 1);
  616. public:
  617. /**
  618. * \param[in] diff the backpropagated gradient wrt. dst
  619. * \param[out] grad the backpropagated gradient wrt. src
  620. */
  621. virtual void exec(
  622. _megdnn_tensor_in diff, _megdnn_tensor_out grad,
  623. _megdnn_workspace workspace) = 0;
  624. virtual size_t get_workspace_in_bytes(
  625. const TensorLayout& diff, const TensorLayout& grad) = 0;
  626. protected:
  627. void check_exec(
  628. const TensorLayout& diff, const TensorLayout& grad,
  629. size_t workspace_in_bytes);
  630. };
  631. class SlidingWindowTransposeBase : public OperatorBase {
  632. DEF_OPR_IMPL_CTOR(SlidingWindowTransposeBase, OperatorBase);
  633. DEF_OPR_PARAM(SlidingWindowTranspose);
  634. protected:
  635. void deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst);
  636. void check_layout_fwd(const TensorLayout& filter, const TensorLayout& dst);
  637. };
  638. class SlidingWindowTransposeForward : public SlidingWindowTransposeBase {
  639. DEF_OPR_IMPL(SlidingWindowTransposeForward, SlidingWindowTransposeBase, 1, 1);
  640. public:
  641. /**
  642. * \param[in] src (N, C, IH, IW, window_h, window_w)
  643. * \param[out] dst (N, C, OH, OW)
  644. */
  645. virtual void exec(
  646. _megdnn_tensor_in src, _megdnn_tensor_out dst,
  647. _megdnn_workspace workspace) = 0;
  648. virtual size_t get_workspace_in_bytes(
  649. const TensorLayout& src, const TensorLayout& dst) = 0;
  650. void deduce_layout(const TensorLayout& src, TensorLayout& dst);
  651. protected:
  652. void check_exec(
  653. const TensorLayout& src, const TensorLayout& dst,
  654. size_t workspace_in_bytes);
  655. };
  656. using SlidingWindowTranspose = SlidingWindowTransposeForward;
  657. class SlidingWindowTransposeBackward : public SlidingWindowTransposeBase {
  658. DEF_OPR_IMPL(SlidingWindowTransposeBackward, SlidingWindowTransposeBase, 1, 1);
  659. public:
  660. /**
  661. * \param[in] diff the backpropagated gradient wrt. dst
  662. * \param[out] grad the backpropagated gradient wrt. src
  663. */
  664. virtual void exec(
  665. _megdnn_tensor_in diff, _megdnn_tensor_out grad,
  666. _megdnn_workspace workspace) = 0;
  667. virtual size_t get_workspace_in_bytes(
  668. const TensorLayout& diff, const TensorLayout& grad) = 0;
  669. protected:
  670. void check_exec(
  671. const TensorLayout& diff, const TensorLayout& grad,
  672. size_t workspace_in_bytes);
  673. };
  674. /**
  675. * \brief base class for Pooling
  676. */
  677. class PoolingBase : public OperatorBase {
  678. DEF_OPR_IMPL_CTOR(PoolingBase, OperatorBase);
  679. DEF_OPR_PARAM(Pooling);
  680. public:
  681. using Mode = Param::Mode;
  682. protected:
  683. void deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst);
  684. void check_layout_fwd(const TensorLayout& src, const TensorLayout& dst);
  685. public:
  686. static void deduce_layout_impl(
  687. const TensorLayout& src, const Param& param, TensorLayout& dst);
  688. };
  689. class PoolingForward : public PoolingBase,
  690. public detail::MultiAlgoOpr<PoolingForward, 2> {
  691. DEF_OPR_IMPL(PoolingForward, PoolingBase, 1, 1);
  692. public:
  693. /**
  694. * \param[in] src input tensor
  695. * \param[out] dst output tensor
  696. */
  697. virtual void exec(
  698. _megdnn_tensor_in src, _megdnn_tensor_out dst,
  699. _megdnn_workspace workspace) = 0;
  700. MGE_WIN_DECLSPEC_FUC void deduce_layout(const TensorLayout& src, TensorLayout& dst);
  701. virtual size_t get_workspace_in_bytes(
  702. const TensorLayout& src, const TensorLayout& dst) = 0;
  703. static Algorithm::OprType get_opr_type() {
  704. return Algorithm::OprType::POOLING_FORWARD;
  705. }
  706. protected:
  707. void check_exec(
  708. const TensorLayout& src, const TensorLayout& dst,
  709. size_t workspace_in_bytes);
  710. };
  711. using Pooling = PoolingForward;
  712. class PoolingBackward : public PoolingBase,
  713. public detail::MultiAlgoOpr<PoolingBackward, 4> {
  714. DEF_OPR_IMPL(PoolingBackward, PoolingBase, 3, 1);
  715. public:
  716. /**
  717. * \param[in] src the `src' parameter in PoolingForward::exec
  718. * \param[in] dst the `dst' parameter in PoolingForward::exec
  719. * \param[in] diff the backpropagated gradient wrt. dst
  720. * \param[out] grad the backpropagated gradient wrt. src
  721. */
  722. virtual void exec(
  723. _megdnn_tensor_in src, _megdnn_tensor_in dst, _megdnn_tensor_in diff,
  724. _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
  725. virtual size_t get_workspace_in_bytes(
  726. const TensorLayout& src, const TensorLayout& dst, const TensorLayout& diff,
  727. const TensorLayout& grad) = 0;
  728. static Algorithm::OprType get_opr_type() {
  729. return Algorithm::OprType::POOLING_BACKWARD;
  730. }
  731. protected:
  732. void check_exec(
  733. const TensorLayout& src, const TensorLayout& dst, const TensorLayout& diff,
  734. const TensorLayout& grad, size_t workspace_in_bytes);
  735. };
  736. /**
  737. * \brief base class for AdaptivePooling
  738. */
  739. class AdaptivePoolingBase : public OperatorBase {
  740. DEF_OPR_IMPL_CTOR(AdaptivePoolingBase, OperatorBase);
  741. DEF_OPR_PARAM(AdaptivePooling);
  742. protected:
  743. param::Pooling deduce_pooling_param(
  744. const TensorLayout& src, const TensorLayout& dst);
  745. };
  746. class AdaptivePoolingForward : public AdaptivePoolingBase {
  747. DEF_OPR_IMPL(AdaptivePoolingForward, AdaptivePoolingBase, 1, 1);
  748. public:
  749. /**
  750. * \param[in] src input tensor
  751. * \param[out] dst output tensor
  752. */
  753. virtual void exec(
  754. _megdnn_tensor_in src, _megdnn_tensor_out dst,
  755. _megdnn_workspace workspace) = 0;
  756. virtual size_t get_workspace_in_bytes(
  757. const TensorLayout& src, const TensorLayout& dst) = 0;
  758. };
  759. using AdaptivePooling = AdaptivePoolingForward;
  760. class AdaptivePoolingBackward : public AdaptivePoolingBase {
  761. DEF_OPR_IMPL(AdaptivePoolingBackward, AdaptivePoolingBase, 3, 1);
  762. public:
  763. /**
  764. * \param[in] src the `src' parameter in AdaptivePoolingForward::exec
  765. * \param[in] dst the `dst' parameter in AdaptivePoolingForward::exec
  766. * \param[in] diff the backpropagated gradient wrt. dst
  767. * \param[out] grad the backpropagated gradient wrt. src
  768. */
  769. virtual void exec(
  770. _megdnn_tensor_in src, _megdnn_tensor_in dst, _megdnn_tensor_in diff,
  771. _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
  772. virtual size_t get_workspace_in_bytes(
  773. const TensorLayout& src, const TensorLayout& dst, const TensorLayout& diff,
  774. const TensorLayout& grad) = 0;
  775. };
  776. /**
  777. * \brief base class for Local
  778. */
  779. class LocalBase : public OperatorBase {
  780. DEF_OPR_IMPL_CTOR(LocalBase, OperatorBase);
  781. DEF_OPR_PARAM(Convolution);
  782. public:
  783. using Mode = Param::Mode;
  784. protected:
  785. void deduce_layout_fwd(
  786. const TensorLayout& src, const TensorLayout& filter, TensorLayout& dst);
  787. void check_layout_fwd(
  788. const TensorLayout& src, const TensorLayout& filter,
  789. const TensorLayout& dst);
  790. };
  791. class LocalForward : public LocalBase {
  792. DEF_OPR_IMPL(LocalForward, LocalBase, 2, 1);
  793. public:
  794. /**
  795. * \param[in] src (n, ic, ih, iw)
  796. * \param[in] filter (oh, ow, ic, fh, fw, oc)
  797. * \param[out] dst (n, oc, oh, ow)
  798. */
  799. virtual void exec(
  800. _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_out dst,
  801. _megdnn_workspace workspace) = 0;
  802. /**
  803. * \brief Deducing output tensor layouts from input tensor layouts.
  804. *
  805. * Be aware that the first and second dimension of `filter' are ignored
  806. * when deducing `dst' layout.
  807. */
  808. void deduce_layout(
  809. const TensorLayout& src, const TensorLayout& filter, TensorLayout& dst);
  810. virtual size_t get_workspace_in_bytes(
  811. const TensorLayout& src, const TensorLayout& filter,
  812. const TensorLayout& dst) = 0;
  813. protected:
  814. void check_exec(
  815. const TensorLayout& src, const TensorLayout& filter,
  816. const TensorLayout& dst, size_t workspace_in_bytes);
  817. };
  818. using Local = LocalForward;
  819. class LocalBackwardData : public LocalBase {
  820. DEF_OPR_IMPL(LocalBackwardData, LocalBase, 2, 1);
  821. public:
  822. /**
  823. * \param[in] filter (oh, ow, ic, fh, fw, oc)
  824. * \param[in] diff (n, oc, oh, ow)
  825. * \param[out] grad (n, ic, ih, iw)
  826. */
  827. virtual void exec(
  828. _megdnn_tensor_in filter, _megdnn_tensor_in diff, _megdnn_tensor_out grad,
  829. _megdnn_workspace workspace) = 0;
  830. virtual size_t get_workspace_in_bytes(
  831. const TensorLayout& filter, const TensorLayout& diff,
  832. const TensorLayout& grad) = 0;
  833. protected:
  834. void check_exec(
  835. const TensorLayout& filter, const TensorLayout& diff,
  836. const TensorLayout& grad, size_t workspace_in_bytes);
  837. };
  838. class LocalBackwardFilter : public LocalBase {
  839. DEF_OPR_IMPL(LocalBackwardFilter, LocalBase, 2, 1);
  840. public:
  841. /**
  842. * \param[in] src (n, ic, ih, iw)
  843. * \param[in] diff (n, oc, oh, ow)
  844. * \param[out] grad (oh, ow, ic, fh, fw, oc)
  845. */
  846. virtual void exec(
  847. _megdnn_tensor_in src, _megdnn_tensor_in diff, _megdnn_tensor_out grad,
  848. _megdnn_workspace workspace) = 0;
  849. virtual size_t get_workspace_in_bytes(
  850. const TensorLayout& src, const TensorLayout& diff,
  851. const TensorLayout& grad) = 0;
  852. protected:
  853. void check_exec(
  854. const TensorLayout& src, const TensorLayout& diff, const TensorLayout& grad,
  855. size_t workspace_in_bytes);
  856. };
  857. class BNBase : public OperatorBase {
  858. DEF_OPR_IMPL_CTOR(BNBase, OperatorBase);
  859. DEF_OPR_PARAM(BN);
  860. protected:
  861. void check_param();
  862. };
  863. class BNForward : public BNBase {
  864. DEF_OPR_IMPL(BNForward, BNBase, 6, 6);
  865. public:
  866. /**
  867. * \dst[i] = gemma
  868. * *(x[i]-estimatedMean[k])/sqrt(epsilon+estimatedVariance[k]) + beta \where
  869. * epsilon is a very small value to avoid a "divide by zero" error.
  870. * \param[in] src (n, c, h, w)
  871. * \param[out] dst (n, c, h, w)
  872. * \param[out] mean (see m_param.ParamDim) Global mean.
  873. * \param[out] variance (see m_param.ParamDim) Global variance.
  874. * \param[out] batch_mean (see m_param.ParamDim)
  875. * Optionally cached intermediate mean from forward pass
  876. * \param[out] batch_inv_variance (see m_param.ParamDim)
  877. * Optionally cached intermediate variance from forward pass
  878. * \param[out] reserve (see cudnnBatchNormalizationForwardTrainingEx)
  879. * src and dst must have the same shape.
  880. * src and dst must be contiguous.
  881. */
  882. virtual void exec(
  883. _megdnn_tensor_in src, _megdnn_tensor_in bn_scale,
  884. _megdnn_tensor_in bn_bias, _megdnn_tensor_inout mean,
  885. _megdnn_tensor_inout variance, _megdnn_tensor_out batch_mean,
  886. _megdnn_tensor_out batch_inv_variance, _megdnn_tensor_out reserve,
  887. _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0;
  888. void deduce_layout(
  889. const TensorLayout& src, const TensorLayout& bn_scale,
  890. const TensorLayout& bn_bias, TensorLayout& mean, TensorLayout& variance,
  891. TensorLayout& batch_mean, TensorLayout& batch_inv_variance,
  892. TensorLayout& reserve, TensorLayout& dst);
  893. virtual size_t get_workspace_in_bytes(
  894. const TensorLayout& src, const TensorLayout& bn_scale,
  895. const TensorLayout& bn_bias, const TensorLayout& mean,
  896. const TensorLayout& variance, const TensorLayout& batch_mean,
  897. const TensorLayout& batch_inv_variance, const TensorLayout& reserve,
  898. const TensorLayout& dst) = 0;
  899. virtual size_t get_reserve_in_bytes(const TensorLayout& src) = 0;
  900. protected:
  901. void check_exec(
  902. const TensorLayout& src, const TensorLayout& bn_scale,
  903. const TensorLayout& bn_bias, const TensorLayout& mean,
  904. const TensorLayout& variance, const TensorLayout& batch_mean,
  905. const TensorLayout& batch_inv_variance, const TensorLayout& dst,
  906. size_t workspace_in_bytes, size_t reserve_in_bytes = 0);
  907. };
  908. using BN = BNForward;
  909. class BNBackward : public BNBase {
  910. DEF_OPR_IMPL(BNBackward, BNBase, 6, 3);
  911. public:
  912. /**
  913. * \param[in] input data of forwarding propagate.
  914. * \param[in] dy the backpropagated gradient of y.
  915. * \param[out] dx the backpropagated gradient of x.
  916. * \param[out] d_bn_scale, the backpropagated gradient of bn_scale.
  917. * \param[out] d_bn_bias, the backpropagated gradient of bn_bias.
  918. * Optionally cached intermediate results from forward pass
  919. * \param[in] saved_batch_mean mean of the input batch.
  920. Calculated in the forwardpropagation.
  921. * \param[in] saved_batch_variance of the input batch.
  922. Calculated in the forwardpropagation.
  923. * \param[in] reserve (see cudnnBatchNormalizationBackwardEx)
  924. */
  925. virtual void exec(
  926. _megdnn_tensor_in x, _megdnn_tensor_in dy,
  927. _megdnn_tensor_in saved_batch_mean, _megdnn_tensor_in saved_batch_variance,
  928. _megdnn_tensor_in bn_scale, _megdnn_tensor_in reserve,
  929. _megdnn_tensor_out d_bn_scale, _megdnn_tensor_out d_bn_bias,
  930. _megdnn_tensor_out dx, _megdnn_workspace workspace) = 0;
  931. virtual size_t get_workspace_in_bytes(
  932. const TensorLayout& x, const TensorLayout& dy,
  933. const TensorLayout& saved_batch_mean,
  934. const TensorLayout& saved_batch_variance, const TensorLayout& bn_scale,
  935. const TensorLayout& reserve, const TensorLayout& d_bn_scale,
  936. const TensorLayout& d_bn_bias, const TensorLayout& dx) = 0;
  937. virtual size_t get_reserve_in_bytes(const TensorLayout& src) = 0;
  938. protected:
  939. void check_exec(
  940. const TensorLayout& x, const TensorLayout& dy,
  941. const TensorLayout& saved_batch_mean,
  942. const TensorLayout& saved_batch_variance, const TensorLayout& bn_scale,
  943. const TensorLayout& d_bn_scale, const TensorLayout& d_bn_bias,
  944. const TensorLayout& dx, size_t workspace_in_bytes,
  945. size_t reserve_in_bytes = 0);
  946. };
  947. class LRNBase : public OperatorBase {
  948. DEF_OPR_IMPL_CTOR(LRNBase, OperatorBase);
  949. DEF_OPR_PARAM(LRN);
  950. protected:
  951. void check_param();
  952. };
  953. class LRNForward : public LRNBase {
  954. DEF_OPR_IMPL(LRNForward, LRNBase, 1, 1);
  955. public:
  956. /**
  957. * \see ImageNet Classification with Deep Convolutional Neural Networks
  958. * \param[in] src (n, c, h, w)
  959. * \param[out] dst (n, c, h, w)
  960. *
  961. * src and dst must have the same shape.
  962. * src and dst must be contiguous.
  963. */
  964. virtual void exec(
  965. _megdnn_tensor_in src, _megdnn_tensor_out dst,
  966. _megdnn_workspace workspace) = 0;
  967. void deduce_layout(const TensorLayout& src, TensorLayout& dst);
  968. virtual size_t get_workspace_in_bytes(
  969. const TensorLayout& src, const TensorLayout& dst) = 0;
  970. protected:
  971. void check_exec(
  972. const TensorLayout& src, const TensorLayout& dst,
  973. size_t workspace_in_bytes);
  974. };
  975. using LRN = LRNForward;
  976. class LRNBackward : public LRNBase {
  977. DEF_OPR_IMPL(LRNBackward, LRNBase, 3, 1);
  978. public:
  979. /**
  980. * \param[in] src the `src' parameter in LRNForward::exec
  981. * \param[in] dst the `dst' parameter in LRNForward::exec
  982. * \param[in] diff the backpropagated gradient wrt. dst
  983. * \param[out] grad the backpropagated gradient wrt. src
  984. *
  985. * All tensors should be contiguous and of the same shape.
  986. */
  987. virtual void exec(
  988. _megdnn_tensor_in src, _megdnn_tensor_in dst, _megdnn_tensor_in diff,
  989. _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
  990. virtual size_t get_workspace_in_bytes(
  991. const TensorLayout& src, const TensorLayout& dst, const TensorLayout& diff,
  992. const TensorLayout& grad) = 0;
  993. protected:
  994. void check_exec(
  995. const TensorLayout& src, const TensorLayout& dst, const TensorLayout& diff,
  996. const TensorLayout& grad, size_t workspace_in_bytes);
  997. };
  998. class ROIPoolingBase : public OperatorBase {
  999. DEF_OPR_IMPL_CTOR(ROIPoolingBase, OperatorBase);
  1000. DEF_OPR_PARAM(ROIPooling);
  1001. protected:
  1002. void check_layout_fwd(
  1003. const TensorLayout& src, const TensorLayout& rois, const TensorLayout& dst,
  1004. const TensorLayout& index);
  1005. };
  1006. class ROIPoolingForward : public ROIPoolingBase {
  1007. DEF_OPR_IMPL(ROIPoolingForward, ROIPoolingBase, 2, 2);
  1008. public:
  1009. /**
  1010. * \param[in] src (n, c, ih, iw)
  1011. * \param[in] rois (m, 5)
  1012. * \param[out] dst (m, c, oh, ow)
  1013. * \param[out] index (m, c, oh, ow) if mode is MAX, (0) if mode is AVERAGE
  1014. *
  1015. * The internal implementation is akin to
  1016. * https://github.com/rbgirshick/caffe-fast-rcnn .d
  1017. * Note that rois(, 0) denotes the input image index. We store it as
  1018. * a float, but it should be an integer instead.
  1019. *
  1020. * index is a temporary tensor to facilitate its backward operator.
  1021. * It is used to store argmax indicex in MAX mode, and it is not used
  1022. * in AVERAGE mode.
  1023. */
  1024. virtual void exec(
  1025. _megdnn_tensor_in src, _megdnn_tensor_in rois, _megdnn_tensor_out dst,
  1026. _megdnn_tensor_out index, _megdnn_workspace workspace) = 0;
  1027. virtual size_t get_workspace_in_bytes(
  1028. const TensorLayout& src, const TensorLayout& rois, const TensorLayout& dst,
  1029. const TensorLayout& index) = 0;
  1030. protected:
  1031. void check_exec(
  1032. const TensorLayout& src, const TensorLayout& rois, const TensorLayout& dst,
  1033. const TensorLayout& index, size_t workspace_in_bytes);
  1034. };
  1035. using ROIPooling = ROIPoolingForward;
  1036. class ROIPoolingBackward : public ROIPoolingBase {
  1037. DEF_OPR_IMPL(ROIPoolingBackward, ROIPoolingBase, 4, 1);
  1038. public:
  1039. /**
  1040. * \param[in] diff the backpropagated gradient wrt. dst
  1041. * \param[in] src the `src' parameter in ROIPoolingForward::exec
  1042. * \param[in] rois the `rois' parameter in ROIPoolingForward::exec
  1043. * \param[in] index the `index' parameter in ROIPoolingForward::exec
  1044. * \param[out] grad the backpropagated gradient wrt. src
  1045. */
  1046. virtual void exec(
  1047. _megdnn_tensor_in diff, _megdnn_tensor_in src, _megdnn_tensor_in rois,
  1048. _megdnn_tensor_in index, _megdnn_tensor_out grad,
  1049. _megdnn_workspace workspace) = 0;
  1050. virtual size_t get_workspace_in_bytes(
  1051. const TensorLayout& diff, const TensorLayout& src, const TensorLayout& rois,
  1052. const TensorLayout& index, const TensorLayout& grad) = 0;
  1053. protected:
  1054. void check_exec(
  1055. const TensorLayout& diff, const TensorLayout& src, const TensorLayout& rois,
  1056. const TensorLayout& index, const TensorLayout& grad,
  1057. size_t workspace_in_bytes);
  1058. };
  1059. class Convolution3DBase : public OperatorBase {
  1060. DEF_OPR_IMPL_CTOR(Convolution3DBase, OperatorBase);
  1061. DEF_OPR_PARAM(Convolution3D);
  1062. public:
  1063. static constexpr size_t MAX_SPATIAL_DIM = 3;
  1064. using Mode = Param::Mode;
  1065. struct CanonizedFilterMeta {
  1066. DTypeEnum dtype_enum;
  1067. Param::Format format;
  1068. uint32_t
  1069. //! whether filter should be flipped (i.e. is CONVOLUTION)
  1070. should_flip,
  1071. group, //!< number of groups
  1072. icpg, //!< input channels per group
  1073. ocpg, //!< output channels per group
  1074. spatial_ndim, stride[MAX_SPATIAL_DIM], padding[MAX_SPATIAL_DIM],
  1075. //! spatial dim
  1076. spatial[MAX_SPATIAL_DIM], dilation[MAX_SPATIAL_DIM],
  1077. //! spatial dim with dilation applied
  1078. dilated_spatial[MAX_SPATIAL_DIM];
  1079. } MEGDNN_PACKED;
  1080. protected:
  1081. CanonizedFilterMeta deduce_layout_fwd(
  1082. const TensorLayout& src, const TensorLayout& filter,
  1083. TensorLayout& dst) const;
  1084. CanonizedFilterMeta check_layout_fwd(
  1085. const TensorLayout& src, const TensorLayout& filter,
  1086. const TensorLayout& dst) const;
  1087. static CanonizedFilterMeta make_canonized_filter_meta_impl(
  1088. size_t src_ndim, const TensorLayout& filter, const Param& param);
  1089. CanonizedFilterMeta make_canonized_filter_meta(
  1090. size_t src_ndim, const TensorLayout& filter) const;
  1091. };
  1092. class Convolution3DForward : public Convolution3DBase,
  1093. public detail::MultiAlgoOpr<Convolution3DForward, 3> {
  1094. DEF_OPR_IMPL(Convolution3DForward, Convolution3DBase, 2, 1);
  1095. public:
  1096. /**
  1097. * \param[in] src (n, ic, id, ih, iw)
  1098. * \param[in] filter (oc, ic, fd, fh, fw)
  1099. * \param[out] dst (n, oc, od, oh, ow)
  1100. */
  1101. virtual void exec(
  1102. _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_out dst,
  1103. _megdnn_workspace workspace) = 0;
  1104. MGE_WIN_DECLSPEC_FUC void deduce_layout(
  1105. const TensorLayout& src, const TensorLayout& filter, TensorLayout& dst);
  1106. virtual size_t get_workspace_in_bytes(
  1107. const TensorLayout& src, const TensorLayout& filter,
  1108. const TensorLayout& dst) = 0;
  1109. static Algorithm::OprType get_opr_type() {
  1110. return Algorithm::OprType::CONVOLUTION3D_FORWARD;
  1111. }
  1112. protected:
  1113. CanonizedFilterMeta check_exec(
  1114. const TensorLayout& src, const TensorLayout& filter,
  1115. const TensorLayout& dst, size_t workspace_in_bytes);
  1116. };
  1117. using Convolution3D = Convolution3DForward;
  1118. class Convolution3DBackwardData
  1119. : public Convolution3DBase,
  1120. public detail::MultiAlgoOpr<Convolution3DBackwardData, 3> {
  1121. DEF_OPR_IMPL(Convolution3DBackwardData, Convolution3DBase, 2, 1);
  1122. public:
  1123. /**
  1124. * \param[in] filter (oc, ic, fd, fh, fw)
  1125. * \param[in] diff (n, oc, od, oh, ow)
  1126. * \param[out] grad (n, ic, id, ih, iw)
  1127. */
  1128. static void deduce_layout_impl(
  1129. const TensorLayout& filter, const TensorLayout& diff, const Param& param,
  1130. TensorLayout& grad);
  1131. virtual void exec(
  1132. _megdnn_tensor_in filter, _megdnn_tensor_in diff, _megdnn_tensor_out grad,
  1133. _megdnn_workspace workspace) = 0;
  1134. virtual size_t get_workspace_in_bytes(
  1135. const TensorLayout& filter, const TensorLayout& diff,
  1136. const TensorLayout& grad) = 0;
  1137. MGE_WIN_DECLSPEC_FUC void deduce_layout(
  1138. const TensorLayout& filter, const TensorLayout& diff, TensorLayout& grad);
  1139. static Algorithm::OprType get_opr_type() {
  1140. return Algorithm::OprType::CONVOLUTION3D_BACKWARD_DATA;
  1141. }
  1142. protected:
  1143. CanonizedFilterMeta check_exec(
  1144. const TensorLayout& filter, const TensorLayout& diff,
  1145. const TensorLayout& grad, size_t workspace_in_bytes);
  1146. };
  1147. class Convolution3DBackwardFilter
  1148. : public Convolution3DBase,
  1149. public detail::MultiAlgoOpr<Convolution3DBackwardFilter, 3> {
  1150. DEF_OPR_IMPL(Convolution3DBackwardFilter, Convolution3DBase, 2, 1);
  1151. public:
  1152. /**
  1153. * \param[in] src (n, ic, id, ih, iw)
  1154. * \param[in] diff (n, oc, od, oh, ow)
  1155. * \param[out] grad (oc, ic, fd, fh, fw)
  1156. */
  1157. virtual void exec(
  1158. _megdnn_tensor_in src, _megdnn_tensor_in diff, _megdnn_tensor_out grad,
  1159. _megdnn_workspace workspace) = 0;
  1160. virtual size_t get_workspace_in_bytes(
  1161. const TensorLayout& src, const TensorLayout& diff,
  1162. const TensorLayout& grad) = 0;
  1163. static Algorithm::OprType get_opr_type() {
  1164. return Algorithm::OprType::CONVOLUTION3D_BACKWARD_FILTER;
  1165. }
  1166. protected:
  1167. CanonizedFilterMeta check_exec(
  1168. const TensorLayout& src, const TensorLayout& diff, const TensorLayout& grad,
  1169. size_t workspace_in_bytes);
  1170. };
  1171. class LocalShareBase : public OperatorBase {
  1172. DEF_OPR_IMPL_CTOR(LocalShareBase, OperatorBase);
  1173. DEF_OPR_PARAM(LocalShare);
  1174. protected:
  1175. void deduce_layout_fwd(
  1176. const TensorLayout& src, const TensorLayout& filter, TensorLayout& dst);
  1177. void check_layout_fwd(
  1178. const TensorLayout& src, const TensorLayout& filter,
  1179. const TensorLayout& dst);
  1180. };
  1181. class LocalShareForward : public LocalShareBase,
  1182. public detail::MultiAlgoOpr<LocalShareForward, 3> {
  1183. DEF_OPR_IMPL(LocalShareForward, LocalShareBase, 2, 1);
  1184. public:
  1185. /**
  1186. * \param[in] src (N, IC, IH, IW)
  1187. * \param[in] filter (G, spatial_groups_h, spatial_groups_w, IC / G,
  1188. * FH, FW, OC / G)
  1189. * \param[out] dst (N, OC, OH, OW)
  1190. */
  1191. virtual void exec(
  1192. _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_out dst,
  1193. _megdnn_workspace workspace) = 0;
  1194. /**
  1195. * \brief deduce layout of the ouput tensor
  1196. */
  1197. void deduce_layout(
  1198. const TensorLayout& src, const TensorLayout& filter, TensorLayout& dst);
  1199. virtual size_t get_workspace_in_bytes(
  1200. const TensorLayout& src, const TensorLayout& filter,
  1201. const TensorLayout& dst) = 0;
  1202. static Algorithm::OprType get_opr_type() {
  1203. return Algorithm::OprType::LOCAL_SHARE_FORWARD;
  1204. }
  1205. protected:
  1206. void check_exec(
  1207. const TensorLayout& src, const TensorLayout& filter,
  1208. const TensorLayout& dst, size_t workspace_in_bytes);
  1209. };
  1210. using LocalShare = LocalShareForward;
  1211. class LocalShareBackwardData : public LocalShareBase,
  1212. public detail::MultiAlgoOpr<LocalShareBackwardData, 3> {
  1213. DEF_OPR_IMPL(LocalShareBackwardData, LocalShareBase, 2, 1);
  1214. public:
  1215. /**
  1216. * \param[in] filter (G, spatial_groups_h, spatial_groups_w, IC / G,
  1217. * FH, FW, OC / G)
  1218. * \param[in] diff (N, OC, OH, OW)
  1219. * \param[out] grad (N, IC, IH, IW)
  1220. */
  1221. virtual void exec(
  1222. _megdnn_tensor_in filter, _megdnn_tensor_in diff, _megdnn_tensor_out grad,
  1223. _megdnn_workspace workspace) = 0;
  1224. virtual size_t get_workspace_in_bytes(
  1225. const TensorLayout& filter, const TensorLayout& diff,
  1226. const TensorLayout& grad) = 0;
  1227. void deduce_layout(
  1228. const TensorLayout& filter, const TensorLayout& diff, TensorLayout& grad);
  1229. static Algorithm::OprType get_opr_type() {
  1230. return Algorithm::OprType::LOCAL_SHARE_BACKWARD_DATA;
  1231. }
  1232. protected:
  1233. void check_exec(
  1234. const TensorLayout& filter, const TensorLayout& diff,
  1235. const TensorLayout& grad, size_t workspace_in_bytes);
  1236. };
  1237. class LocalShareBackwardFilter
  1238. : public LocalShareBase,
  1239. public detail::MultiAlgoOpr<LocalShareBackwardFilter, 3> {
  1240. DEF_OPR_IMPL(LocalShareBackwardFilter, LocalShareBase, 2, 1);
  1241. public:
  1242. /**
  1243. * \param[in] src (N, IC, IH, IW)
  1244. * \param[in] diff (N, OC, OH, OW)
  1245. * \param[out] grad (G, spatial_groups_h, spatial_groups_w, IC / G,
  1246. * FH, FW, OC / G)
  1247. */
  1248. virtual void exec(
  1249. _megdnn_tensor_in src, _megdnn_tensor_in diff, _megdnn_tensor_out grad,
  1250. _megdnn_workspace workspace) = 0;
  1251. virtual size_t get_workspace_in_bytes(
  1252. const TensorLayout& src, const TensorLayout& diff,
  1253. const TensorLayout& grad) = 0;
  1254. static Algorithm::OprType get_opr_type() {
  1255. return Algorithm::OprType::LOCAL_SHARE_BACKWARD_FILTER;
  1256. }
  1257. protected:
  1258. void check_exec(
  1259. const TensorLayout& src, const TensorLayout& diff, const TensorLayout& grad,
  1260. size_t workspace_in_bytes);
  1261. };
  1262. class ROIAlignBase : public OperatorBase {
  1263. DEF_OPR_IMPL_CTOR(ROIAlignBase, OperatorBase);
  1264. DEF_OPR_PARAM(ROIAlign);
  1265. protected:
  1266. void deduce_layout_fwd(
  1267. const TensorLayout& src, const TensorLayout& rois, TensorLayout& dst,
  1268. TensorLayout& index);
  1269. void check_layout_fwd(
  1270. const TensorLayout& src, const TensorLayout& rois, const TensorLayout& dst,
  1271. const TensorLayout& index);
  1272. };
  1273. class ROIAlignForward : public ROIAlignBase {
  1274. DEF_OPR_IMPL(ROIAlignForward, ROIAlignBase, 2, 2);
  1275. public:
  1276. /**
  1277. * \param[in] src (n, c, ih, iw)
  1278. * \param[in] rois (m, 5)
  1279. * \param[out] dst (m, c, oh, ow)
  1280. * \param[out] index (m, c, oh, ow) if mode is MAX, (0) if mode is AVERAGE
  1281. *
  1282. * Note that rois(, 0) denotes the input image index. We store it as
  1283. * a float, but it should be an integer instead.
  1284. *
  1285. * index is a temporary tensor to facilitate its backward operator.
  1286. * It is used to store argmax indicex in MAX mode, and it is not used
  1287. * in AVERAGE mode.
  1288. */
  1289. virtual void exec(
  1290. _megdnn_tensor_in src, _megdnn_tensor_in rois, _megdnn_tensor_out dst,
  1291. _megdnn_tensor_out index, _megdnn_workspace workspace) = 0;
  1292. MGE_WIN_DECLSPEC_FUC void deduce_layout(
  1293. const TensorLayout& src, const TensorLayout& rois, TensorLayout& dst,
  1294. TensorLayout& index);
  1295. virtual size_t get_workspace_in_bytes(
  1296. const TensorLayout& src, const TensorLayout& rois, const TensorLayout& dst,
  1297. const TensorLayout& index) = 0;
  1298. protected:
  1299. void check_exec(
  1300. const TensorLayout& src, const TensorLayout& rois, const TensorLayout& dst,
  1301. const TensorLayout& index, size_t workspace_in_bytes);
  1302. };
  1303. using ROIAlign = ROIAlignForward;
  1304. class ROIAlignBackward : public ROIAlignBase {
  1305. DEF_OPR_IMPL(ROIAlignBackward, ROIAlignBase, 3, 1);
  1306. public:
  1307. /**
  1308. * \param[in] diff the backpropagated gradient wrt. dst
  1309. * \param[in] rois the `rois' parameter in ROIAlignForward::exec
  1310. * \param[in] index the `index' parameter in ROIAlignForward::exec
  1311. * \param[out] grad the backpropagated gradient wrt. src
  1312. */
  1313. virtual void exec(
  1314. _megdnn_tensor_in diff, _megdnn_tensor_in rois, _megdnn_tensor_in index,
  1315. _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
  1316. virtual size_t get_workspace_in_bytes(
  1317. const TensorLayout& diff, const TensorLayout& rois,
  1318. const TensorLayout& index, const TensorLayout& grad) = 0;
  1319. protected:
  1320. void check_exec(
  1321. const TensorLayout& diff, const TensorLayout& rois,
  1322. const TensorLayout& index, const TensorLayout& grad,
  1323. size_t workspace_in_bytes);
  1324. };
  1325. class DeformableConvBase : public OperatorBase {
  1326. DEF_OPR_IMPL_CTOR(DeformableConvBase, OperatorBase);
  1327. DEF_OPR_PARAM(Convolution);
  1328. public:
  1329. static constexpr size_t MAX_SPATIAL_DIM = 2;
  1330. struct CanonizedFilterMeta : Convolution::CanonizedFilterMeta {
  1331. uint32_t deformable_group;
  1332. };
  1333. protected:
  1334. CanonizedFilterMeta make_canonized_filter_meta(
  1335. size_t src_ndim, const TensorLayout& filter,
  1336. const TensorLayout& offset) const;
  1337. void deduce_layout_fwd(
  1338. const TensorLayout& im, const TensorLayout& filter,
  1339. const TensorLayout& mask, const TensorLayout& offset, TensorLayout& dst);
  1340. void check_layout_fwd(
  1341. const TensorLayout& src, const TensorLayout& filter,
  1342. const TensorLayout& mask, const TensorLayout& offset,
  1343. const TensorLayout& dst);
  1344. };
  1345. class DeformableConvForward : public DeformableConvBase,
  1346. public detail::MultiAlgoOpr<DeformableConvForward, 5> {
  1347. DEF_OPR_IMPL(DeformableConvForward, DeformableConvBase, 4, 1);
  1348. public:
  1349. /**
  1350. * \param[in] im (n, ic, ih, iw)
  1351. * \param[in] filter (oc, ic, fh, fw)
  1352. * \param[in] offset (dg, 2, fh, fw, oh, ow)
  1353. * \param[in] mask (dg, fh, fw, oh, ow)
  1354. * \param[out] dst (n, oc, oh, ow)
  1355. */
  1356. virtual void exec(
  1357. _megdnn_tensor_in im, _megdnn_tensor_in filter, _megdnn_tensor_in offset,
  1358. _megdnn_tensor_in mask, _megdnn_tensor_out dst,
  1359. _megdnn_workspace workspace) = 0;
  1360. void deduce_layout(
  1361. const TensorLayout& im, const TensorLayout& filter,
  1362. const TensorLayout& offset, const TensorLayout& mask, TensorLayout& dst);
  1363. virtual size_t get_workspace_in_bytes(
  1364. const TensorLayout& im, const TensorLayout& filter,
  1365. const TensorLayout& offset, const TensorLayout& mask,
  1366. const TensorLayout& dst) = 0;
  1367. static Algorithm::OprType get_opr_type() {
  1368. return Algorithm::OprType::DEFORMABLE_CONV_FORWARD;
  1369. }
  1370. protected:
  1371. CanonizedFilterMeta check_exec(
  1372. const TensorLayout& im, const TensorLayout& filter,
  1373. const TensorLayout& offset, const TensorLayout& mask,
  1374. const TensorLayout& dst, size_t workspace_in_bytes);
  1375. };
  1376. using DeformableConv = DeformableConvForward;
  1377. /**
  1378. * \brief DeformableConvBackwardFilter operator.
  1379. *
  1380. * Calculating the gradient wrt. convolution filter.
  1381. */
  1382. class DeformableConvBackwardFilter
  1383. : public DeformableConvBase,
  1384. public detail::MultiAlgoOpr<DeformableConvBackwardFilter, 5> {
  1385. DEF_OPR_IMPL(DeformableConvBackwardFilter, DeformableConvBase, 4, 1);
  1386. public:
  1387. /**
  1388. * \param[in] im (oc, ic, fh, fw)
  1389. * \param[in] offset (dg, 2, fh, fw, oh, ow)
  1390. * \param[in] mask (dg, fh, fw, oh, ow)
  1391. * \param[in] out_grad (n, oc, oh, ow)
  1392. * \param[out] filter_grad (oc, ic, ih, iw)
  1393. */
  1394. virtual void exec(
  1395. _megdnn_tensor_in im, _megdnn_tensor_in offset, _megdnn_tensor_in mask,
  1396. _megdnn_tensor_in out_grad, _megdnn_tensor_out filter_grad,
  1397. _megdnn_workspace workspace) = 0;
  1398. virtual size_t get_workspace_in_bytes(
  1399. const TensorLayout& im, const TensorLayout& offset,
  1400. const TensorLayout& mask, const TensorLayout& out_grad,
  1401. const TensorLayout& filter_grad) = 0;
  1402. void deduce_layout(
  1403. const TensorLayout& im, const TensorLayout& offset,
  1404. const TensorLayout& mask, const TensorLayout& out_grad,
  1405. TensorLayout& filter_grad);
  1406. static Algorithm::OprType get_opr_type() {
  1407. return Algorithm::OprType::DEFORMABLE_CONV_BACKWARD_FILTER;
  1408. }
  1409. protected:
  1410. CanonizedFilterMeta check_exec(
  1411. const TensorLayout& im, const TensorLayout& offset,
  1412. const TensorLayout& mask, const TensorLayout& out_grad,
  1413. const TensorLayout& filter_grad, size_t workspace_in_bytes);
  1414. };
  1415. /**
  1416. * \brief DeformableConvBackwardData operator.
  1417. *
  1418. * Calculating the gradient wrt. convolution input data, offset and mask.
  1419. */
  1420. class DeformableConvBackwardData
  1421. : public DeformableConvBase,
  1422. public detail::MultiAlgoOpr<DeformableConvBackwardData, 8> {
  1423. DEF_OPR_IMPL(DeformableConvBackwardData, DeformableConvBase, 5, 3);
  1424. public:
  1425. /**
  1426. * \param[in] im (oc, ic, fh, fw)
  1427. * \param[in] filter (oc, ic, fh, fw)
  1428. * \param[in] offset (dg, 2, fh, fw, oh, ow)
  1429. * \param[in] mask (dg, fh, fw, oh, ow)
  1430. * \param[in] out_grad (n, oc, oh, ow)
  1431. * \param[out] im_grad (n, ic, ih, iw)
  1432. * \param[out] offset_grad (dg, 2, fh, fw, oh, ow)
  1433. * \param[out] mask_grad (dg, fh, fw, oh, ow)
  1434. */
  1435. virtual void exec(
  1436. _megdnn_tensor_in im, _megdnn_tensor_in filter, _megdnn_tensor_in offset,
  1437. _megdnn_tensor_in mask, _megdnn_tensor_in out_grad,
  1438. _megdnn_tensor_out im_grad, _megdnn_tensor_out offset_grad,
  1439. _megdnn_tensor_out mask_grad, _megdnn_workspace workspace) = 0;
  1440. virtual size_t get_workspace_in_bytes(
  1441. const TensorLayout& im, const TensorLayout& filter,
  1442. const TensorLayout& offset, const TensorLayout& mask,
  1443. const TensorLayout& out_grad, const TensorLayout& im_grad,
  1444. const TensorLayout& offset_grad, const TensorLayout& mask_grad) = 0;
  1445. void deduce_layout(
  1446. const TensorLayout& im, const TensorLayout& filter,
  1447. const TensorLayout& offset, const TensorLayout& mask,
  1448. const TensorLayout& out_grad, TensorLayout& im_grad,
  1449. TensorLayout& offset_grad, TensorLayout& mask_grad);
  1450. static Algorithm::OprType get_opr_type() {
  1451. return Algorithm::OprType::DEFORMABLE_CONV_BACKWARD_DATA;
  1452. }
  1453. protected:
  1454. CanonizedFilterMeta check_exec(
  1455. const TensorLayout& im, const TensorLayout& filter,
  1456. const TensorLayout& offset, const TensorLayout& mask,
  1457. const TensorLayout& out_grad, const TensorLayout& im_grad,
  1458. const TensorLayout& offset_grad, const TensorLayout& mask_grad,
  1459. size_t workspace_in_bytes);
  1460. };
  1461. class DeformablePSROIPoolingBase : public OperatorBase {
  1462. DEF_OPR_IMPL_CTOR(DeformablePSROIPoolingBase, OperatorBase);
  1463. DEF_OPR_PARAM(DeformablePSROIPooling);
  1464. protected:
  1465. void deduce_layout_fwd(
  1466. const TensorLayout& data, const TensorLayout& trans,
  1467. const TensorLayout& rois, TensorLayout& out_data, TensorLayout& out_count);
  1468. void check_layout_fwd(
  1469. const TensorLayout& data, const TensorLayout& trans,
  1470. const TensorLayout& rois, const TensorLayout& out_data,
  1471. const TensorLayout& out_count, size_t workspace_in_bytes);
  1472. };
  1473. class DeformablePSROIPoolingForward : public DeformablePSROIPoolingBase {
  1474. DEF_OPR_IMPL(DeformablePSROIPoolingForward, DeformablePSROIPoolingBase, 3, 2);
  1475. public:
  1476. /**
  1477. * \param[in] data (oc, ic, ih, iw)
  1478. * \param[in] rois (xx, xx, xx, xx)
  1479. * \param[in] trans (oc, ic, fh, fw)
  1480. * \param[out] out_data ( n, ic, ih, iw)
  1481. * \param[out] out_count ( n, ic, ih, iw)
  1482. */
  1483. virtual size_t get_workspace_in_bytes(
  1484. const TensorLayout& data, const TensorLayout& rois,
  1485. const TensorLayout& trans, const TensorLayout& out_data,
  1486. const TensorLayout& out_count) = 0;
  1487. virtual void exec(
  1488. _megdnn_tensor_in data, _megdnn_tensor_in rois, _megdnn_tensor_in trans,
  1489. _megdnn_tensor_out out_data, _megdnn_tensor_out out_count,
  1490. _megdnn_workspace workspace) = 0;
  1491. void deduce_layout(
  1492. const TensorLayout& data, const TensorLayout& rois,
  1493. const TensorLayout& trans, TensorLayout& out_data, TensorLayout& out_count);
  1494. void check_exec(
  1495. const TensorLayout& data, const TensorLayout& rois,
  1496. const TensorLayout& trans, const TensorLayout& out_data,
  1497. const TensorLayout& out_count, size_t workspace_in_bytes);
  1498. };
  1499. using DeformablePSROIPooling = DeformablePSROIPoolingForward;
  1500. class DeformablePSROIPoolingBackward : public DeformablePSROIPoolingBase {
  1501. DEF_OPR_IMPL(DeformablePSROIPoolingBackward, DeformablePSROIPoolingBase, 5, 2);
  1502. public:
  1503. /**
  1504. * \param[in] data (oc, ic, ih, iw)
  1505. * \param[in] rois (xx, xx, xx, xx)
  1506. * \param[in] trans (oc, ic, fh, fw)
  1507. * \param[in] out_diff (xx, xx, xx, xx)
  1508. * \param[in] out_count (xx, xx, xx, xx)
  1509. * \param[out] data_diff ( n, ic, ih, iw)
  1510. * \param[out] trans_diff ( n, ic, ih, iw)
  1511. */
  1512. virtual void exec(
  1513. _megdnn_tensor_in data, _megdnn_tensor_in rois, _megdnn_tensor_in trans,
  1514. _megdnn_tensor_in out_diff, _megdnn_tensor_in out_count,
  1515. _megdnn_tensor_out data_diff, _megdnn_tensor_out trans_diff,
  1516. _megdnn_workspace workspace) = 0;
  1517. virtual size_t get_workspace_in_bytes(
  1518. const TensorLayout& data, const TensorLayout& rois,
  1519. const TensorLayout& trans, const TensorLayout& out_diff,
  1520. const TensorLayout& out_count, const TensorLayout& data_diff,
  1521. const TensorLayout& trans_diff) = 0;
  1522. void check_exec(
  1523. const TensorLayout& data, const TensorLayout& rois,
  1524. const TensorLayout& trans, const TensorLayout& out_diff,
  1525. const TensorLayout& out_count, const TensorLayout& data_diff,
  1526. const TensorLayout& trans_diff, size_t workspace_in_bytes);
  1527. };
  1528. class BatchConvBiasForward : public ConvolutionBase<param::BatchConvBias>,
  1529. public detail::MultiAlgoOpr<BatchConvBiasForward, 5> {
  1530. DEF_OPR_IMPL(BatchConvBiasForward, ConvolutionBase, 4, 1);
  1531. public:
  1532. virtual void exec(
  1533. _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_in bias,
  1534. _megdnn_tensor_in z, _megdnn_tensor_out dst,
  1535. _megdnn_workspace workspace) = 0;
  1536. void deduce_dtype(DType src, DType filter, DType bias, DType z, DType& dst);
  1537. void deduce_layout(
  1538. const TensorLayout& src, const TensorLayout& filter,
  1539. const TensorLayout& bias, const TensorLayout& z, TensorLayout& dst);
  1540. virtual size_t get_workspace_in_bytes(
  1541. const TensorLayout& src, const TensorLayout& filter,
  1542. const TensorLayout& bias, const TensorLayout& z,
  1543. const TensorLayout& dst) = 0;
  1544. static Algorithm::OprType get_opr_type() {
  1545. return Algorithm::OprType::BATCH_CONV_FORWARD;
  1546. }
  1547. protected:
  1548. CanonizedFilterMeta check_exec(
  1549. const TensorLayout& src, const TensorLayout& filter,
  1550. const TensorLayout& bias, const TensorLayout& z, const TensorLayout& dst,
  1551. size_t workspace_in_bytes);
  1552. };
  1553. using BatchConvBias = BatchConvBiasForward;
  1554. class FakeQuantBase : public OperatorBase {
  1555. DEF_OPR_IMPL_CTOR(FakeQuantBase, OperatorBase);
  1556. DEF_OPR_PARAM(FakeQuant);
  1557. protected:
  1558. void deduce_layout_fwd(const TensorLayout& input, TensorLayout& output);
  1559. void check_layout_fwd(
  1560. const TensorLayout& input, const TensorLayout& scale,
  1561. const TensorLayout& zero_point, const TensorLayout& output);
  1562. };
  1563. class FakeQuantForward : public FakeQuantBase {
  1564. DEF_OPR_IMPL(FakeQuantForward, FakeQuantBase, 3, 1);
  1565. public:
  1566. virtual void exec(
  1567. _megdnn_tensor_in input, _megdnn_tensor_in scale,
  1568. _megdnn_tensor_in zero_point, _megdnn_tensor_out output,
  1569. _megdnn_workspace workspace) = 0;
  1570. void deduce_layout(
  1571. const TensorLayout& input, const TensorLayout& scale,
  1572. const TensorLayout& zero_point, TensorLayout& output);
  1573. virtual size_t get_workspace_in_bytes(
  1574. const TensorLayout& input, const TensorLayout& scale,
  1575. const TensorLayout& zero_point, const TensorLayout& output) = 0;
  1576. protected:
  1577. void check_exec(
  1578. const TensorLayout& input, const TensorLayout& scale,
  1579. const TensorLayout& zero_point, const TensorLayout& output,
  1580. size_t workspace_in_bytes);
  1581. };
  1582. using FakeQuant = FakeQuantForward;
  1583. class FakeQuantBackward : public FakeQuantBase {
  1584. DEF_OPR_IMPL(FakeQuantBackward, FakeQuantBase, 4, 1);
  1585. public:
  1586. virtual void exec(
  1587. _megdnn_tensor_in diff, _megdnn_tensor_in input, _megdnn_tensor_in scale,
  1588. _megdnn_tensor_in zero_point, _megdnn_tensor_out grad,
  1589. _megdnn_workspace workspace) = 0;
  1590. virtual size_t get_workspace_in_bytes(
  1591. const TensorLayout& diff, const TensorLayout& input,
  1592. const TensorLayout& scale, const TensorLayout& zero_point,
  1593. const TensorLayout& grad) = 0;
  1594. protected:
  1595. void check_exec(
  1596. const TensorLayout& diff, const TensorLayout& input,
  1597. const TensorLayout& scale, const TensorLayout& zero_point,
  1598. const TensorLayout& grad, size_t workspace_in_bytes);
  1599. };
  1600. class TQTBase : public OperatorBase {
  1601. DEF_OPR_IMPL_CTOR(TQTBase, OperatorBase);
  1602. DEF_OPR_PARAM(TQT);
  1603. protected:
  1604. void deduce_layout_fwd(const TensorLayout& input, TensorLayout& output);
  1605. void check_layout_fwd(
  1606. const TensorLayout& input, const TensorLayout& scale,
  1607. const TensorLayout& output);
  1608. };
  1609. class TQTForward : public TQTBase {
  1610. DEF_OPR_IMPL(TQTForward, TQTBase, 2, 1);
  1611. public:
  1612. virtual void exec(
  1613. _megdnn_tensor_in input, _megdnn_tensor_in scale, _megdnn_tensor_out output,
  1614. _megdnn_workspace workspace) = 0;
  1615. void deduce_layout(
  1616. const TensorLayout& input, const TensorLayout& scale, TensorLayout& output);
  1617. virtual size_t get_workspace_in_bytes(
  1618. const TensorLayout& input, const TensorLayout& scale,
  1619. const TensorLayout& output) = 0;
  1620. protected:
  1621. void check_exec(
  1622. const TensorLayout& input, const TensorLayout& scale,
  1623. const TensorLayout& output, size_t workspace_in_bytes);
  1624. };
  1625. using TQT = TQTForward;
  1626. class TQTBackward : public TQTBase {
  1627. DEF_OPR_IMPL(TQTBackward, TQTBase, 3, 2);
  1628. public:
  1629. virtual void exec(
  1630. _megdnn_tensor_in diff, _megdnn_tensor_in input, _megdnn_tensor_in scale,
  1631. _megdnn_tensor_out grad_x, _megdnn_tensor_out grad_s,
  1632. _megdnn_workspace workspace) = 0;
  1633. virtual size_t get_workspace_in_bytes(
  1634. const TensorLayout& diff, const TensorLayout& input,
  1635. const TensorLayout& scale, const TensorLayout& grad_x,
  1636. const TensorLayout& grad_s) = 0;
  1637. protected:
  1638. void check_exec(
  1639. const TensorLayout& diff, const TensorLayout& input,
  1640. const TensorLayout& scale, const TensorLayout& grad_x,
  1641. const TensorLayout& grad_s, size_t workspace_in_bytes);
  1642. };
  1643. class LSQBase : public OperatorBase {
  1644. DEF_OPR_IMPL_CTOR(LSQBase, OperatorBase);
  1645. DEF_OPR_PARAM(LSQ);
  1646. protected:
  1647. void deduce_layout_fwd(const TensorLayout& input, TensorLayout& output);
  1648. void check_layout_fwd(
  1649. const TensorLayout& input, const TensorLayout& scale,
  1650. const TensorLayout& zero_point, const TensorLayout& grad_scale,
  1651. const TensorLayout& output);
  1652. };
  1653. class LSQForward : public LSQBase {
  1654. DEF_OPR_IMPL(LSQForward, LSQBase, 4, 1);
  1655. public:
  1656. virtual void exec(
  1657. _megdnn_tensor_in input, _megdnn_tensor_in scale,
  1658. _megdnn_tensor_in zero_point, _megdnn_tensor_in grad_scale,
  1659. _megdnn_tensor_out output, _megdnn_workspace workspace) = 0;
  1660. void deduce_layout(
  1661. const TensorLayout& input, const TensorLayout& scale,
  1662. const TensorLayout& zero_point, const TensorLayout& grad_scale,
  1663. TensorLayout& output);
  1664. virtual size_t get_workspace_in_bytes(
  1665. const TensorLayout& input, const TensorLayout& scale,
  1666. const TensorLayout& zero_point, const TensorLayout& grad_scale,
  1667. const TensorLayout& output) = 0;
  1668. protected:
  1669. void check_exec(
  1670. const TensorLayout& input, const TensorLayout& scale,
  1671. const TensorLayout& zero_point, const TensorLayout& grad_scale,
  1672. const TensorLayout& output, size_t workspace_in_bytes);
  1673. };
  1674. using LSQ = LSQForward;
  1675. class LSQBackward : public LSQBase {
  1676. DEF_OPR_IMPL(LSQBackward, LSQBase, 5, 2);
  1677. public:
  1678. virtual void exec(
  1679. _megdnn_tensor_in diff, _megdnn_tensor_in input, _megdnn_tensor_in scale,
  1680. _megdnn_tensor_in zero_point, _megdnn_tensor_in grad_scale,
  1681. _megdnn_tensor_out grad_x, _megdnn_tensor_out grad_s,
  1682. _megdnn_workspace workspace) = 0;
  1683. virtual size_t get_workspace_in_bytes(
  1684. const TensorLayout& diff, const TensorLayout& input,
  1685. const TensorLayout& scale, const TensorLayout& zero_point,
  1686. const TensorLayout& grad_scale, const TensorLayout& grad_x,
  1687. const TensorLayout& grad_s) = 0;
  1688. protected:
  1689. void check_exec(
  1690. const TensorLayout& diff, const TensorLayout& input,
  1691. const TensorLayout& scale, const TensorLayout& zero_point,
  1692. const TensorLayout& grad_scale, const TensorLayout& grad_x,
  1693. const TensorLayout& grad_s, size_t workspace_in_bytes);
  1694. };
  1695. class LayerNormBase : public OperatorBase {
  1696. DEF_OPR_IMPL_CTOR(LayerNormBase, OperatorBase);
  1697. DEF_OPR_PARAM(LayerNorm);
  1698. public:
  1699. MGE_WIN_DECLSPEC_FUC static void deduce_layout_fwd_impl(
  1700. const TensorLayout& data, const Param& p, TensorLayout& dst,
  1701. TensorLayout& mean, TensorLayout& rstd);
  1702. protected:
  1703. void deduce_layout_fwd(
  1704. const TensorLayout& data, const TensorLayout& weight,
  1705. const TensorLayout& bias, TensorLayout& dst, TensorLayout& mean,
  1706. TensorLayout& rstd);
  1707. void check_layout_fwd(
  1708. const TensorLayout& data, const TensorLayout& weight,
  1709. const TensorLayout& bias, const TensorLayout& dst, const TensorLayout& mean,
  1710. const TensorLayout& rstd);
  1711. };
  1712. class LayerNormForward : public LayerNormBase {
  1713. DEF_OPR_IMPL(LayerNormForward, LayerNormBase, 3, 3);
  1714. public:
  1715. virtual void exec(
  1716. _megdnn_tensor_in data, _megdnn_tensor_in weight, _megdnn_tensor_in bias,
  1717. _megdnn_tensor_out dst, _megdnn_tensor_out mean, _megdnn_tensor_out rstd,
  1718. _megdnn_workspace workspace) = 0;
  1719. MGE_WIN_DECLSPEC_FUC void deduce_layout(
  1720. const TensorLayout& data, const TensorLayout& weight,
  1721. const TensorLayout& bias, TensorLayout& dst, TensorLayout& mean,
  1722. TensorLayout& rstd);
  1723. virtual size_t get_workspace_in_bytes(
  1724. const TensorLayout& data, const TensorLayout& weight,
  1725. const TensorLayout& bias, const TensorLayout& dst, const TensorLayout& mean,
  1726. const TensorLayout& rstd) = 0;
  1727. protected:
  1728. void check_exec(
  1729. const TensorLayout& data, const TensorLayout& weight,
  1730. const TensorLayout& bias, const TensorLayout& dst, const TensorLayout& mean,
  1731. const TensorLayout& rstd, size_t workspace_in_bytes);
  1732. };
  1733. using LayerNorm = LayerNormForward;
  1734. class LayerNormBackward : public LayerNormBase {
  1735. DEF_OPR_IMPL(LayerNormBackward, LayerNormBase, 5, 3);
  1736. public:
  1737. virtual void exec(
  1738. _megdnn_tensor_in diff, _megdnn_tensor_in data, _megdnn_tensor_in weight,
  1739. _megdnn_tensor_in mean, _megdnn_tensor_in rstd, _megdnn_tensor_out ddata,
  1740. _megdnn_tensor_out dweight, _megdnn_tensor_out dbias,
  1741. _megdnn_workspace workspace) = 0;
  1742. void deduce_layout(
  1743. const TensorLayout& diff, const TensorLayout& data,
  1744. const TensorLayout& weight, const TensorLayout& mean,
  1745. const TensorLayout& rstd, TensorLayout& ddata, TensorLayout& dweight,
  1746. TensorLayout& dbias);
  1747. virtual size_t get_workspace_in_bytes(
  1748. const TensorLayout& diff, const TensorLayout& data,
  1749. const TensorLayout& weight, const TensorLayout& mean,
  1750. const TensorLayout& rstd, const TensorLayout& ddata,
  1751. const TensorLayout& dweight, const TensorLayout& dbias) = 0;
  1752. protected:
  1753. void check_exec(
  1754. const TensorLayout& diff, const TensorLayout& data,
  1755. const TensorLayout& weight, const TensorLayout& mean,
  1756. const TensorLayout& rstd, const TensorLayout& ddata,
  1757. const TensorLayout& dweight, const TensorLayout& dbias,
  1758. size_t workspace_in_bytes);
  1759. };
  1760. class DropoutBase : public OperatorBase {
  1761. DEF_OPR_IMPL_CTOR(DropoutBase, OperatorBase);
  1762. DEF_OPR_PARAM(Dropout);
  1763. };
  1764. class DropoutForward : public DropoutBase {
  1765. DEF_OPR_IMPL(DropoutForward, DropoutBase, 1, 2);
  1766. public:
  1767. void deduce_layout(const TensorLayout& inp, TensorLayout& oup, TensorLayout& mask);
  1768. virtual void exec(
  1769. _megdnn_tensor_in inp, _megdnn_tensor_out oup, _megdnn_tensor_out mask,
  1770. _megdnn_workspace workspace) = 0;
  1771. virtual size_t get_workspace_in_bytes(
  1772. const TensorLayout& inp, const TensorLayout& oup,
  1773. const TensorLayout& mask) = 0;
  1774. virtual size_t get_mask_size_in_bytes(const TensorLayout& inp) = 0;
  1775. protected:
  1776. void check_exec(
  1777. const TensorLayout& inp, const TensorLayout& oup, const TensorLayout& mask,
  1778. size_t workspace_in_bytes);
  1779. };
  1780. using Dropout = DropoutForward;
  1781. class DropoutBackward : public DropoutBase {
  1782. DEF_OPR_IMPL(DropoutBackward, DropoutBase, 2, 1);
  1783. public:
  1784. void deduce_layout(
  1785. const TensorLayout& doup, const TensorLayout& mask, TensorLayout& dinp);
  1786. virtual void exec(
  1787. _megdnn_tensor_in doup, _megdnn_tensor_in mask, _megdnn_tensor_out dinp,
  1788. _megdnn_workspace workspace) = 0;
  1789. virtual size_t get_workspace_in_bytes(
  1790. const TensorLayout& doup, const TensorLayout& mask,
  1791. const TensorLayout& dinp) = 0;
  1792. protected:
  1793. void check_exec(
  1794. const TensorLayout& doup, const TensorLayout& mask,
  1795. const TensorLayout& dinp, size_t workspace_in_bytes);
  1796. };
  1797. class SoftmaxBase : public OperatorBase {
  1798. DEF_OPR_IMPL_CTOR(SoftmaxBase, OperatorBase);
  1799. DEF_OPR_PARAM(Softmax);
  1800. protected:
  1801. void deduce_layout_fwd(const TensorLayout& input, TensorLayout& output);
  1802. void check_layout_fwd(const TensorLayout& input, const TensorLayout& output);
  1803. };
  1804. class SoftmaxForward : public SoftmaxBase {
  1805. DEF_OPR_IMPL(SoftmaxForward, SoftmaxBase, 1, 1);
  1806. public:
  1807. /**
  1808. * \param[in] input input tensor
  1809. * \param[out] output output tensor
  1810. */
  1811. virtual void exec(
  1812. _megdnn_tensor_in input, _megdnn_tensor_out output,
  1813. _megdnn_workspace workspace) = 0;
  1814. void deduce_layout(const TensorLayout& input, TensorLayout& output);
  1815. virtual size_t get_workspace_in_bytes(
  1816. const TensorLayout& input, const TensorLayout& output) = 0;
  1817. protected:
  1818. void check_exec(
  1819. const TensorLayout& input, const TensorLayout& output,
  1820. size_t workspace_in_bytes);
  1821. };
  1822. using Softmax = SoftmaxForward;
  1823. class SoftmaxBackward : public SoftmaxBase {
  1824. DEF_OPR_IMPL(SoftmaxBackward, SoftmaxBase, 2, 1);
  1825. public:
  1826. virtual void exec(
  1827. _megdnn_tensor_in input, _megdnn_tensor_in diff, _megdnn_tensor_out grad_x,
  1828. _megdnn_workspace workspace) = 0;
  1829. virtual size_t get_workspace_in_bytes(
  1830. const TensorLayout& input, const TensorLayout& diff,
  1831. const TensorLayout& grad_x) = 0;
  1832. protected:
  1833. void check_exec(
  1834. const TensorLayout& input, const TensorLayout& diff,
  1835. const TensorLayout& grad_x, size_t workspace_in_bytes);
  1836. };
  1837. class RNNCellForward : public OperatorBase {
  1838. DEF_OPR_PARAM(RNNCell);
  1839. DEF_OPR_IMPL(RNNCellForward, OperatorBase, 6, 1);
  1840. public:
  1841. virtual void exec(
  1842. _megdnn_tensor_in input, _megdnn_tensor_in weight_ih,
  1843. _megdnn_tensor_in bias_ih, _megdnn_tensor_in hx,
  1844. _megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_hh,
  1845. _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0;
  1846. void deduce_layout(
  1847. const TensorLayout& input, const TensorLayout& weight_ih,
  1848. const TensorLayout& bias_ih, const TensorLayout& hx,
  1849. const TensorLayout& weight_hh, const TensorLayout& bias_hh,
  1850. TensorLayout& dst);
  1851. virtual size_t get_workspace_in_bytes(
  1852. const TensorLayout& input, const TensorLayout& weight_ih,
  1853. const TensorLayout& bias_ih, const TensorLayout& hx,
  1854. const TensorLayout& weight_hh, const TensorLayout& bias_hh,
  1855. const TensorLayout& dst) = 0;
  1856. protected:
  1857. void check_exec(
  1858. const TensorLayout& input, const TensorLayout& weight_ih,
  1859. const TensorLayout& bias_ih, const TensorLayout& hx,
  1860. const TensorLayout& weight_hh, const TensorLayout& bias_hh,
  1861. const TensorLayout& dst, size_t workspace_in_bytes);
  1862. };
  1863. using RNNCell = RNNCellForward;
  1864. class LSTMCellForward : public OperatorBase {
  1865. // DEF_OPR_PARAM(LSTMCell);
  1866. DEF_OPR_PARAM(Empty);
  1867. DEF_OPR_IMPL(LSTMCellForward, OperatorBase, 7, 3);
  1868. public:
  1869. virtual void exec(
  1870. _megdnn_tensor_in input, _megdnn_tensor_in weight_ih,
  1871. _megdnn_tensor_in bias_ih, _megdnn_tensor_in hx,
  1872. _megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_hh,
  1873. _megdnn_tensor_in cx, _megdnn_tensor_out h_new, _megdnn_tensor_out c_new,
  1874. _megdnn_tensor_out gates, _megdnn_workspace workspace) = 0;
  1875. void deduce_layout(
  1876. const TensorLayout& input, const TensorLayout& weight_ih,
  1877. const TensorLayout& bias_ih, const TensorLayout& hx,
  1878. const TensorLayout& weight_hh, const TensorLayout& bias_hh,
  1879. const TensorLayout& cx, TensorLayout& h_new, TensorLayout& c_new,
  1880. TensorLayout& gates);
  1881. virtual size_t get_workspace_in_bytes(
  1882. const TensorLayout& input, const TensorLayout& weight_ih,
  1883. const TensorLayout& bias_ih, const TensorLayout& hx,
  1884. const TensorLayout& weight_hh, const TensorLayout& bias_hh,
  1885. const TensorLayout& cx, const TensorLayout& h_new,
  1886. const TensorLayout& c_new, const TensorLayout& gates) = 0;
  1887. protected:
  1888. void check_exec(
  1889. const TensorLayout& input, const TensorLayout& weight_ih,
  1890. const TensorLayout& bias_ih, const TensorLayout& hx,
  1891. const TensorLayout& weight_hh, const TensorLayout& bias_hh,
  1892. const TensorLayout& cx, const TensorLayout& h_new,
  1893. const TensorLayout& c_new, const TensorLayout& gates,
  1894. size_t workspace_in_bytes);
  1895. };
  1896. using LSTMCell = LSTMCellForward;
  1897. class RNNForward : public OperatorBase {
  1898. DEF_OPR_PARAM(RNN);
  1899. DEF_OPR_IMPL(RNNForward, OperatorBase, 3, 3);
  1900. public:
  1901. virtual void exec(
  1902. _megdnn_tensor_in input, _megdnn_tensor_in hx,
  1903. _megdnn_tensor_in flatten_weights, _megdnn_tensor_out output,
  1904. _megdnn_tensor_out hy, _megdnn_tensor_out reserve_space,
  1905. _megdnn_workspace workspace) = 0;
  1906. void deduce_layout(
  1907. const TensorLayout& input, const TensorLayout& hx,
  1908. const TensorLayout& flatten_weights, TensorLayout& output, TensorLayout& hy,
  1909. TensorLayout& reserve_space);
  1910. virtual size_t get_workspace_in_bytes(
  1911. const TensorLayout& input, const TensorLayout& hx,
  1912. const TensorLayout& flatten_weights, const TensorLayout& output,
  1913. const TensorLayout& hy, const TensorLayout& reserve_space) = 0;
  1914. virtual size_t get_reserve_size_in_bytes(const TensorLayout& input) = 0;
  1915. protected:
  1916. void check_exec(
  1917. const TensorLayout& input, const TensorLayout& hx,
  1918. const TensorLayout& flatten_weights, const TensorLayout& output,
  1919. const TensorLayout& hy, const TensorLayout& reserve_space,
  1920. size_t workspace_in_bytes);
  1921. };
  1922. using RNN = RNNForward;
  1923. class RNNBackward : public OperatorBase {
  1924. DEF_OPR_PARAM(RNN);
  1925. DEF_OPR_IMPL(RNNBackward, OperatorBase, 7, 3);
  1926. public:
  1927. virtual void exec(
  1928. _megdnn_tensor_in x, _megdnn_tensor_in y, _megdnn_tensor_in hx,
  1929. _megdnn_tensor_in dy, _megdnn_tensor_in dhy,
  1930. _megdnn_tensor_in flatten_weights, _megdnn_tensor_in reserve_space,
  1931. _megdnn_tensor_out dx, _megdnn_tensor_out dhx, _megdnn_tensor_out dw,
  1932. _megdnn_workspace workspace) = 0;
  1933. void deduce_layout(
  1934. const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx,
  1935. const TensorLayout& dy, const TensorLayout& dhy,
  1936. const TensorLayout& flatten_weights, const TensorLayout& reserve_space,
  1937. TensorLayout& dx, TensorLayout& dhx, TensorLayout& dw);
  1938. virtual size_t get_workspace_in_bytes(
  1939. const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx,
  1940. const TensorLayout& dy, const TensorLayout& dhy,
  1941. const TensorLayout& flatten_weights, const TensorLayout& reserve_space,
  1942. const TensorLayout& dx, const TensorLayout& dhx,
  1943. const TensorLayout& dw) = 0;
  1944. protected:
  1945. void check_exec(
  1946. const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx,
  1947. const TensorLayout& dy, const TensorLayout& dhy,
  1948. const TensorLayout& flatten_weights, const TensorLayout& reserve_space,
  1949. const TensorLayout& dx, const TensorLayout& dhx, const TensorLayout& dw,
  1950. size_t workspace_in_bytes);
  1951. };
  1952. class LSTMForward : public OperatorBase {
  1953. DEF_OPR_PARAM(LSTM);
  1954. DEF_OPR_IMPL(LSTMForward, OperatorBase, 4, 4);
  1955. public:
  1956. virtual void exec(
  1957. _megdnn_tensor_in input, _megdnn_tensor_in hx, _megdnn_tensor_in cx,
  1958. _megdnn_tensor_in flatten_weights, _megdnn_tensor_out output,
  1959. _megdnn_tensor_out hy, _megdnn_tensor_out cy,
  1960. _megdnn_tensor_out reserve_space, _megdnn_workspace workspace) = 0;
  1961. void deduce_layout(
  1962. const TensorLayout& input, const TensorLayout& hx, const TensorLayout& cx,
  1963. const TensorLayout& flatten_weights, TensorLayout& output, TensorLayout& hy,
  1964. TensorLayout& cy, TensorLayout& reserve_space);
  1965. virtual size_t get_workspace_in_bytes(
  1966. const TensorLayout& input, const TensorLayout& hx, const TensorLayout& cx,
  1967. const TensorLayout& flatten_weights, const TensorLayout& output,
  1968. const TensorLayout& hy, const TensorLayout& cy,
  1969. const TensorLayout& reserve_space) = 0;
  1970. virtual size_t get_reserve_size_in_bytes(const TensorLayout& input) = 0;
  1971. protected:
  1972. void check_exec(
  1973. const TensorLayout& input, const TensorLayout& hx, const TensorLayout& cx,
  1974. const TensorLayout& flatten_weights, const TensorLayout& output,
  1975. const TensorLayout& hy, const TensorLayout& cy,
  1976. const TensorLayout& reserve_space, size_t workspace_in_bytes);
  1977. };
  1978. using LSTM = LSTMForward;
  1979. class LSTMBackward : public OperatorBase {
  1980. DEF_OPR_PARAM(LSTM);
  1981. DEF_OPR_IMPL(LSTMBackward, OperatorBase, 9, 4);
  1982. public:
  1983. virtual void exec(
  1984. _megdnn_tensor_in x, _megdnn_tensor_in y, _megdnn_tensor_in hx,
  1985. _megdnn_tensor_in cx, _megdnn_tensor_in dy, _megdnn_tensor_in dhy,
  1986. _megdnn_tensor_in dcy, _megdnn_tensor_in flatten_weights,
  1987. _megdnn_tensor_in reserve_space, _megdnn_tensor_out dx,
  1988. _megdnn_tensor_out dhx, _megdnn_tensor_out dcx, _megdnn_tensor_out dw,
  1989. _megdnn_workspace workspace) = 0;
  1990. void deduce_layout(
  1991. const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx,
  1992. const TensorLayout& cx, const TensorLayout& dy, const TensorLayout& dhy,
  1993. const TensorLayout& dcy, const TensorLayout& flatten_weights,
  1994. const TensorLayout& reserve_space, TensorLayout& dx, TensorLayout& dhx,
  1995. TensorLayout& dcx, TensorLayout& dw);
  1996. virtual size_t get_workspace_in_bytes(
  1997. const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx,
  1998. const TensorLayout& cx, const TensorLayout& dy, const TensorLayout& dhy,
  1999. const TensorLayout& dcy, const TensorLayout& flatten_weights,
  2000. const TensorLayout& reserve_space, const TensorLayout& dx,
  2001. const TensorLayout& dhx, const TensorLayout& dcx,
  2002. const TensorLayout& dw) = 0;
  2003. protected:
  2004. void check_exec(
  2005. const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx,
  2006. const TensorLayout& cx, const TensorLayout& dy, const TensorLayout& dhy,
  2007. const TensorLayout& dcy, const TensorLayout& flatten_weights,
  2008. const TensorLayout& reserve_space, const TensorLayout& dx,
  2009. const TensorLayout& dhx, const TensorLayout& dcx, const TensorLayout& dw,
  2010. size_t workspace_in_bytes);
  2011. };
  2012. } // namespace megdnn
  2013. #include "megdnn/internal/opr_header_epilogue.h"
  2014. // vim: syntax=cpp.doxygen