You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

nn.h 62 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493
  1. /**
  2. * \file dnn/include/megdnn/oprs/nn.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #pragma once
  12. #include "megdnn/internal/opr_header_prologue.h"
  13. namespace megdnn {
  14. class SeparableConvBase : public OperatorBase {
  15. DEF_OPR_IMPL_CTOR(SeparableConvBase, OperatorBase);
  16. DEF_OPR_PARAM(SeparableConv);
  17. public:
  18. using Mode = Param::Mode;
  19. protected:
  20. void deduce_layout_fwd(const TensorLayout& src,
  21. const TensorLayout& filter_x,
  22. const TensorLayout& filter_y, TensorLayout& dst);
  23. void check_layout_fwd(const TensorLayout& src, const TensorLayout& filter_x,
  24. const TensorLayout& filter_y,
  25. const TensorLayout& dst);
  26. };
  27. class SeparableConvForward : public SeparableConvBase {
  28. DEF_OPR_IMPL(SeparableConvForward, SeparableConvBase, 3, 1);
  29. public:
  30. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter_x,
  31. _megdnn_tensor_in filter_y, _megdnn_tensor_out dst,
  32. _megdnn_workspace workspace) = 0;
  33. void deduce_layout(const TensorLayout& src, const TensorLayout& filter_x,
  34. const TensorLayout& filter_y, TensorLayout& dst);
  35. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  36. const TensorLayout& filter_x,
  37. const TensorLayout& filter_y,
  38. const TensorLayout& dst) = 0;
  39. protected:
  40. void check_exec(const TensorLayout& src, const TensorLayout& filter_x,
  41. const TensorLayout& filter_y, const TensorLayout& dst,
  42. size_t workspace_in_bytes);
  43. };
  44. using SeparableConv = SeparableConvForward;
  45. /**
  46. * \brief base class for convolution operation
  47. *
  48. * This operator is supposed to perform convolution on arbitrary input
  49. * dimensions. The input/output format is N, C, dims..., and kernel format can
  50. * take two forms:
  51. * 1. OC, IC, dims..., for conventional dense convolution
  52. * 2. GROUP, OC_PER_GRP, IC_PER_GRP, dims... for sparse group convolution
  53. *
  54. * Currently, only 2D images are supported.
  55. */
  56. template <typename Parameter>
  57. class ConvolutionBase : public OperatorBase {
  58. DEF_OPR_IMPL_CTOR(ConvolutionBase, OperatorBase);
  59. using Param = Parameter;
  60. public:
  61. Param& param() { return m_param; }
  62. const Param& param() const { return m_param; }
  63. protected:
  64. Param m_param;
  65. public:
  66. static constexpr size_t MAX_SPATIAL_DIM = 2;
  67. using Mode = typename Param::Mode;
  68. struct CanonizedFilterMeta {
  69. DType dtype;
  70. typename Param::Format format;
  71. uint32_t
  72. //! whether filter should be flipped (i.e. is CONVOLUTION)
  73. should_flip,
  74. group, //!< number of groups
  75. icpg, //!< input channels per group
  76. ocpg, //!< output channels per group
  77. spatial_ndim, stride[MAX_SPATIAL_DIM], padding[MAX_SPATIAL_DIM],
  78. //! spatial dim
  79. spatial[MAX_SPATIAL_DIM], dilation[MAX_SPATIAL_DIM],
  80. //! spatial dim with dilation applied
  81. dilated_spatial[MAX_SPATIAL_DIM];
  82. //! T should be a ConvolutionBase<Z>::CanonizedFilterMeta
  83. template <typename T>
  84. void copy_from(const T& b) {
  85. dtype = b.dtype;
  86. format = b.format;
  87. should_flip = b.should_flip;
  88. group = b.group;
  89. icpg = b.icpg;
  90. ocpg = b.ocpg;
  91. spatial_ndim = b.spatial_ndim;
  92. memcpy(stride, b.stride, sizeof(stride));
  93. memcpy(padding, b.padding, sizeof(padding));
  94. memcpy(spatial, b.spatial, sizeof(spatial));
  95. memcpy(dilation, b.dilation, sizeof(dilation));
  96. memcpy(dilated_spatial, b.dilated_spatial, sizeof(dilated_spatial));
  97. }
  98. bool operator==(const CanonizedFilterMeta& b) const {
  99. bool flag = true;
  100. flag = flag && (format == b.format);
  101. flag = flag && (dtype == b.dtype);
  102. flag = flag && (should_flip == b.should_flip);
  103. flag = flag && (group == b.group);
  104. flag = flag && (icpg == b.icpg);
  105. flag = flag && (ocpg == b.ocpg);
  106. flag = flag && (spatial_ndim == b.spatial_ndim);
  107. if (flag) {
  108. for (uint32_t i = 0; i < spatial_ndim; ++i) {
  109. flag = flag && (stride[i] == b.stride[i]);
  110. flag = flag && (padding[i] == b.padding[i]);
  111. flag = flag && (spatial[i] == b.spatial[i]);
  112. flag = flag && (dilation[i] == b.dilation[i]);
  113. flag = flag && (dilated_spatial[i] == b.dilated_spatial[i]);
  114. }
  115. }
  116. return flag;
  117. }
  118. };
  119. struct PreprocessedFilter {
  120. //! user data; its lifetime should be bound to MegDNN Convolution
  121. //! operator
  122. void* algorithm_id;
  123. TensorNDArray tensors;
  124. };
  125. protected:
  126. // Check or deduce output DType
  127. void check_or_deduce_dtype_fwd(DType src, DType filter, DType& dst) const;
  128. CanonizedFilterMeta deduce_layout_fwd(const TensorLayout& src,
  129. const TensorLayout& filter,
  130. TensorLayout& dst) const;
  131. CanonizedFilterMeta check_layout_fwd(const TensorLayout& src,
  132. const TensorLayout& filter,
  133. const TensorLayout& dst) const;
  134. CanonizedFilterMeta make_canonized_filter_meta(
  135. size_t src_ndim, const TensorLayout& filter) const;
  136. };
  137. class MaskPropagate : public OperatorBase {
  138. DEF_OPR_IMPL(MaskPropagate, OperatorBase, 1, 1);
  139. DEF_OPR_PARAM(MaskPropagate);
  140. public:
  141. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
  142. _megdnn_workspace workspace) = 0;
  143. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  144. const TensorLayout& dst) = 0;
  145. void deduce_layout(const TensorLayout& src, TensorLayout& dst);
  146. };
  147. /**
  148. * \brief ConvolutionForward Operator with 0/1 Mask matrix
  149. */
  150. class MaskConvForward : public ConvolutionBase<param::Convolution> {
  151. DEF_OPR_IMPL(MaskConvForward, ConvolutionBase, 3, 1);
  152. public:
  153. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter,
  154. _megdnn_tensor_in mask, _megdnn_tensor_out dst,
  155. _megdnn_workspace worksapce) = 0;
  156. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  157. const TensorLayout& filter,
  158. const TensorLayout& mask,
  159. const TensorLayout& dst) = 0;
  160. void deduce_dtype(DType src, DType filter, DType mask, DType& dst);
  161. void deduce_layout(const TensorLayout& src, const TensorLayout& filter,
  162. const TensorLayout& mask, TensorLayout& dst);
  163. protected:
  164. CanonizedFilterMeta check_exec(const TensorLayout& src,
  165. const TensorLayout& filter,
  166. const TensorLayout& mask,
  167. const TensorLayout& dst,
  168. size_t workspace_in_bytes);
  169. };
  170. using MaskConvolution = MaskConvForward;
  171. /**
  172. * \brief ConvolutionForward operator.
  173. */
  174. class ConvolutionForward : public ConvolutionBase<param::Convolution>,
  175. public detail::MultiAlgoOpr<ConvolutionForward, 3> {
  176. DEF_OPR_IMPL(ConvolutionForward, ConvolutionBase, 2, 1);
  177. public:
  178. /**
  179. * \param[in] src (n, ic, ih, iw)
  180. * \param[in] filter (oc, ic, fh, fw)
  181. * \param[out] dst (n, oc, oh, ow)
  182. */
  183. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter,
  184. _megdnn_tensor_out dst,
  185. const PreprocessedFilter* preprocessed_filter,
  186. _megdnn_workspace workspace) = 0;
  187. virtual void exec_preprocess(const TensorLayout& src_layout,
  188. _megdnn_tensor_in filter,
  189. const TensorLayout& dst_layout,
  190. PreprocessedFilter* preprocessed_filter,
  191. _megdnn_workspace workspace) = 0;
  192. void deduce_dtype(DType src, DType filter, DType& dst);
  193. void deduce_layout(const TensorLayout& src, const TensorLayout& filter,
  194. TensorLayout& dst);
  195. virtual size_t get_workspace_in_bytes(
  196. const TensorLayout& src, const TensorLayout& filter,
  197. const TensorLayout& dst,
  198. const PreprocessedFilter* preprocessed_filter) = 0;
  199. virtual SmallVector<TensorLayout> deduce_preprocessed_filter_layout(
  200. const TensorLayout& src, const TensorLayout& filter,
  201. const TensorLayout& dst) = 0;
  202. virtual size_t get_preprocess_workspace_in_bytes(
  203. const TensorLayout& src, const TensorLayout& filter,
  204. const TensorLayout& dst) = 0;
  205. protected:
  206. CanonizedFilterMeta check_exec(
  207. const TensorLayout& src, const TensorLayout& filter,
  208. const TensorLayout& dst, size_t workspace_in_bytes,
  209. const PreprocessedFilter* preprocessed_filter);
  210. };
  211. using Convolution = ConvolutionForward;
  212. /**
  213. * \brief ConvolutionBackwardData operator.
  214. *
  215. * Calculating the gradient wrt. convolution input data.
  216. */
  217. class ConvolutionBackwardData
  218. : public ConvolutionBase<param::Convolution>,
  219. public detail::MultiAlgoOpr<ConvolutionBackwardData, 3> {
  220. DEF_OPR_IMPL(ConvolutionBackwardData, ConvolutionBase, 2, 1);
  221. public:
  222. /**
  223. * \param[in] filter (oc, ic, fh, fw)
  224. * \param[in] diff (n, oc, oh, ow)
  225. * \param[out] grad (n, ic, ih, iw)
  226. */
  227. virtual void exec(_megdnn_tensor_in filter, _megdnn_tensor_in diff,
  228. _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
  229. virtual size_t get_workspace_in_bytes(const TensorLayout& filter,
  230. const TensorLayout& diff,
  231. const TensorLayout& grad) = 0;
  232. void deduce_dtype(DType filter, DType diff, DType& grad);
  233. void deduce_layout(const TensorLayout& filter, const TensorLayout& diff,
  234. TensorLayout& grad);
  235. protected:
  236. CanonizedFilterMeta check_exec(const TensorLayout& filter,
  237. const TensorLayout& diff,
  238. const TensorLayout& grad,
  239. size_t workspace_in_bytes);
  240. };
  241. /**
  242. * \brief ConvolutionBackwardFilter operator.
  243. *
  244. * Calculating the gradient wrt. convolution filter.
  245. */
  246. class ConvolutionBackwardFilter
  247. : public ConvolutionBase<param::Convolution>,
  248. public detail::MultiAlgoOpr<ConvolutionBackwardFilter, 3> {
  249. DEF_OPR_IMPL(ConvolutionBackwardFilter, ConvolutionBase, 2, 1);
  250. public:
  251. /**
  252. * \param[in] src (n, ic, ih, iw)
  253. * \param[in] diff (n, oc, oh, ow)
  254. * \param[out] grad (oc, ic, fh, fw)
  255. */
  256. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in diff,
  257. _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
  258. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  259. const TensorLayout& diff,
  260. const TensorLayout& grad) = 0;
  261. protected:
  262. CanonizedFilterMeta check_exec(const TensorLayout& src,
  263. const TensorLayout& diff,
  264. const TensorLayout& grad,
  265. size_t workspace_in_bytes);
  266. };
  267. /**
  268. * \brief ConvolutionBias operator
  269. */
  270. class ConvBiasForward : public ConvolutionBase<param::ConvBias>,
  271. public detail::MultiAlgoOpr<ConvBiasForward, 5> {
  272. DEF_OPR_IMPL(ConvBiasForward, ConvolutionBase, 4, 1);
  273. public:
  274. /**
  275. * \param[in] src (n, ic, ih, iw) or (n, ih, iw, ic)
  276. * \param[in] filter (oc, ic, fh, fw) or (oc, fh, fw, ic) or (oc/4, fh, fw,
  277. * 4*ic) \param[in] bias (1, oc, 1, 1) \param[in] z same as dst \param[out]
  278. * dst (n, oc, oh, ow) or (n, oh, ow, oc)
  279. *
  280. * \note if the format is NCHW_WINOGRAD, the filter layout is (alphah,
  281. * alphaw, oc, ic)
  282. */
  283. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter,
  284. _megdnn_tensor_in bias, _megdnn_tensor_in z,
  285. _megdnn_tensor_out dst,
  286. const PreprocessedFilter* preprocessed_filter,
  287. _megdnn_workspace workspace) = 0;
  288. virtual void exec_preprocess(const TensorLayout& src_layout,
  289. _megdnn_tensor_in filter,
  290. const TensorLayout& bias_layout,
  291. const TensorLayout& z_layout,
  292. const TensorLayout& dst_layout,
  293. PreprocessedFilter* preprocessed_filter,
  294. _megdnn_workspace workspace) = 0;
  295. void deduce_dtype(DType src, DType filter, DType bias, DType z, DType& dst);
  296. void deduce_layout(const TensorLayout& src, const TensorLayout& filter,
  297. const TensorLayout& bias, const TensorLayout& z,
  298. TensorLayout& dst);
  299. virtual size_t get_workspace_in_bytes(
  300. const TensorLayout& src, const TensorLayout& filter,
  301. const TensorLayout& bias, const TensorLayout& z,
  302. const TensorLayout& dst,
  303. const PreprocessedFilter* preprocessed_filter) = 0;
  304. virtual size_t get_preprocess_workspace_in_bytes(
  305. const TensorLayout& src, const TensorLayout& filter,
  306. const TensorLayout& bias, const TensorLayout& z,
  307. const TensorLayout& dst) = 0;
  308. virtual SmallVector<TensorLayout> deduce_preprocessed_filter_layout(
  309. const TensorLayout& src, const TensorLayout& filter,
  310. const TensorLayout& bias, const TensorLayout& z,
  311. const TensorLayout& dst) = 0;
  312. static void deduce_winograd_origin_layout_and_param(
  313. const Param::Format format, const size_t output_block_size,
  314. const TensorLayout& src_layout,
  315. const TensorLayout& winograd_filter_layout,
  316. TensorLayout& origin_layout, Param& origin_param);
  317. enum class BiasMode : uint32_t {
  318. NO_BIAS = 0, //!< no bias
  319. BROADCAST_CHANNEL_BIAS, //!< broadcast channel bias, [1, c, 1, 1]
  320. BIAS //!< [N, C, H, W]
  321. };
  322. //! param for winograd algos.
  323. struct WinogradParam {
  324. uint32_t channel_block_size;
  325. uint32_t output_block_size;
  326. uint32_t tile_size;
  327. bool operator==(const WinogradParam& rhs) const {
  328. return channel_block_size == rhs.channel_block_size &&
  329. output_block_size == rhs.output_block_size &&
  330. tile_size == rhs.tile_size;
  331. }
  332. std::string to_string() const;
  333. };
  334. static constexpr WinogradParam INVALID_WINOGRAD_PARAM = {0, 0, 0};
  335. struct DirectParam {
  336. std::string to_string() const { return ""; }
  337. };
  338. struct MatmulParam {
  339. std::string to_string() const { return ""; }
  340. };
  341. struct DefaultParam {
  342. std::string to_string() const { return ""; }
  343. };
  344. //! get algo name, the format is ParamTrait<T>::category:base:p.to_string()
  345. //! \warning: base must not contain :.
  346. template <typename T>
  347. static std::string algo_name(
  348. const std::string& base, const T& p,
  349. param::ConvBias::Format format = param::ConvBias::Format::NCHW);
  350. /*!
  351. * \brief parse algo_name and get WinogradParam from algo name.
  352. *
  353. * \param algo name string
  354. * \return WinogradParam parsed from algo name, use pattern
  355. * winograd:base:m:tile_size.
  356. *
  357. * \warning: INVALID_WINOGRAD_PARAM returns if the algo_name is not matched.
  358. */
  359. static WinogradParam parse_winograd_name(const std::string& algo_name);
  360. protected:
  361. CanonizedFilterMeta check_exec(
  362. const TensorLayout& src, const TensorLayout& filter,
  363. const TensorLayout& bias, const TensorLayout& z,
  364. const TensorLayout& dst, size_t workspace_in_bytes,
  365. const PreprocessedFilter* preprocessed_filter);
  366. };
  367. using ConvBias = ConvBiasForward;
  368. /**
  369. * \brief base class for Conv - Nonline - Pooling
  370. */
  371. class ConvPoolingBase : public OperatorBase {
  372. DEF_OPR_IMPL_CTOR(ConvPoolingBase, OperatorBase);
  373. /**
  374. * \ Param::Method: Two methods to fetch the input data.
  375. * Default methods is WITH_TEXTURE_OBJ.
  376. * If you want to use WITH_SHARED_MEM mode,
  377. * please make sure that the size of
  378. * [ all of the fliter kernels + a channel
  379. * of input data + a channel of output data]
  380. * should be no large than 38KB.
  381. * And the pooling mode should not be "MAX".
  382. */
  383. DEF_OPR_PARAM(ConvPooling);
  384. protected:
  385. virtual void deduce_layout(const TensorLayout& src,
  386. const TensorLayout& filter,
  387. const TensorLayout& bias, TensorLayout& dst) = 0;
  388. virtual void check_layout(const TensorLayout& src,
  389. const TensorLayout& filter,
  390. const TensorLayout& bias, TensorLayout& dst,
  391. size_t workspace_limit_in_bytes) = 0;
  392. };
  393. class ConvPoolingForward : public ConvPoolingBase {
  394. DEF_OPR_IMPL(ConvPoolingForward, ConvPoolingBase, 2, 1);
  395. public:
  396. /**
  397. * \param[in] src input tensor
  398. * \param[out] dst output tensor
  399. */
  400. virtual void exec(const _megdnn_in TensorND src,
  401. const _megdnn_in TensorND filter,
  402. const _megdnn_in TensorND bias, _megdnn_out TensorND dst,
  403. _megdnn_out Workspace workspace) = 0;
  404. virtual void deduce_layout(const TensorLayout& src,
  405. const TensorLayout& filter,
  406. const TensorLayout& bias, TensorLayout& dst) = 0;
  407. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  408. const TensorLayout& filter,
  409. const TensorLayout& bias,
  410. const TensorLayout& dst) = 0;
  411. protected:
  412. virtual void check_layout(const TensorLayout& src,
  413. const TensorLayout& filter,
  414. const TensorLayout& bias, TensorLayout& dst,
  415. size_t workspace_limit_in_bytes) = 0;
  416. };
  417. using ConvPooling = ConvPoolingForward;
  418. class GroupLocalBase : public OperatorBase {
  419. DEF_OPR_IMPL_CTOR(GroupLocalBase, OperatorBase);
  420. DEF_OPR_PARAM(Convolution);
  421. public:
  422. using Mode = Param::Mode;
  423. protected:
  424. void deduce_layout_fwd(const TensorLayout& src, const TensorLayout& filter,
  425. TensorLayout& dst);
  426. void check_layout_fwd(const TensorLayout& src, const TensorLayout& filter,
  427. const TensorLayout& dst);
  428. };
  429. class GroupLocalForward : public GroupLocalBase {
  430. DEF_OPR_IMPL(GroupLocalForward, GroupLocalBase, 2, 1);
  431. public:
  432. /**
  433. * \param[in] src (N, IC, IH, IW)
  434. * \param[in] filter (G, OH, OW, IC/G, FH, FW, OC/G)
  435. * \param[out] dst (N, OC, OH, OW)
  436. **/
  437. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter,
  438. _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0;
  439. void deduce_layout(const TensorLayout& src, const TensorLayout& filter,
  440. TensorLayout& dst) {
  441. deduce_layout_fwd(src, filter, dst);
  442. }
  443. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  444. const TensorLayout& filter,
  445. const TensorLayout& dst) = 0;
  446. protected:
  447. void check_exec(const TensorLayout& src, const TensorLayout& filter,
  448. const TensorLayout& dst, size_t workspace_in_bytes);
  449. };
  450. using GroupLocal = GroupLocalForward;
  451. class GroupLocalBackwardData : public GroupLocalBase {
  452. DEF_OPR_IMPL(GroupLocalBackwardData, GroupLocalBase, 2, 1);
  453. public:
  454. virtual void exec(_megdnn_tensor_in filter, _megdnn_tensor_in diff,
  455. _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
  456. virtual size_t get_workspace_in_bytes(const TensorLayout& filter,
  457. const TensorLayout& diff,
  458. const TensorLayout& grad) = 0;
  459. protected:
  460. void check_exec(const TensorLayout& filter, const TensorLayout& diff,
  461. const TensorLayout& grad, size_t workspace_in_bytes);
  462. };
  463. class GroupLocalBackwardFilter : public GroupLocalBase {
  464. DEF_OPR_IMPL(GroupLocalBackwardFilter, GroupLocalBase, 2, 1);
  465. public:
  466. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in diff,
  467. _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
  468. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  469. const TensorLayout& diff,
  470. const TensorLayout& grad) = 0;
  471. protected:
  472. void check_exec(const TensorLayout& filter, const TensorLayout& diff,
  473. const TensorLayout& grad, size_t workspace_in_bytes);
  474. };
  475. class Images2NeibsBase : public OperatorBase {
  476. DEF_OPR_IMPL_CTOR(Images2NeibsBase, OperatorBase);
  477. DEF_OPR_PARAM(Images2Neibs);
  478. protected:
  479. void deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst);
  480. void check_layout_fwd(const TensorLayout& filter, const TensorLayout& dst);
  481. };
  482. class Images2NeibsForward : public Images2NeibsBase {
  483. DEF_OPR_IMPL(Images2NeibsForward, Images2NeibsBase, 1, 1);
  484. public:
  485. /**
  486. * \param[in] src (N, C, IH, IW)
  487. * \param[out] dst (N, C, OH, OW, window_h, window_w)
  488. *
  489. * \see
  490. * http://deeplearning.net/software/theano/library/tensor/nnet/neighbours.html
  491. *
  492. * \f$ dst_{n, c, oh, ow, wh, ww} = src_{n, c, ih+wh, iw+fw}\f$,
  493. * where \f$ ih=-pad_h+oh*stride_h, iw=-pad_w+ow*stride_w\f$.
  494. */
  495. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
  496. _megdnn_workspace workspace) = 0;
  497. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  498. const TensorLayout& dst) = 0;
  499. void deduce_layout(const TensorLayout& src, TensorLayout& dst);
  500. protected:
  501. void check_exec(const TensorLayout& src, const TensorLayout& dst,
  502. size_t workspace_in_bytes);
  503. };
  504. using Images2Neibs = Images2NeibsForward;
  505. class Images2NeibsBackward : public Images2NeibsBase {
  506. DEF_OPR_IMPL(Images2NeibsBackward, Images2NeibsBase, 1, 1);
  507. public:
  508. /**
  509. * \param[in] diff the backpropagated gradient wrt. dst
  510. * \param[out] grad the backpropagated gradient wrt. src
  511. */
  512. virtual void exec(_megdnn_tensor_in diff, _megdnn_tensor_out grad,
  513. _megdnn_workspace workspace) = 0;
  514. virtual size_t get_workspace_in_bytes(const TensorLayout& diff,
  515. const TensorLayout& grad) = 0;
  516. protected:
  517. void check_exec(const TensorLayout& diff, const TensorLayout& grad,
  518. size_t workspace_in_bytes);
  519. };
  520. /**
  521. * \brief base class for Pooling
  522. */
  523. class PoolingBase : public OperatorBase {
  524. DEF_OPR_IMPL_CTOR(PoolingBase, OperatorBase);
  525. DEF_OPR_PARAM(Pooling);
  526. public:
  527. using Mode = Param::Mode;
  528. protected:
  529. void deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst);
  530. void check_layout_fwd(const TensorLayout& src, const TensorLayout& dst);
  531. };
  532. class PoolingForward : public PoolingBase {
  533. DEF_OPR_IMPL(PoolingForward, PoolingBase, 1, 1);
  534. public:
  535. /**
  536. * \param[in] src input tensor
  537. * \param[out] dst output tensor
  538. */
  539. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
  540. _megdnn_workspace workspace) = 0;
  541. void deduce_layout(const TensorLayout& src, TensorLayout& dst);
  542. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  543. const TensorLayout& dst) = 0;
  544. protected:
  545. void check_exec(const TensorLayout& src, const TensorLayout& dst,
  546. size_t workspace_in_bytes);
  547. };
  548. using Pooling = PoolingForward;
  549. class PoolingBackward : public PoolingBase {
  550. DEF_OPR_IMPL(PoolingBackward, PoolingBase, 3, 1);
  551. public:
  552. /**
  553. * \param[in] src the `src' parameter in PoolingForward::exec
  554. * \param[in] dst the `dst' parameter in PoolingForward::exec
  555. * \param[in] diff the backpropagated gradient wrt. dst
  556. * \param[out] grad the backpropagated gradient wrt. src
  557. */
  558. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in dst,
  559. _megdnn_tensor_in diff, _megdnn_tensor_out grad,
  560. _megdnn_workspace workspace) = 0;
  561. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  562. const TensorLayout& dst,
  563. const TensorLayout& diff,
  564. const TensorLayout& grad) = 0;
  565. protected:
  566. void check_exec(const TensorLayout& src, const TensorLayout& dst,
  567. const TensorLayout& diff, const TensorLayout& grad,
  568. size_t workspace_in_bytes);
  569. };
  570. /**
  571. * \brief base class for Local
  572. */
  573. class LocalBase : public OperatorBase {
  574. DEF_OPR_IMPL_CTOR(LocalBase, OperatorBase);
  575. DEF_OPR_PARAM(Convolution);
  576. public:
  577. using Mode = Param::Mode;
  578. protected:
  579. void deduce_layout_fwd(const TensorLayout& src, const TensorLayout& filter,
  580. TensorLayout& dst);
  581. void check_layout_fwd(const TensorLayout& src, const TensorLayout& filter,
  582. const TensorLayout& dst);
  583. };
  584. class LocalForward : public LocalBase {
  585. DEF_OPR_IMPL(LocalForward, LocalBase, 2, 1);
  586. public:
  587. /**
  588. * \param[in] src (n, ic, ih, iw)
  589. * \param[in] filter (oh, ow, ic, fh, fw, oc)
  590. * \param[out] dst (n, oc, oh, ow)
  591. */
  592. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter,
  593. _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0;
  594. /**
  595. * \brief Deducing output tensor layouts from input tensor layouts.
  596. *
  597. * Be aware that the first and second dimension of `filter' are ignored
  598. * when deducing `dst' layout.
  599. */
  600. void deduce_layout(const TensorLayout& src, const TensorLayout& filter,
  601. TensorLayout& dst);
  602. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  603. const TensorLayout& filter,
  604. const TensorLayout& dst) = 0;
  605. protected:
  606. void check_exec(const TensorLayout& src, const TensorLayout& filter,
  607. const TensorLayout& dst, size_t workspace_in_bytes);
  608. };
  609. using Local = LocalForward;
  610. class LocalBackwardData : public LocalBase {
  611. DEF_OPR_IMPL(LocalBackwardData, LocalBase, 2, 1);
  612. public:
  613. /**
  614. * \param[in] filter (oh, ow, ic, fh, fw, oc)
  615. * \param[in] diff (n, oc, oh, ow)
  616. * \param[out] grad (n, ic, ih, iw)
  617. */
  618. virtual void exec(_megdnn_tensor_in filter, _megdnn_tensor_in diff,
  619. _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
  620. virtual size_t get_workspace_in_bytes(const TensorLayout& filter,
  621. const TensorLayout& diff,
  622. const TensorLayout& grad) = 0;
  623. protected:
  624. void check_exec(const TensorLayout& filter, const TensorLayout& diff,
  625. const TensorLayout& grad, size_t workspace_in_bytes);
  626. };
  627. class LocalBackwardFilter : public LocalBase {
  628. DEF_OPR_IMPL(LocalBackwardFilter, LocalBase, 2, 1);
  629. public:
  630. /**
  631. * \param[in] src (n, ic, ih, iw)
  632. * \param[in] diff (n, oc, oh, ow)
  633. * \param[out] grad (oh, ow, ic, fh, fw, oc)
  634. */
  635. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in diff,
  636. _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
  637. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  638. const TensorLayout& diff,
  639. const TensorLayout& grad) = 0;
  640. protected:
  641. void check_exec(const TensorLayout& src, const TensorLayout& diff,
  642. const TensorLayout& grad, size_t workspace_in_bytes);
  643. };
  644. class BNBase : public OperatorBase {
  645. DEF_OPR_IMPL_CTOR(BNBase, OperatorBase);
  646. DEF_OPR_PARAM(BN);
  647. protected:
  648. void check_param();
  649. };
  650. class BNForward : public BNBase {
  651. DEF_OPR_IMPL(BNForward, BNBase, 6, 5);
  652. public:
  653. /**
  654. * \dst[i] = gemma
  655. * *(x[i]-estimatedMean[k])/sqrt(epsilon+estimatedVariance[k]) + beta \where
  656. * epsilon is a very small value to avoid a "divide by zero" error.
  657. * \param[in] src (n, c, h, w)
  658. * \param[out] dst (n, c, h, w)
  659. * \param[out] mean (see m_param.ParamDim) Global mean.
  660. * \param[out] variance (see m_param.ParamDim) Global variance.
  661. * \Param[out] batch_mean (see m_param.ParamDim)
  662. * Optionally cached intermediate mean from forward pass
  663. * \Param[out] batch_inv_variance (see m_param.ParamDim)
  664. * Optionally cached intermediate variance from forward pass
  665. * src and dst must have the same shape.
  666. * src and dst must be contiguous.
  667. */
  668. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in bn_scale,
  669. _megdnn_tensor_in bn_bias, _megdnn_tensor_inout mean,
  670. _megdnn_tensor_inout variance,
  671. _megdnn_tensor_out batch_mean,
  672. _megdnn_tensor_out batch_inv_variance,
  673. _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0;
  674. void deduce_layout(const TensorLayout& src, TensorLayout& bn_scale,
  675. TensorLayout& bn_bias, TensorLayout& mean,
  676. TensorLayout& variance, TensorLayout& batch_mean,
  677. TensorLayout& batch_inv_variance, TensorLayout& dst);
  678. virtual size_t get_workspace_in_bytes(
  679. const TensorLayout& src, const TensorLayout& bn_scale,
  680. const TensorLayout& bn_bias, const TensorLayout& mean,
  681. const TensorLayout& variance, const TensorLayout& batch_mean,
  682. const TensorLayout& batch_inv_variance,
  683. const TensorLayout& dst) = 0;
  684. protected:
  685. void check_exec(const TensorLayout& src, const TensorLayout& bn_scale,
  686. const TensorLayout& bn_bias, const TensorLayout& mean,
  687. const TensorLayout& variance,
  688. const TensorLayout& batch_mean,
  689. const TensorLayout& batch_inv_variance,
  690. const TensorLayout& dst, size_t workspace_in_bytes);
  691. };
  692. using BN = BNForward;
  693. class BNBackward : public BNBase {
  694. DEF_OPR_IMPL(BNBackward, BNBase, 5, 3);
  695. public:
  696. /**
  697. * \param[in] input data of forwarding propagate.
  698. * \param[in] dy the backpropagated gradient of y.
  699. * \param[out] dx the backpropagated gradient of x.
  700. * \param[out] d_bn_scale, the backpropagated gradient of bn_scale.
  701. * \param[out] d_bn_bias, the backpropagated gradient of bn_bias.
  702. * Optionally cached intermediate results from forward pass
  703. * \param[in] saved_batch_mean mean of the input batch.
  704. Calculated in the forwardpropagation.
  705. * \param[in] saved_batch_variance of the input batch.
  706. Calculated in the forwardpropagation.
  707. */
  708. virtual void exec(_megdnn_tensor_in x, _megdnn_tensor_in dy,
  709. _megdnn_tensor_in saved_batch_mean,
  710. _megdnn_tensor_in saved_batch_variance,
  711. _megdnn_tensor_in bn_scale, _megdnn_tensor_out d_bn_scale,
  712. _megdnn_tensor_out d_bn_bias, _megdnn_tensor_out dx,
  713. _megdnn_workspace workspace) = 0;
  714. virtual size_t get_workspace_in_bytes(
  715. const TensorLayout& x, const TensorLayout& dy,
  716. const TensorLayout& saved_batch_mean,
  717. const TensorLayout& saved_batch_variance,
  718. const TensorLayout& bn_scale, const TensorLayout& d_bn_scale,
  719. const TensorLayout& d_bn_bias, const TensorLayout& dx) = 0;
  720. protected:
  721. void check_exec(const TensorLayout& x, const TensorLayout& dy,
  722. const TensorLayout& saved_batch_mean,
  723. const TensorLayout& saved_batch_variance,
  724. const TensorLayout& bn_scale,
  725. const TensorLayout& d_bn_scale,
  726. const TensorLayout& d_bn_bias, const TensorLayout& dx,
  727. size_t workspace_in_bytes);
  728. };
  729. class LRNBase : public OperatorBase {
  730. DEF_OPR_IMPL_CTOR(LRNBase, OperatorBase);
  731. DEF_OPR_PARAM(LRN);
  732. protected:
  733. void check_param();
  734. };
  735. class LRNForward : public LRNBase {
  736. DEF_OPR_IMPL(LRNForward, LRNBase, 1, 1);
  737. public:
  738. /**
  739. * \see ImageNet Classification with Deep Convolutional Neural Networks
  740. * \param[in] src (n, c, h, w)
  741. * \param[out] dst (n, c, h, w)
  742. *
  743. * src and dst must have the same shape.
  744. * src and dst must be contiguous.
  745. */
  746. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
  747. _megdnn_workspace workspace) = 0;
  748. void deduce_layout(const TensorLayout& src, TensorLayout& dst);
  749. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  750. const TensorLayout& dst) = 0;
  751. protected:
  752. void check_exec(const TensorLayout& src, const TensorLayout& dst,
  753. size_t workspace_in_bytes);
  754. };
  755. using LRN = LRNForward;
  756. class LRNBackward : public LRNBase {
  757. DEF_OPR_IMPL(LRNBackward, LRNBase, 3, 1);
  758. public:
  759. /**
  760. * \param[in] src the `src' parameter in LRNForward::exec
  761. * \param[in] dst the `dst' parameter in LRNForward::exec
  762. * \param[in] diff the backpropagated gradient wrt. dst
  763. * \param[out] grad the backpropagated gradient wrt. src
  764. *
  765. * All tensors should be contiguous and of the same shape.
  766. */
  767. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in dst,
  768. _megdnn_tensor_in diff, _megdnn_tensor_out grad,
  769. _megdnn_workspace workspace) = 0;
  770. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  771. const TensorLayout& dst,
  772. const TensorLayout& diff,
  773. const TensorLayout& grad) = 0;
  774. protected:
  775. void check_exec(const TensorLayout& src, const TensorLayout& dst,
  776. const TensorLayout& diff, const TensorLayout& grad,
  777. size_t workspace_in_bytes);
  778. };
  779. class ROIPoolingBase : public OperatorBase {
  780. DEF_OPR_IMPL_CTOR(ROIPoolingBase, OperatorBase);
  781. DEF_OPR_PARAM(ROIPooling);
  782. protected:
  783. void check_layout_fwd(const TensorLayout& src, const TensorLayout& rois,
  784. const TensorLayout& dst, const TensorLayout& index);
  785. };
  786. class ROIPoolingForward : public ROIPoolingBase {
  787. DEF_OPR_IMPL(ROIPoolingForward, ROIPoolingBase, 2, 2);
  788. public:
  789. /**
  790. * \param[in] src (n, c, ih, iw)
  791. * \param[in] rois (m, 5)
  792. * \param[out] dst (m, c, oh, ow)
  793. * \param[out] index (m, c, oh, ow) if mode is MAX, (0) if mode is AVERAGE
  794. *
  795. * The internal implementation is akin to
  796. * https://github.com/rbgirshick/caffe-fast-rcnn .d
  797. * Note that rois(, 0) denotes the input image index. We store it as
  798. * a float, but it should be an integer instead.
  799. *
  800. * index is a temporary tensor to facilitate its backward operator.
  801. * It is used to store argmax indicex in MAX mode, and it is not used
  802. * in AVERAGE mode.
  803. */
  804. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in rois,
  805. _megdnn_tensor_out dst, _megdnn_tensor_out index,
  806. _megdnn_workspace workspace) = 0;
  807. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  808. const TensorLayout& rois,
  809. const TensorLayout& dst,
  810. const TensorLayout& index) = 0;
  811. protected:
  812. void check_exec(const TensorLayout& src, const TensorLayout& rois,
  813. const TensorLayout& dst, const TensorLayout& index,
  814. size_t workspace_in_bytes);
  815. };
  816. using ROIPooling = ROIPoolingForward;
  817. class ROIPoolingBackward : public ROIPoolingBase {
  818. DEF_OPR_IMPL(ROIPoolingBackward, ROIPoolingBase, 4, 1);
  819. public:
  820. /**
  821. * \param[in] diff the backpropagated gradient wrt. dst
  822. * \param[in] src the `src' parameter in ROIPoolingForward::exec
  823. * \param[in] rois the `rois' parameter in ROIPoolingForward::exec
  824. * \param[in] index the `index' parameter in ROIPoolingForward::exec
  825. * \param[out] grad the backpropagated gradient wrt. src
  826. */
  827. virtual void exec(_megdnn_tensor_in diff, _megdnn_tensor_in src,
  828. _megdnn_tensor_in rois, _megdnn_tensor_in index,
  829. _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
  830. virtual size_t get_workspace_in_bytes(const TensorLayout& diff,
  831. const TensorLayout& src,
  832. const TensorLayout& rois,
  833. const TensorLayout& index,
  834. const TensorLayout& grad) = 0;
  835. protected:
  836. void check_exec(const TensorLayout& diff, const TensorLayout& src,
  837. const TensorLayout& rois, const TensorLayout& index,
  838. const TensorLayout& grad, size_t workspace_in_bytes);
  839. };
  840. class Convolution3DBase : public OperatorBase {
  841. DEF_OPR_IMPL_CTOR(Convolution3DBase, OperatorBase);
  842. DEF_OPR_PARAM(Convolution3D);
  843. public:
  844. static constexpr size_t MAX_SPATIAL_DIM = 3;
  845. using Mode = Param::Mode;
  846. struct CanonizedFilterMeta {
  847. DTypeEnum dtype_enum;
  848. Param::Format format;
  849. uint32_t
  850. //! whether filter should be flipped (i.e. is CONVOLUTION)
  851. should_flip,
  852. group, //!< number of groups
  853. icpg, //!< input channels per group
  854. ocpg, //!< output channels per group
  855. spatial_ndim, stride[MAX_SPATIAL_DIM], padding[MAX_SPATIAL_DIM],
  856. //! spatial dim
  857. spatial[MAX_SPATIAL_DIM], dilation[MAX_SPATIAL_DIM],
  858. //! spatial dim with dilation applied
  859. dilated_spatial[MAX_SPATIAL_DIM];
  860. } MEGDNN_PACKED;
  861. protected:
  862. CanonizedFilterMeta deduce_layout_fwd(const TensorLayout& src,
  863. const TensorLayout& filter,
  864. TensorLayout& dst) const;
  865. CanonizedFilterMeta check_layout_fwd(const TensorLayout& src,
  866. const TensorLayout& filter,
  867. const TensorLayout& dst) const;
  868. CanonizedFilterMeta make_canonized_filter_meta(
  869. size_t src_ndim, const TensorLayout& filter) const;
  870. };
  871. class Convolution3DForward
  872. : public Convolution3DBase,
  873. public detail::MultiAlgoOpr<Convolution3DForward, 3> {
  874. DEF_OPR_IMPL(Convolution3DForward, Convolution3DBase, 2, 1);
  875. public:
  876. /**
  877. * \param[in] src (n, ic, id, ih, iw)
  878. * \param[in] filter (oc, ic, fd, fh, fw)
  879. * \param[out] dst (n, oc, od, oh, ow)
  880. */
  881. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter,
  882. _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0;
  883. void deduce_layout(const TensorLayout& src, const TensorLayout& filter,
  884. TensorLayout& dst);
  885. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  886. const TensorLayout& filter,
  887. const TensorLayout& dst) = 0;
  888. protected:
  889. CanonizedFilterMeta check_exec(const TensorLayout& src,
  890. const TensorLayout& filter,
  891. const TensorLayout& dst,
  892. size_t workspace_in_bytes);
  893. };
  894. using Convolution3D = Convolution3DForward;
  895. class Convolution3DBackwardData
  896. : public Convolution3DBase,
  897. public detail::MultiAlgoOpr<Convolution3DBackwardData, 3> {
  898. DEF_OPR_IMPL(Convolution3DBackwardData, Convolution3DBase, 2, 1);
  899. public:
  900. /**
  901. * \param[in] filter (oc, ic, fd, fh, fw)
  902. * \param[in] diff (n, oc, od, oh, ow)
  903. * \param[out] grad (n, ic, id, ih, iw)
  904. */
  905. virtual void exec(_megdnn_tensor_in filter, _megdnn_tensor_in diff,
  906. _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
  907. virtual size_t get_workspace_in_bytes(const TensorLayout& filter,
  908. const TensorLayout& diff,
  909. const TensorLayout& grad) = 0;
  910. void deduce_layout(const TensorLayout& filter, const TensorLayout& diff,
  911. TensorLayout& grad);
  912. protected:
  913. CanonizedFilterMeta check_exec(const TensorLayout& filter,
  914. const TensorLayout& diff,
  915. const TensorLayout& grad,
  916. size_t workspace_in_bytes);
  917. };
  918. class Convolution3DBackwardFilter
  919. : public Convolution3DBase,
  920. public detail::MultiAlgoOpr<Convolution3DBackwardFilter, 3> {
  921. DEF_OPR_IMPL(Convolution3DBackwardFilter, Convolution3DBase, 2, 1);
  922. public:
  923. /**
  924. * \param[in] src (n, ic, id, ih, iw)
  925. * \param[in] diff (n, oc, od, oh, ow)
  926. * \param[out] grad (oc, ic, fd, fh, fw)
  927. */
  928. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in diff,
  929. _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
  930. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  931. const TensorLayout& diff,
  932. const TensorLayout& grad) = 0;
  933. protected:
  934. CanonizedFilterMeta check_exec(const TensorLayout& src,
  935. const TensorLayout& diff,
  936. const TensorLayout& grad,
  937. size_t workspace_in_bytes);
  938. };
  939. class LocalShareBase : public OperatorBase {
  940. DEF_OPR_IMPL_CTOR(LocalShareBase, OperatorBase);
  941. DEF_OPR_PARAM(LocalShare);
  942. protected:
  943. void deduce_layout_fwd(const TensorLayout& src, const TensorLayout& filter,
  944. TensorLayout& dst);
  945. void check_layout_fwd(const TensorLayout& src, const TensorLayout& filter,
  946. const TensorLayout& dst);
  947. };
  948. class LocalShareForward : public LocalShareBase,
  949. public detail::MultiAlgoOpr<LocalShareForward, 3> {
  950. DEF_OPR_IMPL(LocalShareForward, LocalShareBase, 2, 1);
  951. public:
  952. /**
  953. * \param[in] src (N, IC, IH, IW)
  954. * \param[in] filter (G, spatial_groups_h, spatial_groups_w, IC / G,
  955. * FH, FW, OC / G)
  956. * \param[out] dst (N, OC, OH, OW)
  957. */
  958. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter,
  959. _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0;
  960. /**
  961. * \brief deduce layout of the ouput tensor
  962. */
  963. void deduce_layout(const TensorLayout& src, const TensorLayout& filter,
  964. TensorLayout& dst);
  965. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  966. const TensorLayout& filter,
  967. const TensorLayout& dst) = 0;
  968. protected:
  969. void check_exec(const TensorLayout& src, const TensorLayout& filter,
  970. const TensorLayout& dst, size_t workspace_in_bytes);
  971. };
  972. using LocalShare = LocalShareForward;
  973. class LocalShareBackwardData
  974. : public LocalShareBase,
  975. public detail::MultiAlgoOpr<LocalShareBackwardData, 3> {
  976. DEF_OPR_IMPL(LocalShareBackwardData, LocalShareBase, 2, 1);
  977. public:
  978. /**
  979. * \param[in] filter (G, spatial_groups_h, spatial_groups_w, IC / G,
  980. * FH, FW, OC / G)
  981. * \param[in] diff (N, OC, OH, OW)
  982. * \param[out] grad (N, IC, IH, IW)
  983. */
  984. virtual void exec(_megdnn_tensor_in filter, _megdnn_tensor_in diff,
  985. _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
  986. virtual size_t get_workspace_in_bytes(const TensorLayout& filter,
  987. const TensorLayout& diff,
  988. const TensorLayout& grad) = 0;
  989. void deduce_layout(const TensorLayout& filter, const TensorLayout& diff,
  990. TensorLayout& grad);
  991. protected:
  992. void check_exec(const TensorLayout& filter, const TensorLayout& diff,
  993. const TensorLayout& grad, size_t workspace_in_bytes);
  994. };
  995. class LocalShareBackwardFilter
  996. : public LocalShareBase,
  997. public detail::MultiAlgoOpr<LocalShareBackwardFilter, 3> {
  998. DEF_OPR_IMPL(LocalShareBackwardFilter, LocalShareBase, 2, 1);
  999. public:
  1000. /**
  1001. * \param[in] src (N, IC, IH, IW)
  1002. * \param[in] diff (N, OC, OH, OW)
  1003. * \param[out] grad (G, spatial_groups_h, spatial_groups_w, IC / G,
  1004. * FH, FW, OC / G)
  1005. */
  1006. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in diff,
  1007. _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
  1008. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  1009. const TensorLayout& diff,
  1010. const TensorLayout& grad) = 0;
  1011. protected:
  1012. void check_exec(const TensorLayout& src, const TensorLayout& diff,
  1013. const TensorLayout& grad, size_t workspace_in_bytes);
  1014. };
  1015. class ROIAlignBase : public OperatorBase {
  1016. DEF_OPR_IMPL_CTOR(ROIAlignBase, OperatorBase);
  1017. DEF_OPR_PARAM(ROIAlign);
  1018. protected:
  1019. void deduce_layout_fwd(const TensorLayout& src, const TensorLayout& rois,
  1020. TensorLayout& dst, TensorLayout& index);
  1021. void check_layout_fwd(const TensorLayout& src, const TensorLayout& rois,
  1022. const TensorLayout& dst, const TensorLayout& index);
  1023. };
  1024. class ROIAlignForward : public ROIAlignBase {
  1025. DEF_OPR_IMPL(ROIAlignForward, ROIAlignBase, 2, 2);
  1026. public:
  1027. /**
  1028. * \param[in] src (n, c, ih, iw)
  1029. * \param[in] rois (m, 5)
  1030. * \param[out] dst (m, c, oh, ow)
  1031. * \param[out] index (m, c, oh, ow) if mode is MAX, (0) if mode is AVERAGE
  1032. *
  1033. * Note that rois(, 0) denotes the input image index. We store it as
  1034. * a float, but it should be an integer instead.
  1035. *
  1036. * index is a temporary tensor to facilitate its backward operator.
  1037. * It is used to store argmax indicex in MAX mode, and it is not used
  1038. * in AVERAGE mode.
  1039. */
  1040. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in rois,
  1041. _megdnn_tensor_out dst, _megdnn_tensor_out index,
  1042. _megdnn_workspace workspace) = 0;
  1043. void deduce_layout(const TensorLayout& src, const TensorLayout& rois,
  1044. TensorLayout& dst, TensorLayout& index);
  1045. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  1046. const TensorLayout& rois,
  1047. const TensorLayout& dst,
  1048. const TensorLayout& index) = 0;
  1049. protected:
  1050. void check_exec(const TensorLayout& src, const TensorLayout& rois,
  1051. const TensorLayout& dst, const TensorLayout& index,
  1052. size_t workspace_in_bytes);
  1053. };
  1054. using ROIAlign = ROIAlignForward;
  1055. class ROIAlignBackward : public ROIAlignBase {
  1056. DEF_OPR_IMPL(ROIAlignBackward, ROIAlignBase, 3, 1);
  1057. public:
  1058. /**
  1059. * \param[in] diff the backpropagated gradient wrt. dst
  1060. * \param[in] rois the `rois' parameter in ROIAlignForward::exec
  1061. * \param[in] index the `index' parameter in ROIAlignForward::exec
  1062. * \param[out] grad the backpropagated gradient wrt. src
  1063. */
  1064. virtual void exec(_megdnn_tensor_in diff, _megdnn_tensor_in rois,
  1065. _megdnn_tensor_in index, _megdnn_tensor_out grad,
  1066. _megdnn_workspace workspace) = 0;
  1067. virtual size_t get_workspace_in_bytes(const TensorLayout& diff,
  1068. const TensorLayout& rois,
  1069. const TensorLayout& index,
  1070. const TensorLayout& grad) = 0;
  1071. protected:
  1072. void check_exec(const TensorLayout& diff, const TensorLayout& rois,
  1073. const TensorLayout& index, const TensorLayout& grad,
  1074. size_t workspace_in_bytes);
  1075. };
  1076. class DeformableConvBase : public OperatorBase {
  1077. DEF_OPR_IMPL_CTOR(DeformableConvBase, OperatorBase);
  1078. DEF_OPR_PARAM(Convolution);
  1079. public:
  1080. static constexpr size_t MAX_SPATIAL_DIM = 2;
  1081. struct CanonizedFilterMeta : Convolution::CanonizedFilterMeta {
  1082. uint32_t deformable_group;
  1083. };
  1084. protected:
  1085. CanonizedFilterMeta make_canonized_filter_meta(
  1086. size_t src_ndim, const TensorLayout& filter,
  1087. const TensorLayout& offset) const;
  1088. void deduce_layout_fwd(const TensorLayout& im, const TensorLayout& filter,
  1089. const TensorLayout& mask, const TensorLayout& offset,
  1090. TensorLayout& dst);
  1091. void check_layout_fwd(const TensorLayout& src, const TensorLayout& filter,
  1092. const TensorLayout& mask, const TensorLayout& offset,
  1093. const TensorLayout& dst);
  1094. };
  1095. class DeformableConvForward
  1096. : public DeformableConvBase,
  1097. public detail::MultiAlgoOpr<DeformableConvForward, 5> {
  1098. DEF_OPR_IMPL(DeformableConvForward, DeformableConvBase, 4, 1);
  1099. public:
  1100. /**
  1101. * \param[in] im (n, ic, ih, iw)
  1102. * \param[in] filter (oc, ic, fh, fw)
  1103. * \param[in] offset (dg, 2, fh, fw, oh, ow)
  1104. * \param[in] mask (dg, fh, fw, oh, ow)
  1105. * \param[out] dst (n, oc, oh, ow)
  1106. */
  1107. virtual void exec(_megdnn_tensor_in im, _megdnn_tensor_in filter,
  1108. _megdnn_tensor_in offset, _megdnn_tensor_in mask,
  1109. _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0;
  1110. void deduce_layout(const TensorLayout& im, const TensorLayout& filter,
  1111. const TensorLayout& offset, const TensorLayout& mask,
  1112. TensorLayout& dst);
  1113. virtual size_t get_workspace_in_bytes(const TensorLayout& im,
  1114. const TensorLayout& filter,
  1115. const TensorLayout& offset,
  1116. const TensorLayout& mask,
  1117. const TensorLayout& dst) = 0;
  1118. protected:
  1119. CanonizedFilterMeta check_exec(const TensorLayout& im,
  1120. const TensorLayout& filter,
  1121. const TensorLayout& offset,
  1122. const TensorLayout& mask,
  1123. const TensorLayout& dst,
  1124. size_t workspace_in_bytes);
  1125. };
  1126. using DeformableConv = DeformableConvForward;
  1127. /**
  1128. * \brief DeformableConvBackwardFilter operator.
  1129. *
  1130. * Calculating the gradient wrt. convolution filter.
  1131. */
  1132. class DeformableConvBackwardFilter
  1133. : public DeformableConvBase,
  1134. public detail::MultiAlgoOpr<DeformableConvBackwardFilter, 5> {
  1135. DEF_OPR_IMPL(DeformableConvBackwardFilter, DeformableConvBase, 4, 1);
  1136. public:
  1137. /**
  1138. * \param[in] im (oc, ic, fh, fw)
  1139. * \param[in] offset (dg, 2, fh, fw, oh, ow)
  1140. * \param[in] mask (dg, fh, fw, oh, ow)
  1141. * \param[in] out_grad (n, oc, oh, ow)
  1142. * \param[out] filter_grad (oc, ic, ih, iw)
  1143. */
  1144. virtual void exec(_megdnn_tensor_in im, _megdnn_tensor_in offset,
  1145. _megdnn_tensor_in mask, _megdnn_tensor_in out_grad,
  1146. _megdnn_tensor_out filter_grad,
  1147. _megdnn_workspace workspace) = 0;
  1148. virtual size_t get_workspace_in_bytes(const TensorLayout& im,
  1149. const TensorLayout& offset,
  1150. const TensorLayout& mask,
  1151. const TensorLayout& out_grad,
  1152. const TensorLayout& filter_grad) = 0;
  1153. void deduce_layout(const TensorLayout& im, const TensorLayout& offset,
  1154. const TensorLayout& mask, const TensorLayout& out_grad,
  1155. TensorLayout& filter_grad);
  1156. protected:
  1157. CanonizedFilterMeta check_exec(const TensorLayout& im,
  1158. const TensorLayout& offset,
  1159. const TensorLayout& mask,
  1160. const TensorLayout& out_grad,
  1161. const TensorLayout& filter_grad,
  1162. size_t workspace_in_bytes);
  1163. };
  1164. /**
  1165. * \brief DeformableConvBackwardData operator.
  1166. *
  1167. * Calculating the gradient wrt. convolution input data, offset and mask.
  1168. */
  1169. class DeformableConvBackwardData
  1170. : public DeformableConvBase,
  1171. public detail::MultiAlgoOpr<DeformableConvBackwardData, 8> {
  1172. DEF_OPR_IMPL(DeformableConvBackwardData, DeformableConvBase, 5, 3);
  1173. public:
  1174. /**
  1175. * \param[in] im (oc, ic, fh, fw)
  1176. * \param[in] filter (oc, ic, fh, fw)
  1177. * \param[in] offset (dg, 2, fh, fw, oh, ow)
  1178. * \param[in] mask (dg, fh, fw, oh, ow)
  1179. * \param[in] out_grad (n, oc, oh, ow)
  1180. * \param[out] im_grad (n, ic, ih, iw)
  1181. * \param[out] offset_grad (dg, 2, fh, fw, oh, ow)
  1182. * \param[out] mask_grad (dg, fh, fw, oh, ow)
  1183. */
  1184. virtual void exec(_megdnn_tensor_in im, _megdnn_tensor_in filter,
  1185. _megdnn_tensor_in offset, _megdnn_tensor_in mask,
  1186. _megdnn_tensor_in out_grad, _megdnn_tensor_out im_grad,
  1187. _megdnn_tensor_out offset_grad,
  1188. _megdnn_tensor_out mask_grad,
  1189. _megdnn_workspace workspace) = 0;
  1190. virtual size_t get_workspace_in_bytes(
  1191. const TensorLayout& im, const TensorLayout& filter,
  1192. const TensorLayout& offset, const TensorLayout& mask,
  1193. const TensorLayout& out_grad, const TensorLayout& im_grad,
  1194. const TensorLayout& offset_grad, const TensorLayout& mask_grad) = 0;
  1195. void deduce_layout(const TensorLayout& im, const TensorLayout& filter,
  1196. const TensorLayout& offset, const TensorLayout& mask,
  1197. const TensorLayout& out_grad, TensorLayout& im_grad,
  1198. TensorLayout& offset_grad, TensorLayout& mask_grad);
  1199. protected:
  1200. CanonizedFilterMeta check_exec(
  1201. const TensorLayout& im, const TensorLayout& filter,
  1202. const TensorLayout& offset, const TensorLayout& mask,
  1203. const TensorLayout& out_grad, const TensorLayout& im_grad,
  1204. const TensorLayout& offset_grad, const TensorLayout& mask_grad,
  1205. size_t workspace_in_bytes);
  1206. };
  1207. class DeformablePSROIPoolingBase : public OperatorBase {
  1208. DEF_OPR_IMPL_CTOR(DeformablePSROIPoolingBase, OperatorBase);
  1209. DEF_OPR_PARAM(DeformablePSROIPooling);
  1210. protected:
  1211. void deduce_layout_fwd(const TensorLayout& data, const TensorLayout& trans,
  1212. const TensorLayout& rois, TensorLayout& out_data,
  1213. TensorLayout& out_count);
  1214. void check_layout_fwd(const TensorLayout& data, const TensorLayout& trans,
  1215. const TensorLayout& rois,
  1216. const TensorLayout& out_data,
  1217. const TensorLayout& out_count,
  1218. size_t workspace_in_bytes);
  1219. };
  1220. class DeformablePSROIPoolingForward : public DeformablePSROIPoolingBase {
  1221. DEF_OPR_IMPL(DeformablePSROIPoolingForward, DeformablePSROIPoolingBase, 3,
  1222. 2);
  1223. public:
  1224. /**
  1225. * \param[in] data (oc, ic, ih, iw)
  1226. * \param[in] rois (xx, xx, xx, xx)
  1227. * \param[in] trans (oc, ic, fh, fw)
  1228. * \param[out] out_data ( n, ic, ih, iw)
  1229. * \param[out] out_count ( n, ic, ih, iw)
  1230. */
  1231. virtual size_t get_workspace_in_bytes(const TensorLayout& data,
  1232. const TensorLayout& rois,
  1233. const TensorLayout& trans,
  1234. const TensorLayout& out_data,
  1235. const TensorLayout& out_count) = 0;
  1236. virtual void exec(_megdnn_tensor_in data, _megdnn_tensor_in rois,
  1237. _megdnn_tensor_in trans, _megdnn_tensor_out out_data,
  1238. _megdnn_tensor_out out_count,
  1239. _megdnn_workspace workspace) = 0;
  1240. void deduce_layout(const TensorLayout& data, const TensorLayout& rois,
  1241. const TensorLayout& trans, TensorLayout& out_data,
  1242. TensorLayout& out_count);
  1243. void check_exec(const TensorLayout& data, const TensorLayout& rois,
  1244. const TensorLayout& trans, const TensorLayout& out_data,
  1245. const TensorLayout& out_count, size_t workspace_in_bytes);
  1246. };
  1247. using DeformablePSROIPooling = DeformablePSROIPoolingForward;
  1248. class DeformablePSROIPoolingBackward : public DeformablePSROIPoolingBase {
  1249. DEF_OPR_IMPL(DeformablePSROIPoolingBackward, DeformablePSROIPoolingBase, 5,
  1250. 2);
  1251. public:
  1252. /**
  1253. * \param[in] data (oc, ic, ih, iw)
  1254. * \param[in] rois (xx, xx, xx, xx)
  1255. * \param[in] trans (oc, ic, fh, fw)
  1256. * \param[in] out_diff (xx, xx, xx, xx)
  1257. * \param[in] out_count (xx, xx, xx, xx)
  1258. * \param[out] data_diff ( n, ic, ih, iw)
  1259. * \param[out] trans_diff ( n, ic, ih, iw)
  1260. */
  1261. virtual void exec(_megdnn_tensor_in data, _megdnn_tensor_in rois,
  1262. _megdnn_tensor_in trans, _megdnn_tensor_in out_diff,
  1263. _megdnn_tensor_in out_count, _megdnn_tensor_out data_diff,
  1264. _megdnn_tensor_out trans_diff,
  1265. _megdnn_workspace workspace) = 0;
  1266. virtual size_t get_workspace_in_bytes(const TensorLayout& data,
  1267. const TensorLayout& rois,
  1268. const TensorLayout& trans,
  1269. const TensorLayout& out_diff,
  1270. const TensorLayout& out_count,
  1271. const TensorLayout& data_diff,
  1272. const TensorLayout& trans_diff) = 0;
  1273. void check_exec(const TensorLayout& data, const TensorLayout& rois,
  1274. const TensorLayout& trans, const TensorLayout& out_diff,
  1275. const TensorLayout& out_count,
  1276. const TensorLayout& data_diff,
  1277. const TensorLayout& trans_diff, size_t workspace_in_bytes);
  1278. };
  1279. class BatchConvBiasForward
  1280. : public ConvolutionBase<param::BatchConvBias>,
  1281. public detail::MultiAlgoOpr<BatchConvBiasForward, 5> {
  1282. DEF_OPR_IMPL(BatchConvBiasForward, ConvolutionBase, 4, 1);
  1283. public:
  1284. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter,
  1285. _megdnn_tensor_in bias, _megdnn_tensor_in z,
  1286. _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0;
  1287. void deduce_dtype(DType src, DType filter, DType bias, DType z, DType& dst);
  1288. void deduce_layout(const TensorLayout& src, const TensorLayout& filter,
  1289. const TensorLayout& bias, const TensorLayout& z,
  1290. TensorLayout& dst);
  1291. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  1292. const TensorLayout& filter,
  1293. const TensorLayout& bias,
  1294. const TensorLayout& z,
  1295. const TensorLayout& dst) = 0;
  1296. protected:
  1297. CanonizedFilterMeta check_exec(const TensorLayout& src,
  1298. const TensorLayout& filter,
  1299. const TensorLayout& bias,
  1300. const TensorLayout& z,
  1301. const TensorLayout& dst,
  1302. size_t workspace_in_bytes);
  1303. };
  1304. using BatchConvBias = BatchConvBiasForward;
  1305. } // namespace megdnn
  1306. #include "megdnn/internal/opr_header_epilogue.h"
  1307. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台