You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

nn.h 61 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486
  1. /**
  2. * \file dnn/include/megdnn/oprs/nn.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #pragma once
  12. #include "megdnn/internal/opr_header_prologue.h"
  13. namespace megdnn {
  14. class SeparableConvBase : public OperatorBase {
  15. DEF_OPR_IMPL_CTOR(SeparableConvBase, OperatorBase);
  16. DEF_OPR_PARAM(SeparableConv);
  17. public:
  18. using Mode = Param::Mode;
  19. protected:
  20. void deduce_layout_fwd(const TensorLayout& src,
  21. const TensorLayout& filter_x,
  22. const TensorLayout& filter_y, TensorLayout& dst);
  23. void check_layout_fwd(const TensorLayout& src, const TensorLayout& filter_x,
  24. const TensorLayout& filter_y,
  25. const TensorLayout& dst);
  26. };
  27. class SeparableConvForward : public SeparableConvBase {
  28. DEF_OPR_IMPL(SeparableConvForward, SeparableConvBase, 3, 1);
  29. public:
  30. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter_x,
  31. _megdnn_tensor_in filter_y, _megdnn_tensor_out dst,
  32. _megdnn_workspace workspace) = 0;
  33. void deduce_layout(const TensorLayout& src, const TensorLayout& filter_x,
  34. const TensorLayout& filter_y, TensorLayout& dst);
  35. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  36. const TensorLayout& filter_x,
  37. const TensorLayout& filter_y,
  38. const TensorLayout& dst) = 0;
  39. protected:
  40. void check_exec(const TensorLayout& src, const TensorLayout& filter_x,
  41. const TensorLayout& filter_y, const TensorLayout& dst,
  42. size_t workspace_in_bytes);
  43. };
  44. using SeparableConv = SeparableConvForward;
  45. /**
  46. * \brief base class for convolution operation
  47. *
  48. * This operator is supposed to perform convolution on arbitrary input
  49. * dimensions. The input/output format is N, C, dims..., and kernel format can
  50. * take two forms:
  51. * 1. OC, IC, dims..., for conventional dense convolution
  52. * 2. GROUP, OC_PER_GRP, IC_PER_GRP, dims... for sparse group convolution
  53. *
  54. * Currently, only 2D images are supported.
  55. */
  56. template <typename Parameter>
  57. class ConvolutionBase : public OperatorBase {
  58. DEF_OPR_IMPL_CTOR(ConvolutionBase, OperatorBase);
  59. using Param = Parameter;
  60. public:
  61. Param& param() { return m_param; }
  62. const Param& param() const { return m_param; }
  63. protected:
  64. Param m_param;
  65. public:
  66. static constexpr size_t MAX_SPATIAL_DIM = 2;
  67. using Mode = typename Param::Mode;
  68. struct CanonizedFilterMeta {
  69. DType dtype;
  70. typename Param::Format format;
  71. uint32_t
  72. //! whether filter should be flipped (i.e. is CONVOLUTION)
  73. should_flip,
  74. group, //!< number of groups
  75. icpg, //!< input channels per group
  76. ocpg, //!< output channels per group
  77. spatial_ndim, stride[MAX_SPATIAL_DIM], padding[MAX_SPATIAL_DIM],
  78. //! spatial dim
  79. spatial[MAX_SPATIAL_DIM], dilation[MAX_SPATIAL_DIM],
  80. //! spatial dim with dilation applied
  81. dilated_spatial[MAX_SPATIAL_DIM];
  82. //! T should be a ConvolutionBase<Z>::CanonizedFilterMeta
  83. template <typename T>
  84. void copy_from(const T& b) {
  85. dtype = b.dtype;
  86. format = b.format;
  87. should_flip = b.should_flip;
  88. group = b.group;
  89. icpg = b.icpg;
  90. ocpg = b.ocpg;
  91. spatial_ndim = b.spatial_ndim;
  92. memcpy(stride, b.stride, sizeof(stride));
  93. memcpy(padding, b.padding, sizeof(padding));
  94. memcpy(spatial, b.spatial, sizeof(spatial));
  95. memcpy(dilation, b.dilation, sizeof(dilation));
  96. memcpy(dilated_spatial, b.dilated_spatial, sizeof(dilated_spatial));
  97. }
  98. bool operator==(const CanonizedFilterMeta& b) const {
  99. bool flag = true;
  100. flag = flag && (format == b.format);
  101. flag = flag && (dtype == b.dtype);
  102. flag = flag && (should_flip == b.should_flip);
  103. flag = flag && (group == b.group);
  104. flag = flag && (icpg == b.icpg);
  105. flag = flag && (ocpg == b.ocpg);
  106. flag = flag && (spatial_ndim == b.spatial_ndim);
  107. if (flag) {
  108. for (uint32_t i = 0; i < spatial_ndim; ++i) {
  109. flag = flag && (stride[i] == b.stride[i]);
  110. flag = flag && (padding[i] == b.padding[i]);
  111. flag = flag && (spatial[i] == b.spatial[i]);
  112. flag = flag && (dilation[i] == b.dilation[i]);
  113. flag = flag && (dilated_spatial[i] == b.dilated_spatial[i]);
  114. }
  115. }
  116. return flag;
  117. }
  118. };
  119. struct PreprocessedFilter {
  120. //! user data; its lifetime should be bound to MegDNN Convolution
  121. //! operator
  122. void* algorithm_id;
  123. TensorNDArray tensors;
  124. };
  125. protected:
  126. // Check or deduce output DType
  127. void check_or_deduce_dtype_fwd(DType src, DType filter, DType& dst) const;
  128. CanonizedFilterMeta deduce_layout_fwd(const TensorLayout& src,
  129. const TensorLayout& filter,
  130. TensorLayout& dst) const;
  131. CanonizedFilterMeta check_layout_fwd(const TensorLayout& src,
  132. const TensorLayout& filter,
  133. const TensorLayout& dst) const;
  134. CanonizedFilterMeta make_canonized_filter_meta(
  135. size_t src_ndim, const TensorLayout& filter) const;
  136. };
  137. class MaskPropagate : public OperatorBase {
  138. DEF_OPR_IMPL(MaskPropagate, OperatorBase, 1, 1);
  139. DEF_OPR_PARAM(MaskPropagate);
  140. public:
  141. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
  142. _megdnn_workspace workspace) = 0;
  143. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  144. const TensorLayout& dst) = 0;
  145. void deduce_layout(const TensorLayout& src, TensorLayout& dst);
  146. };
  147. /**
  148. * \brief ConvolutionForward Operator with 0/1 Mask matrix
  149. */
  150. class MaskConvForward : public ConvolutionBase<param::Convolution> {
  151. DEF_OPR_IMPL(MaskConvForward, ConvolutionBase, 3, 1);
  152. public:
  153. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter,
  154. _megdnn_tensor_in mask, _megdnn_tensor_out dst,
  155. _megdnn_workspace worksapce) = 0;
  156. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  157. const TensorLayout& filter,
  158. const TensorLayout& mask,
  159. const TensorLayout& dst) = 0;
  160. void deduce_dtype(DType src, DType filter, DType mask, DType& dst);
  161. void deduce_layout(const TensorLayout& src, const TensorLayout& filter,
  162. const TensorLayout& mask, TensorLayout& dst);
  163. protected:
  164. CanonizedFilterMeta check_exec(const TensorLayout& src,
  165. const TensorLayout& filter,
  166. const TensorLayout& mask,
  167. const TensorLayout& dst,
  168. size_t workspace_in_bytes);
  169. };
  170. using MaskConvolution = MaskConvForward;
  171. /**
  172. * \brief ConvolutionForward operator.
  173. */
  174. class ConvolutionForward : public ConvolutionBase<param::Convolution>,
  175. public detail::MultiAlgoOpr<ConvolutionForward, 3> {
  176. DEF_OPR_IMPL(ConvolutionForward, ConvolutionBase, 2, 1);
  177. public:
  178. /**
  179. * \param[in] src (n, ic, ih, iw)
  180. * \param[in] filter (oc, ic, fh, fw)
  181. * \param[out] dst (n, oc, oh, ow)
  182. */
  183. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter,
  184. _megdnn_tensor_out dst,
  185. const PreprocessedFilter* preprocessed_filter,
  186. _megdnn_workspace workspace) = 0;
  187. virtual void exec_preprocess(const TensorLayout& src_layout,
  188. _megdnn_tensor_in filter,
  189. const TensorLayout& dst_layout,
  190. PreprocessedFilter* preprocessed_filter,
  191. _megdnn_workspace workspace) = 0;
  192. void deduce_dtype(DType src, DType filter, DType& dst);
  193. void deduce_layout(const TensorLayout& src, const TensorLayout& filter,
  194. TensorLayout& dst);
  195. virtual size_t get_workspace_in_bytes(
  196. const TensorLayout& src, const TensorLayout& filter,
  197. const TensorLayout& dst,
  198. const PreprocessedFilter* preprocessed_filter) = 0;
  199. virtual SmallVector<TensorLayout> deduce_preprocessed_filter_layout(
  200. const TensorLayout& src, const TensorLayout& filter,
  201. const TensorLayout& dst) = 0;
  202. virtual size_t get_preprocess_workspace_in_bytes(
  203. const TensorLayout& src, const TensorLayout& filter,
  204. const TensorLayout& dst) = 0;
  205. protected:
  206. CanonizedFilterMeta check_exec(const TensorLayout& src,
  207. const TensorLayout& filter,
  208. const TensorLayout& dst,
  209. size_t workspace_in_bytes);
  210. };
  211. using Convolution = ConvolutionForward;
  212. /**
  213. * \brief ConvolutionBackwardData operator.
  214. *
  215. * Calculating the gradient wrt. convolution input data.
  216. */
  217. class ConvolutionBackwardData
  218. : public ConvolutionBase<param::Convolution>,
  219. public detail::MultiAlgoOpr<ConvolutionBackwardData, 3> {
  220. DEF_OPR_IMPL(ConvolutionBackwardData, ConvolutionBase, 2, 1);
  221. public:
  222. /**
  223. * \param[in] filter (oc, ic, fh, fw)
  224. * \param[in] diff (n, oc, oh, ow)
  225. * \param[out] grad (n, ic, ih, iw)
  226. */
  227. virtual void exec(_megdnn_tensor_in filter, _megdnn_tensor_in diff,
  228. _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
  229. virtual size_t get_workspace_in_bytes(const TensorLayout& filter,
  230. const TensorLayout& diff,
  231. const TensorLayout& grad) = 0;
  232. void deduce_dtype(DType filter, DType diff, DType& grad);
  233. void deduce_layout(const TensorLayout& filter, const TensorLayout& diff,
  234. TensorLayout& grad);
  235. protected:
  236. CanonizedFilterMeta check_exec(const TensorLayout& filter,
  237. const TensorLayout& diff,
  238. const TensorLayout& grad,
  239. size_t workspace_in_bytes);
  240. };
  241. /**
  242. * \brief ConvolutionBackwardFilter operator.
  243. *
  244. * Calculating the gradient wrt. convolution filter.
  245. */
  246. class ConvolutionBackwardFilter
  247. : public ConvolutionBase<param::Convolution>,
  248. public detail::MultiAlgoOpr<ConvolutionBackwardFilter, 3> {
  249. DEF_OPR_IMPL(ConvolutionBackwardFilter, ConvolutionBase, 2, 1);
  250. public:
  251. /**
  252. * \param[in] src (n, ic, ih, iw)
  253. * \param[in] diff (n, oc, oh, ow)
  254. * \param[out] grad (oc, ic, fh, fw)
  255. */
  256. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in diff,
  257. _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
  258. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  259. const TensorLayout& diff,
  260. const TensorLayout& grad) = 0;
  261. protected:
  262. CanonizedFilterMeta check_exec(const TensorLayout& src,
  263. const TensorLayout& diff,
  264. const TensorLayout& grad,
  265. size_t workspace_in_bytes);
  266. };
  267. /**
  268. * \brief ConvolutionBias operator
  269. */
  270. class ConvBiasForward : public ConvolutionBase<param::ConvBias>,
  271. public detail::MultiAlgoOpr<ConvBiasForward, 5> {
  272. DEF_OPR_IMPL(ConvBiasForward, ConvolutionBase, 4, 1);
  273. public:
  274. /**
  275. * \param[in] src (n, ic, ih, iw) or (n, ih, iw, ic)
  276. * \param[in] filter (oc, ic, fh, fw) or (oc, fh, fw, ic) or (oc/4, fh, fw,
  277. * 4*ic) \param[in] bias (1, oc, 1, 1) \param[in] z same as dst \param[out]
  278. * dst (n, oc, oh, ow) or (n, oh, ow, oc)
  279. *
  280. * \note if the format is NCHW_WINOGRAD, the filter layout is (alphah,
  281. * alphaw, oc, ic)
  282. */
  283. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter,
  284. _megdnn_tensor_in bias, _megdnn_tensor_in z,
  285. _megdnn_tensor_out dst,
  286. const PreprocessedFilter* preprocessed_filter,
  287. _megdnn_workspace workspace) = 0;
  288. virtual void exec_preprocess(const TensorLayout& src_layout,
  289. _megdnn_tensor_in filter,
  290. const TensorLayout& bias_layout,
  291. const TensorLayout& z_layout,
  292. const TensorLayout& dst_layout,
  293. PreprocessedFilter* preprocessed_filter,
  294. _megdnn_workspace workspace) = 0;
  295. void deduce_dtype(DType src, DType filter, DType bias, DType z, DType& dst);
  296. void deduce_layout(const TensorLayout& src, const TensorLayout& filter,
  297. const TensorLayout& bias, const TensorLayout& z,
  298. TensorLayout& dst);
  299. virtual size_t get_workspace_in_bytes(
  300. const TensorLayout& src, const TensorLayout& filter,
  301. const TensorLayout& bias, const TensorLayout& z,
  302. const TensorLayout& dst,
  303. const PreprocessedFilter* preprocessed_filter) = 0;
  304. virtual size_t get_preprocess_workspace_in_bytes(
  305. const TensorLayout& src, const TensorLayout& filter,
  306. const TensorLayout& bias, const TensorLayout& z,
  307. const TensorLayout& dst) = 0;
  308. virtual SmallVector<TensorLayout> deduce_preprocessed_filter_layout(
  309. const TensorLayout& src, const TensorLayout& filter,
  310. const TensorLayout& bias, const TensorLayout& z,
  311. const TensorLayout& dst) = 0;
  312. enum class BiasMode : uint32_t {
  313. NO_BIAS = 0, //!< no bias
  314. BROADCAST_CHANNEL_BIAS, //!< broadcast channel bias, [1, c, 1, 1]
  315. BIAS //!< [N, C, H, W]
  316. };
  317. //! param for winograd algos.
  318. struct WinogradParam {
  319. uint32_t channel_block_size;
  320. uint32_t output_block_size;
  321. uint32_t tile_size;
  322. bool operator==(const WinogradParam& rhs) const {
  323. return channel_block_size == rhs.channel_block_size &&
  324. output_block_size == rhs.output_block_size &&
  325. tile_size == rhs.tile_size;
  326. }
  327. std::string to_string() const;
  328. };
  329. static constexpr WinogradParam INVALID_WINOGRAD_PARAM = {0, 0, 0};
  330. struct DirectParam {
  331. std::string to_string() const { return ""; }
  332. };
  333. struct MatmulParam {
  334. std::string to_string() const { return ""; }
  335. };
  336. struct DefaultParam {
  337. std::string to_string() const { return ""; }
  338. };
  339. //! get algo name, the format is ParamTrait<T>::category:base:p.to_string()
  340. //! \warning: base must not contain :.
  341. template <typename T>
  342. static std::string algo_name(const std::string& base, const T& p);
  343. /*!
  344. * \brief parse algo_name and get WinogradParam from algo name.
  345. *
  346. * \param algo name string
  347. * \return WinogradParam parsed from algo name, use pattern
  348. * winograd:base:m:tile_size.
  349. *
  350. * \warning: INVALID_WINOGRAD_PARAM returns if the algo_name is not matched.
  351. */
  352. static WinogradParam parse_winograd_name(const std::string& algo_name);
  353. protected:
  354. CanonizedFilterMeta check_exec(const TensorLayout& src,
  355. const TensorLayout& filter,
  356. const TensorLayout& bias,
  357. const TensorLayout& z,
  358. const TensorLayout& dst,
  359. size_t workspace_in_bytes);
  360. };
  361. using ConvBias = ConvBiasForward;
  362. /**
  363. * \brief base class for Conv - Nonline - Pooling
  364. */
  365. class ConvPoolingBase : public OperatorBase {
  366. DEF_OPR_IMPL_CTOR(ConvPoolingBase, OperatorBase);
  367. /**
  368. * \ Param::Method: Two methods to fetch the input data.
  369. * Default methods is WITH_TEXTURE_OBJ.
  370. * If you want to use WITH_SHARED_MEM mode,
  371. * please make sure that the size of
  372. * [ all of the fliter kernels + a channel
  373. * of input data + a channel of output data]
  374. * should be no large than 38KB.
  375. * And the pooling mode should not be "MAX".
  376. */
  377. DEF_OPR_PARAM(ConvPooling);
  378. protected:
  379. virtual void deduce_layout(const TensorLayout& src,
  380. const TensorLayout& filter,
  381. const TensorLayout& bias, TensorLayout& dst) = 0;
  382. virtual void check_layout(const TensorLayout& src,
  383. const TensorLayout& filter,
  384. const TensorLayout& bias, TensorLayout& dst,
  385. size_t workspace_limit_in_bytes) = 0;
  386. };
  387. class ConvPoolingForward : public ConvPoolingBase {
  388. DEF_OPR_IMPL(ConvPoolingForward, ConvPoolingBase, 2, 1);
  389. public:
  390. /**
  391. * \param[in] src input tensor
  392. * \param[out] dst output tensor
  393. */
  394. virtual void exec(const _megdnn_in TensorND src,
  395. const _megdnn_in TensorND filter,
  396. const _megdnn_in TensorND bias, _megdnn_out TensorND dst,
  397. _megdnn_out Workspace workspace) = 0;
  398. virtual void deduce_layout(const TensorLayout& src,
  399. const TensorLayout& filter,
  400. const TensorLayout& bias, TensorLayout& dst) = 0;
  401. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  402. const TensorLayout& filter,
  403. const TensorLayout& bias,
  404. const TensorLayout& dst) = 0;
  405. protected:
  406. virtual void check_layout(const TensorLayout& src,
  407. const TensorLayout& filter,
  408. const TensorLayout& bias, TensorLayout& dst,
  409. size_t workspace_limit_in_bytes) = 0;
  410. };
  411. using ConvPooling = ConvPoolingForward;
  412. class GroupLocalBase : public OperatorBase {
  413. DEF_OPR_IMPL_CTOR(GroupLocalBase, OperatorBase);
  414. DEF_OPR_PARAM(Convolution);
  415. public:
  416. using Mode = Param::Mode;
  417. protected:
  418. void deduce_layout_fwd(const TensorLayout& src, const TensorLayout& filter,
  419. TensorLayout& dst);
  420. void check_layout_fwd(const TensorLayout& src, const TensorLayout& filter,
  421. const TensorLayout& dst);
  422. };
  423. class GroupLocalForward : public GroupLocalBase {
  424. DEF_OPR_IMPL(GroupLocalForward, GroupLocalBase, 2, 1);
  425. public:
  426. /**
  427. * \param[in] src (N, IC, IH, IW)
  428. * \param[in] filter (G, OH, OW, IC/G, FH, FW, OC/G)
  429. * \param[out] dst (N, OC, OH, OW)
  430. **/
  431. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter,
  432. _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0;
  433. void deduce_layout(const TensorLayout& src, const TensorLayout& filter,
  434. TensorLayout& dst) {
  435. deduce_layout_fwd(src, filter, dst);
  436. }
  437. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  438. const TensorLayout& filter,
  439. const TensorLayout& dst) = 0;
  440. protected:
  441. void check_exec(const TensorLayout& src, const TensorLayout& filter,
  442. const TensorLayout& dst, size_t workspace_in_bytes);
  443. };
  444. using GroupLocal = GroupLocalForward;
  445. class GroupLocalBackwardData : public GroupLocalBase {
  446. DEF_OPR_IMPL(GroupLocalBackwardData, GroupLocalBase, 2, 1);
  447. public:
  448. virtual void exec(_megdnn_tensor_in filter, _megdnn_tensor_in diff,
  449. _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
  450. virtual size_t get_workspace_in_bytes(const TensorLayout& filter,
  451. const TensorLayout& diff,
  452. const TensorLayout& grad) = 0;
  453. protected:
  454. void check_exec(const TensorLayout& filter, const TensorLayout& diff,
  455. const TensorLayout& grad, size_t workspace_in_bytes);
  456. };
  457. class GroupLocalBackwardFilter : public GroupLocalBase {
  458. DEF_OPR_IMPL(GroupLocalBackwardFilter, GroupLocalBase, 2, 1);
  459. public:
  460. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in diff,
  461. _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
  462. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  463. const TensorLayout& diff,
  464. const TensorLayout& grad) = 0;
  465. protected:
  466. void check_exec(const TensorLayout& filter, const TensorLayout& diff,
  467. const TensorLayout& grad, size_t workspace_in_bytes);
  468. };
  469. class Images2NeibsBase : public OperatorBase {
  470. DEF_OPR_IMPL_CTOR(Images2NeibsBase, OperatorBase);
  471. DEF_OPR_PARAM(Images2Neibs);
  472. protected:
  473. void deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst);
  474. void check_layout_fwd(const TensorLayout& filter, const TensorLayout& dst);
  475. };
  476. class Images2NeibsForward : public Images2NeibsBase {
  477. DEF_OPR_IMPL(Images2NeibsForward, Images2NeibsBase, 1, 1);
  478. public:
  479. /**
  480. * \param[in] src (N, C, IH, IW)
  481. * \param[out] dst (N, C, OH, OW, window_h, window_w)
  482. *
  483. * \see
  484. * http://deeplearning.net/software/theano/library/tensor/nnet/neighbours.html
  485. *
  486. * \f$ dst_{n, c, oh, ow, wh, ww} = src_{n, c, ih+wh, iw+fw}\f$,
  487. * where \f$ ih=-pad_h+oh*stride_h, iw=-pad_w+ow*stride_w\f$.
  488. */
  489. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
  490. _megdnn_workspace workspace) = 0;
  491. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  492. const TensorLayout& dst) = 0;
  493. void deduce_layout(const TensorLayout& src, TensorLayout& dst);
  494. protected:
  495. void check_exec(const TensorLayout& src, const TensorLayout& dst,
  496. size_t workspace_in_bytes);
  497. };
  498. using Images2Neibs = Images2NeibsForward;
  499. class Images2NeibsBackward : public Images2NeibsBase {
  500. DEF_OPR_IMPL(Images2NeibsBackward, Images2NeibsBase, 1, 1);
  501. public:
  502. /**
  503. * \param[in] diff the backpropagated gradient wrt. dst
  504. * \param[out] grad the backpropagated gradient wrt. src
  505. */
  506. virtual void exec(_megdnn_tensor_in diff, _megdnn_tensor_out grad,
  507. _megdnn_workspace workspace) = 0;
  508. virtual size_t get_workspace_in_bytes(const TensorLayout& diff,
  509. const TensorLayout& grad) = 0;
  510. protected:
  511. void check_exec(const TensorLayout& diff, const TensorLayout& grad,
  512. size_t workspace_in_bytes);
  513. };
  514. /**
  515. * \brief base class for Pooling
  516. */
  517. class PoolingBase : public OperatorBase {
  518. DEF_OPR_IMPL_CTOR(PoolingBase, OperatorBase);
  519. DEF_OPR_PARAM(Pooling);
  520. public:
  521. using Mode = Param::Mode;
  522. protected:
  523. void deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst);
  524. void check_layout_fwd(const TensorLayout& src, const TensorLayout& dst);
  525. };
  526. class PoolingForward : public PoolingBase {
  527. DEF_OPR_IMPL(PoolingForward, PoolingBase, 1, 1);
  528. public:
  529. /**
  530. * \param[in] src input tensor
  531. * \param[out] dst output tensor
  532. */
  533. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
  534. _megdnn_workspace workspace) = 0;
  535. void deduce_layout(const TensorLayout& src, TensorLayout& dst);
  536. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  537. const TensorLayout& dst) = 0;
  538. protected:
  539. void check_exec(const TensorLayout& src, const TensorLayout& dst,
  540. size_t workspace_in_bytes);
  541. };
  542. using Pooling = PoolingForward;
  543. class PoolingBackward : public PoolingBase {
  544. DEF_OPR_IMPL(PoolingBackward, PoolingBase, 3, 1);
  545. public:
  546. /**
  547. * \param[in] src the `src' parameter in PoolingForward::exec
  548. * \param[in] dst the `dst' parameter in PoolingForward::exec
  549. * \param[in] diff the backpropagated gradient wrt. dst
  550. * \param[out] grad the backpropagated gradient wrt. src
  551. */
  552. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in dst,
  553. _megdnn_tensor_in diff, _megdnn_tensor_out grad,
  554. _megdnn_workspace workspace) = 0;
  555. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  556. const TensorLayout& dst,
  557. const TensorLayout& diff,
  558. const TensorLayout& grad) = 0;
  559. protected:
  560. void check_exec(const TensorLayout& src, const TensorLayout& dst,
  561. const TensorLayout& diff, const TensorLayout& grad,
  562. size_t workspace_in_bytes);
  563. };
  564. /**
  565. * \brief base class for Local
  566. */
  567. class LocalBase : public OperatorBase {
  568. DEF_OPR_IMPL_CTOR(LocalBase, OperatorBase);
  569. DEF_OPR_PARAM(Convolution);
  570. public:
  571. using Mode = Param::Mode;
  572. protected:
  573. void deduce_layout_fwd(const TensorLayout& src, const TensorLayout& filter,
  574. TensorLayout& dst);
  575. void check_layout_fwd(const TensorLayout& src, const TensorLayout& filter,
  576. const TensorLayout& dst);
  577. };
  578. class LocalForward : public LocalBase {
  579. DEF_OPR_IMPL(LocalForward, LocalBase, 2, 1);
  580. public:
  581. /**
  582. * \param[in] src (n, ic, ih, iw)
  583. * \param[in] filter (oh, ow, ic, fh, fw, oc)
  584. * \param[out] dst (n, oc, oh, ow)
  585. */
  586. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter,
  587. _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0;
  588. /**
  589. * \brief Deducing output tensor layouts from input tensor layouts.
  590. *
  591. * Be aware that the first and second dimension of `filter' are ignored
  592. * when deducing `dst' layout.
  593. */
  594. void deduce_layout(const TensorLayout& src, const TensorLayout& filter,
  595. TensorLayout& dst);
  596. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  597. const TensorLayout& filter,
  598. const TensorLayout& dst) = 0;
  599. protected:
  600. void check_exec(const TensorLayout& src, const TensorLayout& filter,
  601. const TensorLayout& dst, size_t workspace_in_bytes);
  602. };
  603. using Local = LocalForward;
  604. class LocalBackwardData : public LocalBase {
  605. DEF_OPR_IMPL(LocalBackwardData, LocalBase, 2, 1);
  606. public:
  607. /**
  608. * \param[in] filter (oh, ow, ic, fh, fw, oc)
  609. * \param[in] diff (n, oc, oh, ow)
  610. * \param[out] grad (n, ic, ih, iw)
  611. */
  612. virtual void exec(_megdnn_tensor_in filter, _megdnn_tensor_in diff,
  613. _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
  614. virtual size_t get_workspace_in_bytes(const TensorLayout& filter,
  615. const TensorLayout& diff,
  616. const TensorLayout& grad) = 0;
  617. protected:
  618. void check_exec(const TensorLayout& filter, const TensorLayout& diff,
  619. const TensorLayout& grad, size_t workspace_in_bytes);
  620. };
  621. class LocalBackwardFilter : public LocalBase {
  622. DEF_OPR_IMPL(LocalBackwardFilter, LocalBase, 2, 1);
  623. public:
  624. /**
  625. * \param[in] src (n, ic, ih, iw)
  626. * \param[in] diff (n, oc, oh, ow)
  627. * \param[out] grad (oh, ow, ic, fh, fw, oc)
  628. */
  629. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in diff,
  630. _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
  631. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  632. const TensorLayout& diff,
  633. const TensorLayout& grad) = 0;
  634. protected:
  635. void check_exec(const TensorLayout& src, const TensorLayout& diff,
  636. const TensorLayout& grad, size_t workspace_in_bytes);
  637. };
  638. class BNBase : public OperatorBase {
  639. DEF_OPR_IMPL_CTOR(BNBase, OperatorBase);
  640. DEF_OPR_PARAM(BN);
  641. protected:
  642. void check_param();
  643. };
  644. class BNForward : public BNBase {
  645. DEF_OPR_IMPL(BNForward, BNBase, 6, 5);
  646. public:
  647. /**
  648. * \dst[i] = gemma
  649. * *(x[i]-estimatedMean[k])/sqrt(epsilon+estimatedVariance[k]) + beta \where
  650. * epsilon is a very small value to avoid a "divide by zero" error.
  651. * \param[in] src (n, c, h, w)
  652. * \param[out] dst (n, c, h, w)
  653. * \param[out] mean (see m_param.ParamDim) Global mean.
  654. * \param[out] variance (see m_param.ParamDim) Global variance.
  655. * \Param[out] batch_mean (see m_param.ParamDim)
  656. * Optionally cached intermediate mean from forward pass
  657. * \Param[out] batch_inv_variance (see m_param.ParamDim)
  658. * Optionally cached intermediate variance from forward pass
  659. * src and dst must have the same shape.
  660. * src and dst must be contiguous.
  661. */
  662. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in bn_scale,
  663. _megdnn_tensor_in bn_bias, _megdnn_tensor_inout mean,
  664. _megdnn_tensor_inout variance,
  665. _megdnn_tensor_out batch_mean,
  666. _megdnn_tensor_out batch_inv_variance,
  667. _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0;
  668. void deduce_layout(const TensorLayout& src, TensorLayout& bn_scale,
  669. TensorLayout& bn_bias, TensorLayout& mean,
  670. TensorLayout& variance, TensorLayout& batch_mean,
  671. TensorLayout& batch_inv_variance, TensorLayout& dst);
  672. virtual size_t get_workspace_in_bytes(
  673. const TensorLayout& src, const TensorLayout& bn_scale,
  674. const TensorLayout& bn_bias, const TensorLayout& mean,
  675. const TensorLayout& variance, const TensorLayout& batch_mean,
  676. const TensorLayout& batch_inv_variance,
  677. const TensorLayout& dst) = 0;
  678. protected:
  679. void check_exec(const TensorLayout& src, const TensorLayout& bn_scale,
  680. const TensorLayout& bn_bias, const TensorLayout& mean,
  681. const TensorLayout& variance,
  682. const TensorLayout& batch_mean,
  683. const TensorLayout& batch_inv_variance,
  684. const TensorLayout& dst, size_t workspace_in_bytes);
  685. };
  686. using BN = BNForward;
  687. class BNBackward : public BNBase {
  688. DEF_OPR_IMPL(BNBackward, BNBase, 5, 3);
  689. public:
  690. /**
  691. * \param[in] input data of forwarding propagate.
  692. * \param[in] dy the backpropagated gradient of y.
  693. * \param[out] dx the backpropagated gradient of x.
  694. * \param[out] d_bn_scale, the backpropagated gradient of bn_scale.
  695. * \param[out] d_bn_bias, the backpropagated gradient of bn_bias.
  696. * Optionally cached intermediate results from forward pass
  697. * \param[in] saved_batch_mean mean of the input batch.
  698. Calculated in the forwardpropagation.
  699. * \param[in] saved_batch_variance of the input batch.
  700. Calculated in the forwardpropagation.
  701. */
  702. virtual void exec(_megdnn_tensor_in x, _megdnn_tensor_in dy,
  703. _megdnn_tensor_in saved_batch_mean,
  704. _megdnn_tensor_in saved_batch_variance,
  705. _megdnn_tensor_in bn_scale, _megdnn_tensor_out d_bn_scale,
  706. _megdnn_tensor_out d_bn_bias, _megdnn_tensor_out dx,
  707. _megdnn_workspace workspace) = 0;
  708. virtual size_t get_workspace_in_bytes(
  709. const TensorLayout& x, const TensorLayout& dy,
  710. const TensorLayout& saved_batch_mean,
  711. const TensorLayout& saved_batch_variance,
  712. const TensorLayout& bn_scale, const TensorLayout& d_bn_scale,
  713. const TensorLayout& d_bn_bias, const TensorLayout& dx) = 0;
  714. protected:
  715. void check_exec(const TensorLayout& x, const TensorLayout& dy,
  716. const TensorLayout& saved_batch_mean,
  717. const TensorLayout& saved_batch_variance,
  718. const TensorLayout& bn_scale,
  719. const TensorLayout& d_bn_scale,
  720. const TensorLayout& d_bn_bias, const TensorLayout& dx,
  721. size_t workspace_in_bytes);
  722. };
  723. class LRNBase : public OperatorBase {
  724. DEF_OPR_IMPL_CTOR(LRNBase, OperatorBase);
  725. DEF_OPR_PARAM(LRN);
  726. protected:
  727. void check_param();
  728. };
  729. class LRNForward : public LRNBase {
  730. DEF_OPR_IMPL(LRNForward, LRNBase, 1, 1);
  731. public:
  732. /**
  733. * \see ImageNet Classification with Deep Convolutional Neural Networks
  734. * \param[in] src (n, c, h, w)
  735. * \param[out] dst (n, c, h, w)
  736. *
  737. * src and dst must have the same shape.
  738. * src and dst must be contiguous.
  739. */
  740. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
  741. _megdnn_workspace workspace) = 0;
  742. void deduce_layout(const TensorLayout& src, TensorLayout& dst);
  743. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  744. const TensorLayout& dst) = 0;
  745. protected:
  746. void check_exec(const TensorLayout& src, const TensorLayout& dst,
  747. size_t workspace_in_bytes);
  748. };
  749. using LRN = LRNForward;
  750. class LRNBackward : public LRNBase {
  751. DEF_OPR_IMPL(LRNBackward, LRNBase, 3, 1);
  752. public:
  753. /**
  754. * \param[in] src the `src' parameter in LRNForward::exec
  755. * \param[in] dst the `dst' parameter in LRNForward::exec
  756. * \param[in] diff the backpropagated gradient wrt. dst
  757. * \param[out] grad the backpropagated gradient wrt. src
  758. *
  759. * All tensors should be contiguous and of the same shape.
  760. */
  761. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in dst,
  762. _megdnn_tensor_in diff, _megdnn_tensor_out grad,
  763. _megdnn_workspace workspace) = 0;
  764. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  765. const TensorLayout& dst,
  766. const TensorLayout& diff,
  767. const TensorLayout& grad) = 0;
  768. protected:
  769. void check_exec(const TensorLayout& src, const TensorLayout& dst,
  770. const TensorLayout& diff, const TensorLayout& grad,
  771. size_t workspace_in_bytes);
  772. };
  773. class ROIPoolingBase : public OperatorBase {
  774. DEF_OPR_IMPL_CTOR(ROIPoolingBase, OperatorBase);
  775. DEF_OPR_PARAM(ROIPooling);
  776. protected:
  777. void check_layout_fwd(const TensorLayout& src, const TensorLayout& rois,
  778. const TensorLayout& dst, const TensorLayout& index);
  779. };
  780. class ROIPoolingForward : public ROIPoolingBase {
  781. DEF_OPR_IMPL(ROIPoolingForward, ROIPoolingBase, 2, 2);
  782. public:
  783. /**
  784. * \param[in] src (n, c, ih, iw)
  785. * \param[in] rois (m, 5)
  786. * \param[out] dst (m, c, oh, ow)
  787. * \param[out] index (m, c, oh, ow) if mode is MAX, (0) if mode is AVERAGE
  788. *
  789. * The internal implementation is akin to
  790. * https://github.com/rbgirshick/caffe-fast-rcnn .d
  791. * Note that rois(, 0) denotes the input image index. We store it as
  792. * a float, but it should be an integer instead.
  793. *
  794. * index is a temporary tensor to facilitate its backward operator.
  795. * It is used to store argmax indicex in MAX mode, and it is not used
  796. * in AVERAGE mode.
  797. */
  798. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in rois,
  799. _megdnn_tensor_out dst, _megdnn_tensor_out index,
  800. _megdnn_workspace workspace) = 0;
  801. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  802. const TensorLayout& rois,
  803. const TensorLayout& dst,
  804. const TensorLayout& index) = 0;
  805. protected:
  806. void check_exec(const TensorLayout& src, const TensorLayout& rois,
  807. const TensorLayout& dst, const TensorLayout& index,
  808. size_t workspace_in_bytes);
  809. };
  810. using ROIPooling = ROIPoolingForward;
  811. class ROIPoolingBackward : public ROIPoolingBase {
  812. DEF_OPR_IMPL(ROIPoolingBackward, ROIPoolingBase, 4, 1);
  813. public:
  814. /**
  815. * \param[in] diff the backpropagated gradient wrt. dst
  816. * \param[in] src the `src' parameter in ROIPoolingForward::exec
  817. * \param[in] rois the `rois' parameter in ROIPoolingForward::exec
  818. * \param[in] index the `index' parameter in ROIPoolingForward::exec
  819. * \param[out] grad the backpropagated gradient wrt. src
  820. */
  821. virtual void exec(_megdnn_tensor_in diff, _megdnn_tensor_in src,
  822. _megdnn_tensor_in rois, _megdnn_tensor_in index,
  823. _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
  824. virtual size_t get_workspace_in_bytes(const TensorLayout& diff,
  825. const TensorLayout& src,
  826. const TensorLayout& rois,
  827. const TensorLayout& index,
  828. const TensorLayout& grad) = 0;
  829. protected:
  830. void check_exec(const TensorLayout& diff, const TensorLayout& src,
  831. const TensorLayout& rois, const TensorLayout& index,
  832. const TensorLayout& grad, size_t workspace_in_bytes);
  833. };
  834. class Convolution3DBase : public OperatorBase {
  835. DEF_OPR_IMPL_CTOR(Convolution3DBase, OperatorBase);
  836. DEF_OPR_PARAM(Convolution3D);
  837. public:
  838. static constexpr size_t MAX_SPATIAL_DIM = 3;
  839. using Mode = Param::Mode;
  840. struct CanonizedFilterMeta {
  841. DTypeEnum dtype_enum;
  842. Param::Format format;
  843. uint32_t
  844. //! whether filter should be flipped (i.e. is CONVOLUTION)
  845. should_flip,
  846. group, //!< number of groups
  847. icpg, //!< input channels per group
  848. ocpg, //!< output channels per group
  849. spatial_ndim, stride[MAX_SPATIAL_DIM], padding[MAX_SPATIAL_DIM],
  850. //! spatial dim
  851. spatial[MAX_SPATIAL_DIM], dilation[MAX_SPATIAL_DIM],
  852. //! spatial dim with dilation applied
  853. dilated_spatial[MAX_SPATIAL_DIM];
  854. } MEGDNN_PACKED;
  855. protected:
  856. CanonizedFilterMeta deduce_layout_fwd(const TensorLayout& src,
  857. const TensorLayout& filter,
  858. TensorLayout& dst) const;
  859. CanonizedFilterMeta check_layout_fwd(const TensorLayout& src,
  860. const TensorLayout& filter,
  861. const TensorLayout& dst) const;
  862. CanonizedFilterMeta make_canonized_filter_meta(
  863. size_t src_ndim, const TensorLayout& filter) const;
  864. };
  865. class Convolution3DForward
  866. : public Convolution3DBase,
  867. public detail::MultiAlgoOpr<Convolution3DForward, 3> {
  868. DEF_OPR_IMPL(Convolution3DForward, Convolution3DBase, 2, 1);
  869. public:
  870. /**
  871. * \param[in] src (n, ic, id, ih, iw)
  872. * \param[in] filter (oc, ic, fd, fh, fw)
  873. * \param[out] dst (n, oc, od, oh, ow)
  874. */
  875. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter,
  876. _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0;
  877. void deduce_layout(const TensorLayout& src, const TensorLayout& filter,
  878. TensorLayout& dst);
  879. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  880. const TensorLayout& filter,
  881. const TensorLayout& dst) = 0;
  882. protected:
  883. CanonizedFilterMeta check_exec(const TensorLayout& src,
  884. const TensorLayout& filter,
  885. const TensorLayout& dst,
  886. size_t workspace_in_bytes);
  887. };
  888. using Convolution3D = Convolution3DForward;
  889. class Convolution3DBackwardData
  890. : public Convolution3DBase,
  891. public detail::MultiAlgoOpr<Convolution3DBackwardData, 3> {
  892. DEF_OPR_IMPL(Convolution3DBackwardData, Convolution3DBase, 2, 1);
  893. public:
  894. /**
  895. * \param[in] filter (oc, ic, fd, fh, fw)
  896. * \param[in] diff (n, oc, od, oh, ow)
  897. * \param[out] grad (n, ic, id, ih, iw)
  898. */
  899. virtual void exec(_megdnn_tensor_in filter, _megdnn_tensor_in diff,
  900. _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
  901. virtual size_t get_workspace_in_bytes(const TensorLayout& filter,
  902. const TensorLayout& diff,
  903. const TensorLayout& grad) = 0;
  904. void deduce_layout(const TensorLayout& filter, const TensorLayout& diff,
  905. TensorLayout& grad);
  906. protected:
  907. CanonizedFilterMeta check_exec(const TensorLayout& filter,
  908. const TensorLayout& diff,
  909. const TensorLayout& grad,
  910. size_t workspace_in_bytes);
  911. };
  912. class Convolution3DBackwardFilter
  913. : public Convolution3DBase,
  914. public detail::MultiAlgoOpr<Convolution3DBackwardFilter, 3> {
  915. DEF_OPR_IMPL(Convolution3DBackwardFilter, Convolution3DBase, 2, 1);
  916. public:
  917. /**
  918. * \param[in] src (n, ic, id, ih, iw)
  919. * \param[in] diff (n, oc, od, oh, ow)
  920. * \param[out] grad (oc, ic, fd, fh, fw)
  921. */
  922. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in diff,
  923. _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
  924. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  925. const TensorLayout& diff,
  926. const TensorLayout& grad) = 0;
  927. protected:
  928. CanonizedFilterMeta check_exec(const TensorLayout& src,
  929. const TensorLayout& diff,
  930. const TensorLayout& grad,
  931. size_t workspace_in_bytes);
  932. };
  933. class LocalShareBase : public OperatorBase {
  934. DEF_OPR_IMPL_CTOR(LocalShareBase, OperatorBase);
  935. DEF_OPR_PARAM(LocalShare);
  936. protected:
  937. void deduce_layout_fwd(const TensorLayout& src, const TensorLayout& filter,
  938. TensorLayout& dst);
  939. void check_layout_fwd(const TensorLayout& src, const TensorLayout& filter,
  940. const TensorLayout& dst);
  941. };
  942. class LocalShareForward : public LocalShareBase,
  943. public detail::MultiAlgoOpr<LocalShareForward, 3> {
  944. DEF_OPR_IMPL(LocalShareForward, LocalShareBase, 2, 1);
  945. public:
  946. /**
  947. * \param[in] src (N, IC, IH, IW)
  948. * \param[in] filter (G, spatial_groups_h, spatial_groups_w, IC / G,
  949. * FH, FW, OC / G)
  950. * \param[out] dst (N, OC, OH, OW)
  951. */
  952. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter,
  953. _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0;
  954. /**
  955. * \brief deduce layout of the ouput tensor
  956. */
  957. void deduce_layout(const TensorLayout& src, const TensorLayout& filter,
  958. TensorLayout& dst);
  959. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  960. const TensorLayout& filter,
  961. const TensorLayout& dst) = 0;
  962. protected:
  963. void check_exec(const TensorLayout& src, const TensorLayout& filter,
  964. const TensorLayout& dst, size_t workspace_in_bytes);
  965. };
  966. using LocalShare = LocalShareForward;
  967. class LocalShareBackwardData
  968. : public LocalShareBase,
  969. public detail::MultiAlgoOpr<LocalShareBackwardData, 3> {
  970. DEF_OPR_IMPL(LocalShareBackwardData, LocalShareBase, 2, 1);
  971. public:
  972. /**
  973. * \param[in] filter (G, spatial_groups_h, spatial_groups_w, IC / G,
  974. * FH, FW, OC / G)
  975. * \param[in] diff (N, OC, OH, OW)
  976. * \param[out] grad (N, IC, IH, IW)
  977. */
  978. virtual void exec(_megdnn_tensor_in filter, _megdnn_tensor_in diff,
  979. _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
  980. virtual size_t get_workspace_in_bytes(const TensorLayout& filter,
  981. const TensorLayout& diff,
  982. const TensorLayout& grad) = 0;
  983. void deduce_layout(const TensorLayout& filter, const TensorLayout& diff,
  984. TensorLayout& grad);
  985. protected:
  986. void check_exec(const TensorLayout& filter, const TensorLayout& diff,
  987. const TensorLayout& grad, size_t workspace_in_bytes);
  988. };
  989. class LocalShareBackwardFilter
  990. : public LocalShareBase,
  991. public detail::MultiAlgoOpr<LocalShareBackwardFilter, 3> {
  992. DEF_OPR_IMPL(LocalShareBackwardFilter, LocalShareBase, 2, 1);
  993. public:
  994. /**
  995. * \param[in] src (N, IC, IH, IW)
  996. * \param[in] diff (N, OC, OH, OW)
  997. * \param[out] grad (G, spatial_groups_h, spatial_groups_w, IC / G,
  998. * FH, FW, OC / G)
  999. */
  1000. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in diff,
  1001. _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
  1002. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  1003. const TensorLayout& diff,
  1004. const TensorLayout& grad) = 0;
  1005. protected:
  1006. void check_exec(const TensorLayout& src, const TensorLayout& diff,
  1007. const TensorLayout& grad, size_t workspace_in_bytes);
  1008. };
  1009. class ROIAlignBase : public OperatorBase {
  1010. DEF_OPR_IMPL_CTOR(ROIAlignBase, OperatorBase);
  1011. DEF_OPR_PARAM(ROIAlign);
  1012. protected:
  1013. void deduce_layout_fwd(const TensorLayout& src, const TensorLayout& rois,
  1014. TensorLayout& dst, TensorLayout& index);
  1015. void check_layout_fwd(const TensorLayout& src, const TensorLayout& rois,
  1016. const TensorLayout& dst, const TensorLayout& index);
  1017. };
  1018. class ROIAlignForward : public ROIAlignBase {
  1019. DEF_OPR_IMPL(ROIAlignForward, ROIAlignBase, 2, 2);
  1020. public:
  1021. /**
  1022. * \param[in] src (n, c, ih, iw)
  1023. * \param[in] rois (m, 5)
  1024. * \param[out] dst (m, c, oh, ow)
  1025. * \param[out] index (m, c, oh, ow) if mode is MAX, (0) if mode is AVERAGE
  1026. *
  1027. * Note that rois(, 0) denotes the input image index. We store it as
  1028. * a float, but it should be an integer instead.
  1029. *
  1030. * index is a temporary tensor to facilitate its backward operator.
  1031. * It is used to store argmax indicex in MAX mode, and it is not used
  1032. * in AVERAGE mode.
  1033. */
  1034. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in rois,
  1035. _megdnn_tensor_out dst, _megdnn_tensor_out index,
  1036. _megdnn_workspace workspace) = 0;
  1037. void deduce_layout(const TensorLayout& src, const TensorLayout& rois,
  1038. TensorLayout& dst, TensorLayout& index);
  1039. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  1040. const TensorLayout& rois,
  1041. const TensorLayout& dst,
  1042. const TensorLayout& index) = 0;
  1043. protected:
  1044. void check_exec(const TensorLayout& src, const TensorLayout& rois,
  1045. const TensorLayout& dst, const TensorLayout& index,
  1046. size_t workspace_in_bytes);
  1047. };
  1048. using ROIAlign = ROIAlignForward;
  1049. class ROIAlignBackward : public ROIAlignBase {
  1050. DEF_OPR_IMPL(ROIAlignBackward, ROIAlignBase, 3, 1);
  1051. public:
  1052. /**
  1053. * \param[in] diff the backpropagated gradient wrt. dst
  1054. * \param[in] rois the `rois' parameter in ROIAlignForward::exec
  1055. * \param[in] index the `index' parameter in ROIAlignForward::exec
  1056. * \param[out] grad the backpropagated gradient wrt. src
  1057. */
  1058. virtual void exec(_megdnn_tensor_in diff, _megdnn_tensor_in rois,
  1059. _megdnn_tensor_in index, _megdnn_tensor_out grad,
  1060. _megdnn_workspace workspace) = 0;
  1061. virtual size_t get_workspace_in_bytes(const TensorLayout& diff,
  1062. const TensorLayout& rois,
  1063. const TensorLayout& index,
  1064. const TensorLayout& grad) = 0;
  1065. protected:
  1066. void check_exec(const TensorLayout& diff, const TensorLayout& rois,
  1067. const TensorLayout& index, const TensorLayout& grad,
  1068. size_t workspace_in_bytes);
  1069. };
  1070. class DeformableConvBase : public OperatorBase {
  1071. DEF_OPR_IMPL_CTOR(DeformableConvBase, OperatorBase);
  1072. DEF_OPR_PARAM(Convolution);
  1073. public:
  1074. static constexpr size_t MAX_SPATIAL_DIM = 2;
  1075. struct CanonizedFilterMeta : Convolution::CanonizedFilterMeta {
  1076. uint32_t deformable_group;
  1077. };
  1078. protected:
  1079. CanonizedFilterMeta make_canonized_filter_meta(
  1080. size_t src_ndim, const TensorLayout& filter,
  1081. const TensorLayout& offset) const;
  1082. void deduce_layout_fwd(const TensorLayout& im, const TensorLayout& filter,
  1083. const TensorLayout& mask, const TensorLayout& offset,
  1084. TensorLayout& dst);
  1085. void check_layout_fwd(const TensorLayout& src, const TensorLayout& filter,
  1086. const TensorLayout& mask, const TensorLayout& offset,
  1087. const TensorLayout& dst);
  1088. };
  1089. class DeformableConvForward
  1090. : public DeformableConvBase,
  1091. public detail::MultiAlgoOpr<DeformableConvForward, 5> {
  1092. DEF_OPR_IMPL(DeformableConvForward, DeformableConvBase, 4, 1);
  1093. public:
  1094. /**
  1095. * \param[in] im (n, ic, ih, iw)
  1096. * \param[in] filter (oc, ic, fh, fw)
  1097. * \param[in] offset (dg, 2, fh, fw, oh, ow)
  1098. * \param[in] mask (dg, fh, fw, oh, ow)
  1099. * \param[out] dst (n, oc, oh, ow)
  1100. */
  1101. virtual void exec(_megdnn_tensor_in im, _megdnn_tensor_in filter,
  1102. _megdnn_tensor_in offset, _megdnn_tensor_in mask,
  1103. _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0;
  1104. void deduce_layout(const TensorLayout& im, const TensorLayout& filter,
  1105. const TensorLayout& offset, const TensorLayout& mask,
  1106. TensorLayout& dst);
  1107. virtual size_t get_workspace_in_bytes(const TensorLayout& im,
  1108. const TensorLayout& filter,
  1109. const TensorLayout& offset,
  1110. const TensorLayout& mask,
  1111. const TensorLayout& dst) = 0;
  1112. protected:
  1113. CanonizedFilterMeta check_exec(const TensorLayout& im,
  1114. const TensorLayout& filter,
  1115. const TensorLayout& offset,
  1116. const TensorLayout& mask,
  1117. const TensorLayout& dst,
  1118. size_t workspace_in_bytes);
  1119. };
  1120. using DeformableConv = DeformableConvForward;
  1121. /**
  1122. * \brief DeformableConvBackwardFilter operator.
  1123. *
  1124. * Calculating the gradient wrt. convolution filter.
  1125. */
  1126. class DeformableConvBackwardFilter
  1127. : public DeformableConvBase,
  1128. public detail::MultiAlgoOpr<DeformableConvBackwardFilter, 5> {
  1129. DEF_OPR_IMPL(DeformableConvBackwardFilter, DeformableConvBase, 4, 1);
  1130. public:
  1131. /**
  1132. * \param[in] im (oc, ic, fh, fw)
  1133. * \param[in] offset (dg, 2, fh, fw, oh, ow)
  1134. * \param[in] mask (dg, fh, fw, oh, ow)
  1135. * \param[in] out_grad (n, oc, oh, ow)
  1136. * \param[out] filter_grad (oc, ic, ih, iw)
  1137. */
  1138. virtual void exec(_megdnn_tensor_in im, _megdnn_tensor_in offset,
  1139. _megdnn_tensor_in mask, _megdnn_tensor_in out_grad,
  1140. _megdnn_tensor_out filter_grad,
  1141. _megdnn_workspace workspace) = 0;
  1142. virtual size_t get_workspace_in_bytes(const TensorLayout& im,
  1143. const TensorLayout& offset,
  1144. const TensorLayout& mask,
  1145. const TensorLayout& out_grad,
  1146. const TensorLayout& filter_grad) = 0;
  1147. void deduce_layout(const TensorLayout& im, const TensorLayout& offset,
  1148. const TensorLayout& mask, const TensorLayout& out_grad,
  1149. TensorLayout& filter_grad);
  1150. protected:
  1151. CanonizedFilterMeta check_exec(const TensorLayout& im,
  1152. const TensorLayout& offset,
  1153. const TensorLayout& mask,
  1154. const TensorLayout& out_grad,
  1155. const TensorLayout& filter_grad,
  1156. size_t workspace_in_bytes);
  1157. };
  1158. /**
  1159. * \brief DeformableConvBackwardData operator.
  1160. *
  1161. * Calculating the gradient wrt. convolution input data, offset and mask.
  1162. */
  1163. class DeformableConvBackwardData
  1164. : public DeformableConvBase,
  1165. public detail::MultiAlgoOpr<DeformableConvBackwardData, 8> {
  1166. DEF_OPR_IMPL(DeformableConvBackwardData, DeformableConvBase, 5, 3);
  1167. public:
  1168. /**
  1169. * \param[in] im (oc, ic, fh, fw)
  1170. * \param[in] filter (oc, ic, fh, fw)
  1171. * \param[in] offset (dg, 2, fh, fw, oh, ow)
  1172. * \param[in] mask (dg, fh, fw, oh, ow)
  1173. * \param[in] out_grad (n, oc, oh, ow)
  1174. * \param[out] im_grad (n, ic, ih, iw)
  1175. * \param[out] offset_grad (dg, 2, fh, fw, oh, ow)
  1176. * \param[out] mask_grad (dg, fh, fw, oh, ow)
  1177. */
  1178. virtual void exec(_megdnn_tensor_in im, _megdnn_tensor_in filter,
  1179. _megdnn_tensor_in offset, _megdnn_tensor_in mask,
  1180. _megdnn_tensor_in out_grad, _megdnn_tensor_out im_grad,
  1181. _megdnn_tensor_out offset_grad,
  1182. _megdnn_tensor_out mask_grad,
  1183. _megdnn_workspace workspace) = 0;
  1184. virtual size_t get_workspace_in_bytes(
  1185. const TensorLayout& im, const TensorLayout& filter,
  1186. const TensorLayout& offset, const TensorLayout& mask,
  1187. const TensorLayout& out_grad, const TensorLayout& im_grad,
  1188. const TensorLayout& offset_grad, const TensorLayout& mask_grad) = 0;
  1189. void deduce_layout(const TensorLayout& im, const TensorLayout& filter,
  1190. const TensorLayout& offset, const TensorLayout& mask,
  1191. const TensorLayout& out_grad, TensorLayout& im_grad,
  1192. TensorLayout& offset_grad, TensorLayout& mask_grad);
  1193. protected:
  1194. CanonizedFilterMeta check_exec(
  1195. const TensorLayout& im, const TensorLayout& filter,
  1196. const TensorLayout& offset, const TensorLayout& mask,
  1197. const TensorLayout& out_grad, const TensorLayout& im_grad,
  1198. const TensorLayout& offset_grad, const TensorLayout& mask_grad,
  1199. size_t workspace_in_bytes);
  1200. };
  1201. class DeformablePSROIPoolingBase : public OperatorBase {
  1202. DEF_OPR_IMPL_CTOR(DeformablePSROIPoolingBase, OperatorBase);
  1203. DEF_OPR_PARAM(DeformablePSROIPooling);
  1204. protected:
  1205. void deduce_layout_fwd(const TensorLayout& data, const TensorLayout& trans,
  1206. const TensorLayout& rois, TensorLayout& out_data,
  1207. TensorLayout& out_count);
  1208. void check_layout_fwd(const TensorLayout& data, const TensorLayout& trans,
  1209. const TensorLayout& rois,
  1210. const TensorLayout& out_data,
  1211. const TensorLayout& out_count,
  1212. size_t workspace_in_bytes);
  1213. };
  1214. class DeformablePSROIPoolingForward : public DeformablePSROIPoolingBase {
  1215. DEF_OPR_IMPL(DeformablePSROIPoolingForward, DeformablePSROIPoolingBase, 3,
  1216. 2);
  1217. public:
  1218. /**
  1219. * \param[in] data (oc, ic, ih, iw)
  1220. * \param[in] rois (xx, xx, xx, xx)
  1221. * \param[in] trans (oc, ic, fh, fw)
  1222. * \param[out] out_data ( n, ic, ih, iw)
  1223. * \param[out] out_count ( n, ic, ih, iw)
  1224. */
  1225. virtual size_t get_workspace_in_bytes(const TensorLayout& data,
  1226. const TensorLayout& rois,
  1227. const TensorLayout& trans,
  1228. const TensorLayout& out_data,
  1229. const TensorLayout& out_count) = 0;
  1230. virtual void exec(_megdnn_tensor_in data, _megdnn_tensor_in rois,
  1231. _megdnn_tensor_in trans, _megdnn_tensor_out out_data,
  1232. _megdnn_tensor_out out_count,
  1233. _megdnn_workspace workspace) = 0;
  1234. void deduce_layout(const TensorLayout& data, const TensorLayout& rois,
  1235. const TensorLayout& trans, TensorLayout& out_data,
  1236. TensorLayout& out_count);
  1237. void check_exec(const TensorLayout& data, const TensorLayout& rois,
  1238. const TensorLayout& trans, const TensorLayout& out_data,
  1239. const TensorLayout& out_count, size_t workspace_in_bytes);
  1240. };
  1241. using DeformablePSROIPooling = DeformablePSROIPoolingForward;
  1242. class DeformablePSROIPoolingBackward : public DeformablePSROIPoolingBase {
  1243. DEF_OPR_IMPL(DeformablePSROIPoolingBackward, DeformablePSROIPoolingBase, 5,
  1244. 2);
  1245. public:
  1246. /**
  1247. * \param[in] data (oc, ic, ih, iw)
  1248. * \param[in] rois (xx, xx, xx, xx)
  1249. * \param[in] trans (oc, ic, fh, fw)
  1250. * \param[in] out_diff (xx, xx, xx, xx)
  1251. * \param[in] out_count (xx, xx, xx, xx)
  1252. * \param[out] data_diff ( n, ic, ih, iw)
  1253. * \param[out] trans_diff ( n, ic, ih, iw)
  1254. */
  1255. virtual void exec(_megdnn_tensor_in data, _megdnn_tensor_in rois,
  1256. _megdnn_tensor_in trans, _megdnn_tensor_in out_diff,
  1257. _megdnn_tensor_in out_count, _megdnn_tensor_out data_diff,
  1258. _megdnn_tensor_out trans_diff,
  1259. _megdnn_workspace workspace) = 0;
  1260. virtual size_t get_workspace_in_bytes(const TensorLayout& data,
  1261. const TensorLayout& rois,
  1262. const TensorLayout& trans,
  1263. const TensorLayout& out_diff,
  1264. const TensorLayout& out_count,
  1265. const TensorLayout& data_diff,
  1266. const TensorLayout& trans_diff) = 0;
  1267. void check_exec(const TensorLayout& data, const TensorLayout& rois,
  1268. const TensorLayout& trans, const TensorLayout& out_diff,
  1269. const TensorLayout& out_count,
  1270. const TensorLayout& data_diff,
  1271. const TensorLayout& trans_diff, size_t workspace_in_bytes);
  1272. };
  1273. class BatchConvBiasForward
  1274. : public ConvolutionBase<param::BatchConvBias>,
  1275. public detail::MultiAlgoOpr<BatchConvBiasForward, 5> {
  1276. DEF_OPR_IMPL(BatchConvBiasForward, ConvolutionBase, 4, 1);
  1277. public:
  1278. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter,
  1279. _megdnn_tensor_in bias, _megdnn_tensor_in z,
  1280. _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0;
  1281. void deduce_dtype(DType src, DType filter, DType bias, DType z, DType& dst);
  1282. void deduce_layout(const TensorLayout& src, const TensorLayout& filter,
  1283. const TensorLayout& bias, const TensorLayout& z,
  1284. TensorLayout& dst);
  1285. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  1286. const TensorLayout& filter,
  1287. const TensorLayout& bias,
  1288. const TensorLayout& z,
  1289. const TensorLayout& dst) = 0;
  1290. protected:
  1291. CanonizedFilterMeta check_exec(const TensorLayout& src,
  1292. const TensorLayout& filter,
  1293. const TensorLayout& bias,
  1294. const TensorLayout& z,
  1295. const TensorLayout& dst,
  1296. size_t workspace_in_bytes);
  1297. };
  1298. using BatchConvBias = BatchConvBiasForward;
  1299. } // namespace megdnn
  1300. #include "megdnn/internal/opr_header_epilogue.h"
  1301. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台