You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

general.h 48 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432
  1. /**
  2. * \file dnn/include/megdnn/oprs/general.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #pragma once
  12. #include "megdnn/internal/opr_header_prologue.h"
  13. #include "megdnn/thin/small_vector.h"
  14. namespace megdnn {
  15. /*!
  16. * \brief standard element-wise operator
  17. *
  18. * Inputs must have same dtype, and their shapes must broadcastable into a final
  19. * shape. They can have arbitrary layouts, but non-contiguous and non-broadcast
  20. * layouts may harm performance seriously.
  21. *
  22. * Output dtype is the same as input dtype (note that even for compare oprs this
  23. * is true, e.g. float == float returns value of float). Output layout must be
  24. * contiguous.
  25. */
  26. class ElemwiseForward : public OperatorBase {
  27. DEF_OPR_PARAM(Elemwise);
  28. DEF_OPR_IMPL(ElemwiseForward, OperatorBase, -1, 1);
  29. public:
  30. using Mode = Param::Mode;
  31. //! information about a mode
  32. struct ModeTrait {
  33. uint32_t arity; //!< number of inputs needed
  34. bool commutable; //!< whether arity == 2 and inputs commutable
  35. bool allow_int; //!< whether int inputs allowed
  36. bool allow_float; //!< whether float inputs allowed
  37. bool allow_bool; //!< whether bool inputs allowed
  38. const char* name; //!< name of the mode
  39. ModeTrait()
  40. : arity(0),
  41. commutable(0),
  42. allow_int(0),
  43. allow_float(0),
  44. allow_bool(0),
  45. name(NULL) {}
  46. //! get trait from a mode; this function is thread safe
  47. static const ModeTrait& from_mode(Mode mode);
  48. };
  49. //! get trait of current mode
  50. const ModeTrait& mode_trait() const { return ModeTrait::from_mode(m_param.mode); }
  51. /**
  52. * \param[in] src input tensor
  53. * \param[out] dst output tensor
  54. *
  55. * src and dst should have the same shape;
  56. * layouts should be contiguous;
  57. * the underlying data pointer can point to the same memory region for
  58. * src and dst.
  59. */
  60. virtual void exec(_megdnn_in const TensorNDArray& src, _megdnn_tensor_out dst) = 0;
  61. //! deduce output shape (do not check whether arity matches)
  62. static void deduce_shape(const TensorShapeArray& src, TensorShape& dst);
  63. static void deduce_format(const TensorFormatArray& src, TensorFormat& dst);
  64. //! deduce output layout
  65. void deduce_layout(const TensorLayoutArray& src, TensorLayout& dst);
  66. protected:
  67. //! throw exception if incorrect layout; broadcast input shape to
  68. //! output shape
  69. void check_layout_and_broadcast(
  70. const TensorLayoutPtrArray& src, const TensorLayout& dst);
  71. private:
  72. void check_dtype(DType dtype);
  73. };
  74. using Elemwise = ElemwiseForward;
  75. /*!
  76. * \brief compute ``x**a`` where ``a`` is a constant from the Param
  77. *
  78. * This opr is usually not directly accessible by the end user and it is created
  79. * by mgb optimizer, aiming to work around numerical stability issues with pow.
  80. * For example ``powf(x, 2.f)`` with ``x < 0`` in fast math mode may return NaN.
  81. *
  82. * Like elemwise, this opr supports arbitrary strides. But it should only be
  83. * used with monotone strides. Input and output should have the same
  84. * float-category dtype.
  85. */
  86. class PowC : public OperatorBase {
  87. DEF_OPR_PARAM(PowC);
  88. DEF_OPR_IMPL(PowC, OperatorBase, 1, 1);
  89. public:
  90. void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst);
  91. //! compatible API for mgb; workspace is not used
  92. void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace) {
  93. return exec(src, dst);
  94. }
  95. size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&) {
  96. // the impls should require no workspace; this can be later changed to a
  97. // virtual function if this situation changes
  98. return 0;
  99. }
  100. void deduce_layout(const TensorLayout& src, TensorLayout& dst) {
  101. dst.dtype = src.dtype;
  102. dst.init_contiguous_stride(src);
  103. }
  104. protected:
  105. /*!
  106. * Perform the computing where layouts have been verified.
  107. *
  108. * \p src can have arbitrary layout, and \p dst is contiguous. They have the
  109. * same shape and dtype.
  110. *
  111. * The implementation should not access param(). It should check \p exp_f
  112. * and \p exp_i for the exponent value. Exactly one of them would be
  113. * non-null.
  114. *
  115. * Note: \p exp_f and \p exp_i must be dereferenced before dispatching any
  116. * kernel. They are allocated on the caller's stack.
  117. */
  118. virtual void do_exec(
  119. _megdnn_tensor_in src, _megdnn_tensor_out dst, const float* exp_f,
  120. const int* exp_i) = 0;
  121. };
  122. /*!
  123. * \brief modify a tensor inplace by adding another tensor to it
  124. *
  125. * dst and delta can have arbitrary layout but must have the same shape.
  126. */
  127. class AddUpdateForward : public OperatorBase {
  128. DEF_OPR_PARAM(AddUpdate);
  129. DEF_OPR_IMPL(AddUpdateForward, OperatorBase, -1, 1);
  130. public:
  131. virtual void exec(_megdnn_tensor_inout dst, _megdnn_tensor_in delta) = 0;
  132. protected:
  133. void check_exec(const TensorLayout& dst, const TensorLayout& delta);
  134. };
  135. using AddUpdate = AddUpdateForward;
  136. class ReduceForward : public OperatorBase {
  137. DEF_OPR_PARAM(Reduce);
  138. DEF_OPR_IMPL(ReduceForward, OperatorBase, 1, 1);
  139. public:
  140. using Mode = Param::Mode;
  141. using DataType = Param::DataType;
  142. /**
  143. * \param[in] src input tensor
  144. * \param[out] dst output tensor
  145. *
  146. * src and dst should be contiguous.
  147. * src and dst should be of the same shape for all dimensions except
  148. * param().axis.
  149. * the param().axis-th dimension shape for dst should be one.
  150. */
  151. virtual void exec(
  152. _megdnn_tensor_in src, _megdnn_tensor_out dst,
  153. _megdnn_workspace workspace) = 0;
  154. void deduce_layout(const TensorLayout& src, TensorLayout& dst);
  155. virtual size_t get_workspace_in_bytes(
  156. const TensorLayout& src, const TensorLayout& dst) = 0;
  157. protected:
  158. void check_exec(
  159. const TensorLayout& src, const TensorLayout& dst,
  160. size_t workspace_in_bytes);
  161. };
  162. using Reduce = ReduceForward;
  163. class CorrelationBase : public OperatorBase {
  164. DEF_OPR_IMPL_CTOR(CorrelationBase, OperatorBase);
  165. DEF_OPR_PARAM(Correlation);
  166. protected:
  167. void deduce_layout_fwd(
  168. const TensorLayout& data1, const TensorLayout& data2, TensorLayout& dst);
  169. void check_layout_fwd(
  170. const TensorLayout& data1, const TensorLayout& data2,
  171. const TensorLayout& dst);
  172. };
  173. class CorrelationForward : public CorrelationBase {
  174. DEF_OPR_IMPL(CorrelationForward, CorrelationBase, 2, 1);
  175. public:
  176. /**
  177. * \param[in] data1 (n, c, ih, iw)
  178. * \param[in] data2 (n, c, ih, iw)
  179. * \param[out] dst (n, q, oh, ow), q is the number of neighborhood
  180. * */
  181. virtual void exec(
  182. _megdnn_tensor_in data1, _megdnn_tensor_in data2, _megdnn_tensor_out dst,
  183. _megdnn_workspace workspace) = 0;
  184. void deduce_layout(
  185. const TensorLayout& data1, const TensorLayout& data2, TensorLayout& dst);
  186. virtual size_t get_workspace_in_bytes(
  187. const TensorLayout& data1, const TensorLayout& data2,
  188. const TensorLayout& dst) = 0;
  189. protected:
  190. void check_exec(
  191. const TensorLayout& data1, const TensorLayout& data2,
  192. const TensorLayout& dst, size_t workspace_in_bytes);
  193. };
  194. using Correlation = CorrelationForward;
  195. class CorrelationBackwardData1 : public CorrelationBase {
  196. DEF_OPR_IMPL(CorrelationBackwardData1, CorrelationBase, 3, 1);
  197. public:
  198. /**
  199. * \param[in] diff the backpropagated gradient wrt. dst
  200. * \param[in] data1 the `data1' parameter in CorrelationForward::exec
  201. * \param[in] data2 the `data2' parameter in CorrelationForward::exec
  202. * \param[out] grad1 the backpropagated gradient wrt. data1
  203. */
  204. virtual void exec(
  205. _megdnn_tensor_in diff, _megdnn_tensor_in data1, _megdnn_tensor_in data2,
  206. _megdnn_tensor_out grad1, _megdnn_workspace workspace) = 0;
  207. void deduce_layout(
  208. const TensorLayout& diff1, const TensorLayout& data1,
  209. const TensorLayout& data2, TensorLayout& dst);
  210. virtual size_t get_workspace_in_bytes(
  211. const TensorLayout& diff, const TensorLayout& data1,
  212. const TensorLayout& data2, const TensorLayout& grad1) = 0;
  213. protected:
  214. void check_exec(
  215. const TensorLayout& diff, const TensorLayout& data1,
  216. const TensorLayout& data2, const TensorLayout& grad1,
  217. size_t workspace_in_bytes);
  218. };
  219. class CorrelationBackwardData2 : public CorrelationBase {
  220. DEF_OPR_IMPL(CorrelationBackwardData2, CorrelationBase, 3, 1);
  221. public:
  222. /**
  223. * \param[in] diff the backpropagated gradient wrt. dst
  224. * \param[in] data1 the `data1' parameter in CorrelationForward::exec
  225. * \param[in] data2 the `data2' parameter in CorrelationForward::exec
  226. * \param[out] grad2 the backpropagated gradient wrt. data2
  227. */
  228. virtual void exec(
  229. _megdnn_tensor_in diff, _megdnn_tensor_in data1, _megdnn_tensor_in data2,
  230. _megdnn_tensor_out grad2, _megdnn_workspace workspace) = 0;
  231. void deduce_layout(
  232. const TensorLayout& diff1, const TensorLayout& data1,
  233. const TensorLayout& data2, TensorLayout& dst);
  234. virtual size_t get_workspace_in_bytes(
  235. const TensorLayout& diff, const TensorLayout& data1,
  236. const TensorLayout& data2, const TensorLayout& grad2) = 0;
  237. protected:
  238. void check_exec(
  239. const TensorLayout& diff, const TensorLayout& data1,
  240. const TensorLayout& data2, const TensorLayout& grad2,
  241. size_t workspace_in_bytes);
  242. };
  243. class CumsumForward : public OperatorBase {
  244. DEF_OPR_PARAM(Cumsum);
  245. DEF_OPR_IMPL(CumsumForward, OperatorBase, 1, 1);
  246. public:
  247. /**
  248. * \param[in] src input tensor
  249. * \param[out] dst output tensor
  250. *
  251. * src and dst should be contiguous.
  252. * src and dst should have the same shape.
  253. *
  254. * The exclusive flag specifies whether the current element it taken
  255. * into account when calculating results.
  256. *
  257. * The reverse flag specifies whether cumsum is forward (
  258. * from 0 to n) or backward (from n downto 0).
  259. *
  260. * Example:
  261. * exclusive && reverse:
  262. * dst_i = src_{i+1} + src_{i+2} + ... + src_{n-1}
  263. * exclusive && !reverse
  264. * dst_i = src_0 + src_1 + ... + src_{i-1}
  265. * !exclusive && reverse:
  266. * dst_i = src_i + src_{i+1} + ... + src_{n-1}
  267. * !exclusive && !reverse:
  268. * dst_i = src_0 + src_1 + ... + src_i
  269. */
  270. virtual void exec(
  271. _megdnn_tensor_in src, _megdnn_tensor_out dst,
  272. _megdnn_workspace workspace) = 0;
  273. void deduce_layout(const TensorLayout& src, TensorLayout& dst);
  274. virtual size_t get_workspace_in_bytes(
  275. const TensorLayout& src, const TensorLayout& dst) = 0;
  276. protected:
  277. void check_exec(
  278. const TensorLayout& src, const TensorLayout& dst,
  279. size_t workspace_in_bytes);
  280. };
  281. using Cumsum = CumsumForward;
  282. // mxx can be max or min
  283. class ArgmxxBase : public OperatorBase {
  284. DEF_OPR_IMPL_CTOR(ArgmxxBase, OperatorBase);
  285. DEF_OPR_PARAM(Axis);
  286. protected:
  287. void check_layout_fwd(const TensorLayout& src, const TensorLayout& dst);
  288. };
  289. class ArgmaxForward : public ArgmxxBase {
  290. DEF_OPR_IMPL(ArgmaxForward, ArgmxxBase, 1, 1);
  291. public:
  292. /**
  293. * \param[in] src input tensor
  294. * \param[out] dst output tensor containing the argmax indices
  295. *
  296. * src and dst should be contiguous.
  297. * src and dst should be of the same shape for all dimensions except
  298. * param().axis.
  299. * the param().axis-th dimension shape for dst should be one.
  300. */
  301. virtual void exec(
  302. _megdnn_tensor_in src, _megdnn_tensor_out dst,
  303. _megdnn_workspace workspace) = 0;
  304. void deduce_layout(const TensorLayout& src, TensorLayout& dst);
  305. virtual size_t get_workspace_in_bytes(
  306. const TensorLayout& src, const TensorLayout& dst) = 0;
  307. protected:
  308. void check_exec(
  309. const TensorLayout& src, const TensorLayout& dst,
  310. size_t workspace_in_bytes);
  311. };
  312. using Argmax = ArgmaxForward;
  313. class ArgminForward : public ArgmxxBase {
  314. DEF_OPR_IMPL(ArgminForward, ArgmxxBase, 1, 1);
  315. public:
  316. /**
  317. * \param[in] src input tensor
  318. * \param[out] dst output tensor containing the argmax indices
  319. *
  320. * src and dst should be contiguous.
  321. * src and dst should be of the same shape for all dimensions except
  322. * param().axis.
  323. * the param().axis-th dimension shape for dst should be one.
  324. */
  325. virtual void exec(
  326. _megdnn_tensor_in src, _megdnn_tensor_out dst,
  327. _megdnn_workspace workspace) = 0;
  328. void deduce_layout(const TensorLayout& src, TensorLayout& dst);
  329. virtual size_t get_workspace_in_bytes(
  330. const TensorLayout& src, const TensorLayout& dst) = 0;
  331. protected:
  332. void check_exec(
  333. const TensorLayout& src, const TensorLayout& dst,
  334. size_t workspace_in_bytes);
  335. };
  336. using Argmin = ArgminForward;
  337. /*!
  338. * \brief take values from input according to given condition
  339. *
  340. * Output two tensors:
  341. * 1. values copied from *data*, with same dtype as *data*
  342. * 2. selected indices with dtype int32; note that it is 1-dimensional and
  343. * based on the flatten input.
  344. *
  345. * Require data and mask to have the same shape and both be contiguous.
  346. */
  347. class CondTake : public OperatorBase {
  348. DEF_OPR_IMPL(CondTake, OperatorBase, 2, 2);
  349. DEF_OPR_PARAM(CondTake);
  350. public:
  351. using Output = std::array<TensorND, 2>;
  352. using OutputDType = std::array<DType, 2>;
  353. OutputDType infer_dtype(DType data, DType mask);
  354. virtual size_t get_workspace_in_bytes(const TensorLayout& data) = 0;
  355. virtual Output exec(
  356. _megdnn_tensor_in data, _megdnn_tensor_in mask, _megdnn_workspace workspace,
  357. DynOutMallocPolicyCall malloc_policy) = 0;
  358. protected:
  359. //! check input layouts and get flattened size
  360. size_t check_exec_get_size(
  361. const TensorLayout& data, const TensorLayout& mask,
  362. size_t workspace_in_bytes);
  363. };
  364. class TransposeForward : public OperatorBase {
  365. DEF_OPR_IMPL(TransposeForward, OperatorBase, 1, 1);
  366. DEF_OPR_PARAM(Empty);
  367. public:
  368. /**
  369. * \param[in] src (m, n) stride[0] >= n && stride[1] == 1
  370. * \param[out] dst (n, m) stride[0] >= m && stride[1] == 1
  371. */
  372. virtual void exec(
  373. _megdnn_tensor_in src, _megdnn_tensor_out dst,
  374. _megdnn_workspace workspace) = 0;
  375. void deduce_layout(const TensorLayout& src, TensorLayout& dst);
  376. virtual size_t get_workspace_in_bytes(
  377. const TensorLayout& src, const TensorLayout& dst) = 0;
  378. protected:
  379. void check_exec(
  380. const TensorLayout& src, const TensorLayout& dst,
  381. size_t workspace_in_bytes);
  382. };
  383. using Transpose = TransposeForward;
  384. /**
  385. * Change a tensor to another layout that has the same dtype and total number of
  386. * elements, and non-overlapping stride.
  387. *
  388. * ON CPU:
  389. * This operator is optimized for some cases(e.g. both dst and last dim of src
  390. * are contiguous)
  391. *
  392. * ON CUDA:
  393. * More contiguous the input/output layouts, higher performance. There is also
  394. * special optimization for broadcast case.
  395. */
  396. class RelayoutForward : public OperatorBase {
  397. DEF_OPR_IMPL(RelayoutForward, OperatorBase, 1, 1);
  398. DEF_OPR_PARAM(Empty);
  399. public:
  400. /*!
  401. * \brief execute relayout opr
  402. *
  403. * This operator should be placed on the same computing device of *dst*.
  404. *
  405. * \param src_handle handle of input tensor; for CUDA d2d copy, the
  406. * src handle can be on a different GPU for copy tensor with
  407. * non-contig dims <= 2
  408. */
  409. virtual void exec(
  410. _megdnn_tensor_in src, _megdnn_tensor_out dst,
  411. Handle* src_handle = nullptr) = 0;
  412. protected:
  413. //! check layout and collapse contiguous
  414. void check_layout_and_canonize(TensorLayout& src, TensorLayout& dst);
  415. };
  416. using Relayout = RelayoutForward;
  417. /**
  418. * \brief Base class for Concat and Split operators
  419. */
  420. class ConcatSplitBase : public OperatorBase {
  421. public:
  422. using Param = param::Axis;
  423. ConcatSplitBase(Handle* handle);
  424. const Param& param() const { return m_param; }
  425. Param& param() { return m_param; }
  426. protected:
  427. void check_layout_common(const TensorLayoutArray& srcs, const TensorLayout& dst);
  428. Param m_param;
  429. /**
  430. * \brief a helper function
  431. *
  432. * A = shape[0] * shape[1] * ... * shape[axis-1]
  433. * B = {srcs[0].shape[axis], srcs[1].shape[axis], ...}
  434. * C = shape[axis+1] * shape[axis+2] * ... * shape[ndim-1]
  435. */
  436. void get_ABC(const TensorShapeArray& srcs, size_t& A, size_t* B, size_t& C);
  437. thin_function<TensorLayout(const TensorND& tensor)> m_get_layout;
  438. thin_function<TensorShape(const TensorLayout& layout)> m_get_shape;
  439. };
  440. class ConcatForward : public ConcatSplitBase {
  441. DEF_OPR_IMPL(ConcatForward, ConcatSplitBase, 1, 1);
  442. public:
  443. /**
  444. * \param[in] srcs a vector containing all inputs to be concatenated
  445. * \param[out] dst the output tensor.
  446. *
  447. * All tensors in srcs and dst should be contiguous.
  448. * All tensors should have the same shape for all axes except
  449. * param().axis.
  450. * For the param().axis-th axis, the axis shape for dst should be the
  451. * sum of corresponding axis shapes for all srcs.
  452. */
  453. virtual void exec(
  454. _megdnn_in const TensorNDArray& srcs, _megdnn_tensor_out dst,
  455. _megdnn_workspace workspace) = 0;
  456. void deduce_layout(const TensorLayoutArray& srcs, TensorLayout& dst);
  457. virtual size_t get_workspace_in_bytes(
  458. const TensorLayoutArray& srcs, const TensorLayout& dst) = 0;
  459. protected:
  460. void check_exec(
  461. const TensorLayoutArray& srcs, const TensorLayout& dst,
  462. size_t workspace_in_bytes);
  463. };
  464. using Concat = ConcatForward;
  465. class SplitForward : public ConcatSplitBase {
  466. DEF_OPR_IMPL(SplitForward, ConcatSplitBase, 1, 1);
  467. public:
  468. /**
  469. * \param[in] src input tensor
  470. * \param[out] dsts a vector containing all splitted result
  471. *
  472. * All tensors in src and dsts should be contiguous.
  473. * All tensors should have the same shape for all axes except
  474. * param().axis.
  475. * For the param().axis-th axis, the axis shape for src should be the
  476. * sum of corresponding axis shapes for all dsts.
  477. */
  478. virtual void exec(
  479. _megdnn_tensor_in src, const TensorNDArray& dsts,
  480. _megdnn_workspace workspace) = 0;
  481. virtual size_t get_workspace_in_bytes(
  482. const TensorLayout& src, const TensorLayoutArray& dsts) = 0;
  483. protected:
  484. void check_exec(
  485. const TensorLayout& src, const TensorLayoutArray& dsts,
  486. size_t workspace_in_bytes);
  487. };
  488. using Split = SplitForward;
  489. /**
  490. * \brief Base class for ParamPackConcat and ParamPackSplit Operators.
  491. *
  492. * ParamPack oprs act like Concat and Split, but they also are optimized for a
  493. * large number of inputs and can handle alignment requirements. Axis is also
  494. * not supported.
  495. *
  496. * The offsets can be generated by gen_offsets().
  497. */
  498. class ParamPackConcatSplitBase : public OperatorBase {
  499. protected:
  500. void check_exec(
  501. const TensorLayout& concated, const TensorLayout& offsets,
  502. const TensorLayout& parts);
  503. public:
  504. using Param = megdnn::param::Empty;
  505. ParamPackConcatSplitBase(Handle* handle) : OperatorBase(handle) {}
  506. //! generate offsets to be used with ParamPackConcat and ParamPackSplit
  507. static std::vector<dt_int32> gen_offsets(
  508. const TensorShapeArray& shapes, size_t alignment, size_t dtype_size);
  509. };
  510. /**
  511. * \brief ParamPackConcat, used for calculating gradient of ParamPackSplit
  512. * Combine multiple gradient tensors into a single large tensor, use copy
  513. * strategy due to AddUpdate or other dynamic situation.
  514. */
  515. class ParamPackConcat : public ParamPackConcatSplitBase {
  516. DEF_OPR_IMPL(ParamPackConcat, ParamPackConcatSplitBase, 2, 1);
  517. public:
  518. /*
  519. * \param[in] srcs: TensorND on cpu. srcs[i] corresponding to the
  520. * address of i-th Tensor.
  521. * \param[in] offsets: with size `2 * srcs.shape[0]`.
  522. * offsets[i * 2] and offsets[i * 2 + 1] means
  523. * the begin and the end of srcs[i]'s offsets in dst
  524. * \param[out] dst: output TensorND, live on cpu or gpu
  525. */
  526. virtual void exec(
  527. _megdnn_tensor_in srcs, _megdnn_tensor_in offsets, _megdnn_tensor_out dst,
  528. _megdnn_workspace workspace) = 0;
  529. virtual size_t get_workspace_in_bytes(
  530. const TensorShapeArray& srcs, const TensorShape& offsets,
  531. const TensorShape& dst) = 0;
  532. };
  533. /**
  534. * \brief base class for Tile and Repeat
  535. */
  536. class TileRepeatBase : public OperatorBase {
  537. public:
  538. TileRepeatBase(Handle* handle) : OperatorBase(handle) {}
  539. struct Param {
  540. TensorShape times;
  541. };
  542. Param& param() { return m_param; }
  543. const Param& param() const { return m_param; }
  544. protected:
  545. void check_layout_fwd(const TensorLayout& src, const TensorLayout& dst);
  546. void deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst);
  547. /**
  548. * Assuming src/dst/times are already simplified on entrance.
  549. */
  550. size_t get_workspace_in_bytes_fwd(
  551. const TensorShape& src, const TensorShape& dst, const TensorShape& times,
  552. DType dtype);
  553. Param m_param;
  554. };
  555. class TileBase : public TileRepeatBase {
  556. public:
  557. TileBase(Handle* handle) : TileRepeatBase(handle) {}
  558. protected:
  559. void simplify_shape(
  560. const TensorShape& src, const TensorShape& dst, const TensorShape& times,
  561. TensorShape& src2, TensorShape& dst2, TensorShape& times2);
  562. /**
  563. * This is a helper function that would facilitate other backends'
  564. * implementation.
  565. */
  566. size_t get_workspace_in_bytes_fwd(const TensorLayout& src, const TensorLayout& dst);
  567. };
  568. class TileForward : public TileBase {
  569. DEF_OPR_IMPL(TileForward, TileBase, 1, 1);
  570. public:
  571. /**
  572. * \brief Tile src times to get dst.
  573. * \param[in] src input tensor
  574. * \param[out] dst output tensor
  575. * \param[out] workspace temporary workspace
  576. *
  577. * src and dst must be contiguous.
  578. * dst.shape should be {src.shape[0]*param().times[0],
  579. * src.shape[1]*param().times[1], ...}
  580. *
  581. * \see http://docs.scipy.org/doc/numpy/reference/generated/numpy.tile.html
  582. *
  583. * Difference between Tile and Repeat:
  584. * Tiling `abc' twice yields `abcabc', whereas repeating `abc' twice
  585. * yields `aabbcc'.
  586. */
  587. virtual void exec(
  588. _megdnn_tensor_in src, _megdnn_tensor_out dst,
  589. _megdnn_workspace workspace) = 0;
  590. void deduce_layout(const TensorLayout& src, TensorLayout& dst);
  591. virtual size_t get_workspace_in_bytes(
  592. const TensorLayout& src, const TensorLayout& dst) = 0;
  593. protected:
  594. void check_exec(
  595. const TensorLayout& src, const TensorLayout& dst,
  596. size_t workspace_in_bytes);
  597. };
  598. using Tile = TileForward;
  599. class TileBackward : public TileBase {
  600. DEF_OPR_IMPL(TileBackward, TileBase, 1, 1);
  601. public:
  602. /**
  603. * \param[in] diff the backpropagated gradient wrt. dst
  604. * \param[out] grad the backpropagated gradient wrt. src
  605. * \param[out] workspace temporary workspace
  606. */
  607. virtual void exec(
  608. _megdnn_tensor_in diff, _megdnn_tensor_out grad,
  609. _megdnn_workspace workspace) = 0;
  610. virtual size_t get_workspace_in_bytes(
  611. const TensorLayout& diff, const TensorLayout& grad) = 0;
  612. protected:
  613. void check_exec(
  614. const TensorLayout& diff, const TensorLayout& grad,
  615. size_t workspace_in_bytes);
  616. };
  617. class RepeatBase : public TileRepeatBase {
  618. public:
  619. RepeatBase(Handle* handle) : TileRepeatBase(handle) {}
  620. protected:
  621. void simplify_shape(
  622. const TensorShape& src, const TensorShape& dst, const TensorShape& times,
  623. TensorShape& src2, TensorShape& dst2, TensorShape& times2);
  624. /**
  625. * This is a helper function that would facilitate other backends'
  626. * implementation.
  627. */
  628. size_t get_workspace_in_bytes_fwd(const TensorLayout& src, const TensorLayout& dst);
  629. };
  630. class RepeatForward : public RepeatBase {
  631. DEF_OPR_IMPL(RepeatForward, RepeatBase, 1, 1);
  632. public:
  633. /**
  634. * \brief Repeat src times to get dst.
  635. * \param[in] src input tensor
  636. * \param[out] dst output tensor
  637. * \param[out] workspace temporary workspace
  638. *
  639. * src and dst must be contiguous.
  640. * dst.shape should be {src.shape[0]*param().times[0],
  641. * src.shape[1]*param().times[1], ...}
  642. *
  643. * \see http://docs.scipy.org/doc/numpy/reference/generated/numpy.repeat.html
  644. * \see TileForward
  645. */
  646. virtual void exec(
  647. _megdnn_tensor_in src, _megdnn_tensor_out dst,
  648. _megdnn_workspace workspace) = 0;
  649. void deduce_layout(const TensorLayout& src, TensorLayout& dst);
  650. virtual size_t get_workspace_in_bytes(
  651. const TensorLayout& src, const TensorLayout& dst) = 0;
  652. protected:
  653. void check_exec(
  654. const TensorLayout& src, const TensorLayout& dst,
  655. size_t workspace_in_bytes);
  656. };
  657. using Repeat = RepeatForward;
  658. class RepeatBackward : public RepeatBase {
  659. DEF_OPR_IMPL(RepeatBackward, RepeatBase, 1, 1);
  660. public:
  661. /**
  662. * \param[in] diff the backpropagated gradient wrt. dst
  663. * \param[out] grad the backpropagated gradient wrt. src
  664. * \param[out] workspace temporary workspace
  665. */
  666. virtual void exec(
  667. _megdnn_tensor_in diff, _megdnn_tensor_out grad,
  668. _megdnn_workspace workspace) = 0;
  669. virtual size_t get_workspace_in_bytes(
  670. const TensorLayout& diff, const TensorLayout& grad) = 0;
  671. protected:
  672. void check_exec(
  673. const TensorLayout& diff, const TensorLayout& grad,
  674. size_t workspace_in_bytes);
  675. };
  676. class ArgsortForward : public OperatorBase {
  677. DEF_OPR_IMPL(ArgsortForward, OperatorBase, 1, 2);
  678. DEF_OPR_PARAM(Argsort);
  679. public:
  680. using Order = Param::Order;
  681. /**
  682. * \param[in] src (m, n)
  683. * \param[out] dst (m, n)
  684. * \param[out] indices (m, n)
  685. *
  686. * src, dst and indices should be contiguous.
  687. * Performing m independent sorting on m arrays of length n.
  688. * Sorting arrays and storing the resulting array in `dst',
  689. * and the corresponding indices in `indices'.
  690. *
  691. * Indices range from 0 to n-1.
  692. *
  693. * Note that indices is a TensorND of type int.
  694. */
  695. virtual void exec(
  696. _megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_tensor_out indices,
  697. _megdnn_workspace workspace) = 0;
  698. void deduce_layout(
  699. const TensorLayout& src, TensorLayout& dst, TensorLayout& indices);
  700. virtual size_t get_workspace_in_bytes(
  701. const TensorLayout& src, const TensorLayout& dst,
  702. const TensorLayout& indices) = 0;
  703. protected:
  704. void check_exec(
  705. const TensorLayout& src, const TensorLayout& dst,
  706. const TensorLayout& indices, size_t workspace_in_bytes);
  707. };
  708. using Argsort = ArgsortForward;
  709. /*!
  710. * \brief backward opr for Argsort
  711. *
  712. * Note: the name is kept for backward compatibility. This opr is actually a
  713. * batched value setter. It is used for gradient computing of Argsort and TopK.
  714. */
  715. class ArgsortBackward : public OperatorBase {
  716. DEF_OPR_IMPL(ArgsortBackward, OperatorBase, 2, 1);
  717. DEF_OPR_PARAM(Empty);
  718. public:
  719. /**
  720. * \param[in] diff (m, k) the backpropagated gradient wrt. dst
  721. * \param[in] indices (m, k) the `indices' parameter in
  722. * ArgsortForward::exec
  723. * \param[out] grad (m, n) the backpropagated gradient wrt. src
  724. *
  725. * Constraint: n >= k. Untouched values would be initialized as zero.
  726. */
  727. virtual void exec(
  728. _megdnn_tensor_in diff, _megdnn_tensor_in indices, _megdnn_tensor_out grad,
  729. _megdnn_workspace workspace) = 0;
  730. virtual size_t get_workspace_in_bytes(
  731. const TensorLayout& diff, const TensorLayout& indices,
  732. const TensorLayout& grad) = 0;
  733. protected:
  734. void check_exec(
  735. const TensorLayout& diff, const TensorLayout& indices,
  736. const TensorLayout& grad, size_t workspace_in_bytes);
  737. };
  738. class TopK : public OperatorBase {
  739. DEF_OPR_IMPL(TopK, OperatorBase, 1, 2);
  740. DEF_OPR_PARAM(TopK);
  741. protected:
  742. //! impl exec; inputs have been validated
  743. virtual void do_exec(
  744. int k, _megdnn_tensor_in data, _megdnn_tensor_out values, int32_t* indices,
  745. _megdnn_workspace workspace) = 0;
  746. public:
  747. /*!
  748. * \param[in] k if positive, compute the smallest top-k values; otherwise
  749. * compute the largest top-k values
  750. * \param[in] data (m, n) input data, where top-k is computed on the
  751. * second axis. The second dimension must be contiguous, and the first
  752. * dimension can have arbitrary stride.
  753. * \param[out] values (m, ) or (m, k) output values; its shape depends
  754. * on mode
  755. * \param[out] indices () or (m, ) or (m, k) output values; its shape
  756. * depends on mode
  757. */
  758. void exec(
  759. int k, _megdnn_tensor_in data, _megdnn_tensor_out values,
  760. _megdnn_tensor_out indices, _megdnn_workspace workspace);
  761. virtual size_t get_workspace_in_bytes(
  762. int k, const TensorLayout& data, const TensorLayout& values,
  763. const TensorLayout& indices) = 0;
  764. void deduce_layout(
  765. int k, const TensorLayout& data, TensorLayout& values,
  766. TensorLayout& indices);
  767. };
  768. /*!
  769. * \brief convert dtype of *src* to match dtype of *dst*; *src* may have
  770. * arbitrary layout and *dst* must be contiguous.
  771. */
  772. class TypeCvtForward : public OperatorBase {
  773. DEF_OPR_PARAM(Empty);
  774. DEF_OPR_IMPL(TypeCvtForward, OperatorBase, 1, 1);
  775. public:
  776. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) = 0;
  777. protected:
  778. void check_exec(const TensorLayout& src, const TensorLayout& dst);
  779. };
  780. using TypeCvt = TypeCvtForward;
  781. class IndexingRemapBase : public OperatorBase {
  782. public:
  783. using Param = param::IndexingRemap;
  784. IndexingRemapBase(Handle* handle) : OperatorBase(handle) {}
  785. Param& param() { return m_param; }
  786. const Param& param() const { return m_param; }
  787. protected:
  788. Param m_param;
  789. void check_layout_fwd(
  790. const TensorLayout& src, const TensorLayout& map, const TensorLayout& dst);
  791. };
  792. class IndexingRemapForward : public IndexingRemapBase {
  793. DEF_OPR_IMPL(IndexingRemapForward, IndexingRemapBase, 2, 1);
  794. public:
  795. /**
  796. * \param[in] src input tensor
  797. * \param[in] map input map
  798. * \param[out] dst output tensor
  799. *
  800. * Suppose:
  801. * the shape of src is \f$(s_0, s_1, ..., s_{m-1}\f$;
  802. * the shape of dst is \f$(d_0, d_1, ..., d_{n-1})\f$;
  803. * then:
  804. * the shape of map must be \f$(d_0, d_1, ..., d_{n-1}, m)\f$.
  805. *
  806. * The last dimension of map indicates the src indices for the
  807. * corresponding dst entry.
  808. *
  809. * src and dst can be non-contiguous in a non-overlapping manner.
  810. */
  811. virtual void exec(
  812. _megdnn_tensor_in src, _megdnn_tensor_in map, _megdnn_tensor_out dst,
  813. _megdnn_workspace workspace) = 0;
  814. void deduce_layout(
  815. const TensorLayout& src, const TensorLayout& map, TensorLayout& dst);
  816. virtual size_t get_workspace_in_bytes(
  817. const TensorLayout& src, const TensorLayout& map,
  818. const TensorLayout& dst) = 0;
  819. protected:
  820. void check_exec(
  821. const TensorLayout& src, const TensorLayout& map, const TensorLayout& dst,
  822. size_t workspace_in_bytes);
  823. };
  824. using IndexingRemap = IndexingRemapForward;
  825. // The using directives preserve backward compatibility.
  826. using TensorRemapForward = IndexingRemap;
  827. using TensorRemap = TensorRemapForward;
  828. class IndexingRemapBackward : public IndexingRemapBase {
  829. DEF_OPR_IMPL(IndexingRemapBackward, IndexingRemapBase, 2, 1);
  830. public:
  831. /**
  832. * \param[in] diff the backpropagated gradient wrt. dst
  833. * \param[in] map the `map' parameter in IndexingRemapForward::exec
  834. * \param[out] grad the backpropagated gradient wrt. src
  835. */
  836. virtual void exec(
  837. _megdnn_tensor_in diff, _megdnn_tensor_in map, _megdnn_tensor_out grad,
  838. _megdnn_workspace workspace) = 0;
  839. virtual size_t get_workspace_in_bytes(
  840. const TensorLayout& diff, const TensorLayout& map,
  841. const TensorLayout& grad) = 0;
  842. protected:
  843. void check_exec(
  844. const TensorLayout& diff, const TensorLayout& map, const TensorLayout& grad,
  845. size_t workspace_in_bytes);
  846. };
  847. // The using directives preserve backward compatibility.
  848. using TensorRemapBackward = IndexingRemapBackward;
  849. class Linspace : public OperatorBase {
  850. DEF_OPR_IMPL(Linspace, OperatorBase, 0, 1);
  851. DEF_OPR_PARAM(LinspaceFull);
  852. public:
  853. /**
  854. * \param[out] dst must be 1d.
  855. *
  856. * \see http://docs.scipy.org/doc/numpy/reference/generated/numpy.linspace.html
  857. */
  858. virtual void exec(_megdnn_tensor_out dst, _megdnn_workspace workspace) = 0;
  859. virtual size_t get_workspace_in_bytes(const TensorLayout& dst) = 0;
  860. protected:
  861. void check_exec(const TensorLayout& dst, size_t workspace_in_bytes);
  862. };
  863. class Eye : public OperatorBase {
  864. DEF_OPR_IMPL(Eye, OperatorBase, 0, 1);
  865. DEF_OPR_PARAM(Eye);
  866. public:
  867. /**
  868. * \see http://docs.scipy.org/doc/numpy/reference/generated/numpy.eye.html
  869. */
  870. virtual void exec(_megdnn_tensor_out dst, _megdnn_workspace workspace) = 0;
  871. virtual size_t get_workspace_in_bytes(const TensorLayout& dst) = 0;
  872. protected:
  873. void check_exec(const TensorLayout& dst, size_t workspace_in_bytes);
  874. };
  875. class IndexingOneHotBase : public OperatorBase {
  876. DEF_OPR_IMPL_CTOR(IndexingOneHotBase, OperatorBase);
  877. DEF_OPR_PARAM(Axis);
  878. protected:
  879. void deduce_layout_fwd(
  880. const TensorLayout& src, const TensorLayout& index, TensorLayout& dst);
  881. void check_layout_fwd(
  882. const TensorLayout& src, const TensorLayout& index,
  883. const TensorLayout& dst);
  884. };
  885. /*!
  886. * \brief Indexing for one-hot encoding
  887. *
  888. * Given src, axis and index,
  889. * for all valid (n-1)-dimensional subscript tuples i iterating through index:
  890. * dst[i[0], ..., i[axis-1], 0, i[axis], ..., i[n-2]] =
  891. * inp[i[0], ..., i[axis-1], index[i], i[axis], ..., i[n-2]]
  892. *
  893. * \param[in] src n-dimensional input data
  894. * \param[in] index (n-1)-dimensional index, must be int
  895. * \param[out] dst n-dimensional output data
  896. */
  897. class IndexingOneHotForward : public IndexingOneHotBase {
  898. DEF_OPR_IMPL(IndexingOneHotForward, IndexingOneHotBase, 2, 1);
  899. public:
  900. void deduce_layout(
  901. const TensorLayout& src, const TensorLayout& index, TensorLayout& dst) {
  902. deduce_layout_fwd(src, index, dst);
  903. }
  904. virtual void exec(
  905. _megdnn_tensor_in src, _megdnn_tensor_in index, _megdnn_tensor_out dst,
  906. _megdnn_workspace workspace) = 0;
  907. virtual size_t get_workspace_in_bytes(
  908. const TensorLayout& src, const TensorLayout& index,
  909. const TensorLayout& dst) = 0;
  910. protected:
  911. void check_exec(
  912. const TensorLayout& src, const TensorLayout& index, const TensorLayout& dst,
  913. size_t workspace_in_bytes);
  914. };
  915. using IndexingOneHot = IndexingOneHotForward;
  916. /*!
  917. * \brief set-subtensor corresponding to IndexingOneHotForward
  918. *
  919. * \param[in,out] data n-dimensional input and output data, whose sub part
  920. * corresponding to *index* would be replaced by *sub*
  921. * \param[in] index (n-1)-dimensional index, must be int
  922. * \param[in] sub n-dimensional sub tensor to be filled in *data*
  923. */
  924. class IndexingSetOneHotForward : public IndexingOneHotBase {
  925. DEF_OPR_IMPL(IndexingSetOneHotForward, IndexingOneHotBase, -1, 1);
  926. public:
  927. virtual void exec(
  928. _megdnn_tensor_inout data, _megdnn_tensor_in index, _megdnn_tensor_in sub,
  929. _megdnn_workspace workspace) = 0;
  930. virtual size_t get_workspace_in_bytes(
  931. const TensorLayout& data, const TensorLayout& index,
  932. const TensorLayout& sub) = 0;
  933. protected:
  934. void check_exec(
  935. const TensorLayout& data, const TensorLayout& index,
  936. const TensorLayout& sub, size_t workspace_in_bytes);
  937. };
  938. using IndexingSetOneHot = IndexingSetOneHotForward;
  939. /*!
  940. * \brief base class for indexing on multiple axes using vector indices
  941. *
  942. * Note that the indexing axes are required to be sorted in ascending order
  943. */
  944. class IndexingMultiAxisVecBase : public OperatorBase {
  945. DEF_OPR_IMPL_CTOR(IndexingMultiAxisVecBase, OperatorBase);
  946. DEF_OPR_PARAM(Empty);
  947. public:
  948. struct AxisIndexer {
  949. size_t axis;
  950. TensorND vec;
  951. };
  952. struct AxisIndexerLayoutOnly {
  953. size_t axis;
  954. TensorLayout layout;
  955. };
  956. using IndexDesc = std::vector<AxisIndexer>;
  957. using IndexDescLayoutOnly = std::vector<AxisIndexerLayoutOnly>;
  958. /*!
  959. * \brief convert IndexDesc to IndexDescLayoutOnly
  960. */
  961. static IndexDescLayoutOnly extract_index_layout(const IndexDesc& index);
  962. /*!
  963. * \brief get the axes on src that are not used in index
  964. * \param[out] out output buffer; suggested size is
  965. * TensorLayout::MAX_NDIM
  966. * \return number of elements written to *out*
  967. */
  968. static size_t get_nonindex_axes(
  969. size_t src_ndim, const IndexDesc& index, size_t* out);
  970. /*!
  971. * \brief get contiguous-collapsed layout for indexing on value
  972. * \param idx_axis indexer axis on value (i.e. ExecInfo::idx_axis)
  973. * \return a tensor layout and an axis to iterate over *value* and also
  974. * access *data*; stride of layout on that axis would be zero, and
  975. * strides on other axes correspond to the strides in *data*
  976. */
  977. static std::pair<TensorLayout, size_t> get_value_iter_optimized_layout(
  978. const TensorLayout& data, const TensorLayout& value, const IndexDesc& index,
  979. size_t idx_axis);
  980. //! helper info for kernel implementation
  981. struct ExecInfo {
  982. //! axis in value used by indexer
  983. size_t idx_axis;
  984. ptrdiff_t value_stride;
  985. void* error_tracker;
  986. megcore::AsyncErrorInfo* error_info;
  987. };
  988. protected:
  989. /*!
  990. * \return axis on dst used by indexer (i.e. ExecInfo::idx_axis)
  991. */
  992. static size_t deduce_layout_fwd(
  993. const TensorLayout& data, const IndexDescLayoutOnly& index,
  994. TensorLayout& dst);
  995. static ExecInfo check_exec_noworkspace(
  996. const TensorLayout& data, const TensorLayout& value, const IndexDesc& index,
  997. IndexDescLayoutOnly& index_layout);
  998. };
  999. /*!
  1000. * \brief compute indexing result, like numpy advanced indexing
  1001. *
  1002. * src can have arbitrary layout, but dst must be dim1-contig
  1003. */
  1004. class IndexingMultiAxisVec : public IndexingMultiAxisVecBase {
  1005. DEF_OPR_IMPL(IndexingMultiAxisVec, IndexingMultiAxisVecBase, 0, 1);
  1006. public:
  1007. virtual void exec(
  1008. _megdnn_tensor_in src, const IndexDesc& index, _megdnn_tensor_out dst,
  1009. _megdnn_workspace workspace) = 0;
  1010. /*!
  1011. * \brief get workspace size based on output shape and indexing axes
  1012. */
  1013. size_t get_workspace_in_bytes(
  1014. const TensorShape& dst, const size_t* axes, size_t nr_axes);
  1015. static void deduce_layout(
  1016. const TensorLayout& data, const IndexDescLayoutOnly& index,
  1017. TensorLayout& dst) {
  1018. deduce_layout_fwd(data, index, dst);
  1019. }
  1020. protected:
  1021. virtual size_t get_workspace_in_bytes(size_t dst_idx_size) = 0;
  1022. ExecInfo check_exec(
  1023. const TensorLayout& src, const IndexDesc& index, const TensorLayout& dst,
  1024. size_t workspace_in_bytes);
  1025. };
  1026. /*!
  1027. * \brief base class for modifying data by given index
  1028. *
  1029. * data can have arbitrary layout, but value must be dim1-contig
  1030. */
  1031. class IndexingModifyMultiAxisVecBase : public IndexingMultiAxisVecBase {
  1032. DEF_OPR_IMPL_CTOR(IndexingModifyMultiAxisVecBase, IndexingMultiAxisVecBase);
  1033. public:
  1034. virtual void exec(
  1035. _megdnn_tensor_inout data, _megdnn_tensor_in value, const IndexDesc& index,
  1036. _megdnn_workspace workspace) = 0;
  1037. /*!
  1038. * \brief get workspace size based on shape of value input and indexing
  1039. * axes
  1040. */
  1041. size_t get_workspace_in_bytes(
  1042. const TensorShape& value, const size_t* axes, size_t nr_axes);
  1043. protected:
  1044. ExecInfo check_exec(
  1045. const TensorLayout& data, const TensorLayout& value, const IndexDesc& index,
  1046. size_t workspace_in_bytes);
  1047. virtual size_t get_workspace_in_bytes(size_t value_idx_size) = 0;
  1048. };
  1049. //! set value to indexed locations; index values must be non-overlapping
  1050. class IndexingSetMultiAxisVec : public IndexingModifyMultiAxisVecBase {
  1051. DEF_OPR_IMPL(IndexingSetMultiAxisVec, IndexingModifyMultiAxisVecBase, 0, 0);
  1052. };
  1053. //! add value to indexed locations; index values must be non-overlapping
  1054. class IndexingIncrMultiAxisVec : public IndexingModifyMultiAxisVecBase {
  1055. DEF_OPR_IMPL(IndexingIncrMultiAxisVec, IndexingModifyMultiAxisVecBase, 0, 0);
  1056. };
  1057. class MeshBase : public OperatorBase {
  1058. DEF_OPR_PARAM(Empty);
  1059. DEF_OPR_IMPL_CTOR(MeshBase, OperatorBase);
  1060. public:
  1061. using AxisIndexer = IndexingMultiAxisVecBase::AxisIndexer;
  1062. using IndexDesc = IndexingMultiAxisVecBase::IndexDesc;
  1063. using AxisIndexerLayoutOnly = IndexingMultiAxisVecBase::AxisIndexerLayoutOnly;
  1064. using IndexDescLayoutOnly = IndexingMultiAxisVecBase::IndexDescLayoutOnly;
  1065. size_t get_workspace_in_bytes(const TensorShape&, const size_t*, size_t) {
  1066. return 0;
  1067. }
  1068. protected:
  1069. virtual void check_exec(
  1070. const TensorLayout& origin, const TensorLayout& indexed,
  1071. const IndexDesc& desc);
  1072. };
  1073. class NormalMeshBase : public MeshBase {
  1074. DEF_OPR_IMPL(NormalMeshBase, MeshBase, 0, 0);
  1075. protected:
  1076. virtual void check_exec(
  1077. const TensorLayout& origin, const TensorLayout& indexed,
  1078. const IndexDesc& desc) override final;
  1079. };
  1080. class NormalMeshModifyBase : public NormalMeshBase {
  1081. DEF_OPR_IMPL_CTOR(NormalMeshModifyBase, NormalMeshBase);
  1082. public:
  1083. virtual void exec(
  1084. _megdnn_tensor_inout data, _megdnn_tensor_in value, const IndexDesc& desc,
  1085. _megdnn_workspace workspace) = 0;
  1086. };
  1087. class BatchedMeshBase : public MeshBase {
  1088. DEF_OPR_IMPL_CTOR(BatchedMeshBase, MeshBase);
  1089. protected:
  1090. virtual void check_exec(
  1091. const TensorLayout& origin, const TensorLayout& indexed,
  1092. const IndexDesc& desc) override final;
  1093. };
  1094. class BatchedMeshModifyBase : public BatchedMeshBase {
  1095. DEF_OPR_IMPL_CTOR(BatchedMeshModifyBase, BatchedMeshBase);
  1096. public:
  1097. virtual void exec(
  1098. _megdnn_tensor_inout data, _megdnn_tensor_in value, const IndexDesc& desc,
  1099. _megdnn_workspace workspace) = 0;
  1100. };
  1101. class MeshIndexing : public NormalMeshBase {
  1102. DEF_OPR_IMPL(MeshIndexing, NormalMeshBase, 0, 0);
  1103. public:
  1104. virtual void exec(
  1105. _megdnn_tensor_in src, const IndexDesc& desc, _megdnn_tensor_out dst,
  1106. _megdnn_workspace workspace) = 0;
  1107. static void deduce_layout(
  1108. const TensorLayout& inp, const IndexDescLayoutOnly& layouts,
  1109. TensorLayout& out_layout);
  1110. };
  1111. class IncrMeshIndexing : public NormalMeshModifyBase {
  1112. DEF_OPR_IMPL(IncrMeshIndexing, NormalMeshModifyBase, 0, 0);
  1113. };
  1114. class SetMeshIndexing : public NormalMeshModifyBase {
  1115. DEF_OPR_IMPL(SetMeshIndexing, NormalMeshModifyBase, 0, 0);
  1116. };
  1117. class BatchedMeshIndexing : public BatchedMeshBase {
  1118. DEF_OPR_IMPL(BatchedMeshIndexing, BatchedMeshBase, 0, 0);
  1119. public:
  1120. virtual void exec(
  1121. _megdnn_tensor_in src, const IndexDesc& desc, _megdnn_tensor_out dst,
  1122. _megdnn_workspace workspace) = 0;
  1123. static void deduce_layout(
  1124. const TensorLayout& inp, const IndexDescLayoutOnly& layouts,
  1125. TensorLayout& out_layout);
  1126. };
  1127. class BatchedIncrMeshIndexing : public BatchedMeshModifyBase {
  1128. DEF_OPR_IMPL(BatchedIncrMeshIndexing, BatchedMeshModifyBase, 0, 0);
  1129. };
  1130. class BatchedSetMeshIndexing : public BatchedMeshModifyBase {
  1131. DEF_OPR_IMPL(BatchedSetMeshIndexing, BatchedMeshModifyBase, 0, 0);
  1132. };
  1133. class RelayoutFormat : public OperatorBase {
  1134. DEF_OPR_PARAM(RelayoutFormat);
  1135. DEF_OPR_IMPL(RelayoutFormat, OperatorBase, 1, 1);
  1136. public:
  1137. virtual void exec(
  1138. _megdnn_tensor_in src, _megdnn_tensor_out dst,
  1139. _megdnn_workspace workspace) = 0;
  1140. void deduce_layout(const TensorLayout& src, TensorLayout& dst);
  1141. void deduce_format(TensorFormat src, TensorFormat& dst);
  1142. virtual size_t get_workspace_in_bytes(
  1143. const TensorLayout& src, const TensorLayout& dst) = 0;
  1144. protected:
  1145. void deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst);
  1146. void check_layout_fwd(const TensorLayout& src, const TensorLayout& dst);
  1147. void check_exec(
  1148. const TensorLayout& src, const TensorLayout& dst,
  1149. size_t workspace_in_bytes);
  1150. void deduce_exec_layout(
  1151. const TensorLayout& src, const TensorLayout& dst,
  1152. TensorLayout& exec_workspace, TensorLayout& exec_src,
  1153. TensorLayout& exec_dst);
  1154. };
  1155. /*!
  1156. * \brief check whether input contains inf or nan value.
  1157. */
  1158. class CheckNonFinite : public OperatorBase {
  1159. DEF_OPR_PARAM(Empty);
  1160. DEF_OPR_IMPL(CheckNonFinite, OperatorBase, 1, 1);
  1161. public:
  1162. virtual size_t get_workspace_in_bytes(
  1163. const TensorLayout& src, const TensorLayout& dst) = 0;
  1164. void deduce_layout(const TensorLayout& src, TensorLayout& dst);
  1165. virtual void exec(
  1166. _megdnn_tensor_in src, _megdnn_tensor_out dst,
  1167. _megdnn_workspace workspace) = 0;
  1168. protected:
  1169. void check_exec(
  1170. const TensorLayout& src, const TensorLayout& dst,
  1171. size_t workspace_in_bytes);
  1172. };
  1173. /*!
  1174. * \brief fill the tensor with a scalar value
  1175. */
  1176. class Fill : public OperatorBase {
  1177. DEF_OPR_PARAM(Fill);
  1178. DEF_OPR_IMPL(Fill, OperatorBase, 0, 1);
  1179. public:
  1180. virtual void exec(_megdnn_tensor_out dst, _megdnn_workspace workspace) = 0;
  1181. virtual size_t get_workspace_in_bytes(const TensorLayout& dst) = 0;
  1182. protected:
  1183. void check_exec(const TensorLayout& dst, size_t workspace_in_bytes);
  1184. };
  1185. /*!
  1186. * \brief standard padding operator
  1187. * Inputs must have the same dtype, and the output tensor shape must greater or equal
  1188. * than input tensor in every dimensions, the extra space will be fulled with m which
  1189. * default to be 0.
  1190. */
  1191. class PaddingBase : public OperatorBase {
  1192. DEF_OPR_PARAM(Padding);
  1193. DEF_OPR_IMPL(PaddingBase, OperatorBase, 1, 1);
  1194. public:
  1195. using Mode = Param::PaddingMode;
  1196. protected:
  1197. SmallVector<size_t> get_offsets();
  1198. void check_exec(const TensorLayout& src, const TensorLayout& dst);
  1199. };
  1200. class PaddingForward : public PaddingBase {
  1201. DEF_OPR_IMPL(PaddingForward, PaddingBase, 1, 1);
  1202. public:
  1203. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) = 0;
  1204. void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace) {
  1205. return exec(src, dst);
  1206. }
  1207. virtual size_t get_workspace_in_bytes(
  1208. const TensorLayout& src, const TensorLayout& dst) = 0;
  1209. void deduce_layout(const TensorLayout& src, TensorLayout& dst);
  1210. protected:
  1211. void forward_check_exec(const TensorLayout& src, const TensorLayout& dst);
  1212. };
  1213. using Padding = PaddingForward;
  1214. class PaddingBackward : public PaddingBase {
  1215. DEF_OPR_IMPL(PaddingBackward, PaddingBase, 1, 1);
  1216. public:
  1217. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) = 0;
  1218. void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace) {
  1219. return exec(src, dst);
  1220. }
  1221. virtual size_t get_workspace_in_bytes(
  1222. const TensorLayout& src, const TensorLayout& dst) = 0;
  1223. protected:
  1224. void backward_check_exec(const TensorLayout& src, const TensorLayout& dst);
  1225. };
  1226. } // namespace megdnn
  1227. #include "megdnn/internal/opr_header_epilogue.h"
  1228. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台