You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

nn.h 59 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443
  1. /**
  2. * \file dnn/include/megdnn/oprs/nn.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #pragma once
  12. #include "megdnn/internal/opr_header_prologue.h"
  13. namespace megdnn {
  14. class SeparableConvBase : public OperatorBase {
  15. DEF_OPR_IMPL_CTOR(SeparableConvBase, OperatorBase);
  16. DEF_OPR_PARAM(SeparableConv);
  17. public:
  18. using Mode = Param::Mode;
  19. protected:
  20. void deduce_layout_fwd(const TensorLayout& src,
  21. const TensorLayout& filter_x,
  22. const TensorLayout& filter_y, TensorLayout& dst);
  23. void check_layout_fwd(const TensorLayout& src, const TensorLayout& filter_x,
  24. const TensorLayout& filter_y,
  25. const TensorLayout& dst);
  26. };
  27. class SeparableConvForward : public SeparableConvBase {
  28. DEF_OPR_IMPL(SeparableConvForward, SeparableConvBase, 3, 1);
  29. public:
  30. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter_x,
  31. _megdnn_tensor_in filter_y, _megdnn_tensor_out dst,
  32. _megdnn_workspace workspace) = 0;
  33. void deduce_layout(const TensorLayout& src, const TensorLayout& filter_x,
  34. const TensorLayout& filter_y, TensorLayout& dst);
  35. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  36. const TensorLayout& filter_x,
  37. const TensorLayout& filter_y,
  38. const TensorLayout& dst) = 0;
  39. protected:
  40. void check_exec(const TensorLayout& src, const TensorLayout& filter_x,
  41. const TensorLayout& filter_y, const TensorLayout& dst,
  42. size_t workspace_in_bytes);
  43. };
  44. using SeparableConv = SeparableConvForward;
  45. /**
  46. * \brief base class for convolution operation
  47. *
  48. * This operator is supposed to perform convolution on arbitrary input
  49. * dimensions. The input/output format is N, C, dims..., and kernel format can
  50. * take two forms:
  51. * 1. OC, IC, dims..., for conventional dense convolution
  52. * 2. GROUP, OC_PER_GRP, IC_PER_GRP, dims... for sparse group convolution
  53. *
  54. * Currently, only 2D images are supported.
  55. */
  56. template <typename Parameter>
  57. class ConvolutionBase : public OperatorBase {
  58. DEF_OPR_IMPL_CTOR(ConvolutionBase, OperatorBase);
  59. using Param = Parameter;
  60. public:
  61. Param& param() { return m_param; }
  62. const Param& param() const { return m_param; }
  63. protected:
  64. Param m_param;
  65. public:
  66. static constexpr size_t MAX_SPATIAL_DIM = 2;
  67. using Mode = typename Param::Mode;
  68. struct CanonizedFilterMeta {
  69. DType dtype;
  70. typename Param::Format format;
  71. uint32_t
  72. //! whether filter should be flipped (i.e. is CONVOLUTION)
  73. should_flip,
  74. group, //!< number of groups
  75. icpg, //!< input channels per group
  76. ocpg, //!< output channels per group
  77. spatial_ndim, stride[MAX_SPATIAL_DIM], padding[MAX_SPATIAL_DIM],
  78. //! spatial dim
  79. spatial[MAX_SPATIAL_DIM], dilation[MAX_SPATIAL_DIM],
  80. //! spatial dim with dilation applied
  81. dilated_spatial[MAX_SPATIAL_DIM];
  82. //! T should be a ConvolutionBase<Z>::CanonizedFilterMeta
  83. template <typename T>
  84. void copy_from(const T& b) {
  85. dtype = b.dtype;
  86. format = b.format;
  87. should_flip = b.should_flip;
  88. group = b.group;
  89. icpg = b.icpg;
  90. ocpg = b.ocpg;
  91. spatial_ndim = b.spatial_ndim;
  92. memcpy(stride, b.stride, sizeof(stride));
  93. memcpy(padding, b.padding, sizeof(padding));
  94. memcpy(spatial, b.spatial, sizeof(spatial));
  95. memcpy(dilation, b.dilation, sizeof(dilation));
  96. memcpy(dilated_spatial, b.dilated_spatial, sizeof(dilated_spatial));
  97. }
  98. bool operator==(const CanonizedFilterMeta& b) const {
  99. bool flag = true;
  100. flag = flag && (format == b.format);
  101. flag = flag && (dtype == b.dtype);
  102. flag = flag && (should_flip == b.should_flip);
  103. flag = flag && (group == b.group);
  104. flag = flag && (icpg == b.icpg);
  105. flag = flag && (ocpg == b.ocpg);
  106. flag = flag && (spatial_ndim == b.spatial_ndim);
  107. if (flag) {
  108. for (uint32_t i = 0; i < spatial_ndim; ++i) {
  109. flag = flag && (stride[i] == b.stride[i]);
  110. flag = flag && (padding[i] == b.padding[i]);
  111. flag = flag && (spatial[i] == b.spatial[i]);
  112. flag = flag && (dilation[i] == b.dilation[i]);
  113. flag = flag && (dilated_spatial[i] == b.dilated_spatial[i]);
  114. }
  115. }
  116. return flag;
  117. }
  118. };
  119. protected:
  120. // Check or deduce output DType
  121. void check_or_deduce_dtype_fwd(DType src, DType filter, DType& dst) const;
  122. CanonizedFilterMeta deduce_layout_fwd(const TensorLayout& src,
  123. const TensorLayout& filter,
  124. TensorLayout& dst) const;
  125. CanonizedFilterMeta check_layout_fwd(const TensorLayout& src,
  126. const TensorLayout& filter,
  127. const TensorLayout& dst) const;
  128. CanonizedFilterMeta make_canonized_filter_meta(
  129. size_t src_ndim, const TensorLayout& filter) const;
  130. };
  131. class MaskPropagate : public OperatorBase {
  132. DEF_OPR_IMPL(MaskPropagate, OperatorBase, 1, 1);
  133. DEF_OPR_PARAM(MaskPropagate);
  134. public:
  135. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
  136. _megdnn_workspace workspace) = 0;
  137. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  138. const TensorLayout& dst) = 0;
  139. void deduce_layout(const TensorLayout& src, TensorLayout& dst);
  140. };
  141. /**
  142. * \brief ConvolutionForward Operator with 0/1 Mask matrix
  143. */
  144. class MaskConvForward : public ConvolutionBase<param::Convolution> {
  145. DEF_OPR_IMPL(MaskConvForward, ConvolutionBase, 3, 1);
  146. public:
  147. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter,
  148. _megdnn_tensor_in mask, _megdnn_tensor_out dst,
  149. _megdnn_workspace worksapce) = 0;
  150. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  151. const TensorLayout& filter,
  152. const TensorLayout& mask,
  153. const TensorLayout& dst) = 0;
  154. void deduce_dtype(DType src, DType filter, DType mask, DType& dst);
  155. void deduce_layout(const TensorLayout& src, const TensorLayout& filter,
  156. const TensorLayout& mask, TensorLayout& dst);
  157. protected:
  158. CanonizedFilterMeta check_exec(const TensorLayout& src,
  159. const TensorLayout& filter,
  160. const TensorLayout& mask,
  161. const TensorLayout& dst,
  162. size_t workspace_in_bytes);
  163. };
  164. using MaskConvolution = MaskConvForward;
  165. /**
  166. * \brief ConvolutionForward operator.
  167. */
  168. class ConvolutionForward : public ConvolutionBase<param::Convolution>,
  169. public detail::MultiAlgoOpr<ConvolutionForward, 3> {
  170. DEF_OPR_IMPL(ConvolutionForward, ConvolutionBase, 2, 1);
  171. public:
  172. /**
  173. * \param[in] src (n, ic, ih, iw)
  174. * \param[in] filter (oc, ic, fh, fw)
  175. * \param[out] dst (n, oc, oh, ow)
  176. */
  177. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter,
  178. _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0;
  179. void deduce_dtype(DType src, DType filter, DType& dst);
  180. void deduce_layout(const TensorLayout& src, const TensorLayout& filter,
  181. TensorLayout& dst);
  182. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  183. const TensorLayout& filter,
  184. const TensorLayout& dst) = 0;
  185. protected:
  186. CanonizedFilterMeta check_exec(const TensorLayout& src,
  187. const TensorLayout& filter,
  188. const TensorLayout& dst,
  189. size_t workspace_in_bytes);
  190. };
  191. using Convolution = ConvolutionForward;
  192. /**
  193. * \brief ConvolutionBackwardData operator.
  194. *
  195. * Calculating the gradient wrt. convolution input data.
  196. */
  197. class ConvolutionBackwardData
  198. : public ConvolutionBase<param::Convolution>,
  199. public detail::MultiAlgoOpr<ConvolutionBackwardData, 3> {
  200. DEF_OPR_IMPL(ConvolutionBackwardData, ConvolutionBase, 2, 1);
  201. public:
  202. /**
  203. * \param[in] filter (oc, ic, fh, fw)
  204. * \param[in] diff (n, oc, oh, ow)
  205. * \param[out] grad (n, ic, ih, iw)
  206. */
  207. virtual void exec(_megdnn_tensor_in filter, _megdnn_tensor_in diff,
  208. _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
  209. virtual size_t get_workspace_in_bytes(const TensorLayout& filter,
  210. const TensorLayout& diff,
  211. const TensorLayout& grad) = 0;
  212. void deduce_dtype(DType filter, DType diff, DType& grad);
  213. void deduce_layout(const TensorLayout& filter, const TensorLayout& diff,
  214. TensorLayout& grad);
  215. protected:
  216. CanonizedFilterMeta check_exec(const TensorLayout& filter,
  217. const TensorLayout& diff,
  218. const TensorLayout& grad,
  219. size_t workspace_in_bytes);
  220. };
  221. /**
  222. * \brief ConvolutionBackwardFilter operator.
  223. *
  224. * Calculating the gradient wrt. convolution filter.
  225. */
  226. class ConvolutionBackwardFilter
  227. : public ConvolutionBase<param::Convolution>,
  228. public detail::MultiAlgoOpr<ConvolutionBackwardFilter, 3> {
  229. DEF_OPR_IMPL(ConvolutionBackwardFilter, ConvolutionBase, 2, 1);
  230. public:
  231. /**
  232. * \param[in] src (n, ic, ih, iw)
  233. * \param[in] diff (n, oc, oh, ow)
  234. * \param[out] grad (oc, ic, fh, fw)
  235. */
  236. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in diff,
  237. _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
  238. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  239. const TensorLayout& diff,
  240. const TensorLayout& grad) = 0;
  241. protected:
  242. CanonizedFilterMeta check_exec(const TensorLayout& src,
  243. const TensorLayout& diff,
  244. const TensorLayout& grad,
  245. size_t workspace_in_bytes);
  246. };
  247. /**
  248. * \brief ConvolutionBias operator
  249. */
  250. class ConvBiasForward : public ConvolutionBase<param::ConvBias>,
  251. public detail::MultiAlgoOpr<ConvBiasForward, 5> {
  252. DEF_OPR_IMPL(ConvBiasForward, ConvolutionBase, 4, 1);
  253. public:
  254. /**
  255. * \param[in] src (n, ic, ih, iw) or (n, ih, iw, ic)
  256. * \param[in] filter (oc, ic, fh, fw) or (oc, fh, fw, ic) or (oc/4, fh, fw,
  257. * 4*ic) \param[in] bias (1, oc, 1, 1) \param[in] z same as dst \param[out]
  258. * dst (n, oc, oh, ow) or (n, oh, ow, oc)
  259. *
  260. * \note if the format is NCHW_WINOGRAD, the filter layout is (alphah,
  261. * alphaw, oc, ic)
  262. */
  263. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter,
  264. _megdnn_tensor_in bias, _megdnn_tensor_in z,
  265. _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0;
  266. void deduce_dtype(DType src, DType filter, DType bias, DType z, DType& dst);
  267. void deduce_layout(const TensorLayout& src, const TensorLayout& filter,
  268. const TensorLayout& bias, const TensorLayout& z,
  269. TensorLayout& dst);
  270. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  271. const TensorLayout& filter,
  272. const TensorLayout& bias,
  273. const TensorLayout& z,
  274. const TensorLayout& dst) = 0;
  275. enum class BiasMode : uint32_t {
  276. NO_BIAS = 0, //!< no bias
  277. BROADCAST_CHANNEL_BIAS, //!< broadcast channel bias, [1, c, 1, 1]
  278. BIAS //!< [N, C, H, W]
  279. };
  280. //! param for winograd algos.
  281. struct WinogradParam {
  282. uint32_t channel_block_size;
  283. uint32_t output_block_size;
  284. uint32_t tile_size;
  285. bool operator==(const WinogradParam& rhs) const {
  286. return channel_block_size == rhs.channel_block_size &&
  287. output_block_size == rhs.output_block_size &&
  288. tile_size == rhs.tile_size;
  289. }
  290. std::string to_string() const;
  291. };
  292. static constexpr WinogradParam INVALID_WINOGRAD_PARAM = {0, 0, 0};
  293. struct DirectParam {
  294. std::string to_string() const { return ""; }
  295. };
  296. struct MatmulParam {
  297. std::string to_string() const { return ""; }
  298. };
  299. struct DefaultParam {
  300. std::string to_string() const { return ""; }
  301. };
  302. //! get algo name, the format is ParamTrait<T>::category:base:p.to_string()
  303. //! \warning: base must not contain :.
  304. template <typename T>
  305. static std::string algo_name(const std::string& base, const T& p);
  306. /*!
  307. * \brief parse algo_name and get WinogradParam from algo name.
  308. *
  309. * \param algo name string
  310. * \return WinogradParam parsed from algo name, use pattern
  311. * winograd:base:m:tile_size.
  312. *
  313. * \warning: INVALID_WINOGRAD_PARAM returns if the algo_name is not matched.
  314. */
  315. static WinogradParam parse_winograd_name(const std::string& algo_name);
  316. protected:
  317. CanonizedFilterMeta check_exec(const TensorLayout& src,
  318. const TensorLayout& filter,
  319. const TensorLayout& bias,
  320. const TensorLayout& z,
  321. const TensorLayout& dst,
  322. size_t workspace_in_bytes);
  323. };
  324. using ConvBias = ConvBiasForward;
  325. /**
  326. * \brief base class for Conv - Nonline - Pooling
  327. */
  328. class ConvPoolingBase : public OperatorBase {
  329. DEF_OPR_IMPL_CTOR(ConvPoolingBase, OperatorBase);
  330. /**
  331. * \ Param::Method: Two methods to fetch the input data.
  332. * Default methods is WITH_TEXTURE_OBJ.
  333. * If you want to use WITH_SHARED_MEM mode,
  334. * please make sure that the size of
  335. * [ all of the fliter kernels + a channel
  336. * of input data + a channel of output data]
  337. * should be no large than 38KB.
  338. * And the pooling mode should not be "MAX".
  339. */
  340. DEF_OPR_PARAM(ConvPooling);
  341. protected:
  342. virtual void deduce_layout(const TensorLayout& src,
  343. const TensorLayout& filter,
  344. const TensorLayout& bias, TensorLayout& dst) = 0;
  345. virtual void check_layout(const TensorLayout& src,
  346. const TensorLayout& filter,
  347. const TensorLayout& bias, TensorLayout& dst,
  348. size_t workspace_limit_in_bytes) = 0;
  349. };
  350. class ConvPoolingForward : public ConvPoolingBase {
  351. DEF_OPR_IMPL(ConvPoolingForward, ConvPoolingBase, 2, 1);
  352. public:
  353. /**
  354. * \param[in] src input tensor
  355. * \param[out] dst output tensor
  356. */
  357. virtual void exec(const _megdnn_in TensorND src,
  358. const _megdnn_in TensorND filter,
  359. const _megdnn_in TensorND bias, _megdnn_out TensorND dst,
  360. _megdnn_out Workspace workspace) = 0;
  361. virtual void deduce_layout(const TensorLayout& src,
  362. const TensorLayout& filter,
  363. const TensorLayout& bias, TensorLayout& dst) = 0;
  364. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  365. const TensorLayout& filter,
  366. const TensorLayout& bias,
  367. const TensorLayout& dst) = 0;
  368. protected:
  369. virtual void check_layout(const TensorLayout& src,
  370. const TensorLayout& filter,
  371. const TensorLayout& bias, TensorLayout& dst,
  372. size_t workspace_limit_in_bytes) = 0;
  373. };
  374. using ConvPooling = ConvPoolingForward;
  375. class GroupLocalBase : public OperatorBase {
  376. DEF_OPR_IMPL_CTOR(GroupLocalBase, OperatorBase);
  377. DEF_OPR_PARAM(Convolution);
  378. public:
  379. using Mode = Param::Mode;
  380. protected:
  381. void deduce_layout_fwd(const TensorLayout& src, const TensorLayout& filter,
  382. TensorLayout& dst);
  383. void check_layout_fwd(const TensorLayout& src, const TensorLayout& filter,
  384. const TensorLayout& dst);
  385. };
  386. class GroupLocalForward : public GroupLocalBase {
  387. DEF_OPR_IMPL(GroupLocalForward, GroupLocalBase, 2, 1);
  388. public:
  389. /**
  390. * \param[in] src (N, IC, IH, IW)
  391. * \param[in] filter (G, OH, OW, IC/G, FH, FW, OC/G)
  392. * \param[out] dst (N, OC, OH, OW)
  393. **/
  394. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter,
  395. _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0;
  396. void deduce_layout(const TensorLayout& src, const TensorLayout& filter,
  397. TensorLayout& dst) {
  398. deduce_layout_fwd(src, filter, dst);
  399. }
  400. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  401. const TensorLayout& filter,
  402. const TensorLayout& dst) = 0;
  403. protected:
  404. void check_exec(const TensorLayout& src, const TensorLayout& filter,
  405. const TensorLayout& dst, size_t workspace_in_bytes);
  406. };
  407. using GroupLocal = GroupLocalForward;
  408. class GroupLocalBackwardData : public GroupLocalBase {
  409. DEF_OPR_IMPL(GroupLocalBackwardData, GroupLocalBase, 2, 1);
  410. public:
  411. virtual void exec(_megdnn_tensor_in filter, _megdnn_tensor_in diff,
  412. _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
  413. virtual size_t get_workspace_in_bytes(const TensorLayout& filter,
  414. const TensorLayout& diff,
  415. const TensorLayout& grad) = 0;
  416. protected:
  417. void check_exec(const TensorLayout& filter, const TensorLayout& diff,
  418. const TensorLayout& grad, size_t workspace_in_bytes);
  419. };
  420. class GroupLocalBackwardFilter : public GroupLocalBase {
  421. DEF_OPR_IMPL(GroupLocalBackwardFilter, GroupLocalBase, 2, 1);
  422. public:
  423. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in diff,
  424. _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
  425. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  426. const TensorLayout& diff,
  427. const TensorLayout& grad) = 0;
  428. protected:
  429. void check_exec(const TensorLayout& filter, const TensorLayout& diff,
  430. const TensorLayout& grad, size_t workspace_in_bytes);
  431. };
  432. class Images2NeibsBase : public OperatorBase {
  433. DEF_OPR_IMPL_CTOR(Images2NeibsBase, OperatorBase);
  434. DEF_OPR_PARAM(Images2Neibs);
  435. protected:
  436. void deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst);
  437. void check_layout_fwd(const TensorLayout& filter, const TensorLayout& dst);
  438. };
  439. class Images2NeibsForward : public Images2NeibsBase {
  440. DEF_OPR_IMPL(Images2NeibsForward, Images2NeibsBase, 1, 1);
  441. public:
  442. /**
  443. * \param[in] src (N, C, IH, IW)
  444. * \param[out] dst (N, C, OH, OW, window_h, window_w)
  445. *
  446. * \see
  447. * http://deeplearning.net/software/theano/library/tensor/nnet/neighbours.html
  448. *
  449. * \f$ dst_{n, c, oh, ow, wh, ww} = src_{n, c, ih+wh, iw+fw}\f$,
  450. * where \f$ ih=-pad_h+oh*stride_h, iw=-pad_w+ow*stride_w\f$.
  451. */
  452. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
  453. _megdnn_workspace workspace) = 0;
  454. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  455. const TensorLayout& dst) = 0;
  456. void deduce_layout(const TensorLayout& src, TensorLayout& dst);
  457. protected:
  458. void check_exec(const TensorLayout& src, const TensorLayout& dst,
  459. size_t workspace_in_bytes);
  460. };
  461. using Images2Neibs = Images2NeibsForward;
  462. class Images2NeibsBackward : public Images2NeibsBase {
  463. DEF_OPR_IMPL(Images2NeibsBackward, Images2NeibsBase, 1, 1);
  464. public:
  465. /**
  466. * \param[in] diff the backpropagated gradient wrt. dst
  467. * \param[out] grad the backpropagated gradient wrt. src
  468. */
  469. virtual void exec(_megdnn_tensor_in diff, _megdnn_tensor_out grad,
  470. _megdnn_workspace workspace) = 0;
  471. virtual size_t get_workspace_in_bytes(const TensorLayout& diff,
  472. const TensorLayout& grad) = 0;
  473. protected:
  474. void check_exec(const TensorLayout& diff, const TensorLayout& grad,
  475. size_t workspace_in_bytes);
  476. };
  477. /**
  478. * \brief base class for Pooling
  479. */
  480. class PoolingBase : public OperatorBase {
  481. DEF_OPR_IMPL_CTOR(PoolingBase, OperatorBase);
  482. DEF_OPR_PARAM(Pooling);
  483. public:
  484. using Mode = Param::Mode;
  485. protected:
  486. void deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst);
  487. void check_layout_fwd(const TensorLayout& src, const TensorLayout& dst);
  488. };
  489. class PoolingForward : public PoolingBase {
  490. DEF_OPR_IMPL(PoolingForward, PoolingBase, 1, 1);
  491. public:
  492. /**
  493. * \param[in] src input tensor
  494. * \param[out] dst output tensor
  495. */
  496. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
  497. _megdnn_workspace workspace) = 0;
  498. void deduce_layout(const TensorLayout& src, TensorLayout& dst);
  499. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  500. const TensorLayout& dst) = 0;
  501. protected:
  502. void check_exec(const TensorLayout& src, const TensorLayout& dst,
  503. size_t workspace_in_bytes);
  504. };
  505. using Pooling = PoolingForward;
  506. class PoolingBackward : public PoolingBase {
  507. DEF_OPR_IMPL(PoolingBackward, PoolingBase, 3, 1);
  508. public:
  509. /**
  510. * \param[in] src the `src' parameter in PoolingForward::exec
  511. * \param[in] dst the `dst' parameter in PoolingForward::exec
  512. * \param[in] diff the backpropagated gradient wrt. dst
  513. * \param[out] grad the backpropagated gradient wrt. src
  514. */
  515. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in dst,
  516. _megdnn_tensor_in diff, _megdnn_tensor_out grad,
  517. _megdnn_workspace workspace) = 0;
  518. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  519. const TensorLayout& dst,
  520. const TensorLayout& diff,
  521. const TensorLayout& grad) = 0;
  522. protected:
  523. void check_exec(const TensorLayout& src, const TensorLayout& dst,
  524. const TensorLayout& diff, const TensorLayout& grad,
  525. size_t workspace_in_bytes);
  526. };
  527. /**
  528. * \brief base class for Local
  529. */
  530. class LocalBase : public OperatorBase {
  531. DEF_OPR_IMPL_CTOR(LocalBase, OperatorBase);
  532. DEF_OPR_PARAM(Convolution);
  533. public:
  534. using Mode = Param::Mode;
  535. protected:
  536. void deduce_layout_fwd(const TensorLayout& src, const TensorLayout& filter,
  537. TensorLayout& dst);
  538. void check_layout_fwd(const TensorLayout& src, const TensorLayout& filter,
  539. const TensorLayout& dst);
  540. };
  541. class LocalForward : public LocalBase {
  542. DEF_OPR_IMPL(LocalForward, LocalBase, 2, 1);
  543. public:
  544. /**
  545. * \param[in] src (n, ic, ih, iw)
  546. * \param[in] filter (oh, ow, ic, fh, fw, oc)
  547. * \param[out] dst (n, oc, oh, ow)
  548. */
  549. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter,
  550. _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0;
  551. /**
  552. * \brief Deducing output tensor layouts from input tensor layouts.
  553. *
  554. * Be aware that the first and second dimension of `filter' are ignored
  555. * when deducing `dst' layout.
  556. */
  557. void deduce_layout(const TensorLayout& src, const TensorLayout& filter,
  558. TensorLayout& dst);
  559. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  560. const TensorLayout& filter,
  561. const TensorLayout& dst) = 0;
  562. protected:
  563. void check_exec(const TensorLayout& src, const TensorLayout& filter,
  564. const TensorLayout& dst, size_t workspace_in_bytes);
  565. };
  566. using Local = LocalForward;
  567. class LocalBackwardData : public LocalBase {
  568. DEF_OPR_IMPL(LocalBackwardData, LocalBase, 2, 1);
  569. public:
  570. /**
  571. * \param[in] filter (oh, ow, ic, fh, fw, oc)
  572. * \param[in] diff (n, oc, oh, ow)
  573. * \param[out] grad (n, ic, ih, iw)
  574. */
  575. virtual void exec(_megdnn_tensor_in filter, _megdnn_tensor_in diff,
  576. _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
  577. virtual size_t get_workspace_in_bytes(const TensorLayout& filter,
  578. const TensorLayout& diff,
  579. const TensorLayout& grad) = 0;
  580. protected:
  581. void check_exec(const TensorLayout& filter, const TensorLayout& diff,
  582. const TensorLayout& grad, size_t workspace_in_bytes);
  583. };
  584. class LocalBackwardFilter : public LocalBase {
  585. DEF_OPR_IMPL(LocalBackwardFilter, LocalBase, 2, 1);
  586. public:
  587. /**
  588. * \param[in] src (n, ic, ih, iw)
  589. * \param[in] diff (n, oc, oh, ow)
  590. * \param[out] grad (oh, ow, ic, fh, fw, oc)
  591. */
  592. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in diff,
  593. _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
  594. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  595. const TensorLayout& diff,
  596. const TensorLayout& grad) = 0;
  597. protected:
  598. void check_exec(const TensorLayout& src, const TensorLayout& diff,
  599. const TensorLayout& grad, size_t workspace_in_bytes);
  600. };
  601. class BNBase : public OperatorBase {
  602. DEF_OPR_IMPL_CTOR(BNBase, OperatorBase);
  603. DEF_OPR_PARAM(BN);
  604. protected:
  605. void check_param();
  606. };
  607. class BNForward : public BNBase {
  608. DEF_OPR_IMPL(BNForward, BNBase, 6, 5);
  609. public:
  610. /**
  611. * \dst[i] = gemma
  612. * *(x[i]-estimatedMean[k])/sqrt(epsilon+estimatedVariance[k]) + beta \where
  613. * epsilon is a very small value to avoid a "divide by zero" error.
  614. * \param[in] src (n, c, h, w)
  615. * \param[out] dst (n, c, h, w)
  616. * \param[out] mean (see m_param.ParamDim) Global mean.
  617. * \param[out] variance (see m_param.ParamDim) Global variance.
  618. * \Param[out] batch_mean (see m_param.ParamDim)
  619. * Optionally cached intermediate mean from forward pass
  620. * \Param[out] batch_inv_variance (see m_param.ParamDim)
  621. * Optionally cached intermediate variance from forward pass
  622. * src and dst must have the same shape.
  623. * src and dst must be contiguous.
  624. */
  625. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in bn_scale,
  626. _megdnn_tensor_in bn_bias, _megdnn_tensor_inout mean,
  627. _megdnn_tensor_inout variance,
  628. _megdnn_tensor_out batch_mean,
  629. _megdnn_tensor_out batch_inv_variance,
  630. _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0;
  631. void deduce_layout(const TensorLayout& src, TensorLayout& bn_scale,
  632. TensorLayout& bn_bias, TensorLayout& mean,
  633. TensorLayout& variance, TensorLayout& batch_mean,
  634. TensorLayout& batch_inv_variance, TensorLayout& dst);
  635. virtual size_t get_workspace_in_bytes(
  636. const TensorLayout& src, const TensorLayout& bn_scale,
  637. const TensorLayout& bn_bias, const TensorLayout& mean,
  638. const TensorLayout& variance, const TensorLayout& batch_mean,
  639. const TensorLayout& batch_inv_variance,
  640. const TensorLayout& dst) = 0;
  641. protected:
  642. void check_exec(const TensorLayout& src, const TensorLayout& bn_scale,
  643. const TensorLayout& bn_bias, const TensorLayout& mean,
  644. const TensorLayout& variance,
  645. const TensorLayout& batch_mean,
  646. const TensorLayout& batch_inv_variance,
  647. const TensorLayout& dst, size_t workspace_in_bytes);
  648. };
  649. using BN = BNForward;
  650. class BNBackward : public BNBase {
  651. DEF_OPR_IMPL(BNBackward, BNBase, 5, 3);
  652. public:
  653. /**
  654. * \param[in] input data of forwarding propagate.
  655. * \param[in] dy the backpropagated gradient of y.
  656. * \param[out] dx the backpropagated gradient of x.
  657. * \param[out] d_bn_scale, the backpropagated gradient of bn_scale.
  658. * \param[out] d_bn_bias, the backpropagated gradient of bn_bias.
  659. * Optionally cached intermediate results from forward pass
  660. * \param[in] saved_batch_mean mean of the input batch.
  661. Calculated in the forwardpropagation.
  662. * \param[in] saved_batch_variance of the input batch.
  663. Calculated in the forwardpropagation.
  664. */
  665. virtual void exec(_megdnn_tensor_in x, _megdnn_tensor_in dy,
  666. _megdnn_tensor_in saved_batch_mean,
  667. _megdnn_tensor_in saved_batch_variance,
  668. _megdnn_tensor_in bn_scale, _megdnn_tensor_out d_bn_scale,
  669. _megdnn_tensor_out d_bn_bias, _megdnn_tensor_out dx,
  670. _megdnn_workspace workspace) = 0;
  671. virtual size_t get_workspace_in_bytes(
  672. const TensorLayout& x, const TensorLayout& dy,
  673. const TensorLayout& saved_batch_mean,
  674. const TensorLayout& saved_batch_variance,
  675. const TensorLayout& bn_scale, const TensorLayout& d_bn_scale,
  676. const TensorLayout& d_bn_bias, const TensorLayout& dx) = 0;
  677. protected:
  678. void check_exec(const TensorLayout& x, const TensorLayout& dy,
  679. const TensorLayout& saved_batch_mean,
  680. const TensorLayout& saved_batch_variance,
  681. const TensorLayout& bn_scale,
  682. const TensorLayout& d_bn_scale,
  683. const TensorLayout& d_bn_bias, const TensorLayout& dx,
  684. size_t workspace_in_bytes);
  685. };
  686. class LRNBase : public OperatorBase {
  687. DEF_OPR_IMPL_CTOR(LRNBase, OperatorBase);
  688. DEF_OPR_PARAM(LRN);
  689. protected:
  690. void check_param();
  691. };
  692. class LRNForward : public LRNBase {
  693. DEF_OPR_IMPL(LRNForward, LRNBase, 1, 1);
  694. public:
  695. /**
  696. * \see ImageNet Classification with Deep Convolutional Neural Networks
  697. * \param[in] src (n, c, h, w)
  698. * \param[out] dst (n, c, h, w)
  699. *
  700. * src and dst must have the same shape.
  701. * src and dst must be contiguous.
  702. */
  703. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
  704. _megdnn_workspace workspace) = 0;
  705. void deduce_layout(const TensorLayout& src, TensorLayout& dst);
  706. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  707. const TensorLayout& dst) = 0;
  708. protected:
  709. void check_exec(const TensorLayout& src, const TensorLayout& dst,
  710. size_t workspace_in_bytes);
  711. };
  712. using LRN = LRNForward;
  713. class LRNBackward : public LRNBase {
  714. DEF_OPR_IMPL(LRNBackward, LRNBase, 3, 1);
  715. public:
  716. /**
  717. * \param[in] src the `src' parameter in LRNForward::exec
  718. * \param[in] dst the `dst' parameter in LRNForward::exec
  719. * \param[in] diff the backpropagated gradient wrt. dst
  720. * \param[out] grad the backpropagated gradient wrt. src
  721. *
  722. * All tensors should be contiguous and of the same shape.
  723. */
  724. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in dst,
  725. _megdnn_tensor_in diff, _megdnn_tensor_out grad,
  726. _megdnn_workspace workspace) = 0;
  727. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  728. const TensorLayout& dst,
  729. const TensorLayout& diff,
  730. const TensorLayout& grad) = 0;
  731. protected:
  732. void check_exec(const TensorLayout& src, const TensorLayout& dst,
  733. const TensorLayout& diff, const TensorLayout& grad,
  734. size_t workspace_in_bytes);
  735. };
  736. class ROIPoolingBase : public OperatorBase {
  737. DEF_OPR_IMPL_CTOR(ROIPoolingBase, OperatorBase);
  738. DEF_OPR_PARAM(ROIPooling);
  739. protected:
  740. void check_layout_fwd(const TensorLayout& src, const TensorLayout& rois,
  741. const TensorLayout& dst, const TensorLayout& index);
  742. };
  743. class ROIPoolingForward : public ROIPoolingBase {
  744. DEF_OPR_IMPL(ROIPoolingForward, ROIPoolingBase, 2, 2);
  745. public:
  746. /**
  747. * \param[in] src (n, c, ih, iw)
  748. * \param[in] rois (m, 5)
  749. * \param[out] dst (m, c, oh, ow)
  750. * \param[out] index (m, c, oh, ow) if mode is MAX, (0) if mode is AVERAGE
  751. *
  752. * The internal implementation is akin to
  753. * https://github.com/rbgirshick/caffe-fast-rcnn .d
  754. * Note that rois(, 0) denotes the input image index. We store it as
  755. * a float, but it should be an integer instead.
  756. *
  757. * index is a temporary tensor to facilitate its backward operator.
  758. * It is used to store argmax indicex in MAX mode, and it is not used
  759. * in AVERAGE mode.
  760. */
  761. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in rois,
  762. _megdnn_tensor_out dst, _megdnn_tensor_out index,
  763. _megdnn_workspace workspace) = 0;
  764. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  765. const TensorLayout& rois,
  766. const TensorLayout& dst,
  767. const TensorLayout& index) = 0;
  768. protected:
  769. void check_exec(const TensorLayout& src, const TensorLayout& rois,
  770. const TensorLayout& dst, const TensorLayout& index,
  771. size_t workspace_in_bytes);
  772. };
  773. using ROIPooling = ROIPoolingForward;
  774. class ROIPoolingBackward : public ROIPoolingBase {
  775. DEF_OPR_IMPL(ROIPoolingBackward, ROIPoolingBase, 4, 1);
  776. public:
  777. /**
  778. * \param[in] diff the backpropagated gradient wrt. dst
  779. * \param[in] src the `src' parameter in ROIPoolingForward::exec
  780. * \param[in] rois the `rois' parameter in ROIPoolingForward::exec
  781. * \param[in] index the `index' parameter in ROIPoolingForward::exec
  782. * \param[out] grad the backpropagated gradient wrt. src
  783. */
  784. virtual void exec(_megdnn_tensor_in diff, _megdnn_tensor_in src,
  785. _megdnn_tensor_in rois, _megdnn_tensor_in index,
  786. _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
  787. virtual size_t get_workspace_in_bytes(const TensorLayout& diff,
  788. const TensorLayout& src,
  789. const TensorLayout& rois,
  790. const TensorLayout& index,
  791. const TensorLayout& grad) = 0;
  792. protected:
  793. void check_exec(const TensorLayout& diff, const TensorLayout& src,
  794. const TensorLayout& rois, const TensorLayout& index,
  795. const TensorLayout& grad, size_t workspace_in_bytes);
  796. };
  797. class Convolution3DBase : public OperatorBase {
  798. DEF_OPR_IMPL_CTOR(Convolution3DBase, OperatorBase);
  799. DEF_OPR_PARAM(Convolution3D);
  800. public:
  801. static constexpr size_t MAX_SPATIAL_DIM = 3;
  802. using Mode = Param::Mode;
  803. struct CanonizedFilterMeta {
  804. DTypeEnum dtype_enum;
  805. Param::Format format;
  806. uint32_t
  807. //! whether filter should be flipped (i.e. is CONVOLUTION)
  808. should_flip,
  809. group, //!< number of groups
  810. icpg, //!< input channels per group
  811. ocpg, //!< output channels per group
  812. spatial_ndim, stride[MAX_SPATIAL_DIM], padding[MAX_SPATIAL_DIM],
  813. //! spatial dim
  814. spatial[MAX_SPATIAL_DIM], dilation[MAX_SPATIAL_DIM],
  815. //! spatial dim with dilation applied
  816. dilated_spatial[MAX_SPATIAL_DIM];
  817. } MEGDNN_PACKED;
  818. protected:
  819. CanonizedFilterMeta deduce_layout_fwd(const TensorLayout& src,
  820. const TensorLayout& filter,
  821. TensorLayout& dst) const;
  822. CanonizedFilterMeta check_layout_fwd(const TensorLayout& src,
  823. const TensorLayout& filter,
  824. const TensorLayout& dst) const;
  825. CanonizedFilterMeta make_canonized_filter_meta(
  826. size_t src_ndim, const TensorLayout& filter) const;
  827. };
  828. class Convolution3DForward
  829. : public Convolution3DBase,
  830. public detail::MultiAlgoOpr<Convolution3DForward, 3> {
  831. DEF_OPR_IMPL(Convolution3DForward, Convolution3DBase, 2, 1);
  832. public:
  833. /**
  834. * \param[in] src (n, ic, id, ih, iw)
  835. * \param[in] filter (oc, ic, fd, fh, fw)
  836. * \param[out] dst (n, oc, od, oh, ow)
  837. */
  838. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter,
  839. _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0;
  840. void deduce_layout(const TensorLayout& src, const TensorLayout& filter,
  841. TensorLayout& dst);
  842. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  843. const TensorLayout& filter,
  844. const TensorLayout& dst) = 0;
  845. protected:
  846. CanonizedFilterMeta check_exec(const TensorLayout& src,
  847. const TensorLayout& filter,
  848. const TensorLayout& dst,
  849. size_t workspace_in_bytes);
  850. };
  851. using Convolution3D = Convolution3DForward;
  852. class Convolution3DBackwardData
  853. : public Convolution3DBase,
  854. public detail::MultiAlgoOpr<Convolution3DBackwardData, 3> {
  855. DEF_OPR_IMPL(Convolution3DBackwardData, Convolution3DBase, 2, 1);
  856. public:
  857. /**
  858. * \param[in] filter (oc, ic, fd, fh, fw)
  859. * \param[in] diff (n, oc, od, oh, ow)
  860. * \param[out] grad (n, ic, id, ih, iw)
  861. */
  862. virtual void exec(_megdnn_tensor_in filter, _megdnn_tensor_in diff,
  863. _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
  864. virtual size_t get_workspace_in_bytes(const TensorLayout& filter,
  865. const TensorLayout& diff,
  866. const TensorLayout& grad) = 0;
  867. void deduce_layout(const TensorLayout& filter, const TensorLayout& diff,
  868. TensorLayout& grad);
  869. protected:
  870. CanonizedFilterMeta check_exec(const TensorLayout& filter,
  871. const TensorLayout& diff,
  872. const TensorLayout& grad,
  873. size_t workspace_in_bytes);
  874. };
  875. class Convolution3DBackwardFilter
  876. : public Convolution3DBase,
  877. public detail::MultiAlgoOpr<Convolution3DBackwardFilter, 3> {
  878. DEF_OPR_IMPL(Convolution3DBackwardFilter, Convolution3DBase, 2, 1);
  879. public:
  880. /**
  881. * \param[in] src (n, ic, id, ih, iw)
  882. * \param[in] diff (n, oc, od, oh, ow)
  883. * \param[out] grad (oc, ic, fd, fh, fw)
  884. */
  885. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in diff,
  886. _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
  887. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  888. const TensorLayout& diff,
  889. const TensorLayout& grad) = 0;
  890. protected:
  891. CanonizedFilterMeta check_exec(const TensorLayout& src,
  892. const TensorLayout& diff,
  893. const TensorLayout& grad,
  894. size_t workspace_in_bytes);
  895. };
  896. class LocalShareBase : public OperatorBase {
  897. DEF_OPR_IMPL_CTOR(LocalShareBase, OperatorBase);
  898. DEF_OPR_PARAM(LocalShare);
  899. protected:
  900. void deduce_layout_fwd(const TensorLayout& src, const TensorLayout& filter,
  901. TensorLayout& dst);
  902. void check_layout_fwd(const TensorLayout& src, const TensorLayout& filter,
  903. const TensorLayout& dst);
  904. };
  905. class LocalShareForward : public LocalShareBase,
  906. public detail::MultiAlgoOpr<LocalShareForward, 3> {
  907. DEF_OPR_IMPL(LocalShareForward, LocalShareBase, 2, 1);
  908. public:
  909. /**
  910. * \param[in] src (N, IC, IH, IW)
  911. * \param[in] filter (G, spatial_groups_h, spatial_groups_w, IC / G,
  912. * FH, FW, OC / G)
  913. * \param[out] dst (N, OC, OH, OW)
  914. */
  915. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter,
  916. _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0;
  917. /**
  918. * \brief deduce layout of the ouput tensor
  919. */
  920. void deduce_layout(const TensorLayout& src, const TensorLayout& filter,
  921. TensorLayout& dst);
  922. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  923. const TensorLayout& filter,
  924. const TensorLayout& dst) = 0;
  925. protected:
  926. void check_exec(const TensorLayout& src, const TensorLayout& filter,
  927. const TensorLayout& dst, size_t workspace_in_bytes);
  928. };
  929. using LocalShare = LocalShareForward;
  930. class LocalShareBackwardData
  931. : public LocalShareBase,
  932. public detail::MultiAlgoOpr<LocalShareBackwardData, 3> {
  933. DEF_OPR_IMPL(LocalShareBackwardData, LocalShareBase, 2, 1);
  934. public:
  935. /**
  936. * \param[in] filter (G, spatial_groups_h, spatial_groups_w, IC / G,
  937. * FH, FW, OC / G)
  938. * \param[in] diff (N, OC, OH, OW)
  939. * \param[out] grad (N, IC, IH, IW)
  940. */
  941. virtual void exec(_megdnn_tensor_in filter, _megdnn_tensor_in diff,
  942. _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
  943. virtual size_t get_workspace_in_bytes(const TensorLayout& filter,
  944. const TensorLayout& diff,
  945. const TensorLayout& grad) = 0;
  946. void deduce_layout(const TensorLayout& filter, const TensorLayout& diff,
  947. TensorLayout& grad);
  948. protected:
  949. void check_exec(const TensorLayout& filter, const TensorLayout& diff,
  950. const TensorLayout& grad, size_t workspace_in_bytes);
  951. };
  952. class LocalShareBackwardFilter
  953. : public LocalShareBase,
  954. public detail::MultiAlgoOpr<LocalShareBackwardFilter, 3> {
  955. DEF_OPR_IMPL(LocalShareBackwardFilter, LocalShareBase, 2, 1);
  956. public:
  957. /**
  958. * \param[in] src (N, IC, IH, IW)
  959. * \param[in] diff (N, OC, OH, OW)
  960. * \param[out] grad (G, spatial_groups_h, spatial_groups_w, IC / G,
  961. * FH, FW, OC / G)
  962. */
  963. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in diff,
  964. _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
  965. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  966. const TensorLayout& diff,
  967. const TensorLayout& grad) = 0;
  968. protected:
  969. void check_exec(const TensorLayout& src, const TensorLayout& diff,
  970. const TensorLayout& grad, size_t workspace_in_bytes);
  971. };
  972. class ROIAlignBase : public OperatorBase {
  973. DEF_OPR_IMPL_CTOR(ROIAlignBase, OperatorBase);
  974. DEF_OPR_PARAM(ROIAlign);
  975. protected:
  976. void deduce_layout_fwd(const TensorLayout& src, const TensorLayout& rois,
  977. TensorLayout& dst, TensorLayout& index);
  978. void check_layout_fwd(const TensorLayout& src, const TensorLayout& rois,
  979. const TensorLayout& dst, const TensorLayout& index);
  980. };
  981. class ROIAlignForward : public ROIAlignBase {
  982. DEF_OPR_IMPL(ROIAlignForward, ROIAlignBase, 2, 2);
  983. public:
  984. /**
  985. * \param[in] src (n, c, ih, iw)
  986. * \param[in] rois (m, 5)
  987. * \param[out] dst (m, c, oh, ow)
  988. * \param[out] index (m, c, oh, ow) if mode is MAX, (0) if mode is AVERAGE
  989. *
  990. * Note that rois(, 0) denotes the input image index. We store it as
  991. * a float, but it should be an integer instead.
  992. *
  993. * index is a temporary tensor to facilitate its backward operator.
  994. * It is used to store argmax indicex in MAX mode, and it is not used
  995. * in AVERAGE mode.
  996. */
  997. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in rois,
  998. _megdnn_tensor_out dst, _megdnn_tensor_out index,
  999. _megdnn_workspace workspace) = 0;
  1000. void deduce_layout(const TensorLayout& src, const TensorLayout& rois,
  1001. TensorLayout& dst, TensorLayout& index);
  1002. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  1003. const TensorLayout& rois,
  1004. const TensorLayout& dst,
  1005. const TensorLayout& index) = 0;
  1006. protected:
  1007. void check_exec(const TensorLayout& src, const TensorLayout& rois,
  1008. const TensorLayout& dst, const TensorLayout& index,
  1009. size_t workspace_in_bytes);
  1010. };
  1011. using ROIAlign = ROIAlignForward;
  1012. class ROIAlignBackward : public ROIAlignBase {
  1013. DEF_OPR_IMPL(ROIAlignBackward, ROIAlignBase, 3, 1);
  1014. public:
  1015. /**
  1016. * \param[in] diff the backpropagated gradient wrt. dst
  1017. * \param[in] rois the `rois' parameter in ROIAlignForward::exec
  1018. * \param[in] index the `index' parameter in ROIAlignForward::exec
  1019. * \param[out] grad the backpropagated gradient wrt. src
  1020. */
  1021. virtual void exec(_megdnn_tensor_in diff, _megdnn_tensor_in rois,
  1022. _megdnn_tensor_in index, _megdnn_tensor_out grad,
  1023. _megdnn_workspace workspace) = 0;
  1024. virtual size_t get_workspace_in_bytes(const TensorLayout& diff,
  1025. const TensorLayout& rois,
  1026. const TensorLayout& index,
  1027. const TensorLayout& grad) = 0;
  1028. protected:
  1029. void check_exec(const TensorLayout& diff, const TensorLayout& rois,
  1030. const TensorLayout& index, const TensorLayout& grad,
  1031. size_t workspace_in_bytes);
  1032. };
  1033. class DeformableConvBase : public OperatorBase {
  1034. DEF_OPR_IMPL_CTOR(DeformableConvBase, OperatorBase);
  1035. DEF_OPR_PARAM(Convolution);
  1036. public:
  1037. static constexpr size_t MAX_SPATIAL_DIM = 2;
  1038. struct CanonizedFilterMeta : Convolution::CanonizedFilterMeta {
  1039. uint32_t deformable_group;
  1040. };
  1041. protected:
  1042. CanonizedFilterMeta make_canonized_filter_meta(
  1043. size_t src_ndim, const TensorLayout& filter,
  1044. const TensorLayout& offset) const;
  1045. void deduce_layout_fwd(const TensorLayout& im, const TensorLayout& filter,
  1046. const TensorLayout& mask, const TensorLayout& offset,
  1047. TensorLayout& dst);
  1048. void check_layout_fwd(const TensorLayout& src, const TensorLayout& filter,
  1049. const TensorLayout& mask, const TensorLayout& offset,
  1050. const TensorLayout& dst);
  1051. };
  1052. class DeformableConvForward
  1053. : public DeformableConvBase,
  1054. public detail::MultiAlgoOpr<DeformableConvForward, 5> {
  1055. DEF_OPR_IMPL(DeformableConvForward, DeformableConvBase, 4, 1);
  1056. public:
  1057. /**
  1058. * \param[in] im (n, ic, ih, iw)
  1059. * \param[in] filter (oc, ic, fh, fw)
  1060. * \param[in] offset (dg, 2, fh, fw, oh, ow)
  1061. * \param[in] mask (dg, fh, fw, oh, ow)
  1062. * \param[out] dst (n, oc, oh, ow)
  1063. */
  1064. virtual void exec(_megdnn_tensor_in im, _megdnn_tensor_in filter,
  1065. _megdnn_tensor_in offset, _megdnn_tensor_in mask,
  1066. _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0;
  1067. void deduce_layout(const TensorLayout& im, const TensorLayout& filter,
  1068. const TensorLayout& offset, const TensorLayout& mask,
  1069. TensorLayout& dst);
  1070. virtual size_t get_workspace_in_bytes(const TensorLayout& im,
  1071. const TensorLayout& filter,
  1072. const TensorLayout& offset,
  1073. const TensorLayout& mask,
  1074. const TensorLayout& dst) = 0;
  1075. protected:
  1076. CanonizedFilterMeta check_exec(const TensorLayout& im,
  1077. const TensorLayout& filter,
  1078. const TensorLayout& offset,
  1079. const TensorLayout& mask,
  1080. const TensorLayout& dst,
  1081. size_t workspace_in_bytes);
  1082. };
  1083. using DeformableConv = DeformableConvForward;
  1084. /**
  1085. * \brief DeformableConvBackwardFilter operator.
  1086. *
  1087. * Calculating the gradient wrt. convolution filter.
  1088. */
  1089. class DeformableConvBackwardFilter
  1090. : public DeformableConvBase,
  1091. public detail::MultiAlgoOpr<DeformableConvBackwardFilter, 5> {
  1092. DEF_OPR_IMPL(DeformableConvBackwardFilter, DeformableConvBase, 4, 1);
  1093. public:
  1094. /**
  1095. * \param[in] im (oc, ic, fh, fw)
  1096. * \param[in] offset (dg, 2, fh, fw, oh, ow)
  1097. * \param[in] mask (dg, fh, fw, oh, ow)
  1098. * \param[in] out_grad (n, oc, oh, ow)
  1099. * \param[out] filter_grad (oc, ic, ih, iw)
  1100. */
  1101. virtual void exec(_megdnn_tensor_in im, _megdnn_tensor_in offset,
  1102. _megdnn_tensor_in mask, _megdnn_tensor_in out_grad,
  1103. _megdnn_tensor_out filter_grad,
  1104. _megdnn_workspace workspace) = 0;
  1105. virtual size_t get_workspace_in_bytes(const TensorLayout& im,
  1106. const TensorLayout& offset,
  1107. const TensorLayout& mask,
  1108. const TensorLayout& out_grad,
  1109. const TensorLayout& filter_grad) = 0;
  1110. void deduce_layout(const TensorLayout& im, const TensorLayout& offset,
  1111. const TensorLayout& mask, const TensorLayout& out_grad,
  1112. TensorLayout& filter_grad);
  1113. protected:
  1114. CanonizedFilterMeta check_exec(const TensorLayout& im,
  1115. const TensorLayout& offset,
  1116. const TensorLayout& mask,
  1117. const TensorLayout& out_grad,
  1118. const TensorLayout& filter_grad,
  1119. size_t workspace_in_bytes);
  1120. };
  1121. /**
  1122. * \brief DeformableConvBackwardData operator.
  1123. *
  1124. * Calculating the gradient wrt. convolution input data, offset and mask.
  1125. */
  1126. class DeformableConvBackwardData
  1127. : public DeformableConvBase,
  1128. public detail::MultiAlgoOpr<DeformableConvBackwardData, 8> {
  1129. DEF_OPR_IMPL(DeformableConvBackwardData, DeformableConvBase, 5, 3);
  1130. public:
  1131. /**
  1132. * \param[in] im (oc, ic, fh, fw)
  1133. * \param[in] filter (oc, ic, fh, fw)
  1134. * \param[in] offset (dg, 2, fh, fw, oh, ow)
  1135. * \param[in] mask (dg, fh, fw, oh, ow)
  1136. * \param[in] out_grad (n, oc, oh, ow)
  1137. * \param[out] im_grad (n, ic, ih, iw)
  1138. * \param[out] offset_grad (dg, 2, fh, fw, oh, ow)
  1139. * \param[out] mask_grad (dg, fh, fw, oh, ow)
  1140. */
  1141. virtual void exec(_megdnn_tensor_in im, _megdnn_tensor_in filter,
  1142. _megdnn_tensor_in offset, _megdnn_tensor_in mask,
  1143. _megdnn_tensor_in out_grad, _megdnn_tensor_out im_grad,
  1144. _megdnn_tensor_out offset_grad,
  1145. _megdnn_tensor_out mask_grad,
  1146. _megdnn_workspace workspace) = 0;
  1147. virtual size_t get_workspace_in_bytes(
  1148. const TensorLayout& im, const TensorLayout& filter,
  1149. const TensorLayout& offset, const TensorLayout& mask,
  1150. const TensorLayout& out_grad, const TensorLayout& im_grad,
  1151. const TensorLayout& offset_grad, const TensorLayout& mask_grad) = 0;
  1152. void deduce_layout(const TensorLayout& im, const TensorLayout& filter,
  1153. const TensorLayout& offset, const TensorLayout& mask,
  1154. const TensorLayout& out_grad, TensorLayout& im_grad,
  1155. TensorLayout& offset_grad, TensorLayout& mask_grad);
  1156. protected:
  1157. CanonizedFilterMeta check_exec(
  1158. const TensorLayout& im, const TensorLayout& filter,
  1159. const TensorLayout& offset, const TensorLayout& mask,
  1160. const TensorLayout& out_grad, const TensorLayout& im_grad,
  1161. const TensorLayout& offset_grad, const TensorLayout& mask_grad,
  1162. size_t workspace_in_bytes);
  1163. };
  1164. class DeformablePSROIPoolingBase : public OperatorBase {
  1165. DEF_OPR_IMPL_CTOR(DeformablePSROIPoolingBase, OperatorBase);
  1166. DEF_OPR_PARAM(DeformablePSROIPooling);
  1167. protected:
  1168. void deduce_layout_fwd(const TensorLayout& data, const TensorLayout& trans,
  1169. const TensorLayout& rois, TensorLayout& out_data,
  1170. TensorLayout& out_count);
  1171. void check_layout_fwd(const TensorLayout& data, const TensorLayout& trans,
  1172. const TensorLayout& rois,
  1173. const TensorLayout& out_data,
  1174. const TensorLayout& out_count,
  1175. size_t workspace_in_bytes);
  1176. };
  1177. class DeformablePSROIPoolingForward : public DeformablePSROIPoolingBase {
  1178. DEF_OPR_IMPL(DeformablePSROIPoolingForward, DeformablePSROIPoolingBase, 3,
  1179. 2);
  1180. public:
  1181. /**
  1182. * \param[in] data (oc, ic, ih, iw)
  1183. * \param[in] rois (xx, xx, xx, xx)
  1184. * \param[in] trans (oc, ic, fh, fw)
  1185. * \param[out] out_data ( n, ic, ih, iw)
  1186. * \param[out] out_count ( n, ic, ih, iw)
  1187. */
  1188. virtual size_t get_workspace_in_bytes(const TensorLayout& data,
  1189. const TensorLayout& rois,
  1190. const TensorLayout& trans,
  1191. const TensorLayout& out_data,
  1192. const TensorLayout& out_count) = 0;
  1193. virtual void exec(_megdnn_tensor_in data, _megdnn_tensor_in rois,
  1194. _megdnn_tensor_in trans, _megdnn_tensor_out out_data,
  1195. _megdnn_tensor_out out_count,
  1196. _megdnn_workspace workspace) = 0;
  1197. void deduce_layout(const TensorLayout& data, const TensorLayout& rois,
  1198. const TensorLayout& trans, TensorLayout& out_data,
  1199. TensorLayout& out_count);
  1200. void check_exec(const TensorLayout& data, const TensorLayout& rois,
  1201. const TensorLayout& trans, const TensorLayout& out_data,
  1202. const TensorLayout& out_count, size_t workspace_in_bytes);
  1203. };
  1204. using DeformablePSROIPooling = DeformablePSROIPoolingForward;
  1205. class DeformablePSROIPoolingBackward : public DeformablePSROIPoolingBase {
  1206. DEF_OPR_IMPL(DeformablePSROIPoolingBackward, DeformablePSROIPoolingBase, 5,
  1207. 2);
  1208. public:
  1209. /**
  1210. * \param[in] data (oc, ic, ih, iw)
  1211. * \param[in] rois (xx, xx, xx, xx)
  1212. * \param[in] trans (oc, ic, fh, fw)
  1213. * \param[in] out_diff (xx, xx, xx, xx)
  1214. * \param[in] out_count (xx, xx, xx, xx)
  1215. * \param[out] data_diff ( n, ic, ih, iw)
  1216. * \param[out] trans_diff ( n, ic, ih, iw)
  1217. */
  1218. virtual void exec(_megdnn_tensor_in data, _megdnn_tensor_in rois,
  1219. _megdnn_tensor_in trans, _megdnn_tensor_in out_diff,
  1220. _megdnn_tensor_in out_count, _megdnn_tensor_out data_diff,
  1221. _megdnn_tensor_out trans_diff,
  1222. _megdnn_workspace workspace) = 0;
  1223. virtual size_t get_workspace_in_bytes(const TensorLayout& data,
  1224. const TensorLayout& rois,
  1225. const TensorLayout& trans,
  1226. const TensorLayout& out_diff,
  1227. const TensorLayout& out_count,
  1228. const TensorLayout& data_diff,
  1229. const TensorLayout& trans_diff) = 0;
  1230. void check_exec(const TensorLayout& data, const TensorLayout& rois,
  1231. const TensorLayout& trans, const TensorLayout& out_diff,
  1232. const TensorLayout& out_count,
  1233. const TensorLayout& data_diff,
  1234. const TensorLayout& trans_diff, size_t workspace_in_bytes);
  1235. };
  1236. class BatchConvBiasForward
  1237. : public ConvolutionBase<param::BatchConvBias>,
  1238. public detail::MultiAlgoOpr<BatchConvBiasForward, 5> {
  1239. DEF_OPR_IMPL(BatchConvBiasForward, ConvolutionBase, 4, 1);
  1240. public:
  1241. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter,
  1242. _megdnn_tensor_in bias, _megdnn_tensor_in z,
  1243. _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0;
  1244. void deduce_dtype(DType src, DType filter, DType bias, DType z, DType& dst);
  1245. void deduce_layout(const TensorLayout& src, const TensorLayout& filter,
  1246. const TensorLayout& bias, const TensorLayout& z,
  1247. TensorLayout& dst);
  1248. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  1249. const TensorLayout& filter,
  1250. const TensorLayout& bias,
  1251. const TensorLayout& z,
  1252. const TensorLayout& dst) = 0;
  1253. protected:
  1254. CanonizedFilterMeta check_exec(const TensorLayout& src,
  1255. const TensorLayout& filter,
  1256. const TensorLayout& bias,
  1257. const TensorLayout& z,
  1258. const TensorLayout& dst,
  1259. size_t workspace_in_bytes);
  1260. };
  1261. using BatchConvBias = BatchConvBiasForward;
  1262. } // namespace megdnn
  1263. #include "megdnn/internal/opr_header_epilogue.h"
  1264. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台

Contributors (1)