You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

nn.h 67 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620
  1. /**
  2. * \file dnn/include/megdnn/oprs/nn.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #pragma once
  13. #include "megdnn/internal/opr_header_prologue.h"
  14. namespace megdnn {
  15. class SeparableConvBase : public OperatorBase {
  16. DEF_OPR_IMPL_CTOR(SeparableConvBase, OperatorBase);
  17. DEF_OPR_PARAM(SeparableConv);
  18. public:
  19. using Mode = Param::Mode;
  20. protected:
  21. void deduce_layout_fwd(const TensorLayout& src,
  22. const TensorLayout& filter_x,
  23. const TensorLayout& filter_y, TensorLayout& dst);
  24. void check_layout_fwd(const TensorLayout& src, const TensorLayout& filter_x,
  25. const TensorLayout& filter_y,
  26. const TensorLayout& dst);
  27. };
  28. class SeparableConvForward : public SeparableConvBase {
  29. DEF_OPR_IMPL(SeparableConvForward, SeparableConvBase, 3, 1);
  30. public:
  31. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter_x,
  32. _megdnn_tensor_in filter_y, _megdnn_tensor_out dst,
  33. _megdnn_workspace workspace) = 0;
  34. void deduce_layout(const TensorLayout& src, const TensorLayout& filter_x,
  35. const TensorLayout& filter_y, TensorLayout& dst);
  36. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  37. const TensorLayout& filter_x,
  38. const TensorLayout& filter_y,
  39. const TensorLayout& dst) = 0;
  40. protected:
  41. void check_exec(const TensorLayout& src, const TensorLayout& filter_x,
  42. const TensorLayout& filter_y, const TensorLayout& dst,
  43. size_t workspace_in_bytes);
  44. };
  45. using SeparableConv = SeparableConvForward;
  46. namespace detail {
  47. struct PreprocessedFilter {
  48. //! user data; its lifetime should be bound to MegDNN Convolution
  49. //! operator
  50. void* algorithm_id;
  51. TensorNDArray tensors;
  52. };
  53. } // namespace detail
  54. /**
  55. * \brief base class for convolution operation
  56. *
  57. * This operator is supposed to perform convolution on arbitrary input
  58. * dimensions. The input/output format is N, C, dims..., and kernel format can
  59. * take two forms:
  60. * 1. OC, IC, dims..., for conventional dense convolution
  61. * 2. GROUP, OC_PER_GRP, IC_PER_GRP, dims... for sparse group convolution
  62. *
  63. * Currently, only 2D images are supported.
  64. */
  65. template <typename Parameter>
  66. class ConvolutionBase : public OperatorBase {
  67. DEF_OPR_IMPL_CTOR(ConvolutionBase, OperatorBase);
  68. using Param = Parameter;
  69. public:
  70. Param& param() { return m_param; }
  71. const Param& param() const { return m_param; }
  72. protected:
  73. Param m_param;
  74. public:
  75. static constexpr size_t MAX_SPATIAL_DIM = 2;
  76. using Mode = typename Param::Mode;
  77. struct CanonizedFilterMeta {
  78. DType dtype;
  79. typename Param::Format format;
  80. uint32_t
  81. //! whether filter should be flipped (i.e. is CONVOLUTION)
  82. should_flip,
  83. group, //!< number of groups
  84. icpg, //!< input channels per group
  85. ocpg, //!< output channels per group
  86. spatial_ndim, stride[MAX_SPATIAL_DIM], padding[MAX_SPATIAL_DIM],
  87. //! spatial dim
  88. spatial[MAX_SPATIAL_DIM], dilation[MAX_SPATIAL_DIM],
  89. //! spatial dim with dilation applied
  90. dilated_spatial[MAX_SPATIAL_DIM];
  91. //! T should be a ConvolutionBase<Z>::CanonizedFilterMeta
  92. template <typename T>
  93. void copy_from(const T& b) {
  94. dtype = b.dtype;
  95. format = b.format;
  96. should_flip = b.should_flip;
  97. group = b.group;
  98. icpg = b.icpg;
  99. ocpg = b.ocpg;
  100. spatial_ndim = b.spatial_ndim;
  101. memcpy(stride, b.stride, sizeof(stride));
  102. memcpy(padding, b.padding, sizeof(padding));
  103. memcpy(spatial, b.spatial, sizeof(spatial));
  104. memcpy(dilation, b.dilation, sizeof(dilation));
  105. memcpy(dilated_spatial, b.dilated_spatial, sizeof(dilated_spatial));
  106. }
  107. bool operator==(const CanonizedFilterMeta& b) const {
  108. bool flag = true;
  109. flag = flag && (format == b.format);
  110. flag = flag && (dtype == b.dtype);
  111. flag = flag && (should_flip == b.should_flip);
  112. flag = flag && (group == b.group);
  113. flag = flag && (icpg == b.icpg);
  114. flag = flag && (ocpg == b.ocpg);
  115. flag = flag && (spatial_ndim == b.spatial_ndim);
  116. if (flag) {
  117. for (uint32_t i = 0; i < spatial_ndim; ++i) {
  118. flag = flag && (stride[i] == b.stride[i]);
  119. flag = flag && (padding[i] == b.padding[i]);
  120. flag = flag && (spatial[i] == b.spatial[i]);
  121. flag = flag && (dilation[i] == b.dilation[i]);
  122. flag = flag && (dilated_spatial[i] == b.dilated_spatial[i]);
  123. }
  124. }
  125. return flag;
  126. }
  127. };
  128. using PreprocessedFilter = detail::PreprocessedFilter;
  129. protected:
  130. // Check or deduce output DType
  131. void check_or_deduce_dtype_fwd(DType src, DType filter, DType& dst) const;
  132. CanonizedFilterMeta deduce_layout_fwd(const TensorLayout& src,
  133. const TensorLayout& filter,
  134. TensorLayout& dst) const;
  135. CanonizedFilterMeta check_layout_fwd(const TensorLayout& src,
  136. const TensorLayout& filter,
  137. const TensorLayout& dst) const;
  138. CanonizedFilterMeta make_canonized_filter_meta(
  139. size_t src_ndim, const TensorLayout& filter) const;
  140. };
  141. class MaskPropagate : public OperatorBase {
  142. DEF_OPR_IMPL(MaskPropagate, OperatorBase, 1, 1);
  143. DEF_OPR_PARAM(MaskPropagate);
  144. public:
  145. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
  146. _megdnn_workspace workspace) = 0;
  147. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  148. const TensorLayout& dst) = 0;
  149. void deduce_layout(const TensorLayout& src, TensorLayout& dst);
  150. };
  151. /**
  152. * \brief ConvolutionForward Operator with 0/1 Mask matrix
  153. */
  154. class MaskConvForward : public ConvolutionBase<param::Convolution> {
  155. DEF_OPR_IMPL(MaskConvForward, ConvolutionBase, 3, 1);
  156. public:
  157. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter,
  158. _megdnn_tensor_in mask, _megdnn_tensor_out dst,
  159. _megdnn_workspace worksapce) = 0;
  160. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  161. const TensorLayout& filter,
  162. const TensorLayout& mask,
  163. const TensorLayout& dst) = 0;
  164. void deduce_dtype(DType src, DType filter, DType mask, DType& dst);
  165. void deduce_layout(const TensorLayout& src, const TensorLayout& filter,
  166. const TensorLayout& mask, TensorLayout& dst);
  167. protected:
  168. CanonizedFilterMeta check_exec(const TensorLayout& src,
  169. const TensorLayout& filter,
  170. const TensorLayout& mask,
  171. const TensorLayout& dst,
  172. size_t workspace_in_bytes);
  173. };
  174. using MaskConvolution = MaskConvForward;
  175. /**
  176. * \brief ConvolutionForward operator.
  177. */
  178. class ConvolutionForward : public ConvolutionBase<param::Convolution>,
  179. public detail::MultiAlgoOpr<ConvolutionForward, 3> {
  180. DEF_OPR_IMPL(ConvolutionForward, ConvolutionBase, 2, 1);
  181. public:
  182. /**
  183. * \param[in] src (n, ic, ih, iw)
  184. * \param[in] filter (oc, ic, fh, fw)
  185. * \param[out] dst (n, oc, oh, ow)
  186. */
  187. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter,
  188. _megdnn_tensor_out dst,
  189. const PreprocessedFilter* preprocessed_filter,
  190. _megdnn_workspace workspace) = 0;
  191. virtual void exec_preprocess(const TensorLayout& src_layout,
  192. _megdnn_tensor_in filter,
  193. const TensorLayout& dst_layout,
  194. PreprocessedFilter* preprocessed_filter,
  195. _megdnn_workspace workspace) = 0;
  196. void deduce_dtype(DType src, DType filter, DType& dst);
  197. void deduce_layout(const TensorLayout& src, const TensorLayout& filter,
  198. TensorLayout& dst);
  199. virtual size_t get_workspace_in_bytes(
  200. const TensorLayout& src, const TensorLayout& filter,
  201. const TensorLayout& dst,
  202. const PreprocessedFilter* preprocessed_filter) = 0;
  203. virtual SmallVector<TensorLayout> deduce_preprocessed_filter_layout(
  204. const TensorLayout& src, const TensorLayout& filter,
  205. const TensorLayout& dst) = 0;
  206. virtual size_t get_preprocess_workspace_in_bytes(
  207. const TensorLayout& src, const TensorLayout& filter,
  208. const TensorLayout& dst) = 0;
  209. protected:
  210. CanonizedFilterMeta check_exec(
  211. const TensorLayout& src, const TensorLayout& filter,
  212. const TensorLayout& dst, size_t workspace_in_bytes,
  213. const PreprocessedFilter* preprocessed_filter);
  214. };
  215. using Convolution = ConvolutionForward;
  216. /**
  217. * \brief ConvolutionBackwardData operator.
  218. *
  219. * Calculating the gradient wrt. convolution input data.
  220. */
  221. class ConvolutionBackwardData
  222. : public ConvolutionBase<param::Convolution>,
  223. public detail::MultiAlgoOpr<ConvolutionBackwardData, 3> {
  224. DEF_OPR_IMPL(ConvolutionBackwardData, ConvolutionBase, 2, 1);
  225. public:
  226. /**
  227. * \param[in] filter (oc, ic, fh, fw)
  228. * \param[in] diff (n, oc, oh, ow)
  229. * \param[out] grad (n, ic, ih, iw)
  230. */
  231. virtual void exec(_megdnn_tensor_in filter, _megdnn_tensor_in diff,
  232. _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
  233. virtual size_t get_workspace_in_bytes(const TensorLayout& filter,
  234. const TensorLayout& diff,
  235. const TensorLayout& grad) = 0;
  236. void deduce_dtype(DType filter, DType diff, DType& grad);
  237. void deduce_layout(const TensorLayout& filter, const TensorLayout& diff,
  238. TensorLayout& grad);
  239. protected:
  240. CanonizedFilterMeta check_exec(const TensorLayout& filter,
  241. const TensorLayout& diff,
  242. const TensorLayout& grad,
  243. size_t workspace_in_bytes);
  244. };
  245. /**
  246. * \brief ConvolutionBackwardFilter operator.
  247. *
  248. * Calculating the gradient wrt. convolution filter.
  249. */
  250. class ConvolutionBackwardFilter
  251. : public ConvolutionBase<param::Convolution>,
  252. public detail::MultiAlgoOpr<ConvolutionBackwardFilter, 3> {
  253. DEF_OPR_IMPL(ConvolutionBackwardFilter, ConvolutionBase, 2, 1);
  254. public:
  255. /**
  256. * \param[in] src (n, ic, ih, iw)
  257. * \param[in] diff (n, oc, oh, ow)
  258. * \param[out] grad (oc, ic, fh, fw)
  259. */
  260. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in diff,
  261. _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
  262. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  263. const TensorLayout& diff,
  264. const TensorLayout& grad) = 0;
  265. protected:
  266. CanonizedFilterMeta check_exec(const TensorLayout& src,
  267. const TensorLayout& diff,
  268. const TensorLayout& grad,
  269. size_t workspace_in_bytes);
  270. };
  271. /**
  272. * \brief ConvolutionBias operator
  273. */
  274. class ConvBiasForward : public ConvolutionBase<param::ConvBias>,
  275. public detail::MultiAlgoOpr<ConvBiasForward, 5> {
  276. DEF_OPR_IMPL(ConvBiasForward, ConvolutionBase, 4, 1);
  277. public:
  278. /**
  279. * \param[in] src (n, ic, ih, iw) or (n, ih, iw, ic)
  280. * \param[in] filter (oc, ic, fh, fw) or (oc, fh, fw, ic) or (oc/4, fh, fw,
  281. * 4 * ic)
  282. * \param[in] bias (1, oc, 1, 1)
  283. * \param[in] z same as dst
  284. * \param[out] dst (n, oc, oh, ow) or (n, oh, ow, oc)
  285. *
  286. * \note if the format is NCHW_WINOGRAD, the filter layout is (alphah,
  287. * alphaw, oc, ic)
  288. */
  289. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter,
  290. _megdnn_tensor_in bias, _megdnn_tensor_in z,
  291. _megdnn_tensor_out dst,
  292. const PreprocessedFilter* preprocessed_filter,
  293. _megdnn_workspace workspace) = 0;
  294. virtual void exec_preprocess(const TensorLayout& src_layout,
  295. _megdnn_tensor_in filter,
  296. const TensorLayout& bias_layout,
  297. const TensorLayout& z_layout,
  298. const TensorLayout& dst_layout,
  299. PreprocessedFilter* preprocessed_filter,
  300. _megdnn_workspace workspace) = 0;
  301. void deduce_dtype(DType src, DType filter, DType bias, DType z, DType& dst);
  302. void deduce_layout(const TensorLayout& src, const TensorLayout& filter,
  303. const TensorLayout& bias, const TensorLayout& z,
  304. TensorLayout& dst);
  305. virtual size_t get_workspace_in_bytes(
  306. const TensorLayout& src, const TensorLayout& filter,
  307. const TensorLayout& bias, const TensorLayout& z,
  308. const TensorLayout& dst,
  309. const PreprocessedFilter* preprocessed_filter) = 0;
  310. virtual size_t get_preprocess_workspace_in_bytes(
  311. const TensorLayout& src, const TensorLayout& filter,
  312. const TensorLayout& bias, const TensorLayout& z,
  313. const TensorLayout& dst) = 0;
  314. virtual SmallVector<TensorLayout> deduce_preprocessed_filter_layout(
  315. const TensorLayout& src, const TensorLayout& filter,
  316. const TensorLayout& bias, const TensorLayout& z,
  317. const TensorLayout& dst) = 0;
  318. static void deduce_winograd_origin_layout_and_param(
  319. const Param::Format format, const size_t output_block_size,
  320. const TensorLayout& src_layout,
  321. const TensorLayout& winograd_filter_layout,
  322. TensorLayout& origin_layout, Param& origin_param);
  323. enum class BiasMode : uint32_t {
  324. NO_BIAS = 0, //!< no bias
  325. BROADCAST_CHANNEL_BIAS, //!< broadcast channel bias, [1, c, 1, 1]
  326. BIAS //!< [N, C, H, W]
  327. };
  328. //! param for winograd algos.
  329. struct WinogradParam {
  330. uint32_t channel_block_size;
  331. uint32_t output_block_size;
  332. uint32_t tile_size;
  333. bool operator==(const WinogradParam& rhs) const {
  334. return channel_block_size == rhs.channel_block_size &&
  335. output_block_size == rhs.output_block_size &&
  336. tile_size == rhs.tile_size;
  337. }
  338. std::string to_string() const;
  339. };
  340. static constexpr WinogradParam INVALID_WINOGRAD_PARAM = {0, 0, 0};
  341. struct DirectParam {
  342. std::string to_string() const { return ""; }
  343. };
  344. struct MatmulParam {
  345. std::string to_string() const { return ""; }
  346. };
  347. struct DefaultParam {
  348. std::string to_string() const { return ""; }
  349. };
  350. //! get algo name, the format is ParamTrait<T>::category:base:p.to_string()
  351. //! \warning: base must not contain :.
  352. template <typename T>
  353. static std::string algo_name(
  354. const std::string& base, const T& p,
  355. param::ConvBias::Format format = param::ConvBias::Format::NCHW);
  356. /*!
  357. * \brief parse algo_name and get WinogradParam from algo name.
  358. *
  359. * \param algo name string
  360. * \return WinogradParam parsed from algo name, use pattern
  361. * winograd:base:m:tile_size.
  362. *
  363. * \warning: INVALID_WINOGRAD_PARAM returns if the algo_name is not matched.
  364. */
  365. static WinogradParam parse_winograd_name(const std::string& algo_name);
  366. /**
  367. * @brief find if there is nchw_nchwxx conv kernel optimized for argment,
  368. * nchw44 used for arm, nchw88 used for x86
  369. *
  370. * @param src_dtype conv feature map data type
  371. * @param filter_dtype conv filter or weight data type
  372. * @param dst_dtype output data type
  373. * @param fm filter meta param
  374. * @param bias_mode bias mode, no_bias or broadcast or bias
  375. * @param nonline_mode identity or relu or h_swish or sigmoid
  376. * @return true, found a kernel
  377. * @return false, can`t found any kernel
  378. */
  379. static bool is_nchw_nchwxx_optimized(
  380. const DTypeEnum src_dtype, const DTypeEnum filter_dtype,
  381. const DTypeEnum dst_dtype,
  382. const ConvolutionBase<param::Convolution>::CanonizedFilterMeta& fm,
  383. const ConvBiasForward::BiasMode bias_mode,
  384. const param::ConvBias::NonlineMode nonline_mode);
  385. protected:
  386. CanonizedFilterMeta check_exec(
  387. const TensorLayout& src, const TensorLayout& filter,
  388. const TensorLayout& bias, const TensorLayout& z,
  389. const TensorLayout& dst, size_t workspace_in_bytes,
  390. const PreprocessedFilter* preprocessed_filter);
  391. };
  392. using ConvBias = ConvBiasForward;
  393. /**
  394. * \brief base class for Conv - Nonline - Pooling
  395. */
  396. class ConvPoolingBase : public OperatorBase {
  397. DEF_OPR_IMPL_CTOR(ConvPoolingBase, OperatorBase);
  398. /**
  399. * \ Param::Method: Two methods to fetch the input data.
  400. * Default methods is WITH_TEXTURE_OBJ.
  401. * If you want to use WITH_SHARED_MEM mode,
  402. * please make sure that the size of
  403. * [ all of the fliter kernels + a channel
  404. * of input data + a channel of output data]
  405. * should be no large than 38KB.
  406. * And the pooling mode should not be "MAX".
  407. */
  408. DEF_OPR_PARAM(ConvPooling);
  409. protected:
  410. virtual void deduce_layout(const TensorLayout& src,
  411. const TensorLayout& filter,
  412. const TensorLayout& bias, TensorLayout& dst) = 0;
  413. virtual void check_layout(const TensorLayout& src,
  414. const TensorLayout& filter,
  415. const TensorLayout& bias, TensorLayout& dst,
  416. size_t workspace_limit_in_bytes) = 0;
  417. };
  418. class ConvPoolingForward : public ConvPoolingBase {
  419. DEF_OPR_IMPL(ConvPoolingForward, ConvPoolingBase, 2, 1);
  420. public:
  421. /**
  422. * \param[in] src input tensor
  423. * \param[out] dst output tensor
  424. */
  425. virtual void exec(const _megdnn_in TensorND src,
  426. const _megdnn_in TensorND filter,
  427. const _megdnn_in TensorND bias, _megdnn_out TensorND dst,
  428. _megdnn_out Workspace workspace) = 0;
  429. virtual void deduce_layout(const TensorLayout& src,
  430. const TensorLayout& filter,
  431. const TensorLayout& bias, TensorLayout& dst) = 0;
  432. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  433. const TensorLayout& filter,
  434. const TensorLayout& bias,
  435. const TensorLayout& dst) = 0;
  436. protected:
  437. virtual void check_layout(const TensorLayout& src,
  438. const TensorLayout& filter,
  439. const TensorLayout& bias, TensorLayout& dst,
  440. size_t workspace_limit_in_bytes) = 0;
  441. };
  442. using ConvPooling = ConvPoolingForward;
  443. class GroupLocalBase : public OperatorBase {
  444. DEF_OPR_IMPL_CTOR(GroupLocalBase, OperatorBase);
  445. DEF_OPR_PARAM(Convolution);
  446. public:
  447. using Mode = Param::Mode;
  448. protected:
  449. void deduce_layout_fwd(const TensorLayout& src, const TensorLayout& filter,
  450. TensorLayout& dst);
  451. void check_layout_fwd(const TensorLayout& src, const TensorLayout& filter,
  452. const TensorLayout& dst);
  453. };
  454. class GroupLocalForward : public GroupLocalBase {
  455. DEF_OPR_IMPL(GroupLocalForward, GroupLocalBase, 2, 1);
  456. public:
  457. /**
  458. * \param[in] src (N, IC, IH, IW)
  459. * \param[in] filter (G, OH, OW, IC/G, FH, FW, OC/G)
  460. * \param[out] dst (N, OC, OH, OW)
  461. **/
  462. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter,
  463. _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0;
  464. void deduce_layout(const TensorLayout& src, const TensorLayout& filter,
  465. TensorLayout& dst) {
  466. deduce_layout_fwd(src, filter, dst);
  467. }
  468. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  469. const TensorLayout& filter,
  470. const TensorLayout& dst) = 0;
  471. protected:
  472. void check_exec(const TensorLayout& src, const TensorLayout& filter,
  473. const TensorLayout& dst, size_t workspace_in_bytes);
  474. };
  475. using GroupLocal = GroupLocalForward;
  476. class GroupLocalBackwardData : public GroupLocalBase {
  477. DEF_OPR_IMPL(GroupLocalBackwardData, GroupLocalBase, 2, 1);
  478. public:
  479. virtual void exec(_megdnn_tensor_in filter, _megdnn_tensor_in diff,
  480. _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
  481. virtual size_t get_workspace_in_bytes(const TensorLayout& filter,
  482. const TensorLayout& diff,
  483. const TensorLayout& grad) = 0;
  484. protected:
  485. void check_exec(const TensorLayout& filter, const TensorLayout& diff,
  486. const TensorLayout& grad, size_t workspace_in_bytes);
  487. };
  488. class GroupLocalBackwardFilter : public GroupLocalBase {
  489. DEF_OPR_IMPL(GroupLocalBackwardFilter, GroupLocalBase, 2, 1);
  490. public:
  491. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in diff,
  492. _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
  493. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  494. const TensorLayout& diff,
  495. const TensorLayout& grad) = 0;
  496. protected:
  497. void check_exec(const TensorLayout& filter, const TensorLayout& diff,
  498. const TensorLayout& grad, size_t workspace_in_bytes);
  499. };
  500. class Images2NeibsBase : public OperatorBase {
  501. DEF_OPR_IMPL_CTOR(Images2NeibsBase, OperatorBase);
  502. DEF_OPR_PARAM(Images2Neibs);
  503. protected:
  504. void deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst);
  505. void check_layout_fwd(const TensorLayout& filter, const TensorLayout& dst);
  506. };
  507. class Images2NeibsForward : public Images2NeibsBase {
  508. DEF_OPR_IMPL(Images2NeibsForward, Images2NeibsBase, 1, 1);
  509. public:
  510. /**
  511. * \param[in] src (N, C, IH, IW)
  512. * \param[out] dst (N, C, OH, OW, window_h, window_w)
  513. *
  514. * \see
  515. * http://deeplearning.net/software/theano/library/tensor/nnet/neighbours.html
  516. *
  517. * \f$ dst_{n, c, oh, ow, wh, ww} = src_{n, c, ih+wh, iw+fw}\f$,
  518. * where \f$ ih=-pad_h+oh*stride_h, iw=-pad_w+ow*stride_w\f$.
  519. */
  520. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
  521. _megdnn_workspace workspace) = 0;
  522. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  523. const TensorLayout& dst) = 0;
  524. void deduce_layout(const TensorLayout& src, TensorLayout& dst);
  525. protected:
  526. void check_exec(const TensorLayout& src, const TensorLayout& dst,
  527. size_t workspace_in_bytes);
  528. };
  529. using Images2Neibs = Images2NeibsForward;
  530. class Images2NeibsBackward : public Images2NeibsBase {
  531. DEF_OPR_IMPL(Images2NeibsBackward, Images2NeibsBase, 1, 1);
  532. public:
  533. /**
  534. * \param[in] diff the backpropagated gradient wrt. dst
  535. * \param[out] grad the backpropagated gradient wrt. src
  536. */
  537. virtual void exec(_megdnn_tensor_in diff, _megdnn_tensor_out grad,
  538. _megdnn_workspace workspace) = 0;
  539. virtual size_t get_workspace_in_bytes(const TensorLayout& diff,
  540. const TensorLayout& grad) = 0;
  541. protected:
  542. void check_exec(const TensorLayout& diff, const TensorLayout& grad,
  543. size_t workspace_in_bytes);
  544. };
  545. /**
  546. * \brief base class for Pooling
  547. */
  548. class PoolingBase : public OperatorBase {
  549. DEF_OPR_IMPL_CTOR(PoolingBase, OperatorBase);
  550. DEF_OPR_PARAM(Pooling);
  551. public:
  552. using Mode = Param::Mode;
  553. protected:
  554. void deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst);
  555. void check_layout_fwd(const TensorLayout& src, const TensorLayout& dst);
  556. };
  557. class PoolingForward : public PoolingBase {
  558. DEF_OPR_IMPL(PoolingForward, PoolingBase, 1, 1);
  559. public:
  560. /**
  561. * \param[in] src input tensor
  562. * \param[out] dst output tensor
  563. */
  564. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
  565. _megdnn_workspace workspace) = 0;
  566. void deduce_layout(const TensorLayout& src, TensorLayout& dst);
  567. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  568. const TensorLayout& dst) = 0;
  569. protected:
  570. void check_exec(const TensorLayout& src, const TensorLayout& dst,
  571. size_t workspace_in_bytes);
  572. };
  573. using Pooling = PoolingForward;
  574. class PoolingBackward : public PoolingBase {
  575. DEF_OPR_IMPL(PoolingBackward, PoolingBase, 3, 1);
  576. public:
  577. /**
  578. * \param[in] src the `src' parameter in PoolingForward::exec
  579. * \param[in] dst the `dst' parameter in PoolingForward::exec
  580. * \param[in] diff the backpropagated gradient wrt. dst
  581. * \param[out] grad the backpropagated gradient wrt. src
  582. */
  583. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in dst,
  584. _megdnn_tensor_in diff, _megdnn_tensor_out grad,
  585. _megdnn_workspace workspace) = 0;
  586. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  587. const TensorLayout& dst,
  588. const TensorLayout& diff,
  589. const TensorLayout& grad) = 0;
  590. protected:
  591. void check_exec(const TensorLayout& src, const TensorLayout& dst,
  592. const TensorLayout& diff, const TensorLayout& grad,
  593. size_t workspace_in_bytes);
  594. };
  595. /**
  596. * \brief base class for AdaptivePooling
  597. */
  598. class AdaptivePoolingBase : public OperatorBase {
  599. DEF_OPR_IMPL_CTOR(AdaptivePoolingBase, OperatorBase);
  600. DEF_OPR_PARAM(AdaptivePooling);
  601. protected:
  602. param::Pooling deduce_pooling_param(const TensorLayout& src,
  603. const TensorLayout& dst);
  604. };
  605. class AdaptivePoolingForward : public AdaptivePoolingBase {
  606. DEF_OPR_IMPL(AdaptivePoolingForward, AdaptivePoolingBase, 1, 1);
  607. public:
  608. /**
  609. * \param[in] src input tensor
  610. * \param[out] dst output tensor
  611. */
  612. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
  613. _megdnn_workspace workspace) = 0;
  614. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  615. const TensorLayout& dst) = 0;
  616. };
  617. using AdaptivePooling = AdaptivePoolingForward;
  618. class AdaptivePoolingBackward : public AdaptivePoolingBase {
  619. DEF_OPR_IMPL(AdaptivePoolingBackward, AdaptivePoolingBase, 3, 1);
  620. public:
  621. /**
  622. * \param[in] src the `src' parameter in AdaptivePoolingForward::exec
  623. * \param[in] dst the `dst' parameter in AdaptivePoolingForward::exec
  624. * \param[in] diff the backpropagated gradient wrt. dst
  625. * \param[out] grad the backpropagated gradient wrt. src
  626. */
  627. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in dst,
  628. _megdnn_tensor_in diff, _megdnn_tensor_out grad,
  629. _megdnn_workspace workspace) = 0;
  630. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  631. const TensorLayout& dst,
  632. const TensorLayout& diff,
  633. const TensorLayout& grad) = 0;
  634. };
  635. /**
  636. * \brief base class for Local
  637. */
  638. class LocalBase : public OperatorBase {
  639. DEF_OPR_IMPL_CTOR(LocalBase, OperatorBase);
  640. DEF_OPR_PARAM(Convolution);
  641. public:
  642. using Mode = Param::Mode;
  643. protected:
  644. void deduce_layout_fwd(const TensorLayout& src, const TensorLayout& filter,
  645. TensorLayout& dst);
  646. void check_layout_fwd(const TensorLayout& src, const TensorLayout& filter,
  647. const TensorLayout& dst);
  648. };
  649. class LocalForward : public LocalBase {
  650. DEF_OPR_IMPL(LocalForward, LocalBase, 2, 1);
  651. public:
  652. /**
  653. * \param[in] src (n, ic, ih, iw)
  654. * \param[in] filter (oh, ow, ic, fh, fw, oc)
  655. * \param[out] dst (n, oc, oh, ow)
  656. */
  657. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter,
  658. _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0;
  659. /**
  660. * \brief Deducing output tensor layouts from input tensor layouts.
  661. *
  662. * Be aware that the first and second dimension of `filter' are ignored
  663. * when deducing `dst' layout.
  664. */
  665. void deduce_layout(const TensorLayout& src, const TensorLayout& filter,
  666. TensorLayout& dst);
  667. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  668. const TensorLayout& filter,
  669. const TensorLayout& dst) = 0;
  670. protected:
  671. void check_exec(const TensorLayout& src, const TensorLayout& filter,
  672. const TensorLayout& dst, size_t workspace_in_bytes);
  673. };
  674. using Local = LocalForward;
  675. class LocalBackwardData : public LocalBase {
  676. DEF_OPR_IMPL(LocalBackwardData, LocalBase, 2, 1);
  677. public:
  678. /**
  679. * \param[in] filter (oh, ow, ic, fh, fw, oc)
  680. * \param[in] diff (n, oc, oh, ow)
  681. * \param[out] grad (n, ic, ih, iw)
  682. */
  683. virtual void exec(_megdnn_tensor_in filter, _megdnn_tensor_in diff,
  684. _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
  685. virtual size_t get_workspace_in_bytes(const TensorLayout& filter,
  686. const TensorLayout& diff,
  687. const TensorLayout& grad) = 0;
  688. protected:
  689. void check_exec(const TensorLayout& filter, const TensorLayout& diff,
  690. const TensorLayout& grad, size_t workspace_in_bytes);
  691. };
  692. class LocalBackwardFilter : public LocalBase {
  693. DEF_OPR_IMPL(LocalBackwardFilter, LocalBase, 2, 1);
  694. public:
  695. /**
  696. * \param[in] src (n, ic, ih, iw)
  697. * \param[in] diff (n, oc, oh, ow)
  698. * \param[out] grad (oh, ow, ic, fh, fw, oc)
  699. */
  700. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in diff,
  701. _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
  702. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  703. const TensorLayout& diff,
  704. const TensorLayout& grad) = 0;
  705. protected:
  706. void check_exec(const TensorLayout& src, const TensorLayout& diff,
  707. const TensorLayout& grad, size_t workspace_in_bytes);
  708. };
  709. class BNBase : public OperatorBase {
  710. DEF_OPR_IMPL_CTOR(BNBase, OperatorBase);
  711. DEF_OPR_PARAM(BN);
  712. protected:
  713. void check_param();
  714. };
  715. class BNForward : public BNBase {
  716. DEF_OPR_IMPL(BNForward, BNBase, 6, 5);
  717. public:
  718. /**
  719. * \dst[i] = gemma
  720. * *(x[i]-estimatedMean[k])/sqrt(epsilon+estimatedVariance[k]) + beta \where
  721. * epsilon is a very small value to avoid a "divide by zero" error.
  722. * \param[in] src (n, c, h, w)
  723. * \param[out] dst (n, c, h, w)
  724. * \param[out] mean (see m_param.ParamDim) Global mean.
  725. * \param[out] variance (see m_param.ParamDim) Global variance.
  726. * \Param[out] batch_mean (see m_param.ParamDim)
  727. * Optionally cached intermediate mean from forward pass
  728. * \Param[out] batch_inv_variance (see m_param.ParamDim)
  729. * Optionally cached intermediate variance from forward pass
  730. * src and dst must have the same shape.
  731. * src and dst must be contiguous.
  732. */
  733. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in bn_scale,
  734. _megdnn_tensor_in bn_bias, _megdnn_tensor_inout mean,
  735. _megdnn_tensor_inout variance,
  736. _megdnn_tensor_out batch_mean,
  737. _megdnn_tensor_out batch_inv_variance,
  738. _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0;
  739. void deduce_layout(const TensorLayout& src, TensorLayout& bn_scale,
  740. TensorLayout& bn_bias, TensorLayout& mean,
  741. TensorLayout& variance, TensorLayout& batch_mean,
  742. TensorLayout& batch_inv_variance, TensorLayout& dst);
  743. virtual size_t get_workspace_in_bytes(
  744. const TensorLayout& src, const TensorLayout& bn_scale,
  745. const TensorLayout& bn_bias, const TensorLayout& mean,
  746. const TensorLayout& variance, const TensorLayout& batch_mean,
  747. const TensorLayout& batch_inv_variance,
  748. const TensorLayout& dst) = 0;
  749. protected:
  750. void check_exec(const TensorLayout& src, const TensorLayout& bn_scale,
  751. const TensorLayout& bn_bias, const TensorLayout& mean,
  752. const TensorLayout& variance,
  753. const TensorLayout& batch_mean,
  754. const TensorLayout& batch_inv_variance,
  755. const TensorLayout& dst, size_t workspace_in_bytes);
  756. };
  757. using BN = BNForward;
  758. class BNBackward : public BNBase {
  759. DEF_OPR_IMPL(BNBackward, BNBase, 5, 3);
  760. public:
  761. /**
  762. * \param[in] input data of forwarding propagate.
  763. * \param[in] dy the backpropagated gradient of y.
  764. * \param[out] dx the backpropagated gradient of x.
  765. * \param[out] d_bn_scale, the backpropagated gradient of bn_scale.
  766. * \param[out] d_bn_bias, the backpropagated gradient of bn_bias.
  767. * Optionally cached intermediate results from forward pass
  768. * \param[in] saved_batch_mean mean of the input batch.
  769. Calculated in the forwardpropagation.
  770. * \param[in] saved_batch_variance of the input batch.
  771. Calculated in the forwardpropagation.
  772. */
  773. virtual void exec(_megdnn_tensor_in x, _megdnn_tensor_in dy,
  774. _megdnn_tensor_in saved_batch_mean,
  775. _megdnn_tensor_in saved_batch_variance,
  776. _megdnn_tensor_in bn_scale, _megdnn_tensor_out d_bn_scale,
  777. _megdnn_tensor_out d_bn_bias, _megdnn_tensor_out dx,
  778. _megdnn_workspace workspace) = 0;
  779. virtual size_t get_workspace_in_bytes(
  780. const TensorLayout& x, const TensorLayout& dy,
  781. const TensorLayout& saved_batch_mean,
  782. const TensorLayout& saved_batch_variance,
  783. const TensorLayout& bn_scale, const TensorLayout& d_bn_scale,
  784. const TensorLayout& d_bn_bias, const TensorLayout& dx) = 0;
  785. protected:
  786. void check_exec(const TensorLayout& x, const TensorLayout& dy,
  787. const TensorLayout& saved_batch_mean,
  788. const TensorLayout& saved_batch_variance,
  789. const TensorLayout& bn_scale,
  790. const TensorLayout& d_bn_scale,
  791. const TensorLayout& d_bn_bias, const TensorLayout& dx,
  792. size_t workspace_in_bytes);
  793. };
  794. class LRNBase : public OperatorBase {
  795. DEF_OPR_IMPL_CTOR(LRNBase, OperatorBase);
  796. DEF_OPR_PARAM(LRN);
  797. protected:
  798. void check_param();
  799. };
  800. class LRNForward : public LRNBase {
  801. DEF_OPR_IMPL(LRNForward, LRNBase, 1, 1);
  802. public:
  803. /**
  804. * \see ImageNet Classification with Deep Convolutional Neural Networks
  805. * \param[in] src (n, c, h, w)
  806. * \param[out] dst (n, c, h, w)
  807. *
  808. * src and dst must have the same shape.
  809. * src and dst must be contiguous.
  810. */
  811. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
  812. _megdnn_workspace workspace) = 0;
  813. void deduce_layout(const TensorLayout& src, TensorLayout& dst);
  814. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  815. const TensorLayout& dst) = 0;
  816. protected:
  817. void check_exec(const TensorLayout& src, const TensorLayout& dst,
  818. size_t workspace_in_bytes);
  819. };
  820. using LRN = LRNForward;
  821. class LRNBackward : public LRNBase {
  822. DEF_OPR_IMPL(LRNBackward, LRNBase, 3, 1);
  823. public:
  824. /**
  825. * \param[in] src the `src' parameter in LRNForward::exec
  826. * \param[in] dst the `dst' parameter in LRNForward::exec
  827. * \param[in] diff the backpropagated gradient wrt. dst
  828. * \param[out] grad the backpropagated gradient wrt. src
  829. *
  830. * All tensors should be contiguous and of the same shape.
  831. */
  832. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in dst,
  833. _megdnn_tensor_in diff, _megdnn_tensor_out grad,
  834. _megdnn_workspace workspace) = 0;
  835. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  836. const TensorLayout& dst,
  837. const TensorLayout& diff,
  838. const TensorLayout& grad) = 0;
  839. protected:
  840. void check_exec(const TensorLayout& src, const TensorLayout& dst,
  841. const TensorLayout& diff, const TensorLayout& grad,
  842. size_t workspace_in_bytes);
  843. };
  844. class ROIPoolingBase : public OperatorBase {
  845. DEF_OPR_IMPL_CTOR(ROIPoolingBase, OperatorBase);
  846. DEF_OPR_PARAM(ROIPooling);
  847. protected:
  848. void check_layout_fwd(const TensorLayout& src, const TensorLayout& rois,
  849. const TensorLayout& dst, const TensorLayout& index);
  850. };
  851. class ROIPoolingForward : public ROIPoolingBase {
  852. DEF_OPR_IMPL(ROIPoolingForward, ROIPoolingBase, 2, 2);
  853. public:
  854. /**
  855. * \param[in] src (n, c, ih, iw)
  856. * \param[in] rois (m, 5)
  857. * \param[out] dst (m, c, oh, ow)
  858. * \param[out] index (m, c, oh, ow) if mode is MAX, (0) if mode is AVERAGE
  859. *
  860. * The internal implementation is akin to
  861. * https://github.com/rbgirshick/caffe-fast-rcnn .d
  862. * Note that rois(, 0) denotes the input image index. We store it as
  863. * a float, but it should be an integer instead.
  864. *
  865. * index is a temporary tensor to facilitate its backward operator.
  866. * It is used to store argmax indicex in MAX mode, and it is not used
  867. * in AVERAGE mode.
  868. */
  869. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in rois,
  870. _megdnn_tensor_out dst, _megdnn_tensor_out index,
  871. _megdnn_workspace workspace) = 0;
  872. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  873. const TensorLayout& rois,
  874. const TensorLayout& dst,
  875. const TensorLayout& index) = 0;
  876. protected:
  877. void check_exec(const TensorLayout& src, const TensorLayout& rois,
  878. const TensorLayout& dst, const TensorLayout& index,
  879. size_t workspace_in_bytes);
  880. };
  881. using ROIPooling = ROIPoolingForward;
  882. class ROIPoolingBackward : public ROIPoolingBase {
  883. DEF_OPR_IMPL(ROIPoolingBackward, ROIPoolingBase, 4, 1);
  884. public:
  885. /**
  886. * \param[in] diff the backpropagated gradient wrt. dst
  887. * \param[in] src the `src' parameter in ROIPoolingForward::exec
  888. * \param[in] rois the `rois' parameter in ROIPoolingForward::exec
  889. * \param[in] index the `index' parameter in ROIPoolingForward::exec
  890. * \param[out] grad the backpropagated gradient wrt. src
  891. */
  892. virtual void exec(_megdnn_tensor_in diff, _megdnn_tensor_in src,
  893. _megdnn_tensor_in rois, _megdnn_tensor_in index,
  894. _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
  895. virtual size_t get_workspace_in_bytes(const TensorLayout& diff,
  896. const TensorLayout& src,
  897. const TensorLayout& rois,
  898. const TensorLayout& index,
  899. const TensorLayout& grad) = 0;
  900. protected:
  901. void check_exec(const TensorLayout& diff, const TensorLayout& src,
  902. const TensorLayout& rois, const TensorLayout& index,
  903. const TensorLayout& grad, size_t workspace_in_bytes);
  904. };
  905. class Convolution3DBase : public OperatorBase {
  906. DEF_OPR_IMPL_CTOR(Convolution3DBase, OperatorBase);
  907. DEF_OPR_PARAM(Convolution3D);
  908. public:
  909. static constexpr size_t MAX_SPATIAL_DIM = 3;
  910. using Mode = Param::Mode;
  911. struct CanonizedFilterMeta {
  912. DTypeEnum dtype_enum;
  913. Param::Format format;
  914. uint32_t
  915. //! whether filter should be flipped (i.e. is CONVOLUTION)
  916. should_flip,
  917. group, //!< number of groups
  918. icpg, //!< input channels per group
  919. ocpg, //!< output channels per group
  920. spatial_ndim, stride[MAX_SPATIAL_DIM], padding[MAX_SPATIAL_DIM],
  921. //! spatial dim
  922. spatial[MAX_SPATIAL_DIM], dilation[MAX_SPATIAL_DIM],
  923. //! spatial dim with dilation applied
  924. dilated_spatial[MAX_SPATIAL_DIM];
  925. } MEGDNN_PACKED;
  926. protected:
  927. CanonizedFilterMeta deduce_layout_fwd(const TensorLayout& src,
  928. const TensorLayout& filter,
  929. TensorLayout& dst) const;
  930. CanonizedFilterMeta check_layout_fwd(const TensorLayout& src,
  931. const TensorLayout& filter,
  932. const TensorLayout& dst) const;
  933. CanonizedFilterMeta make_canonized_filter_meta(
  934. size_t src_ndim, const TensorLayout& filter) const;
  935. };
  936. class Convolution3DForward
  937. : public Convolution3DBase,
  938. public detail::MultiAlgoOpr<Convolution3DForward, 3> {
  939. DEF_OPR_IMPL(Convolution3DForward, Convolution3DBase, 2, 1);
  940. public:
  941. /**
  942. * \param[in] src (n, ic, id, ih, iw)
  943. * \param[in] filter (oc, ic, fd, fh, fw)
  944. * \param[out] dst (n, oc, od, oh, ow)
  945. */
  946. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter,
  947. _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0;
  948. void deduce_layout(const TensorLayout& src, const TensorLayout& filter,
  949. TensorLayout& dst);
  950. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  951. const TensorLayout& filter,
  952. const TensorLayout& dst) = 0;
  953. protected:
  954. CanonizedFilterMeta check_exec(const TensorLayout& src,
  955. const TensorLayout& filter,
  956. const TensorLayout& dst,
  957. size_t workspace_in_bytes);
  958. };
  959. using Convolution3D = Convolution3DForward;
  960. class Convolution3DBackwardData
  961. : public Convolution3DBase,
  962. public detail::MultiAlgoOpr<Convolution3DBackwardData, 3> {
  963. DEF_OPR_IMPL(Convolution3DBackwardData, Convolution3DBase, 2, 1);
  964. public:
  965. /**
  966. * \param[in] filter (oc, ic, fd, fh, fw)
  967. * \param[in] diff (n, oc, od, oh, ow)
  968. * \param[out] grad (n, ic, id, ih, iw)
  969. */
  970. virtual void exec(_megdnn_tensor_in filter, _megdnn_tensor_in diff,
  971. _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
  972. virtual size_t get_workspace_in_bytes(const TensorLayout& filter,
  973. const TensorLayout& diff,
  974. const TensorLayout& grad) = 0;
  975. void deduce_layout(const TensorLayout& filter, const TensorLayout& diff,
  976. TensorLayout& grad);
  977. protected:
  978. CanonizedFilterMeta check_exec(const TensorLayout& filter,
  979. const TensorLayout& diff,
  980. const TensorLayout& grad,
  981. size_t workspace_in_bytes);
  982. };
  983. class Convolution3DBackwardFilter
  984. : public Convolution3DBase,
  985. public detail::MultiAlgoOpr<Convolution3DBackwardFilter, 3> {
  986. DEF_OPR_IMPL(Convolution3DBackwardFilter, Convolution3DBase, 2, 1);
  987. public:
  988. /**
  989. * \param[in] src (n, ic, id, ih, iw)
  990. * \param[in] diff (n, oc, od, oh, ow)
  991. * \param[out] grad (oc, ic, fd, fh, fw)
  992. */
  993. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in diff,
  994. _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
  995. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  996. const TensorLayout& diff,
  997. const TensorLayout& grad) = 0;
  998. protected:
  999. CanonizedFilterMeta check_exec(const TensorLayout& src,
  1000. const TensorLayout& diff,
  1001. const TensorLayout& grad,
  1002. size_t workspace_in_bytes);
  1003. };
  1004. class LocalShareBase : public OperatorBase {
  1005. DEF_OPR_IMPL_CTOR(LocalShareBase, OperatorBase);
  1006. DEF_OPR_PARAM(LocalShare);
  1007. protected:
  1008. void deduce_layout_fwd(const TensorLayout& src, const TensorLayout& filter,
  1009. TensorLayout& dst);
  1010. void check_layout_fwd(const TensorLayout& src, const TensorLayout& filter,
  1011. const TensorLayout& dst);
  1012. };
  1013. class LocalShareForward : public LocalShareBase,
  1014. public detail::MultiAlgoOpr<LocalShareForward, 3> {
  1015. DEF_OPR_IMPL(LocalShareForward, LocalShareBase, 2, 1);
  1016. public:
  1017. /**
  1018. * \param[in] src (N, IC, IH, IW)
  1019. * \param[in] filter (G, spatial_groups_h, spatial_groups_w, IC / G,
  1020. * FH, FW, OC / G)
  1021. * \param[out] dst (N, OC, OH, OW)
  1022. */
  1023. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter,
  1024. _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0;
  1025. /**
  1026. * \brief deduce layout of the ouput tensor
  1027. */
  1028. void deduce_layout(const TensorLayout& src, const TensorLayout& filter,
  1029. TensorLayout& dst);
  1030. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  1031. const TensorLayout& filter,
  1032. const TensorLayout& dst) = 0;
  1033. protected:
  1034. void check_exec(const TensorLayout& src, const TensorLayout& filter,
  1035. const TensorLayout& dst, size_t workspace_in_bytes);
  1036. };
  1037. using LocalShare = LocalShareForward;
  1038. class LocalShareBackwardData
  1039. : public LocalShareBase,
  1040. public detail::MultiAlgoOpr<LocalShareBackwardData, 3> {
  1041. DEF_OPR_IMPL(LocalShareBackwardData, LocalShareBase, 2, 1);
  1042. public:
  1043. /**
  1044. * \param[in] filter (G, spatial_groups_h, spatial_groups_w, IC / G,
  1045. * FH, FW, OC / G)
  1046. * \param[in] diff (N, OC, OH, OW)
  1047. * \param[out] grad (N, IC, IH, IW)
  1048. */
  1049. virtual void exec(_megdnn_tensor_in filter, _megdnn_tensor_in diff,
  1050. _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
  1051. virtual size_t get_workspace_in_bytes(const TensorLayout& filter,
  1052. const TensorLayout& diff,
  1053. const TensorLayout& grad) = 0;
  1054. void deduce_layout(const TensorLayout& filter, const TensorLayout& diff,
  1055. TensorLayout& grad);
  1056. protected:
  1057. void check_exec(const TensorLayout& filter, const TensorLayout& diff,
  1058. const TensorLayout& grad, size_t workspace_in_bytes);
  1059. };
  1060. class LocalShareBackwardFilter
  1061. : public LocalShareBase,
  1062. public detail::MultiAlgoOpr<LocalShareBackwardFilter, 3> {
  1063. DEF_OPR_IMPL(LocalShareBackwardFilter, LocalShareBase, 2, 1);
  1064. public:
  1065. /**
  1066. * \param[in] src (N, IC, IH, IW)
  1067. * \param[in] diff (N, OC, OH, OW)
  1068. * \param[out] grad (G, spatial_groups_h, spatial_groups_w, IC / G,
  1069. * FH, FW, OC / G)
  1070. */
  1071. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in diff,
  1072. _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
  1073. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  1074. const TensorLayout& diff,
  1075. const TensorLayout& grad) = 0;
  1076. protected:
  1077. void check_exec(const TensorLayout& src, const TensorLayout& diff,
  1078. const TensorLayout& grad, size_t workspace_in_bytes);
  1079. };
  1080. class ROIAlignBase : public OperatorBase {
  1081. DEF_OPR_IMPL_CTOR(ROIAlignBase, OperatorBase);
  1082. DEF_OPR_PARAM(ROIAlign);
  1083. protected:
  1084. void deduce_layout_fwd(const TensorLayout& src, const TensorLayout& rois,
  1085. TensorLayout& dst, TensorLayout& index);
  1086. void check_layout_fwd(const TensorLayout& src, const TensorLayout& rois,
  1087. const TensorLayout& dst, const TensorLayout& index);
  1088. };
  1089. class ROIAlignForward : public ROIAlignBase {
  1090. DEF_OPR_IMPL(ROIAlignForward, ROIAlignBase, 2, 2);
  1091. public:
  1092. /**
  1093. * \param[in] src (n, c, ih, iw)
  1094. * \param[in] rois (m, 5)
  1095. * \param[out] dst (m, c, oh, ow)
  1096. * \param[out] index (m, c, oh, ow) if mode is MAX, (0) if mode is AVERAGE
  1097. *
  1098. * Note that rois(, 0) denotes the input image index. We store it as
  1099. * a float, but it should be an integer instead.
  1100. *
  1101. * index is a temporary tensor to facilitate its backward operator.
  1102. * It is used to store argmax indicex in MAX mode, and it is not used
  1103. * in AVERAGE mode.
  1104. */
  1105. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in rois,
  1106. _megdnn_tensor_out dst, _megdnn_tensor_out index,
  1107. _megdnn_workspace workspace) = 0;
  1108. void deduce_layout(const TensorLayout& src, const TensorLayout& rois,
  1109. TensorLayout& dst, TensorLayout& index);
  1110. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  1111. const TensorLayout& rois,
  1112. const TensorLayout& dst,
  1113. const TensorLayout& index) = 0;
  1114. protected:
  1115. void check_exec(const TensorLayout& src, const TensorLayout& rois,
  1116. const TensorLayout& dst, const TensorLayout& index,
  1117. size_t workspace_in_bytes);
  1118. };
  1119. using ROIAlign = ROIAlignForward;
  1120. class ROIAlignBackward : public ROIAlignBase {
  1121. DEF_OPR_IMPL(ROIAlignBackward, ROIAlignBase, 3, 1);
  1122. public:
  1123. /**
  1124. * \param[in] diff the backpropagated gradient wrt. dst
  1125. * \param[in] rois the `rois' parameter in ROIAlignForward::exec
  1126. * \param[in] index the `index' parameter in ROIAlignForward::exec
  1127. * \param[out] grad the backpropagated gradient wrt. src
  1128. */
  1129. virtual void exec(_megdnn_tensor_in diff, _megdnn_tensor_in rois,
  1130. _megdnn_tensor_in index, _megdnn_tensor_out grad,
  1131. _megdnn_workspace workspace) = 0;
  1132. virtual size_t get_workspace_in_bytes(const TensorLayout& diff,
  1133. const TensorLayout& rois,
  1134. const TensorLayout& index,
  1135. const TensorLayout& grad) = 0;
  1136. protected:
  1137. void check_exec(const TensorLayout& diff, const TensorLayout& rois,
  1138. const TensorLayout& index, const TensorLayout& grad,
  1139. size_t workspace_in_bytes);
  1140. };
  1141. class DeformableConvBase : public OperatorBase {
  1142. DEF_OPR_IMPL_CTOR(DeformableConvBase, OperatorBase);
  1143. DEF_OPR_PARAM(Convolution);
  1144. public:
  1145. static constexpr size_t MAX_SPATIAL_DIM = 2;
  1146. struct CanonizedFilterMeta : Convolution::CanonizedFilterMeta {
  1147. uint32_t deformable_group;
  1148. };
  1149. protected:
  1150. CanonizedFilterMeta make_canonized_filter_meta(
  1151. size_t src_ndim, const TensorLayout& filter,
  1152. const TensorLayout& offset) const;
  1153. void deduce_layout_fwd(const TensorLayout& im, const TensorLayout& filter,
  1154. const TensorLayout& mask, const TensorLayout& offset,
  1155. TensorLayout& dst);
  1156. void check_layout_fwd(const TensorLayout& src, const TensorLayout& filter,
  1157. const TensorLayout& mask, const TensorLayout& offset,
  1158. const TensorLayout& dst);
  1159. };
  1160. class DeformableConvForward
  1161. : public DeformableConvBase,
  1162. public detail::MultiAlgoOpr<DeformableConvForward, 5> {
  1163. DEF_OPR_IMPL(DeformableConvForward, DeformableConvBase, 4, 1);
  1164. public:
  1165. /**
  1166. * \param[in] im (n, ic, ih, iw)
  1167. * \param[in] filter (oc, ic, fh, fw)
  1168. * \param[in] offset (dg, 2, fh, fw, oh, ow)
  1169. * \param[in] mask (dg, fh, fw, oh, ow)
  1170. * \param[out] dst (n, oc, oh, ow)
  1171. */
  1172. virtual void exec(_megdnn_tensor_in im, _megdnn_tensor_in filter,
  1173. _megdnn_tensor_in offset, _megdnn_tensor_in mask,
  1174. _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0;
  1175. void deduce_layout(const TensorLayout& im, const TensorLayout& filter,
  1176. const TensorLayout& offset, const TensorLayout& mask,
  1177. TensorLayout& dst);
  1178. virtual size_t get_workspace_in_bytes(const TensorLayout& im,
  1179. const TensorLayout& filter,
  1180. const TensorLayout& offset,
  1181. const TensorLayout& mask,
  1182. const TensorLayout& dst) = 0;
  1183. protected:
  1184. CanonizedFilterMeta check_exec(const TensorLayout& im,
  1185. const TensorLayout& filter,
  1186. const TensorLayout& offset,
  1187. const TensorLayout& mask,
  1188. const TensorLayout& dst,
  1189. size_t workspace_in_bytes);
  1190. };
  1191. using DeformableConv = DeformableConvForward;
  1192. /**
  1193. * \brief DeformableConvBackwardFilter operator.
  1194. *
  1195. * Calculating the gradient wrt. convolution filter.
  1196. */
  1197. class DeformableConvBackwardFilter
  1198. : public DeformableConvBase,
  1199. public detail::MultiAlgoOpr<DeformableConvBackwardFilter, 5> {
  1200. DEF_OPR_IMPL(DeformableConvBackwardFilter, DeformableConvBase, 4, 1);
  1201. public:
  1202. /**
  1203. * \param[in] im (oc, ic, fh, fw)
  1204. * \param[in] offset (dg, 2, fh, fw, oh, ow)
  1205. * \param[in] mask (dg, fh, fw, oh, ow)
  1206. * \param[in] out_grad (n, oc, oh, ow)
  1207. * \param[out] filter_grad (oc, ic, ih, iw)
  1208. */
  1209. virtual void exec(_megdnn_tensor_in im, _megdnn_tensor_in offset,
  1210. _megdnn_tensor_in mask, _megdnn_tensor_in out_grad,
  1211. _megdnn_tensor_out filter_grad,
  1212. _megdnn_workspace workspace) = 0;
  1213. virtual size_t get_workspace_in_bytes(const TensorLayout& im,
  1214. const TensorLayout& offset,
  1215. const TensorLayout& mask,
  1216. const TensorLayout& out_grad,
  1217. const TensorLayout& filter_grad) = 0;
  1218. void deduce_layout(const TensorLayout& im, const TensorLayout& offset,
  1219. const TensorLayout& mask, const TensorLayout& out_grad,
  1220. TensorLayout& filter_grad);
  1221. protected:
  1222. CanonizedFilterMeta check_exec(const TensorLayout& im,
  1223. const TensorLayout& offset,
  1224. const TensorLayout& mask,
  1225. const TensorLayout& out_grad,
  1226. const TensorLayout& filter_grad,
  1227. size_t workspace_in_bytes);
  1228. };
  1229. /**
  1230. * \brief DeformableConvBackwardData operator.
  1231. *
  1232. * Calculating the gradient wrt. convolution input data, offset and mask.
  1233. */
  1234. class DeformableConvBackwardData
  1235. : public DeformableConvBase,
  1236. public detail::MultiAlgoOpr<DeformableConvBackwardData, 8> {
  1237. DEF_OPR_IMPL(DeformableConvBackwardData, DeformableConvBase, 5, 3);
  1238. public:
  1239. /**
  1240. * \param[in] im (oc, ic, fh, fw)
  1241. * \param[in] filter (oc, ic, fh, fw)
  1242. * \param[in] offset (dg, 2, fh, fw, oh, ow)
  1243. * \param[in] mask (dg, fh, fw, oh, ow)
  1244. * \param[in] out_grad (n, oc, oh, ow)
  1245. * \param[out] im_grad (n, ic, ih, iw)
  1246. * \param[out] offset_grad (dg, 2, fh, fw, oh, ow)
  1247. * \param[out] mask_grad (dg, fh, fw, oh, ow)
  1248. */
  1249. virtual void exec(_megdnn_tensor_in im, _megdnn_tensor_in filter,
  1250. _megdnn_tensor_in offset, _megdnn_tensor_in mask,
  1251. _megdnn_tensor_in out_grad, _megdnn_tensor_out im_grad,
  1252. _megdnn_tensor_out offset_grad,
  1253. _megdnn_tensor_out mask_grad,
  1254. _megdnn_workspace workspace) = 0;
  1255. virtual size_t get_workspace_in_bytes(
  1256. const TensorLayout& im, const TensorLayout& filter,
  1257. const TensorLayout& offset, const TensorLayout& mask,
  1258. const TensorLayout& out_grad, const TensorLayout& im_grad,
  1259. const TensorLayout& offset_grad, const TensorLayout& mask_grad) = 0;
  1260. void deduce_layout(const TensorLayout& im, const TensorLayout& filter,
  1261. const TensorLayout& offset, const TensorLayout& mask,
  1262. const TensorLayout& out_grad, TensorLayout& im_grad,
  1263. TensorLayout& offset_grad, TensorLayout& mask_grad);
  1264. protected:
  1265. CanonizedFilterMeta check_exec(
  1266. const TensorLayout& im, const TensorLayout& filter,
  1267. const TensorLayout& offset, const TensorLayout& mask,
  1268. const TensorLayout& out_grad, const TensorLayout& im_grad,
  1269. const TensorLayout& offset_grad, const TensorLayout& mask_grad,
  1270. size_t workspace_in_bytes);
  1271. };
  1272. class DeformablePSROIPoolingBase : public OperatorBase {
  1273. DEF_OPR_IMPL_CTOR(DeformablePSROIPoolingBase, OperatorBase);
  1274. DEF_OPR_PARAM(DeformablePSROIPooling);
  1275. protected:
  1276. void deduce_layout_fwd(const TensorLayout& data, const TensorLayout& trans,
  1277. const TensorLayout& rois, TensorLayout& out_data,
  1278. TensorLayout& out_count);
  1279. void check_layout_fwd(const TensorLayout& data, const TensorLayout& trans,
  1280. const TensorLayout& rois,
  1281. const TensorLayout& out_data,
  1282. const TensorLayout& out_count,
  1283. size_t workspace_in_bytes);
  1284. };
  1285. class DeformablePSROIPoolingForward : public DeformablePSROIPoolingBase {
  1286. DEF_OPR_IMPL(DeformablePSROIPoolingForward, DeformablePSROIPoolingBase, 3,
  1287. 2);
  1288. public:
  1289. /**
  1290. * \param[in] data (oc, ic, ih, iw)
  1291. * \param[in] rois (xx, xx, xx, xx)
  1292. * \param[in] trans (oc, ic, fh, fw)
  1293. * \param[out] out_data ( n, ic, ih, iw)
  1294. * \param[out] out_count ( n, ic, ih, iw)
  1295. */
  1296. virtual size_t get_workspace_in_bytes(const TensorLayout& data,
  1297. const TensorLayout& rois,
  1298. const TensorLayout& trans,
  1299. const TensorLayout& out_data,
  1300. const TensorLayout& out_count) = 0;
  1301. virtual void exec(_megdnn_tensor_in data, _megdnn_tensor_in rois,
  1302. _megdnn_tensor_in trans, _megdnn_tensor_out out_data,
  1303. _megdnn_tensor_out out_count,
  1304. _megdnn_workspace workspace) = 0;
  1305. void deduce_layout(const TensorLayout& data, const TensorLayout& rois,
  1306. const TensorLayout& trans, TensorLayout& out_data,
  1307. TensorLayout& out_count);
  1308. void check_exec(const TensorLayout& data, const TensorLayout& rois,
  1309. const TensorLayout& trans, const TensorLayout& out_data,
  1310. const TensorLayout& out_count, size_t workspace_in_bytes);
  1311. };
  1312. using DeformablePSROIPooling = DeformablePSROIPoolingForward;
  1313. class DeformablePSROIPoolingBackward : public DeformablePSROIPoolingBase {
  1314. DEF_OPR_IMPL(DeformablePSROIPoolingBackward, DeformablePSROIPoolingBase, 5,
  1315. 2);
  1316. public:
  1317. /**
  1318. * \param[in] data (oc, ic, ih, iw)
  1319. * \param[in] rois (xx, xx, xx, xx)
  1320. * \param[in] trans (oc, ic, fh, fw)
  1321. * \param[in] out_diff (xx, xx, xx, xx)
  1322. * \param[in] out_count (xx, xx, xx, xx)
  1323. * \param[out] data_diff ( n, ic, ih, iw)
  1324. * \param[out] trans_diff ( n, ic, ih, iw)
  1325. */
  1326. virtual void exec(_megdnn_tensor_in data, _megdnn_tensor_in rois,
  1327. _megdnn_tensor_in trans, _megdnn_tensor_in out_diff,
  1328. _megdnn_tensor_in out_count, _megdnn_tensor_out data_diff,
  1329. _megdnn_tensor_out trans_diff,
  1330. _megdnn_workspace workspace) = 0;
  1331. virtual size_t get_workspace_in_bytes(const TensorLayout& data,
  1332. const TensorLayout& rois,
  1333. const TensorLayout& trans,
  1334. const TensorLayout& out_diff,
  1335. const TensorLayout& out_count,
  1336. const TensorLayout& data_diff,
  1337. const TensorLayout& trans_diff) = 0;
  1338. void check_exec(const TensorLayout& data, const TensorLayout& rois,
  1339. const TensorLayout& trans, const TensorLayout& out_diff,
  1340. const TensorLayout& out_count,
  1341. const TensorLayout& data_diff,
  1342. const TensorLayout& trans_diff, size_t workspace_in_bytes);
  1343. };
  1344. class BatchConvBiasForward
  1345. : public ConvolutionBase<param::BatchConvBias>,
  1346. public detail::MultiAlgoOpr<BatchConvBiasForward, 5> {
  1347. DEF_OPR_IMPL(BatchConvBiasForward, ConvolutionBase, 4, 1);
  1348. public:
  1349. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter,
  1350. _megdnn_tensor_in bias, _megdnn_tensor_in z,
  1351. _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0;
  1352. void deduce_dtype(DType src, DType filter, DType bias, DType z, DType& dst);
  1353. void deduce_layout(const TensorLayout& src, const TensorLayout& filter,
  1354. const TensorLayout& bias, const TensorLayout& z,
  1355. TensorLayout& dst);
  1356. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  1357. const TensorLayout& filter,
  1358. const TensorLayout& bias,
  1359. const TensorLayout& z,
  1360. const TensorLayout& dst) = 0;
  1361. protected:
  1362. CanonizedFilterMeta check_exec(const TensorLayout& src,
  1363. const TensorLayout& filter,
  1364. const TensorLayout& bias,
  1365. const TensorLayout& z,
  1366. const TensorLayout& dst,
  1367. size_t workspace_in_bytes);
  1368. };
  1369. using BatchConvBias = BatchConvBiasForward;
  1370. class FakeQuantBase : public OperatorBase {
  1371. DEF_OPR_IMPL_CTOR(FakeQuantBase, OperatorBase);
  1372. DEF_OPR_PARAM(FakeQuant);
  1373. protected:
  1374. void deduce_layout_fwd(const TensorLayout& input, TensorLayout& output);
  1375. void check_layout_fwd(const TensorLayout& input, const TensorLayout& scale,
  1376. const TensorLayout& zero_point,
  1377. const TensorLayout& output);
  1378. };
  1379. class FakeQuantForward : public FakeQuantBase {
  1380. DEF_OPR_IMPL(FakeQuantForward, FakeQuantBase, 3, 1);
  1381. public:
  1382. virtual void exec(_megdnn_tensor_in input, _megdnn_tensor_in scale,
  1383. _megdnn_tensor_in zero_point, _megdnn_tensor_out output,
  1384. _megdnn_workspace workspace) = 0;
  1385. void deduce_layout(const TensorLayout& input, const TensorLayout& scale,
  1386. const TensorLayout& zero_point, TensorLayout& output);
  1387. virtual size_t get_workspace_in_bytes(const TensorLayout& input,
  1388. const TensorLayout& scale,
  1389. const TensorLayout& zero_point,
  1390. const TensorLayout& output) = 0;
  1391. protected:
  1392. void check_exec(const TensorLayout& input, const TensorLayout& scale,
  1393. const TensorLayout& zero_point, const TensorLayout& output,
  1394. size_t workspace_in_bytes);
  1395. };
  1396. using FakeQuant = FakeQuantForward;
  1397. class FakeQuantBackward : public FakeQuantBase {
  1398. DEF_OPR_IMPL(FakeQuantBackward, FakeQuantBase, 4, 1);
  1399. public:
  1400. virtual void exec(_megdnn_tensor_in diff, _megdnn_tensor_in input,
  1401. _megdnn_tensor_in scale, _megdnn_tensor_in zero_point,
  1402. _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
  1403. virtual size_t get_workspace_in_bytes(const TensorLayout& diff,
  1404. const TensorLayout& input,
  1405. const TensorLayout& scale,
  1406. const TensorLayout& zero_point,
  1407. const TensorLayout& grad) = 0;
  1408. protected:
  1409. void check_exec(const TensorLayout& diff, const TensorLayout& input,
  1410. const TensorLayout& scale, const TensorLayout& zero_point,
  1411. const TensorLayout& grad, size_t workspace_in_bytes);
  1412. };
  1413. } // namespace megdnn
  1414. #include "megdnn/internal/opr_header_epilogue.h"
  1415. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台