You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

nn.h 80 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931
  1. /**
  2. * \file dnn/include/megdnn/oprs/nn.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #pragma once
  13. #include "megdnn/internal/opr_header_prologue.h"
  14. namespace megdnn {
  15. class SeparableConvBase : public OperatorBase {
  16. DEF_OPR_IMPL_CTOR(SeparableConvBase, OperatorBase);
  17. DEF_OPR_PARAM(SeparableConv);
  18. public:
  19. using Mode = Param::Mode;
  20. protected:
  21. void deduce_layout_fwd(const TensorLayout& src,
  22. const TensorLayout& filter_x,
  23. const TensorLayout& filter_y, TensorLayout& dst);
  24. void check_layout_fwd(const TensorLayout& src, const TensorLayout& filter_x,
  25. const TensorLayout& filter_y,
  26. const TensorLayout& dst);
  27. };
  28. class SeparableConvForward : public SeparableConvBase {
  29. DEF_OPR_IMPL(SeparableConvForward, SeparableConvBase, 3, 1);
  30. public:
  31. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter_x,
  32. _megdnn_tensor_in filter_y, _megdnn_tensor_out dst,
  33. _megdnn_workspace workspace) = 0;
  34. void deduce_layout(const TensorLayout& src, const TensorLayout& filter_x,
  35. const TensorLayout& filter_y, TensorLayout& dst);
  36. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  37. const TensorLayout& filter_x,
  38. const TensorLayout& filter_y,
  39. const TensorLayout& dst) = 0;
  40. protected:
  41. void check_exec(const TensorLayout& src, const TensorLayout& filter_x,
  42. const TensorLayout& filter_y, const TensorLayout& dst,
  43. size_t workspace_in_bytes);
  44. };
  45. using SeparableConv = SeparableConvForward;
  46. namespace detail {
  47. struct PreprocessedFilter {
  48. //! user data; its lifetime should be bound to MegDNN Convolution
  49. //! operator
  50. void* algorithm_id;
  51. TensorNDArray tensors;
  52. };
  53. } // namespace detail
  54. /**
  55. * \brief base class for convolution operation
  56. *
  57. * This operator is supposed to perform convolution on arbitrary input
  58. * dimensions. The input/output format is N, C, dims..., and kernel format can
  59. * take two forms:
  60. * 1. OC, IC, dims..., for conventional dense convolution
  61. * 2. GROUP, OC_PER_GRP, IC_PER_GRP, dims... for sparse group convolution
  62. *
  63. * Currently, only 2D images are supported.
  64. */
  65. template <typename Parameter>
  66. class ConvolutionBase : public OperatorBase {
  67. DEF_OPR_IMPL_CTOR(ConvolutionBase, OperatorBase);
  68. using Param = Parameter;
  69. public:
  70. Param& param() { return m_param; }
  71. const Param& param() const { return m_param; }
  72. protected:
  73. Param m_param;
  74. public:
  75. static constexpr size_t MAX_SPATIAL_DIM = 2;
  76. using Mode = typename Param::Mode;
  77. struct CanonizedFilterMeta {
  78. DType dtype;
  79. typename Param::Format format;
  80. uint32_t
  81. //! whether filter should be flipped (i.e. is CONVOLUTION)
  82. should_flip,
  83. group, //!< number of groups
  84. icpg, //!< input channels per group
  85. ocpg, //!< output channels per group
  86. spatial_ndim, stride[MAX_SPATIAL_DIM], padding[MAX_SPATIAL_DIM],
  87. //! spatial dim
  88. spatial[MAX_SPATIAL_DIM], dilation[MAX_SPATIAL_DIM],
  89. //! spatial dim with dilation applied
  90. dilated_spatial[MAX_SPATIAL_DIM];
  91. //! T should be a ConvolutionBase<Z>::CanonizedFilterMeta
  92. template <typename T>
  93. void copy_from(const T& b) {
  94. dtype = b.dtype;
  95. format = b.format;
  96. should_flip = b.should_flip;
  97. group = b.group;
  98. icpg = b.icpg;
  99. ocpg = b.ocpg;
  100. spatial_ndim = b.spatial_ndim;
  101. memcpy(stride, b.stride, sizeof(stride));
  102. memcpy(padding, b.padding, sizeof(padding));
  103. memcpy(spatial, b.spatial, sizeof(spatial));
  104. memcpy(dilation, b.dilation, sizeof(dilation));
  105. memcpy(dilated_spatial, b.dilated_spatial, sizeof(dilated_spatial));
  106. }
  107. bool operator==(const CanonizedFilterMeta& b) const {
  108. bool flag = true;
  109. flag = flag && (format == b.format);
  110. flag = flag && (dtype == b.dtype);
  111. flag = flag && (should_flip == b.should_flip);
  112. flag = flag && (group == b.group);
  113. flag = flag && (icpg == b.icpg);
  114. flag = flag && (ocpg == b.ocpg);
  115. flag = flag && (spatial_ndim == b.spatial_ndim);
  116. if (flag) {
  117. for (uint32_t i = 0; i < spatial_ndim; ++i) {
  118. flag = flag && (stride[i] == b.stride[i]);
  119. flag = flag && (padding[i] == b.padding[i]);
  120. flag = flag && (spatial[i] == b.spatial[i]);
  121. flag = flag && (dilation[i] == b.dilation[i]);
  122. flag = flag && (dilated_spatial[i] == b.dilated_spatial[i]);
  123. }
  124. }
  125. return flag;
  126. }
  127. };
  128. using PreprocessedFilter = detail::PreprocessedFilter;
  129. protected:
  130. // Check or deduce output DType
  131. void check_or_deduce_dtype_fwd(DType src, DType filter, DType& dst) const;
  132. CanonizedFilterMeta deduce_layout_fwd(const TensorLayout& src,
  133. const TensorLayout& filter,
  134. TensorLayout& dst) const;
  135. CanonizedFilterMeta check_layout_fwd(const TensorLayout& src,
  136. const TensorLayout& filter,
  137. const TensorLayout& dst) const;
  138. CanonizedFilterMeta make_canonized_filter_meta(
  139. size_t src_ndim, const TensorLayout& filter) const;
  140. };
  141. class MaskPropagate : public OperatorBase {
  142. DEF_OPR_IMPL(MaskPropagate, OperatorBase, 1, 1);
  143. DEF_OPR_PARAM(MaskPropagate);
  144. public:
  145. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
  146. _megdnn_workspace workspace) = 0;
  147. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  148. const TensorLayout& dst) = 0;
  149. void deduce_layout(const TensorLayout& src, TensorLayout& dst);
  150. };
  151. /**
  152. * \brief ConvolutionForward Operator with 0/1 Mask matrix
  153. */
  154. class MaskConvForward : public ConvolutionBase<param::Convolution> {
  155. DEF_OPR_IMPL(MaskConvForward, ConvolutionBase, 3, 1);
  156. public:
  157. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter,
  158. _megdnn_tensor_in mask, _megdnn_tensor_out dst,
  159. _megdnn_workspace worksapce) = 0;
  160. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  161. const TensorLayout& filter,
  162. const TensorLayout& mask,
  163. const TensorLayout& dst) = 0;
  164. void deduce_dtype(DType src, DType filter, DType mask, DType& dst);
  165. void deduce_layout(const TensorLayout& src, const TensorLayout& filter,
  166. const TensorLayout& mask, TensorLayout& dst);
  167. protected:
  168. CanonizedFilterMeta check_exec(const TensorLayout& src,
  169. const TensorLayout& filter,
  170. const TensorLayout& mask,
  171. const TensorLayout& dst,
  172. size_t workspace_in_bytes);
  173. };
  174. using MaskConvolution = MaskConvForward;
  175. /**
  176. * \brief ConvolutionForward operator.
  177. */
  178. class ConvolutionForward : public ConvolutionBase<param::Convolution>,
  179. public detail::MultiAlgoOpr<ConvolutionForward, 3> {
  180. DEF_OPR_IMPL(ConvolutionForward, ConvolutionBase, 2, 1);
  181. public:
  182. /**
  183. * \param[in] src (n, ic, ih, iw)
  184. * \param[in] filter (oc, ic, fh, fw)
  185. * \param[in] preprocessed_filter if weight no preprocessed it will be
  186. * nullptr, else the preprocessed weights store in the tensors of
  187. * preprocessed_filter.
  188. * \param[in] workspace if weight no preprocessed
  189. * (preprocessed_filter == nullptr), The size of the workspace satisfies the
  190. * situation that weights is not processed, other wise the size of workspace
  191. * satisfies the situation that weights is preprocessed
  192. * \param[out] dst (n, oc, oh, ow)
  193. */
  194. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter,
  195. _megdnn_tensor_out dst,
  196. const PreprocessedFilter* preprocessed_filter,
  197. _megdnn_workspace workspace) = 0;
  198. /**
  199. * \brief execute weight preprocessing, read weights form filter and write
  200. * to preprocessed_filter after preprocessed.
  201. *
  202. * \praram[in] workspace the needed tmp workspace when exec_preprocess
  203. */
  204. virtual void exec_preprocess(const TensorLayout& src_layout,
  205. _megdnn_tensor_in filter,
  206. const TensorLayout& dst_layout,
  207. PreprocessedFilter* preprocessed_filter,
  208. _megdnn_workspace workspace) = 0;
  209. void deduce_dtype(DType src, DType filter, DType& dst);
  210. void deduce_layout(const TensorLayout& src, const TensorLayout& filter,
  211. TensorLayout& dst);
  212. /**
  213. * \brief query the workspace needed when executing the opr, if the weights
  214. * are preprocessed the preprocessed_filter will not be nullptr, else it
  215. * will be nullptr, the workspace size maybe different whether weights are
  216. * preprocessed
  217. *
  218. * \return the size of workspace needed when executing
  219. */
  220. virtual size_t get_workspace_in_bytes(
  221. const TensorLayout& src, const TensorLayout& filter,
  222. const TensorLayout& dst,
  223. const PreprocessedFilter* preprocessed_filter) = 0;
  224. /**
  225. * \brief deduce the preprocessed filter layouts according to the src,
  226. * filter and dst layout, the result may contain multi layouts when the
  227. * weights is not one
  228. *
  229. * \return SmallVector<TensorLayout> Derive the layouts of weight
  230. * preprocessing, return empty if preprocessing is not needed.
  231. */
  232. virtual SmallVector<TensorLayout> deduce_preprocessed_filter_layout(
  233. const TensorLayout& src, const TensorLayout& filter,
  234. const TensorLayout& dst) = 0;
  235. /**
  236. * \brief query the workspace needed when preprocessing the weights,
  237. * according to the return size, a _megdnn_workspace will be created and
  238. * passed through exec_preprocess
  239. *
  240. * \return the size of workspace needed when preprocessing
  241. */
  242. virtual size_t get_preprocess_workspace_in_bytes(
  243. const TensorLayout& src, const TensorLayout& filter,
  244. const TensorLayout& dst) = 0;
  245. static Algorithm::OprType get_opr_type() {
  246. return Algorithm::OprType::CONVOLUTION_FORWARD;
  247. }
  248. protected:
  249. CanonizedFilterMeta check_exec(
  250. const TensorLayout& src, const TensorLayout& filter,
  251. const TensorLayout& dst, size_t workspace_in_bytes,
  252. const PreprocessedFilter* preprocessed_filter);
  253. };
  254. using Convolution = ConvolutionForward;
  255. /**
  256. * \brief ConvolutionBackwardData operator.
  257. *
  258. * Calculating the gradient wrt. convolution input data.
  259. */
  260. class ConvolutionBackwardData
  261. : public ConvolutionBase<param::Convolution>,
  262. public detail::MultiAlgoOpr<ConvolutionBackwardData, 3> {
  263. DEF_OPR_IMPL(ConvolutionBackwardData, ConvolutionBase, 2, 1);
  264. public:
  265. /**
  266. * \param[in] filter (oc, ic, fh, fw)
  267. * \param[in] diff (n, oc, oh, ow)
  268. * \param[out] grad (n, ic, ih, iw)
  269. */
  270. virtual void exec(_megdnn_tensor_in filter, _megdnn_tensor_in diff,
  271. _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
  272. virtual size_t get_workspace_in_bytes(const TensorLayout& filter,
  273. const TensorLayout& diff,
  274. const TensorLayout& grad) = 0;
  275. void deduce_dtype(DType filter, DType diff, DType& grad);
  276. void deduce_layout(const TensorLayout& filter, const TensorLayout& diff,
  277. TensorLayout& grad);
  278. static Algorithm::OprType get_opr_type() {
  279. return Algorithm::OprType::CONVOLUTION_BACKWARD_DATA;
  280. }
  281. protected:
  282. CanonizedFilterMeta check_exec(const TensorLayout& filter,
  283. const TensorLayout& diff,
  284. const TensorLayout& grad,
  285. size_t workspace_in_bytes);
  286. };
  287. /**
  288. * \brief ConvolutionBackwardFilter operator.
  289. *
  290. * Calculating the gradient wrt. convolution filter.
  291. */
  292. class ConvolutionBackwardFilter
  293. : public ConvolutionBase<param::Convolution>,
  294. public detail::MultiAlgoOpr<ConvolutionBackwardFilter, 3> {
  295. DEF_OPR_IMPL(ConvolutionBackwardFilter, ConvolutionBase, 2, 1);
  296. public:
  297. /**
  298. * \param[in] src (n, ic, ih, iw)
  299. * \param[in] diff (n, oc, oh, ow)
  300. * \param[out] grad (oc, ic, fh, fw)
  301. */
  302. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in diff,
  303. _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
  304. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  305. const TensorLayout& diff,
  306. const TensorLayout& grad) = 0;
  307. static Algorithm::OprType get_opr_type() {
  308. return Algorithm::OprType::CONVOLUTION_BACKWARD_FILTER;
  309. }
  310. protected:
  311. CanonizedFilterMeta check_exec(const TensorLayout& src,
  312. const TensorLayout& diff,
  313. const TensorLayout& grad,
  314. size_t workspace_in_bytes);
  315. };
  316. /**
  317. * \brief ConvolutionBias operator
  318. */
  319. class ConvBiasForward : public ConvolutionBase<param::ConvBias>,
  320. public detail::MultiAlgoOpr<ConvBiasForward, 5> {
  321. DEF_OPR_IMPL(ConvBiasForward, ConvolutionBase, 4, 1);
  322. public:
  323. /**
  324. * \param[in] src (n, ic, ih, iw) or (n, ih, iw, ic)
  325. * \param[in] filter (oc, ic, fh, fw) or (oc, fh, fw, ic) or (oc/4, fh, fw,
  326. * 4 * ic)
  327. * \param[in] bias (1, oc, 1, 1)
  328. * \param[in] z same as dst
  329. * \param[in] preprocessed_filter if weight no preprocessed it will be
  330. * nullptr, else the preprocessed weights store in the tensors of
  331. * preprocessed_filter.
  332. * \param[in] workspace if weight no preprocessed
  333. * (preprocessed_filter == nullptr), The size of the workspace satisfies the
  334. * situation that weights is not processed, other wise the size of workspace
  335. * satisfies the situation that weights is preprocessed
  336. * \param[out] dst (n, oc, oh, ow) or (n, oh, ow, oc)
  337. *
  338. * \note if the format is NCHW_WINOGRAD, the filter layout is (alphah,
  339. * alphaw, oc, ic)
  340. */
  341. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter,
  342. _megdnn_tensor_in bias, _megdnn_tensor_in z,
  343. _megdnn_tensor_out dst,
  344. const PreprocessedFilter* preprocessed_filter,
  345. _megdnn_workspace workspace) = 0;
  346. /**
  347. * \brief execute weight preprocessing, read weights form filter and bias,
  348. * write to preprocessed_filter after preprocessed.
  349. *
  350. * \praram[in] workspace the needed tmp workspace when exec_preprocess
  351. * running, the size is got by get_preprocess_workspace_in_bytes
  352. */
  353. virtual void exec_preprocess(const TensorLayout& src_layout,
  354. _megdnn_tensor_in filter,
  355. _megdnn_tensor_in bias,
  356. const TensorLayout& z_layout,
  357. const TensorLayout& dst_layout,
  358. PreprocessedFilter* preprocessed_filter,
  359. _megdnn_workspace workspace) = 0;
  360. void deduce_dtype(DType src, DType filter, DType bias, DType z, DType& dst);
  361. void deduce_layout(const TensorLayout& src, const TensorLayout& filter,
  362. const TensorLayout& bias, const TensorLayout& z,
  363. TensorLayout& dst);
  364. /**
  365. * \brief query the workspace needed when executing the opr, if the weights
  366. * are preprocessed the preprocessed_filter will not be nullptr, else it
  367. * will be nullptr, the workspace size maybe different whether weights are
  368. * preprocessed
  369. *
  370. * \return the size of workspace needed when executing
  371. */
  372. virtual size_t get_workspace_in_bytes(
  373. const TensorLayout& src, const TensorLayout& filter,
  374. const TensorLayout& bias, const TensorLayout& z,
  375. const TensorLayout& dst,
  376. const PreprocessedFilter* preprocessed_filter) = 0;
  377. /**
  378. * \brief query the workspace needed when pre-processing the weights,
  379. * according to the return size, a _megdnn_workspace will be created and
  380. * passed through exec_preprocess
  381. *
  382. * \return the size of workspace needed when pre-processing
  383. */
  384. virtual size_t get_preprocess_workspace_in_bytes(
  385. const TensorLayout& src, const TensorLayout& filter,
  386. const TensorLayout& bias, const TensorLayout& z,
  387. const TensorLayout& dst) = 0;
  388. /**
  389. * \brief deduce the pre-processed filter layouts according to the src,
  390. * filter and dst layout, which may contain multi layouts when the weights
  391. * is not one
  392. *
  393. * \return SmallVector<TensorLayout> Derive the layouts of weight
  394. * preprocessing, return empty if preprocessing is not needed.
  395. */
  396. virtual SmallVector<TensorLayout> deduce_preprocessed_filter_layout(
  397. const TensorLayout& src, const TensorLayout& filter,
  398. const TensorLayout& bias, const TensorLayout& z,
  399. const TensorLayout& dst) = 0;
  400. enum class BiasMode : uint32_t {
  401. NO_BIAS = 0, //!< no bias
  402. BROADCAST_CHANNEL_BIAS, //!< broadcast channel bias, [1, c, 1, 1]
  403. BIAS //!< [N, C, H, W]
  404. };
  405. //! param for winograd algos.
  406. struct WinogradParam {
  407. uint32_t channel_block_size;
  408. uint32_t output_block_size;
  409. uint32_t tile_size;
  410. bool operator==(const WinogradParam& rhs) const {
  411. return channel_block_size == rhs.channel_block_size &&
  412. output_block_size == rhs.output_block_size &&
  413. tile_size == rhs.tile_size;
  414. }
  415. std::string to_string() const;
  416. };
  417. static constexpr WinogradParam INVALID_WINOGRAD_PARAM = {0, 0, 0};
  418. struct DirectParam {
  419. std::string to_string() const { return ""; }
  420. };
  421. struct MatmulParam {
  422. std::string to_string() const { return ""; }
  423. };
  424. struct DefaultParam {
  425. std::string to_string() const { return ""; }
  426. };
  427. //! get algo name, the format is ParamTrait<T>::category:base:p.to_string()
  428. //! \warning: base must not contain :.
  429. template <typename T>
  430. static std::string algo_name(
  431. const std::string& base, const T& p,
  432. param::ConvBias::Format format = param::ConvBias::Format::NCHW);
  433. /*!
  434. * \brief parse algo_name and get WinogradParam from algo name.
  435. *
  436. * \param algo name string
  437. * \return WinogradParam parsed from algo name, use pattern
  438. * winograd:base:m:tile_size.
  439. *
  440. * \warning: INVALID_WINOGRAD_PARAM returns if the algo_name is not matched.
  441. */
  442. static WinogradParam parse_winograd_name(const std::string& algo_name);
  443. /**
  444. * @brief find if there is nchw_nchwxx conv kernel optimized for argment,
  445. * nchw44 used for arm, nchw88 used for x86
  446. *
  447. * @param src_dtype conv feature map data type
  448. * @param filter_dtype conv filter or weight data type
  449. * @param dst_dtype output data type
  450. * @param fm filter meta param
  451. * @param bias_mode bias mode, no_bias or broadcast or bias
  452. * @param nonline_mode identity or relu or h_swish or sigmoid
  453. * @return true, found a kernel
  454. * @return false, can`t found any kernel
  455. */
  456. static bool is_nchw_nchwxx_optimized(
  457. const DTypeEnum src_dtype, const DTypeEnum filter_dtype,
  458. const DTypeEnum dst_dtype,
  459. const ConvolutionBase<param::Convolution>::CanonizedFilterMeta& fm,
  460. const ConvBiasForward::BiasMode bias_mode,
  461. const param::ConvBias::NonlineMode nonline_mode);
  462. static Algorithm::OprType get_opr_type() {
  463. return Algorithm::OprType::CONVBIAS_FORWARD;
  464. }
  465. protected:
  466. CanonizedFilterMeta check_exec(
  467. const TensorLayout& src, const TensorLayout& filter,
  468. const TensorLayout& bias, const TensorLayout& z,
  469. const TensorLayout& dst, size_t workspace_in_bytes,
  470. const PreprocessedFilter* preprocessed_filter);
  471. CanonizedFilterMeta check_exec_allow_noncontiguous(
  472. const TensorLayout& src, const TensorLayout& filter,
  473. const TensorLayout& bias, const TensorLayout& z,
  474. const TensorLayout& dst, size_t workspace_in_bytes,
  475. const PreprocessedFilter* preprocessed_filter);
  476. };
  477. using ConvBias = ConvBiasForward;
  478. /**
  479. * \brief base class for Conv - Nonline - Pooling
  480. */
  481. class ConvPoolingBase : public OperatorBase {
  482. DEF_OPR_IMPL_CTOR(ConvPoolingBase, OperatorBase);
  483. /**
  484. * \ Param::Method: Two methods to fetch the input data.
  485. * Default methods is WITH_TEXTURE_OBJ.
  486. * If you want to use WITH_SHARED_MEM mode,
  487. * please make sure that the size of
  488. * [ all of the fliter kernels + a channel
  489. * of input data + a channel of output data]
  490. * should be no large than 38KB.
  491. * And the pooling mode should not be "MAX".
  492. */
  493. DEF_OPR_PARAM(ConvPooling);
  494. protected:
  495. virtual void deduce_layout(const TensorLayout& src,
  496. const TensorLayout& filter,
  497. const TensorLayout& bias, TensorLayout& dst) = 0;
  498. virtual void check_layout(const TensorLayout& src,
  499. const TensorLayout& filter,
  500. const TensorLayout& bias, TensorLayout& dst,
  501. size_t workspace_limit_in_bytes) = 0;
  502. };
  503. class ConvPoolingForward : public ConvPoolingBase {
  504. DEF_OPR_IMPL(ConvPoolingForward, ConvPoolingBase, 2, 1);
  505. public:
  506. /**
  507. * \param[in] src input tensor
  508. * \param[out] dst output tensor
  509. */
  510. virtual void exec(const _megdnn_in TensorND src,
  511. const _megdnn_in TensorND filter,
  512. const _megdnn_in TensorND bias, _megdnn_out TensorND dst,
  513. _megdnn_out Workspace workspace) = 0;
  514. virtual void deduce_layout(const TensorLayout& src,
  515. const TensorLayout& filter,
  516. const TensorLayout& bias, TensorLayout& dst) = 0;
  517. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  518. const TensorLayout& filter,
  519. const TensorLayout& bias,
  520. const TensorLayout& dst) = 0;
  521. protected:
  522. virtual void check_layout(const TensorLayout& src,
  523. const TensorLayout& filter,
  524. const TensorLayout& bias, TensorLayout& dst,
  525. size_t workspace_limit_in_bytes) = 0;
  526. };
  527. using ConvPooling = ConvPoolingForward;
  528. class GroupLocalBase : public OperatorBase {
  529. DEF_OPR_IMPL_CTOR(GroupLocalBase, OperatorBase);
  530. DEF_OPR_PARAM(Convolution);
  531. public:
  532. using Mode = Param::Mode;
  533. protected:
  534. void deduce_layout_fwd(const TensorLayout& src, const TensorLayout& filter,
  535. TensorLayout& dst);
  536. void check_layout_fwd(const TensorLayout& src, const TensorLayout& filter,
  537. const TensorLayout& dst);
  538. };
  539. class GroupLocalForward : public GroupLocalBase {
  540. DEF_OPR_IMPL(GroupLocalForward, GroupLocalBase, 2, 1);
  541. public:
  542. /**
  543. * \param[in] src (N, IC, IH, IW)
  544. * \param[in] filter (G, OH, OW, IC/G, FH, FW, OC/G)
  545. * \param[out] dst (N, OC, OH, OW)
  546. **/
  547. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter,
  548. _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0;
  549. void deduce_layout(const TensorLayout& src, const TensorLayout& filter,
  550. TensorLayout& dst) {
  551. deduce_layout_fwd(src, filter, dst);
  552. }
  553. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  554. const TensorLayout& filter,
  555. const TensorLayout& dst) = 0;
  556. protected:
  557. void check_exec(const TensorLayout& src, const TensorLayout& filter,
  558. const TensorLayout& dst, size_t workspace_in_bytes);
  559. };
  560. using GroupLocal = GroupLocalForward;
  561. class GroupLocalBackwardData : public GroupLocalBase {
  562. DEF_OPR_IMPL(GroupLocalBackwardData, GroupLocalBase, 2, 1);
  563. public:
  564. virtual void exec(_megdnn_tensor_in filter, _megdnn_tensor_in diff,
  565. _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
  566. virtual size_t get_workspace_in_bytes(const TensorLayout& filter,
  567. const TensorLayout& diff,
  568. const TensorLayout& grad) = 0;
  569. protected:
  570. void check_exec(const TensorLayout& filter, const TensorLayout& diff,
  571. const TensorLayout& grad, size_t workspace_in_bytes);
  572. };
  573. class GroupLocalBackwardFilter : public GroupLocalBase {
  574. DEF_OPR_IMPL(GroupLocalBackwardFilter, GroupLocalBase, 2, 1);
  575. public:
  576. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in diff,
  577. _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
  578. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  579. const TensorLayout& diff,
  580. const TensorLayout& grad) = 0;
  581. protected:
  582. void check_exec(const TensorLayout& filter, const TensorLayout& diff,
  583. const TensorLayout& grad, size_t workspace_in_bytes);
  584. };
  585. class Images2NeibsBase : public OperatorBase {
  586. DEF_OPR_IMPL_CTOR(Images2NeibsBase, OperatorBase);
  587. DEF_OPR_PARAM(Images2Neibs);
  588. protected:
  589. void deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst);
  590. void check_layout_fwd(const TensorLayout& filter, const TensorLayout& dst);
  591. };
  592. class Images2NeibsForward : public Images2NeibsBase {
  593. DEF_OPR_IMPL(Images2NeibsForward, Images2NeibsBase, 1, 1);
  594. public:
  595. /**
  596. * \param[in] src (N, C, IH, IW)
  597. * \param[out] dst (N, C, OH, OW, window_h, window_w)
  598. *
  599. * \see
  600. * http://deeplearning.net/software/theano/library/tensor/nnet/neighbours.html
  601. *
  602. * \f$ dst_{n, c, oh, ow, wh, ww} = src_{n, c, ih+wh, iw+fw}\f$,
  603. * where \f$ ih=-pad_h+oh*stride_h+(wh-1)*(dilation_h-1),
  604. * iw=-pad_w+ow*stride_w+(ww-1)*(dilation_w-1)\f$.
  605. */
  606. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
  607. _megdnn_workspace workspace) = 0;
  608. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  609. const TensorLayout& dst) = 0;
  610. void deduce_layout(const TensorLayout& src, TensorLayout& dst);
  611. protected:
  612. void check_exec(const TensorLayout& src, const TensorLayout& dst,
  613. size_t workspace_in_bytes);
  614. };
  615. using Images2Neibs = Images2NeibsForward;
  616. class Images2NeibsBackward : public Images2NeibsBase {
  617. DEF_OPR_IMPL(Images2NeibsBackward, Images2NeibsBase, 1, 1);
  618. public:
  619. /**
  620. * \param[in] diff the backpropagated gradient wrt. dst
  621. * \param[out] grad the backpropagated gradient wrt. src
  622. */
  623. virtual void exec(_megdnn_tensor_in diff, _megdnn_tensor_out grad,
  624. _megdnn_workspace workspace) = 0;
  625. virtual size_t get_workspace_in_bytes(const TensorLayout& diff,
  626. const TensorLayout& grad) = 0;
  627. protected:
  628. void check_exec(const TensorLayout& diff, const TensorLayout& grad,
  629. size_t workspace_in_bytes);
  630. };
  631. class SlidingWindowTransposeBase : public OperatorBase {
  632. DEF_OPR_IMPL_CTOR(SlidingWindowTransposeBase, OperatorBase);
  633. DEF_OPR_PARAM(SlidingWindowTranspose);
  634. protected:
  635. void deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst);
  636. void check_layout_fwd(const TensorLayout& filter, const TensorLayout& dst);
  637. };
  638. class SlidingWindowTransposeForward : public SlidingWindowTransposeBase {
  639. DEF_OPR_IMPL(SlidingWindowTransposeForward, SlidingWindowTransposeBase, 1,
  640. 1);
  641. public:
  642. /**
  643. * \param[in] src (N, C, IH, IW, window_h, window_w)
  644. * \param[out] dst (N, C, OH, OW)
  645. */
  646. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
  647. _megdnn_workspace workspace) = 0;
  648. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  649. const TensorLayout& dst) = 0;
  650. void deduce_layout(const TensorLayout& src, TensorLayout& dst);
  651. protected:
  652. void check_exec(const TensorLayout& src, const TensorLayout& dst,
  653. size_t workspace_in_bytes);
  654. };
  655. using SlidingWindowTranspose = SlidingWindowTransposeForward;
  656. class SlidingWindowTransposeBackward : public SlidingWindowTransposeBase {
  657. DEF_OPR_IMPL(SlidingWindowTransposeBackward, SlidingWindowTransposeBase, 1,
  658. 1);
  659. public:
  660. /**
  661. * \param[in] diff the backpropagated gradient wrt. dst
  662. * \param[out] grad the backpropagated gradient wrt. src
  663. */
  664. virtual void exec(_megdnn_tensor_in diff, _megdnn_tensor_out grad,
  665. _megdnn_workspace workspace) = 0;
  666. virtual size_t get_workspace_in_bytes(const TensorLayout& diff,
  667. const TensorLayout& grad) = 0;
  668. protected:
  669. void check_exec(const TensorLayout& diff, const TensorLayout& grad,
  670. size_t workspace_in_bytes);
  671. };
  672. /**
  673. * \brief base class for Pooling
  674. */
  675. class PoolingBase : public OperatorBase {
  676. DEF_OPR_IMPL_CTOR(PoolingBase, OperatorBase);
  677. DEF_OPR_PARAM(Pooling);
  678. public:
  679. using Mode = Param::Mode;
  680. protected:
  681. void deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst);
  682. void check_layout_fwd(const TensorLayout& src, const TensorLayout& dst);
  683. };
  684. class PoolingForward : public PoolingBase,
  685. public detail::MultiAlgoOpr<PoolingForward, 2> {
  686. DEF_OPR_IMPL(PoolingForward, PoolingBase, 1, 1);
  687. public:
  688. /**
  689. * \param[in] src input tensor
  690. * \param[out] dst output tensor
  691. */
  692. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
  693. _megdnn_workspace workspace) = 0;
  694. void deduce_layout(const TensorLayout& src, TensorLayout& dst);
  695. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  696. const TensorLayout& dst) = 0;
  697. static Algorithm::OprType get_opr_type() {
  698. return Algorithm::OprType::POOLING_FORWARD;
  699. }
  700. protected:
  701. void check_exec(const TensorLayout& src, const TensorLayout& dst,
  702. size_t workspace_in_bytes);
  703. };
  704. using Pooling = PoolingForward;
  705. class PoolingBackward : public PoolingBase,
  706. public detail::MultiAlgoOpr<PoolingBackward, 4> {
  707. DEF_OPR_IMPL(PoolingBackward, PoolingBase, 3, 1);
  708. public:
  709. /**
  710. * \param[in] src the `src' parameter in PoolingForward::exec
  711. * \param[in] dst the `dst' parameter in PoolingForward::exec
  712. * \param[in] diff the backpropagated gradient wrt. dst
  713. * \param[out] grad the backpropagated gradient wrt. src
  714. */
  715. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in dst,
  716. _megdnn_tensor_in diff, _megdnn_tensor_out grad,
  717. _megdnn_workspace workspace) = 0;
  718. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  719. const TensorLayout& dst,
  720. const TensorLayout& diff,
  721. const TensorLayout& grad) = 0;
  722. static Algorithm::OprType get_opr_type() {
  723. return Algorithm::OprType::POOLING_BACKWARD;
  724. }
  725. protected:
  726. void check_exec(const TensorLayout& src, const TensorLayout& dst,
  727. const TensorLayout& diff, const TensorLayout& grad,
  728. size_t workspace_in_bytes);
  729. };
  730. /**
  731. * \brief base class for AdaptivePooling
  732. */
  733. class AdaptivePoolingBase : public OperatorBase {
  734. DEF_OPR_IMPL_CTOR(AdaptivePoolingBase, OperatorBase);
  735. DEF_OPR_PARAM(AdaptivePooling);
  736. protected:
  737. param::Pooling deduce_pooling_param(const TensorLayout& src,
  738. const TensorLayout& dst);
  739. };
  740. class AdaptivePoolingForward : public AdaptivePoolingBase {
  741. DEF_OPR_IMPL(AdaptivePoolingForward, AdaptivePoolingBase, 1, 1);
  742. public:
  743. /**
  744. * \param[in] src input tensor
  745. * \param[out] dst output tensor
  746. */
  747. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
  748. _megdnn_workspace workspace) = 0;
  749. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  750. const TensorLayout& dst) = 0;
  751. };
  752. using AdaptivePooling = AdaptivePoolingForward;
  753. class AdaptivePoolingBackward : public AdaptivePoolingBase {
  754. DEF_OPR_IMPL(AdaptivePoolingBackward, AdaptivePoolingBase, 3, 1);
  755. public:
  756. /**
  757. * \param[in] src the `src' parameter in AdaptivePoolingForward::exec
  758. * \param[in] dst the `dst' parameter in AdaptivePoolingForward::exec
  759. * \param[in] diff the backpropagated gradient wrt. dst
  760. * \param[out] grad the backpropagated gradient wrt. src
  761. */
  762. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in dst,
  763. _megdnn_tensor_in diff, _megdnn_tensor_out grad,
  764. _megdnn_workspace workspace) = 0;
  765. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  766. const TensorLayout& dst,
  767. const TensorLayout& diff,
  768. const TensorLayout& grad) = 0;
  769. };
  770. /**
  771. * \brief base class for Local
  772. */
  773. class LocalBase : public OperatorBase {
  774. DEF_OPR_IMPL_CTOR(LocalBase, OperatorBase);
  775. DEF_OPR_PARAM(Convolution);
  776. public:
  777. using Mode = Param::Mode;
  778. protected:
  779. void deduce_layout_fwd(const TensorLayout& src, const TensorLayout& filter,
  780. TensorLayout& dst);
  781. void check_layout_fwd(const TensorLayout& src, const TensorLayout& filter,
  782. const TensorLayout& dst);
  783. };
  784. class LocalForward : public LocalBase {
  785. DEF_OPR_IMPL(LocalForward, LocalBase, 2, 1);
  786. public:
  787. /**
  788. * \param[in] src (n, ic, ih, iw)
  789. * \param[in] filter (oh, ow, ic, fh, fw, oc)
  790. * \param[out] dst (n, oc, oh, ow)
  791. */
  792. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter,
  793. _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0;
  794. /**
  795. * \brief Deducing output tensor layouts from input tensor layouts.
  796. *
  797. * Be aware that the first and second dimension of `filter' are ignored
  798. * when deducing `dst' layout.
  799. */
  800. void deduce_layout(const TensorLayout& src, const TensorLayout& filter,
  801. TensorLayout& dst);
  802. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  803. const TensorLayout& filter,
  804. const TensorLayout& dst) = 0;
  805. protected:
  806. void check_exec(const TensorLayout& src, const TensorLayout& filter,
  807. const TensorLayout& dst, size_t workspace_in_bytes);
  808. };
  809. using Local = LocalForward;
  810. class LocalBackwardData : public LocalBase {
  811. DEF_OPR_IMPL(LocalBackwardData, LocalBase, 2, 1);
  812. public:
  813. /**
  814. * \param[in] filter (oh, ow, ic, fh, fw, oc)
  815. * \param[in] diff (n, oc, oh, ow)
  816. * \param[out] grad (n, ic, ih, iw)
  817. */
  818. virtual void exec(_megdnn_tensor_in filter, _megdnn_tensor_in diff,
  819. _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
  820. virtual size_t get_workspace_in_bytes(const TensorLayout& filter,
  821. const TensorLayout& diff,
  822. const TensorLayout& grad) = 0;
  823. protected:
  824. void check_exec(const TensorLayout& filter, const TensorLayout& diff,
  825. const TensorLayout& grad, size_t workspace_in_bytes);
  826. };
  827. class LocalBackwardFilter : public LocalBase {
  828. DEF_OPR_IMPL(LocalBackwardFilter, LocalBase, 2, 1);
  829. public:
  830. /**
  831. * \param[in] src (n, ic, ih, iw)
  832. * \param[in] diff (n, oc, oh, ow)
  833. * \param[out] grad (oh, ow, ic, fh, fw, oc)
  834. */
  835. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in diff,
  836. _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
  837. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  838. const TensorLayout& diff,
  839. const TensorLayout& grad) = 0;
  840. protected:
  841. void check_exec(const TensorLayout& src, const TensorLayout& diff,
  842. const TensorLayout& grad, size_t workspace_in_bytes);
  843. };
  844. class BNBase : public OperatorBase {
  845. DEF_OPR_IMPL_CTOR(BNBase, OperatorBase);
  846. DEF_OPR_PARAM(BN);
  847. protected:
  848. void check_param();
  849. };
  850. class BNForward : public BNBase {
  851. DEF_OPR_IMPL(BNForward, BNBase, 6, 6);
  852. public:
  853. /**
  854. * \dst[i] = gemma
  855. * *(x[i]-estimatedMean[k])/sqrt(epsilon+estimatedVariance[k]) + beta \where
  856. * epsilon is a very small value to avoid a "divide by zero" error.
  857. * \param[in] src (n, c, h, w)
  858. * \param[out] dst (n, c, h, w)
  859. * \param[out] mean (see m_param.ParamDim) Global mean.
  860. * \param[out] variance (see m_param.ParamDim) Global variance.
  861. * \param[out] batch_mean (see m_param.ParamDim)
  862. * Optionally cached intermediate mean from forward pass
  863. * \param[out] batch_inv_variance (see m_param.ParamDim)
  864. * Optionally cached intermediate variance from forward pass
  865. * \param[out] reserve (see cudnnBatchNormalizationForwardTrainingEx)
  866. * src and dst must have the same shape.
  867. * src and dst must be contiguous.
  868. */
  869. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in bn_scale,
  870. _megdnn_tensor_in bn_bias, _megdnn_tensor_inout mean,
  871. _megdnn_tensor_inout variance,
  872. _megdnn_tensor_out batch_mean,
  873. _megdnn_tensor_out batch_inv_variance,
  874. _megdnn_tensor_out reserve, _megdnn_tensor_out dst,
  875. _megdnn_workspace workspace) = 0;
  876. void deduce_layout(const TensorLayout& src, const TensorLayout& bn_scale,
  877. const TensorLayout& bn_bias, TensorLayout& mean,
  878. TensorLayout& variance, TensorLayout& batch_mean,
  879. TensorLayout& batch_inv_variance, TensorLayout& reserve,
  880. TensorLayout& dst);
  881. virtual size_t get_workspace_in_bytes(
  882. const TensorLayout& src, const TensorLayout& bn_scale,
  883. const TensorLayout& bn_bias, const TensorLayout& mean,
  884. const TensorLayout& variance, const TensorLayout& batch_mean,
  885. const TensorLayout& batch_inv_variance, const TensorLayout& reserve,
  886. const TensorLayout& dst) = 0;
  887. virtual size_t get_reserve_in_bytes(const TensorLayout& src) = 0;
  888. protected:
  889. void check_exec(const TensorLayout& src, const TensorLayout& bn_scale,
  890. const TensorLayout& bn_bias, const TensorLayout& mean,
  891. const TensorLayout& variance,
  892. const TensorLayout& batch_mean,
  893. const TensorLayout& batch_inv_variance,
  894. const TensorLayout& dst, size_t workspace_in_bytes,
  895. size_t reserve_in_bytes = 0);
  896. };
  897. using BN = BNForward;
  898. class BNBackward : public BNBase {
  899. DEF_OPR_IMPL(BNBackward, BNBase, 6, 3);
  900. public:
  901. /**
  902. * \param[in] input data of forwarding propagate.
  903. * \param[in] dy the backpropagated gradient of y.
  904. * \param[out] dx the backpropagated gradient of x.
  905. * \param[out] d_bn_scale, the backpropagated gradient of bn_scale.
  906. * \param[out] d_bn_bias, the backpropagated gradient of bn_bias.
  907. * Optionally cached intermediate results from forward pass
  908. * \param[in] saved_batch_mean mean of the input batch.
  909. Calculated in the forwardpropagation.
  910. * \param[in] saved_batch_variance of the input batch.
  911. Calculated in the forwardpropagation.
  912. * \param[in] reserve (see cudnnBatchNormalizationBackwardEx)
  913. */
  914. virtual void exec(_megdnn_tensor_in x, _megdnn_tensor_in dy,
  915. _megdnn_tensor_in saved_batch_mean,
  916. _megdnn_tensor_in saved_batch_variance,
  917. _megdnn_tensor_in bn_scale, _megdnn_tensor_in reserve,
  918. _megdnn_tensor_out d_bn_scale,
  919. _megdnn_tensor_out d_bn_bias, _megdnn_tensor_out dx,
  920. _megdnn_workspace workspace) = 0;
  921. virtual size_t get_workspace_in_bytes(
  922. const TensorLayout& x, const TensorLayout& dy,
  923. const TensorLayout& saved_batch_mean,
  924. const TensorLayout& saved_batch_variance,
  925. const TensorLayout& bn_scale, const TensorLayout& reserve,
  926. const TensorLayout& d_bn_scale, const TensorLayout& d_bn_bias,
  927. const TensorLayout& dx) = 0;
  928. virtual size_t get_reserve_in_bytes(const TensorLayout& src) = 0;
  929. protected:
  930. void check_exec(const TensorLayout& x, const TensorLayout& dy,
  931. const TensorLayout& saved_batch_mean,
  932. const TensorLayout& saved_batch_variance,
  933. const TensorLayout& bn_scale,
  934. const TensorLayout& d_bn_scale,
  935. const TensorLayout& d_bn_bias, const TensorLayout& dx,
  936. size_t workspace_in_bytes, size_t reserve_in_bytes = 0);
  937. };
  938. class LRNBase : public OperatorBase {
  939. DEF_OPR_IMPL_CTOR(LRNBase, OperatorBase);
  940. DEF_OPR_PARAM(LRN);
  941. protected:
  942. void check_param();
  943. };
  944. class LRNForward : public LRNBase {
  945. DEF_OPR_IMPL(LRNForward, LRNBase, 1, 1);
  946. public:
  947. /**
  948. * \see ImageNet Classification with Deep Convolutional Neural Networks
  949. * \param[in] src (n, c, h, w)
  950. * \param[out] dst (n, c, h, w)
  951. *
  952. * src and dst must have the same shape.
  953. * src and dst must be contiguous.
  954. */
  955. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
  956. _megdnn_workspace workspace) = 0;
  957. void deduce_layout(const TensorLayout& src, TensorLayout& dst);
  958. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  959. const TensorLayout& dst) = 0;
  960. protected:
  961. void check_exec(const TensorLayout& src, const TensorLayout& dst,
  962. size_t workspace_in_bytes);
  963. };
  964. using LRN = LRNForward;
  965. class LRNBackward : public LRNBase {
  966. DEF_OPR_IMPL(LRNBackward, LRNBase, 3, 1);
  967. public:
  968. /**
  969. * \param[in] src the `src' parameter in LRNForward::exec
  970. * \param[in] dst the `dst' parameter in LRNForward::exec
  971. * \param[in] diff the backpropagated gradient wrt. dst
  972. * \param[out] grad the backpropagated gradient wrt. src
  973. *
  974. * All tensors should be contiguous and of the same shape.
  975. */
  976. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in dst,
  977. _megdnn_tensor_in diff, _megdnn_tensor_out grad,
  978. _megdnn_workspace workspace) = 0;
  979. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  980. const TensorLayout& dst,
  981. const TensorLayout& diff,
  982. const TensorLayout& grad) = 0;
  983. protected:
  984. void check_exec(const TensorLayout& src, const TensorLayout& dst,
  985. const TensorLayout& diff, const TensorLayout& grad,
  986. size_t workspace_in_bytes);
  987. };
  988. class ROIPoolingBase : public OperatorBase {
  989. DEF_OPR_IMPL_CTOR(ROIPoolingBase, OperatorBase);
  990. DEF_OPR_PARAM(ROIPooling);
  991. protected:
  992. void check_layout_fwd(const TensorLayout& src, const TensorLayout& rois,
  993. const TensorLayout& dst, const TensorLayout& index);
  994. };
  995. class ROIPoolingForward : public ROIPoolingBase {
  996. DEF_OPR_IMPL(ROIPoolingForward, ROIPoolingBase, 2, 2);
  997. public:
  998. /**
  999. * \param[in] src (n, c, ih, iw)
  1000. * \param[in] rois (m, 5)
  1001. * \param[out] dst (m, c, oh, ow)
  1002. * \param[out] index (m, c, oh, ow) if mode is MAX, (0) if mode is AVERAGE
  1003. *
  1004. * The internal implementation is akin to
  1005. * https://github.com/rbgirshick/caffe-fast-rcnn .d
  1006. * Note that rois(, 0) denotes the input image index. We store it as
  1007. * a float, but it should be an integer instead.
  1008. *
  1009. * index is a temporary tensor to facilitate its backward operator.
  1010. * It is used to store argmax indicex in MAX mode, and it is not used
  1011. * in AVERAGE mode.
  1012. */
  1013. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in rois,
  1014. _megdnn_tensor_out dst, _megdnn_tensor_out index,
  1015. _megdnn_workspace workspace) = 0;
  1016. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  1017. const TensorLayout& rois,
  1018. const TensorLayout& dst,
  1019. const TensorLayout& index) = 0;
  1020. protected:
  1021. void check_exec(const TensorLayout& src, const TensorLayout& rois,
  1022. const TensorLayout& dst, const TensorLayout& index,
  1023. size_t workspace_in_bytes);
  1024. };
  1025. using ROIPooling = ROIPoolingForward;
  1026. class ROIPoolingBackward : public ROIPoolingBase {
  1027. DEF_OPR_IMPL(ROIPoolingBackward, ROIPoolingBase, 4, 1);
  1028. public:
  1029. /**
  1030. * \param[in] diff the backpropagated gradient wrt. dst
  1031. * \param[in] src the `src' parameter in ROIPoolingForward::exec
  1032. * \param[in] rois the `rois' parameter in ROIPoolingForward::exec
  1033. * \param[in] index the `index' parameter in ROIPoolingForward::exec
  1034. * \param[out] grad the backpropagated gradient wrt. src
  1035. */
  1036. virtual void exec(_megdnn_tensor_in diff, _megdnn_tensor_in src,
  1037. _megdnn_tensor_in rois, _megdnn_tensor_in index,
  1038. _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
  1039. virtual size_t get_workspace_in_bytes(const TensorLayout& diff,
  1040. const TensorLayout& src,
  1041. const TensorLayout& rois,
  1042. const TensorLayout& index,
  1043. const TensorLayout& grad) = 0;
  1044. protected:
  1045. void check_exec(const TensorLayout& diff, const TensorLayout& src,
  1046. const TensorLayout& rois, const TensorLayout& index,
  1047. const TensorLayout& grad, size_t workspace_in_bytes);
  1048. };
  1049. class Convolution3DBase : public OperatorBase {
  1050. DEF_OPR_IMPL_CTOR(Convolution3DBase, OperatorBase);
  1051. DEF_OPR_PARAM(Convolution3D);
  1052. public:
  1053. static constexpr size_t MAX_SPATIAL_DIM = 3;
  1054. using Mode = Param::Mode;
  1055. struct CanonizedFilterMeta {
  1056. DTypeEnum dtype_enum;
  1057. Param::Format format;
  1058. uint32_t
  1059. //! whether filter should be flipped (i.e. is CONVOLUTION)
  1060. should_flip,
  1061. group, //!< number of groups
  1062. icpg, //!< input channels per group
  1063. ocpg, //!< output channels per group
  1064. spatial_ndim, stride[MAX_SPATIAL_DIM], padding[MAX_SPATIAL_DIM],
  1065. //! spatial dim
  1066. spatial[MAX_SPATIAL_DIM], dilation[MAX_SPATIAL_DIM],
  1067. //! spatial dim with dilation applied
  1068. dilated_spatial[MAX_SPATIAL_DIM];
  1069. } MEGDNN_PACKED;
  1070. protected:
  1071. CanonizedFilterMeta deduce_layout_fwd(const TensorLayout& src,
  1072. const TensorLayout& filter,
  1073. TensorLayout& dst) const;
  1074. CanonizedFilterMeta check_layout_fwd(const TensorLayout& src,
  1075. const TensorLayout& filter,
  1076. const TensorLayout& dst) const;
  1077. CanonizedFilterMeta make_canonized_filter_meta(
  1078. size_t src_ndim, const TensorLayout& filter) const;
  1079. };
  1080. class Convolution3DForward
  1081. : public Convolution3DBase,
  1082. public detail::MultiAlgoOpr<Convolution3DForward, 3> {
  1083. DEF_OPR_IMPL(Convolution3DForward, Convolution3DBase, 2, 1);
  1084. public:
  1085. /**
  1086. * \param[in] src (n, ic, id, ih, iw)
  1087. * \param[in] filter (oc, ic, fd, fh, fw)
  1088. * \param[out] dst (n, oc, od, oh, ow)
  1089. */
  1090. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter,
  1091. _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0;
  1092. void deduce_layout(const TensorLayout& src, const TensorLayout& filter,
  1093. TensorLayout& dst);
  1094. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  1095. const TensorLayout& filter,
  1096. const TensorLayout& dst) = 0;
  1097. static Algorithm::OprType get_opr_type() {
  1098. return Algorithm::OprType::CONVOLUTION3D_FORWARD;
  1099. }
  1100. protected:
  1101. CanonizedFilterMeta check_exec(const TensorLayout& src,
  1102. const TensorLayout& filter,
  1103. const TensorLayout& dst,
  1104. size_t workspace_in_bytes);
  1105. };
  1106. using Convolution3D = Convolution3DForward;
  1107. class Convolution3DBackwardData
  1108. : public Convolution3DBase,
  1109. public detail::MultiAlgoOpr<Convolution3DBackwardData, 3> {
  1110. DEF_OPR_IMPL(Convolution3DBackwardData, Convolution3DBase, 2, 1);
  1111. public:
  1112. /**
  1113. * \param[in] filter (oc, ic, fd, fh, fw)
  1114. * \param[in] diff (n, oc, od, oh, ow)
  1115. * \param[out] grad (n, ic, id, ih, iw)
  1116. */
  1117. virtual void exec(_megdnn_tensor_in filter, _megdnn_tensor_in diff,
  1118. _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
  1119. virtual size_t get_workspace_in_bytes(const TensorLayout& filter,
  1120. const TensorLayout& diff,
  1121. const TensorLayout& grad) = 0;
  1122. void deduce_layout(const TensorLayout& filter, const TensorLayout& diff,
  1123. TensorLayout& grad);
  1124. static Algorithm::OprType get_opr_type() {
  1125. return Algorithm::OprType::CONVOLUTION3D_BACKWARD_DATA;
  1126. }
  1127. protected:
  1128. CanonizedFilterMeta check_exec(const TensorLayout& filter,
  1129. const TensorLayout& diff,
  1130. const TensorLayout& grad,
  1131. size_t workspace_in_bytes);
  1132. };
  1133. class Convolution3DBackwardFilter
  1134. : public Convolution3DBase,
  1135. public detail::MultiAlgoOpr<Convolution3DBackwardFilter, 3> {
  1136. DEF_OPR_IMPL(Convolution3DBackwardFilter, Convolution3DBase, 2, 1);
  1137. public:
  1138. /**
  1139. * \param[in] src (n, ic, id, ih, iw)
  1140. * \param[in] diff (n, oc, od, oh, ow)
  1141. * \param[out] grad (oc, ic, fd, fh, fw)
  1142. */
  1143. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in diff,
  1144. _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
  1145. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  1146. const TensorLayout& diff,
  1147. const TensorLayout& grad) = 0;
  1148. static Algorithm::OprType get_opr_type() {
  1149. return Algorithm::OprType::CONVOLUTION3D_BACKWARD_FILTER;
  1150. }
  1151. protected:
  1152. CanonizedFilterMeta check_exec(const TensorLayout& src,
  1153. const TensorLayout& diff,
  1154. const TensorLayout& grad,
  1155. size_t workspace_in_bytes);
  1156. };
  1157. class LocalShareBase : public OperatorBase {
  1158. DEF_OPR_IMPL_CTOR(LocalShareBase, OperatorBase);
  1159. DEF_OPR_PARAM(LocalShare);
  1160. protected:
  1161. void deduce_layout_fwd(const TensorLayout& src, const TensorLayout& filter,
  1162. TensorLayout& dst);
  1163. void check_layout_fwd(const TensorLayout& src, const TensorLayout& filter,
  1164. const TensorLayout& dst);
  1165. };
  1166. class LocalShareForward : public LocalShareBase,
  1167. public detail::MultiAlgoOpr<LocalShareForward, 3> {
  1168. DEF_OPR_IMPL(LocalShareForward, LocalShareBase, 2, 1);
  1169. public:
  1170. /**
  1171. * \param[in] src (N, IC, IH, IW)
  1172. * \param[in] filter (G, spatial_groups_h, spatial_groups_w, IC / G,
  1173. * FH, FW, OC / G)
  1174. * \param[out] dst (N, OC, OH, OW)
  1175. */
  1176. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter,
  1177. _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0;
  1178. /**
  1179. * \brief deduce layout of the ouput tensor
  1180. */
  1181. void deduce_layout(const TensorLayout& src, const TensorLayout& filter,
  1182. TensorLayout& dst);
  1183. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  1184. const TensorLayout& filter,
  1185. const TensorLayout& dst) = 0;
  1186. static Algorithm::OprType get_opr_type() {
  1187. return Algorithm::OprType::LOCAL_SHARE_FORWARD;
  1188. }
  1189. protected:
  1190. void check_exec(const TensorLayout& src, const TensorLayout& filter,
  1191. const TensorLayout& dst, size_t workspace_in_bytes);
  1192. };
  1193. using LocalShare = LocalShareForward;
  1194. class LocalShareBackwardData
  1195. : public LocalShareBase,
  1196. public detail::MultiAlgoOpr<LocalShareBackwardData, 3> {
  1197. DEF_OPR_IMPL(LocalShareBackwardData, LocalShareBase, 2, 1);
  1198. public:
  1199. /**
  1200. * \param[in] filter (G, spatial_groups_h, spatial_groups_w, IC / G,
  1201. * FH, FW, OC / G)
  1202. * \param[in] diff (N, OC, OH, OW)
  1203. * \param[out] grad (N, IC, IH, IW)
  1204. */
  1205. virtual void exec(_megdnn_tensor_in filter, _megdnn_tensor_in diff,
  1206. _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
  1207. virtual size_t get_workspace_in_bytes(const TensorLayout& filter,
  1208. const TensorLayout& diff,
  1209. const TensorLayout& grad) = 0;
  1210. void deduce_layout(const TensorLayout& filter, const TensorLayout& diff,
  1211. TensorLayout& grad);
  1212. static Algorithm::OprType get_opr_type() {
  1213. return Algorithm::OprType::LOCAL_SHARE_BACKWARD_DATA;
  1214. }
  1215. protected:
  1216. void check_exec(const TensorLayout& filter, const TensorLayout& diff,
  1217. const TensorLayout& grad, size_t workspace_in_bytes);
  1218. };
  1219. class LocalShareBackwardFilter
  1220. : public LocalShareBase,
  1221. public detail::MultiAlgoOpr<LocalShareBackwardFilter, 3> {
  1222. DEF_OPR_IMPL(LocalShareBackwardFilter, LocalShareBase, 2, 1);
  1223. public:
  1224. /**
  1225. * \param[in] src (N, IC, IH, IW)
  1226. * \param[in] diff (N, OC, OH, OW)
  1227. * \param[out] grad (G, spatial_groups_h, spatial_groups_w, IC / G,
  1228. * FH, FW, OC / G)
  1229. */
  1230. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in diff,
  1231. _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
  1232. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  1233. const TensorLayout& diff,
  1234. const TensorLayout& grad) = 0;
  1235. static Algorithm::OprType get_opr_type() {
  1236. return Algorithm::OprType::LOCAL_SHARE_BACKWARD_FILTER;
  1237. }
  1238. protected:
  1239. void check_exec(const TensorLayout& src, const TensorLayout& diff,
  1240. const TensorLayout& grad, size_t workspace_in_bytes);
  1241. };
  1242. class ROIAlignBase : public OperatorBase {
  1243. DEF_OPR_IMPL_CTOR(ROIAlignBase, OperatorBase);
  1244. DEF_OPR_PARAM(ROIAlign);
  1245. protected:
  1246. void deduce_layout_fwd(const TensorLayout& src, const TensorLayout& rois,
  1247. TensorLayout& dst, TensorLayout& index);
  1248. void check_layout_fwd(const TensorLayout& src, const TensorLayout& rois,
  1249. const TensorLayout& dst, const TensorLayout& index);
  1250. };
  1251. class ROIAlignForward : public ROIAlignBase {
  1252. DEF_OPR_IMPL(ROIAlignForward, ROIAlignBase, 2, 2);
  1253. public:
  1254. /**
  1255. * \param[in] src (n, c, ih, iw)
  1256. * \param[in] rois (m, 5)
  1257. * \param[out] dst (m, c, oh, ow)
  1258. * \param[out] index (m, c, oh, ow) if mode is MAX, (0) if mode is AVERAGE
  1259. *
  1260. * Note that rois(, 0) denotes the input image index. We store it as
  1261. * a float, but it should be an integer instead.
  1262. *
  1263. * index is a temporary tensor to facilitate its backward operator.
  1264. * It is used to store argmax indicex in MAX mode, and it is not used
  1265. * in AVERAGE mode.
  1266. */
  1267. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in rois,
  1268. _megdnn_tensor_out dst, _megdnn_tensor_out index,
  1269. _megdnn_workspace workspace) = 0;
  1270. void deduce_layout(const TensorLayout& src, const TensorLayout& rois,
  1271. TensorLayout& dst, TensorLayout& index);
  1272. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  1273. const TensorLayout& rois,
  1274. const TensorLayout& dst,
  1275. const TensorLayout& index) = 0;
  1276. protected:
  1277. void check_exec(const TensorLayout& src, const TensorLayout& rois,
  1278. const TensorLayout& dst, const TensorLayout& index,
  1279. size_t workspace_in_bytes);
  1280. };
  1281. using ROIAlign = ROIAlignForward;
  1282. class ROIAlignBackward : public ROIAlignBase {
  1283. DEF_OPR_IMPL(ROIAlignBackward, ROIAlignBase, 3, 1);
  1284. public:
  1285. /**
  1286. * \param[in] diff the backpropagated gradient wrt. dst
  1287. * \param[in] rois the `rois' parameter in ROIAlignForward::exec
  1288. * \param[in] index the `index' parameter in ROIAlignForward::exec
  1289. * \param[out] grad the backpropagated gradient wrt. src
  1290. */
  1291. virtual void exec(_megdnn_tensor_in diff, _megdnn_tensor_in rois,
  1292. _megdnn_tensor_in index, _megdnn_tensor_out grad,
  1293. _megdnn_workspace workspace) = 0;
  1294. virtual size_t get_workspace_in_bytes(const TensorLayout& diff,
  1295. const TensorLayout& rois,
  1296. const TensorLayout& index,
  1297. const TensorLayout& grad) = 0;
  1298. protected:
  1299. void check_exec(const TensorLayout& diff, const TensorLayout& rois,
  1300. const TensorLayout& index, const TensorLayout& grad,
  1301. size_t workspace_in_bytes);
  1302. };
  1303. class DeformableConvBase : public OperatorBase {
  1304. DEF_OPR_IMPL_CTOR(DeformableConvBase, OperatorBase);
  1305. DEF_OPR_PARAM(Convolution);
  1306. public:
  1307. static constexpr size_t MAX_SPATIAL_DIM = 2;
  1308. struct CanonizedFilterMeta : Convolution::CanonizedFilterMeta {
  1309. uint32_t deformable_group;
  1310. };
  1311. protected:
  1312. CanonizedFilterMeta make_canonized_filter_meta(
  1313. size_t src_ndim, const TensorLayout& filter,
  1314. const TensorLayout& offset) const;
  1315. void deduce_layout_fwd(const TensorLayout& im, const TensorLayout& filter,
  1316. const TensorLayout& mask, const TensorLayout& offset,
  1317. TensorLayout& dst);
  1318. void check_layout_fwd(const TensorLayout& src, const TensorLayout& filter,
  1319. const TensorLayout& mask, const TensorLayout& offset,
  1320. const TensorLayout& dst);
  1321. };
  1322. class DeformableConvForward
  1323. : public DeformableConvBase,
  1324. public detail::MultiAlgoOpr<DeformableConvForward, 5> {
  1325. DEF_OPR_IMPL(DeformableConvForward, DeformableConvBase, 4, 1);
  1326. public:
  1327. /**
  1328. * \param[in] im (n, ic, ih, iw)
  1329. * \param[in] filter (oc, ic, fh, fw)
  1330. * \param[in] offset (dg, 2, fh, fw, oh, ow)
  1331. * \param[in] mask (dg, fh, fw, oh, ow)
  1332. * \param[out] dst (n, oc, oh, ow)
  1333. */
  1334. virtual void exec(_megdnn_tensor_in im, _megdnn_tensor_in filter,
  1335. _megdnn_tensor_in offset, _megdnn_tensor_in mask,
  1336. _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0;
  1337. void deduce_layout(const TensorLayout& im, const TensorLayout& filter,
  1338. const TensorLayout& offset, const TensorLayout& mask,
  1339. TensorLayout& dst);
  1340. virtual size_t get_workspace_in_bytes(const TensorLayout& im,
  1341. const TensorLayout& filter,
  1342. const TensorLayout& offset,
  1343. const TensorLayout& mask,
  1344. const TensorLayout& dst) = 0;
  1345. static Algorithm::OprType get_opr_type() {
  1346. return Algorithm::OprType::DEFORMABLE_CONV_FORWARD;
  1347. }
  1348. protected:
  1349. CanonizedFilterMeta check_exec(const TensorLayout& im,
  1350. const TensorLayout& filter,
  1351. const TensorLayout& offset,
  1352. const TensorLayout& mask,
  1353. const TensorLayout& dst,
  1354. size_t workspace_in_bytes);
  1355. };
  1356. using DeformableConv = DeformableConvForward;
  1357. /**
  1358. * \brief DeformableConvBackwardFilter operator.
  1359. *
  1360. * Calculating the gradient wrt. convolution filter.
  1361. */
  1362. class DeformableConvBackwardFilter
  1363. : public DeformableConvBase,
  1364. public detail::MultiAlgoOpr<DeformableConvBackwardFilter, 5> {
  1365. DEF_OPR_IMPL(DeformableConvBackwardFilter, DeformableConvBase, 4, 1);
  1366. public:
  1367. /**
  1368. * \param[in] im (oc, ic, fh, fw)
  1369. * \param[in] offset (dg, 2, fh, fw, oh, ow)
  1370. * \param[in] mask (dg, fh, fw, oh, ow)
  1371. * \param[in] out_grad (n, oc, oh, ow)
  1372. * \param[out] filter_grad (oc, ic, ih, iw)
  1373. */
  1374. virtual void exec(_megdnn_tensor_in im, _megdnn_tensor_in offset,
  1375. _megdnn_tensor_in mask, _megdnn_tensor_in out_grad,
  1376. _megdnn_tensor_out filter_grad,
  1377. _megdnn_workspace workspace) = 0;
  1378. virtual size_t get_workspace_in_bytes(const TensorLayout& im,
  1379. const TensorLayout& offset,
  1380. const TensorLayout& mask,
  1381. const TensorLayout& out_grad,
  1382. const TensorLayout& filter_grad) = 0;
  1383. void deduce_layout(const TensorLayout& im, const TensorLayout& offset,
  1384. const TensorLayout& mask, const TensorLayout& out_grad,
  1385. TensorLayout& filter_grad);
  1386. static Algorithm::OprType get_opr_type() {
  1387. return Algorithm::OprType::DEFORMABLE_CONV_BACKWARD_FILTER;
  1388. }
  1389. protected:
  1390. CanonizedFilterMeta check_exec(const TensorLayout& im,
  1391. const TensorLayout& offset,
  1392. const TensorLayout& mask,
  1393. const TensorLayout& out_grad,
  1394. const TensorLayout& filter_grad,
  1395. size_t workspace_in_bytes);
  1396. };
  1397. /**
  1398. * \brief DeformableConvBackwardData operator.
  1399. *
  1400. * Calculating the gradient wrt. convolution input data, offset and mask.
  1401. */
  1402. class DeformableConvBackwardData
  1403. : public DeformableConvBase,
  1404. public detail::MultiAlgoOpr<DeformableConvBackwardData, 8> {
  1405. DEF_OPR_IMPL(DeformableConvBackwardData, DeformableConvBase, 5, 3);
  1406. public:
  1407. /**
  1408. * \param[in] im (oc, ic, fh, fw)
  1409. * \param[in] filter (oc, ic, fh, fw)
  1410. * \param[in] offset (dg, 2, fh, fw, oh, ow)
  1411. * \param[in] mask (dg, fh, fw, oh, ow)
  1412. * \param[in] out_grad (n, oc, oh, ow)
  1413. * \param[out] im_grad (n, ic, ih, iw)
  1414. * \param[out] offset_grad (dg, 2, fh, fw, oh, ow)
  1415. * \param[out] mask_grad (dg, fh, fw, oh, ow)
  1416. */
  1417. virtual void exec(_megdnn_tensor_in im, _megdnn_tensor_in filter,
  1418. _megdnn_tensor_in offset, _megdnn_tensor_in mask,
  1419. _megdnn_tensor_in out_grad, _megdnn_tensor_out im_grad,
  1420. _megdnn_tensor_out offset_grad,
  1421. _megdnn_tensor_out mask_grad,
  1422. _megdnn_workspace workspace) = 0;
  1423. virtual size_t get_workspace_in_bytes(
  1424. const TensorLayout& im, const TensorLayout& filter,
  1425. const TensorLayout& offset, const TensorLayout& mask,
  1426. const TensorLayout& out_grad, const TensorLayout& im_grad,
  1427. const TensorLayout& offset_grad, const TensorLayout& mask_grad) = 0;
  1428. void deduce_layout(const TensorLayout& im, const TensorLayout& filter,
  1429. const TensorLayout& offset, const TensorLayout& mask,
  1430. const TensorLayout& out_grad, TensorLayout& im_grad,
  1431. TensorLayout& offset_grad, TensorLayout& mask_grad);
  1432. static Algorithm::OprType get_opr_type() {
  1433. return Algorithm::OprType::DEFORMABLE_CONV_BACKWARD_DATA;
  1434. }
  1435. protected:
  1436. CanonizedFilterMeta check_exec(
  1437. const TensorLayout& im, const TensorLayout& filter,
  1438. const TensorLayout& offset, const TensorLayout& mask,
  1439. const TensorLayout& out_grad, const TensorLayout& im_grad,
  1440. const TensorLayout& offset_grad, const TensorLayout& mask_grad,
  1441. size_t workspace_in_bytes);
  1442. };
  1443. class DeformablePSROIPoolingBase : public OperatorBase {
  1444. DEF_OPR_IMPL_CTOR(DeformablePSROIPoolingBase, OperatorBase);
  1445. DEF_OPR_PARAM(DeformablePSROIPooling);
  1446. protected:
  1447. void deduce_layout_fwd(const TensorLayout& data, const TensorLayout& trans,
  1448. const TensorLayout& rois, TensorLayout& out_data,
  1449. TensorLayout& out_count);
  1450. void check_layout_fwd(const TensorLayout& data, const TensorLayout& trans,
  1451. const TensorLayout& rois,
  1452. const TensorLayout& out_data,
  1453. const TensorLayout& out_count,
  1454. size_t workspace_in_bytes);
  1455. };
  1456. class DeformablePSROIPoolingForward : public DeformablePSROIPoolingBase {
  1457. DEF_OPR_IMPL(DeformablePSROIPoolingForward, DeformablePSROIPoolingBase, 3,
  1458. 2);
  1459. public:
  1460. /**
  1461. * \param[in] data (oc, ic, ih, iw)
  1462. * \param[in] rois (xx, xx, xx, xx)
  1463. * \param[in] trans (oc, ic, fh, fw)
  1464. * \param[out] out_data ( n, ic, ih, iw)
  1465. * \param[out] out_count ( n, ic, ih, iw)
  1466. */
  1467. virtual size_t get_workspace_in_bytes(const TensorLayout& data,
  1468. const TensorLayout& rois,
  1469. const TensorLayout& trans,
  1470. const TensorLayout& out_data,
  1471. const TensorLayout& out_count) = 0;
  1472. virtual void exec(_megdnn_tensor_in data, _megdnn_tensor_in rois,
  1473. _megdnn_tensor_in trans, _megdnn_tensor_out out_data,
  1474. _megdnn_tensor_out out_count,
  1475. _megdnn_workspace workspace) = 0;
  1476. void deduce_layout(const TensorLayout& data, const TensorLayout& rois,
  1477. const TensorLayout& trans, TensorLayout& out_data,
  1478. TensorLayout& out_count);
  1479. void check_exec(const TensorLayout& data, const TensorLayout& rois,
  1480. const TensorLayout& trans, const TensorLayout& out_data,
  1481. const TensorLayout& out_count, size_t workspace_in_bytes);
  1482. };
  1483. using DeformablePSROIPooling = DeformablePSROIPoolingForward;
  1484. class DeformablePSROIPoolingBackward : public DeformablePSROIPoolingBase {
  1485. DEF_OPR_IMPL(DeformablePSROIPoolingBackward, DeformablePSROIPoolingBase, 5,
  1486. 2);
  1487. public:
  1488. /**
  1489. * \param[in] data (oc, ic, ih, iw)
  1490. * \param[in] rois (xx, xx, xx, xx)
  1491. * \param[in] trans (oc, ic, fh, fw)
  1492. * \param[in] out_diff (xx, xx, xx, xx)
  1493. * \param[in] out_count (xx, xx, xx, xx)
  1494. * \param[out] data_diff ( n, ic, ih, iw)
  1495. * \param[out] trans_diff ( n, ic, ih, iw)
  1496. */
  1497. virtual void exec(_megdnn_tensor_in data, _megdnn_tensor_in rois,
  1498. _megdnn_tensor_in trans, _megdnn_tensor_in out_diff,
  1499. _megdnn_tensor_in out_count, _megdnn_tensor_out data_diff,
  1500. _megdnn_tensor_out trans_diff,
  1501. _megdnn_workspace workspace) = 0;
  1502. virtual size_t get_workspace_in_bytes(const TensorLayout& data,
  1503. const TensorLayout& rois,
  1504. const TensorLayout& trans,
  1505. const TensorLayout& out_diff,
  1506. const TensorLayout& out_count,
  1507. const TensorLayout& data_diff,
  1508. const TensorLayout& trans_diff) = 0;
  1509. void check_exec(const TensorLayout& data, const TensorLayout& rois,
  1510. const TensorLayout& trans, const TensorLayout& out_diff,
  1511. const TensorLayout& out_count,
  1512. const TensorLayout& data_diff,
  1513. const TensorLayout& trans_diff, size_t workspace_in_bytes);
  1514. };
  1515. class BatchConvBiasForward
  1516. : public ConvolutionBase<param::BatchConvBias>,
  1517. public detail::MultiAlgoOpr<BatchConvBiasForward, 5> {
  1518. DEF_OPR_IMPL(BatchConvBiasForward, ConvolutionBase, 4, 1);
  1519. public:
  1520. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter,
  1521. _megdnn_tensor_in bias, _megdnn_tensor_in z,
  1522. _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0;
  1523. void deduce_dtype(DType src, DType filter, DType bias, DType z, DType& dst);
  1524. void deduce_layout(const TensorLayout& src, const TensorLayout& filter,
  1525. const TensorLayout& bias, const TensorLayout& z,
  1526. TensorLayout& dst);
  1527. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  1528. const TensorLayout& filter,
  1529. const TensorLayout& bias,
  1530. const TensorLayout& z,
  1531. const TensorLayout& dst) = 0;
  1532. static Algorithm::OprType get_opr_type() {
  1533. return Algorithm::OprType::BATCH_CONV_FORWARD;
  1534. }
  1535. protected:
  1536. CanonizedFilterMeta check_exec(const TensorLayout& src,
  1537. const TensorLayout& filter,
  1538. const TensorLayout& bias,
  1539. const TensorLayout& z,
  1540. const TensorLayout& dst,
  1541. size_t workspace_in_bytes);
  1542. };
  1543. using BatchConvBias = BatchConvBiasForward;
  1544. class FakeQuantBase : public OperatorBase {
  1545. DEF_OPR_IMPL_CTOR(FakeQuantBase, OperatorBase);
  1546. DEF_OPR_PARAM(FakeQuant);
  1547. protected:
  1548. void deduce_layout_fwd(const TensorLayout& input, TensorLayout& output);
  1549. void check_layout_fwd(const TensorLayout& input, const TensorLayout& scale,
  1550. const TensorLayout& zero_point,
  1551. const TensorLayout& output);
  1552. };
  1553. class FakeQuantForward : public FakeQuantBase {
  1554. DEF_OPR_IMPL(FakeQuantForward, FakeQuantBase, 3, 1);
  1555. public:
  1556. virtual void exec(_megdnn_tensor_in input, _megdnn_tensor_in scale,
  1557. _megdnn_tensor_in zero_point, _megdnn_tensor_out output,
  1558. _megdnn_workspace workspace) = 0;
  1559. void deduce_layout(const TensorLayout& input, const TensorLayout& scale,
  1560. const TensorLayout& zero_point, TensorLayout& output);
  1561. virtual size_t get_workspace_in_bytes(const TensorLayout& input,
  1562. const TensorLayout& scale,
  1563. const TensorLayout& zero_point,
  1564. const TensorLayout& output) = 0;
  1565. protected:
  1566. void check_exec(const TensorLayout& input, const TensorLayout& scale,
  1567. const TensorLayout& zero_point, const TensorLayout& output,
  1568. size_t workspace_in_bytes);
  1569. };
  1570. using FakeQuant = FakeQuantForward;
  1571. class FakeQuantBackward : public FakeQuantBase {
  1572. DEF_OPR_IMPL(FakeQuantBackward, FakeQuantBase, 4, 1);
  1573. public:
  1574. virtual void exec(_megdnn_tensor_in diff, _megdnn_tensor_in input,
  1575. _megdnn_tensor_in scale, _megdnn_tensor_in zero_point,
  1576. _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
  1577. virtual size_t get_workspace_in_bytes(const TensorLayout& diff,
  1578. const TensorLayout& input,
  1579. const TensorLayout& scale,
  1580. const TensorLayout& zero_point,
  1581. const TensorLayout& grad) = 0;
  1582. protected:
  1583. void check_exec(const TensorLayout& diff, const TensorLayout& input,
  1584. const TensorLayout& scale, const TensorLayout& zero_point,
  1585. const TensorLayout& grad, size_t workspace_in_bytes);
  1586. };
  1587. class TQTBase : public OperatorBase {
  1588. DEF_OPR_IMPL_CTOR(TQTBase, OperatorBase);
  1589. DEF_OPR_PARAM(TQT);
  1590. protected:
  1591. void deduce_layout_fwd(const TensorLayout& input, TensorLayout& output);
  1592. void check_layout_fwd(const TensorLayout& input, const TensorLayout& scale,
  1593. const TensorLayout& output);
  1594. };
  1595. class TQTForward : public TQTBase {
  1596. DEF_OPR_IMPL(TQTForward, TQTBase, 2, 1);
  1597. public:
  1598. virtual void exec(_megdnn_tensor_in input, _megdnn_tensor_in scale,
  1599. _megdnn_tensor_out output,
  1600. _megdnn_workspace workspace) = 0;
  1601. void deduce_layout(const TensorLayout& input, const TensorLayout& scale,
  1602. TensorLayout& output);
  1603. virtual size_t get_workspace_in_bytes(const TensorLayout& input,
  1604. const TensorLayout& scale,
  1605. const TensorLayout& output) = 0;
  1606. protected:
  1607. void check_exec(const TensorLayout& input, const TensorLayout& scale,
  1608. const TensorLayout& output, size_t workspace_in_bytes);
  1609. };
  1610. using TQT = TQTForward;
  1611. class TQTBackward : public TQTBase {
  1612. DEF_OPR_IMPL(TQTBackward, TQTBase, 3, 2);
  1613. public:
  1614. virtual void exec(_megdnn_tensor_in diff, _megdnn_tensor_in input,
  1615. _megdnn_tensor_in scale, _megdnn_tensor_out grad_x,
  1616. _megdnn_tensor_out grad_s,
  1617. _megdnn_workspace workspace) = 0;
  1618. virtual size_t get_workspace_in_bytes(const TensorLayout& diff,
  1619. const TensorLayout& input,
  1620. const TensorLayout& scale,
  1621. const TensorLayout& grad_x,
  1622. const TensorLayout& grad_s) = 0;
  1623. protected:
  1624. void check_exec(const TensorLayout& diff, const TensorLayout& input,
  1625. const TensorLayout& scale, const TensorLayout& grad_x,
  1626. const TensorLayout& grad_s, size_t workspace_in_bytes);
  1627. };
  1628. class LSQBase : public OperatorBase {
  1629. DEF_OPR_IMPL_CTOR(LSQBase, OperatorBase);
  1630. DEF_OPR_PARAM(LSQ);
  1631. protected:
  1632. void deduce_layout_fwd(const TensorLayout& input, TensorLayout& output);
  1633. void check_layout_fwd(const TensorLayout& input, const TensorLayout& scale,
  1634. const TensorLayout& zero_point,
  1635. const TensorLayout& grad_scale,
  1636. const TensorLayout& output);
  1637. };
  1638. class LSQForward : public LSQBase {
  1639. DEF_OPR_IMPL(LSQForward, LSQBase, 4, 1);
  1640. public:
  1641. virtual void exec(_megdnn_tensor_in input, _megdnn_tensor_in scale,
  1642. _megdnn_tensor_in zero_point,
  1643. _megdnn_tensor_in grad_scale, _megdnn_tensor_out output,
  1644. _megdnn_workspace workspace) = 0;
  1645. void deduce_layout(const TensorLayout& input, const TensorLayout& scale,
  1646. const TensorLayout& zero_point,
  1647. const TensorLayout& grad_scale, TensorLayout& output);
  1648. virtual size_t get_workspace_in_bytes(const TensorLayout& input,
  1649. const TensorLayout& scale,
  1650. const TensorLayout& zero_point,
  1651. const TensorLayout& grad_scale,
  1652. const TensorLayout& output) = 0;
  1653. protected:
  1654. void check_exec(const TensorLayout& input, const TensorLayout& scale,
  1655. const TensorLayout& zero_point,
  1656. const TensorLayout& grad_scale, const TensorLayout& output,
  1657. size_t workspace_in_bytes);
  1658. };
  1659. using LSQ = LSQForward;
  1660. class LSQBackward : public LSQBase {
  1661. DEF_OPR_IMPL(LSQBackward, LSQBase, 5, 2);
  1662. public:
  1663. virtual void exec(_megdnn_tensor_in diff, _megdnn_tensor_in input,
  1664. _megdnn_tensor_in scale, _megdnn_tensor_in zero_point,
  1665. _megdnn_tensor_in grad_scale, _megdnn_tensor_out grad_x,
  1666. _megdnn_tensor_out grad_s,
  1667. _megdnn_workspace workspace) = 0;
  1668. virtual size_t get_workspace_in_bytes(const TensorLayout& diff,
  1669. const TensorLayout& input,
  1670. const TensorLayout& scale,
  1671. const TensorLayout& zero_point,
  1672. const TensorLayout& grad_scale,
  1673. const TensorLayout& grad_x,
  1674. const TensorLayout& grad_s) = 0;
  1675. protected:
  1676. void check_exec(const TensorLayout& diff, const TensorLayout& input,
  1677. const TensorLayout& scale, const TensorLayout& zero_point,
  1678. const TensorLayout& grad_scale, const TensorLayout& grad_x,
  1679. const TensorLayout& grad_s, size_t workspace_in_bytes);
  1680. };
  1681. } // namespace megdnn
  1682. #include "megdnn/internal/opr_header_epilogue.h"
  1683. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台