You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

general.h 47 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270
  1. /**
  2. * \file dnn/include/megdnn/oprs/general.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #pragma once
  12. #include "megdnn/internal/opr_header_prologue.h"
  13. #include "megdnn/thin/small_vector.h"
  14. namespace megdnn {
  15. /*!
  16. * \brief standard element-wise operator
  17. *
  18. * Inputs must have same dtype, and their shapes must broadcastable into a final
  19. * shape. They can have arbitrary layouts, but non-contiguous and non-broadcast
  20. * layouts may harm performance seriously.
  21. *
  22. * Output dtype is the same as input dtype (note that even for compare oprs this
  23. * is true, e.g. float == float returns value of float). Output layout must be
  24. * contiguous.
  25. */
  26. class ElemwiseForward: public OperatorBase {
  27. DEF_OPR_PARAM(Elemwise);
  28. DEF_OPR_IMPL(ElemwiseForward, OperatorBase, -1, 1);
  29. public:
  30. using Mode = Param::Mode;
  31. //! information about a mode
  32. struct ModeTrait {
  33. uint32_t arity; //!< number of inputs needed
  34. bool commutable; //!< whether arity == 2 and inputs commutable
  35. bool allow_int; //!< whether int inputs allowed
  36. bool allow_float; //!< whether float inputs allowed
  37. const char* name; //!< name of the mode
  38. ModeTrait():
  39. arity(0), commutable(0), allow_int(0), allow_float(0),
  40. name(NULL)
  41. {}
  42. //! get trait from a mode; this function is thread safe
  43. static const ModeTrait& from_mode(Mode mode);
  44. };
  45. //! get trait of current mode
  46. const ModeTrait& mode_trait() const {
  47. return ModeTrait::from_mode(m_param.mode);
  48. }
  49. /**
  50. * \param[in] src input tensor
  51. * \param[out] dst output tensor
  52. *
  53. * src and dst should have the same shape;
  54. * layouts should be contiguous;
  55. * the underlying data pointer can point to the same memory region for
  56. * src and dst.
  57. */
  58. virtual void exec(_megdnn_in const TensorNDArray &src,
  59. _megdnn_tensor_out dst) = 0;
  60. //! deduce output shape (do not check whether arity matches)
  61. static void deduce_shape(
  62. const TensorShapeArray &src,
  63. TensorShape &dst);
  64. static void deduce_format(const TensorFormatArray& src,
  65. TensorFormat& dst);
  66. //! deduce output layout
  67. void deduce_layout(const TensorLayoutArray &src,
  68. TensorLayout &dst);
  69. protected:
  70. //! throw exception if incorrect layout; broadcast input shape to
  71. //! output shape
  72. void check_layout_and_broadcast(
  73. const TensorLayoutPtrArray &src, const TensorLayout &dst);
  74. private:
  75. void check_dtype(DType dtype);
  76. };
  77. using Elemwise = ElemwiseForward;
  78. /*!
  79. * \brief compute ``x**a`` where ``a`` is a constant from the Param
  80. *
  81. * This opr is usually not directly accessible by the end user and it is created
  82. * by mgb optimizer, aiming to work around numerical stability issues with pow.
  83. * For example ``powf(x, 2.f)`` with ``x < 0`` in fast math mode may return NaN.
  84. *
  85. * Like elemwise, this opr supports arbitrary strides. But it should only be
  86. * used with monotone strides. Input and output should have the same
  87. * float-category dtype.
  88. */
  89. class PowC : public OperatorBase {
  90. DEF_OPR_PARAM(PowC);
  91. DEF_OPR_IMPL(PowC, OperatorBase, 1, 1);
  92. public:
  93. void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst);
  94. //! compatible API for mgb; workspace is not used
  95. void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
  96. _megdnn_workspace) {
  97. return exec(src, dst);
  98. }
  99. size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&) {
  100. // the impls should require no workspace; this can be later changed to a
  101. // virtual function if this situation changes
  102. return 0;
  103. }
  104. void deduce_layout(const TensorLayout& src, TensorLayout& dst) {
  105. dst.dtype = src.dtype;
  106. dst.init_contiguous_stride(src);
  107. }
  108. protected:
  109. /*!
  110. * Perform the computing where layouts have been verified.
  111. *
  112. * \p src can have arbitrary layout, and \p dst is contiguous. They have the
  113. * same shape and dtype.
  114. *
  115. * The implementation should not access param(). It should check \p exp_f
  116. * and \p exp_i for the exponent value. Exactly one of them would be
  117. * non-null.
  118. *
  119. * Note: \p exp_f and \p exp_i must be dereferenced before dispatching any
  120. * kernel. They are allocated on the caller's stack.
  121. */
  122. virtual void do_exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
  123. const float* exp_f, const int* exp_i) = 0;
  124. };
  125. /*!
  126. * \brief modify a tensor inplace by adding another tensor to it
  127. *
  128. * dst and delta can have arbitrary layout but must have the same shape.
  129. */
  130. class AddUpdateForward: public OperatorBase {
  131. DEF_OPR_PARAM(AddUpdate);
  132. DEF_OPR_IMPL(AddUpdateForward, OperatorBase, -1, 1);
  133. public:
  134. virtual void exec(
  135. _megdnn_tensor_inout dst, _megdnn_tensor_in delta) = 0;
  136. protected:
  137. void check_exec(const TensorLayout &dst, const TensorLayout &delta);
  138. };
  139. using AddUpdate = AddUpdateForward;
  140. class ReduceForward: public OperatorBase {
  141. DEF_OPR_PARAM(Reduce);
  142. DEF_OPR_IMPL(ReduceForward, OperatorBase, 1, 1);
  143. public:
  144. using Mode = Param::Mode;
  145. using DataType = Param::DataType;
  146. /**
  147. * \param[in] src input tensor
  148. * \param[out] dst output tensor
  149. *
  150. * src and dst should be contiguous.
  151. * src and dst should be of the same shape for all dimensions except
  152. * param().axis.
  153. * the param().axis-th dimension shape for dst should be one.
  154. */
  155. virtual void exec(_megdnn_tensor_in src,
  156. _megdnn_tensor_out dst,
  157. _megdnn_workspace workspace) = 0;
  158. void deduce_layout(const TensorLayout &src, TensorLayout &dst);
  159. virtual size_t get_workspace_in_bytes(const TensorLayout &src,
  160. const TensorLayout &dst) = 0;
  161. protected:
  162. void check_exec(const TensorLayout &src, const TensorLayout &dst,
  163. size_t workspace_in_bytes);
  164. };
  165. using Reduce = ReduceForward;
  166. class CumsumForward: public OperatorBase {
  167. DEF_OPR_PARAM(Cumsum);
  168. DEF_OPR_IMPL(CumsumForward, OperatorBase, 1, 1);
  169. public:
  170. /**
  171. * \param[in] src input tensor
  172. * \param[out] dst output tensor
  173. *
  174. * src and dst should be contiguous.
  175. * src and dst should have the same shape.
  176. *
  177. * The exclusive flag specifies whether the current element it taken
  178. * into account when calculating results.
  179. *
  180. * The reverse flag specifies whether cumsum is forward (
  181. * from 0 to n) or backward (from n downto 0).
  182. *
  183. * Example:
  184. * exclusive && reverse:
  185. * dst_i = src_{i+1} + src_{i+2} + ... + src_{n-1}
  186. * exclusive && !reverse
  187. * dst_i = src_0 + src_1 + ... + src_{i-1}
  188. * !exclusive && reverse:
  189. * dst_i = src_i + src_{i+1} + ... + src_{n-1}
  190. * !exclusive && !reverse:
  191. * dst_i = src_0 + src_1 + ... + src_i
  192. */
  193. virtual void exec(_megdnn_tensor_in src,
  194. _megdnn_tensor_out dst,
  195. _megdnn_workspace workspace) = 0;
  196. void deduce_layout(const TensorLayout &src, TensorLayout &dst);
  197. virtual size_t get_workspace_in_bytes(const TensorLayout &src,
  198. const TensorLayout &dst) = 0;
  199. protected:
  200. void check_exec(const TensorLayout &src, const TensorLayout &dst,
  201. size_t workspace_in_bytes);
  202. };
  203. using Cumsum = CumsumForward;
  204. // mxx can be max or min
  205. class ArgmxxBase: public OperatorBase {
  206. DEF_OPR_IMPL_CTOR(ArgmxxBase, OperatorBase);
  207. DEF_OPR_PARAM(Axis);
  208. protected:
  209. void check_layout_fwd(const TensorLayout &src,
  210. const TensorLayout &dst);
  211. };
  212. class ArgmaxForward: public ArgmxxBase {
  213. DEF_OPR_IMPL(ArgmaxForward, ArgmxxBase, 1, 1);
  214. public:
  215. /**
  216. * \param[in] src input tensor
  217. * \param[out] dst output tensor containing the argmax indices
  218. *
  219. * src and dst should be contiguous.
  220. * src and dst should be of the same shape for all dimensions except
  221. * param().axis.
  222. * the param().axis-th dimension shape for dst should be one.
  223. */
  224. virtual void exec(_megdnn_tensor_in src,
  225. _megdnn_tensor_out dst,
  226. _megdnn_workspace workspace) = 0;
  227. void deduce_layout(const TensorLayout &src,
  228. TensorLayout &dst);
  229. virtual size_t get_workspace_in_bytes(const TensorLayout &src,
  230. const TensorLayout &dst) = 0;
  231. protected:
  232. void check_exec(const TensorLayout &src,
  233. const TensorLayout &dst,
  234. size_t workspace_in_bytes);
  235. };
  236. using Argmax = ArgmaxForward;
  237. class ArgminForward: public ArgmxxBase {
  238. DEF_OPR_IMPL(ArgminForward, ArgmxxBase, 1, 1);
  239. public:
  240. /**
  241. * \param[in] src input tensor
  242. * \param[out] dst output tensor containing the argmax indices
  243. *
  244. * src and dst should be contiguous.
  245. * src and dst should be of the same shape for all dimensions except
  246. * param().axis.
  247. * the param().axis-th dimension shape for dst should be one.
  248. */
  249. virtual void exec(_megdnn_tensor_in src,
  250. _megdnn_tensor_out dst,
  251. _megdnn_workspace workspace) = 0;
  252. void deduce_layout(const TensorLayout &src,
  253. TensorLayout &dst);
  254. virtual size_t get_workspace_in_bytes(const TensorLayout &src,
  255. const TensorLayout &dst) = 0;
  256. protected:
  257. void check_exec(const TensorLayout &src,
  258. const TensorLayout &dst,
  259. size_t workspace_in_bytes);
  260. };
  261. using Argmin = ArgminForward;
  262. /*!
  263. * \brief take values from input according to given condition
  264. *
  265. * Output two tensors:
  266. * 1. values copied from *data*, with same dtype as *data*
  267. * 2. selected indices with dtype int32; note that it is 1-dimensional and
  268. * based on the flatten input.
  269. *
  270. * Require data and mask to have the same shape and both be contiguous.
  271. */
  272. class CondTake : public OperatorBase {
  273. DEF_OPR_IMPL(CondTake, OperatorBase, 2, 2);
  274. DEF_OPR_PARAM(CondTake);
  275. public:
  276. using Output = std::array<TensorND, 2>;
  277. using OutputDType = std::array<DType, 2>;
  278. OutputDType infer_dtype(DType data, DType mask);
  279. virtual size_t get_workspace_in_bytes(const TensorLayout& data) = 0;
  280. virtual Output exec(_megdnn_tensor_in data, _megdnn_tensor_in mask,
  281. _megdnn_workspace workspace,
  282. DynOutMallocPolicyCall malloc_policy) = 0;
  283. protected:
  284. //! check input layouts and get flattened size
  285. size_t check_exec_get_size(const TensorLayout& data,
  286. const TensorLayout& mask,
  287. size_t workspace_in_bytes);
  288. };
  289. class TransposeForward: public OperatorBase {
  290. DEF_OPR_IMPL(TransposeForward, OperatorBase, 1, 1);
  291. DEF_OPR_PARAM(Empty);
  292. public:
  293. /**
  294. * \param[in] src (m, n) stride[0] >= n && stride[1] == 1
  295. * \param[out] dst (n, m) stride[0] >= m && stride[1] == 1
  296. */
  297. virtual void exec(_megdnn_tensor_in src,
  298. _megdnn_tensor_out dst,
  299. _megdnn_workspace workspace) = 0;
  300. void deduce_layout(const TensorLayout &src, TensorLayout &dst);
  301. virtual size_t get_workspace_in_bytes(const TensorLayout &src,
  302. const TensorLayout &dst) = 0;
  303. protected:
  304. void check_exec(const TensorLayout &src, const TensorLayout &dst,
  305. size_t workspace_in_bytes);
  306. };
  307. using Transpose = TransposeForward;
  308. /**
  309. * Change a tensor to another layout that has the same dtype and total number of
  310. * elements, and non-overlapping stride.
  311. *
  312. * ON CPU:
  313. * This operator is optimized for some cases(e.g. both dst and last dim of src
  314. * are contiguous)
  315. *
  316. * ON CUDA:
  317. * More contiguous the input/output layouts, higher performance. There is also
  318. * special optimization for broadcast case.
  319. */
  320. class RelayoutForward: public OperatorBase {
  321. DEF_OPR_IMPL(RelayoutForward, OperatorBase, 1, 1);
  322. DEF_OPR_PARAM(Empty);
  323. public:
  324. /*!
  325. * \brief execute relayout opr
  326. *
  327. * This operator should be placed on the same computing device of *dst*.
  328. *
  329. * \param src_handle handle of input tensor; for CUDA d2d copy, the
  330. * src handle can be on a different GPU for copy tensor with
  331. * non-contig dims <= 2
  332. */
  333. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
  334. Handle *src_handle = nullptr) = 0;
  335. protected:
  336. //! check layout and collapse contiguous
  337. void check_layout_and_canonize(
  338. TensorLayout &src, TensorLayout &dst);
  339. };
  340. using Relayout = RelayoutForward;
  341. /**
  342. * \brief Base class for Concat and Split operators
  343. */
  344. class ConcatSplitBase: public OperatorBase {
  345. public:
  346. using Param = param::Axis;
  347. ConcatSplitBase(Handle *handle);
  348. const Param &param() const { return m_param; }
  349. Param &param() { return m_param; }
  350. protected:
  351. void check_layout_common(const TensorLayoutArray &srcs,
  352. const TensorLayout &dst);
  353. Param m_param;
  354. /**
  355. * \brief a helper function
  356. *
  357. * A = shape[0] * shape[1] * ... * shape[axis-1]
  358. * B = {srcs[0].shape[axis], srcs[1].shape[axis], ...}
  359. * C = shape[axis+1] * shape[axis+2] * ... * shape[ndim-1]
  360. */
  361. void get_ABC(const TensorShapeArray &srcs,
  362. size_t &A,
  363. size_t *B,
  364. size_t &C);
  365. thin_function<TensorLayout(const TensorND &tensor)> m_get_layout;
  366. thin_function<TensorShape(const TensorLayout &layout)> m_get_shape;
  367. };
  368. class ConcatForward: public ConcatSplitBase {
  369. DEF_OPR_IMPL(ConcatForward, ConcatSplitBase, 1, 1);
  370. public:
  371. /**
  372. * \param[in] srcs a vector containing all inputs to be concatenated
  373. * \param[out] dst the output tensor.
  374. *
  375. * All tensors in srcs and dst should be contiguous.
  376. * All tensors should have the same shape for all axes except
  377. * param().axis.
  378. * For the param().axis-th axis, the axis shape for dst should be the
  379. * sum of corresponding axis shapes for all srcs.
  380. */
  381. virtual void exec(_megdnn_in const TensorNDArray &srcs,
  382. _megdnn_tensor_out dst,
  383. _megdnn_workspace workspace) = 0;
  384. void deduce_layout(const TensorLayoutArray &srcs,
  385. TensorLayout &dst);
  386. virtual size_t get_workspace_in_bytes(
  387. const TensorLayoutArray &srcs,
  388. const TensorLayout &dst) = 0;
  389. protected:
  390. void check_exec(const TensorLayoutArray &srcs,
  391. const TensorLayout &dst,
  392. size_t workspace_in_bytes);
  393. };
  394. using Concat = ConcatForward;
  395. class SplitForward: public ConcatSplitBase {
  396. DEF_OPR_IMPL(SplitForward, ConcatSplitBase, 1, 1);
  397. public:
  398. /**
  399. * \param[in] src input tensor
  400. * \param[out] dsts a vector containing all splitted result
  401. *
  402. * All tensors in src and dsts should be contiguous.
  403. * All tensors should have the same shape for all axes except
  404. * param().axis.
  405. * For the param().axis-th axis, the axis shape for src should be the
  406. * sum of corresponding axis shapes for all dsts.
  407. */
  408. virtual void exec(_megdnn_tensor_in src,
  409. const TensorNDArray &dsts,
  410. _megdnn_workspace workspace) = 0;
  411. virtual size_t get_workspace_in_bytes(const TensorLayout &src,
  412. const TensorLayoutArray &dsts) = 0;
  413. protected:
  414. void check_exec(const TensorLayout &src,
  415. const TensorLayoutArray &dsts,
  416. size_t workspace_in_bytes);
  417. };
  418. using Split = SplitForward;
  419. /**
  420. * \brief Base class for ParamPackConcat and ParamPackSplit Operators.
  421. *
  422. * ParamPack oprs act like Concat and Split, but they also are optimized for a
  423. * large number of inputs and can handle alignment requirements. Axis is also
  424. * not supported.
  425. *
  426. * The offsets can be generated by gen_offsets(). The \p srcs in ParamPackSplit and
  427. * \p dsts in ParamPackConcat must be on CPU, and must remain valid until the
  428. * execution stream is synchronized.
  429. */
  430. class ParamPackConcatSplitBase : public OperatorBase {
  431. protected:
  432. void check_exec(const TensorLayout& concated, const TensorLayout& offsets,
  433. const TensorLayout& parts);
  434. public:
  435. using Param = megdnn::param::Empty;
  436. ParamPackConcatSplitBase(Handle* handle) : OperatorBase(handle) {}
  437. //! generate offsets to be used with ParamPackConcat and ParamPackSplit
  438. static std::vector<dt_int32> gen_offsets(const TensorShapeArray& shapes,
  439. size_t alignment,
  440. size_t dtype_size);
  441. };
  442. /**
  443. * \brief ParamPackConcat, used for calculating gradient of ParamPackSplit
  444. * Combine multiple gradient tensors into a single large tensor, use copy
  445. * strategy due to AddUpdate or other dynamic situation.
  446. */
  447. class ParamPackConcat: public ParamPackConcatSplitBase {
  448. DEF_OPR_IMPL(ParamPackConcat, ParamPackConcatSplitBase, 2, 1);
  449. public:
  450. /*
  451. * \param[in] srcs: TensorND on cpu. srcs[i] corresponding to the
  452. * address of i-th Tensor.
  453. * \param[in] table: with size `2 * srcs.nr_total_elems()`.
  454. * table[addr] corresponding to outer_idx,
  455. * table[addr+srcs.nr_total_elems()] corresponding to
  456. * inner_idx of dsts.
  457. * \param[out] dst: output TensorND, live on cpu or gpu
  458. */
  459. virtual void exec(_megdnn_tensor_in srcs, _megdnn_tensor_in table,
  460. _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0;
  461. virtual size_t get_workspace_in_bytes(const TensorShapeArray& srcs,
  462. const TensorShape& table,
  463. const TensorShape& dst) = 0;
  464. };
  465. /**
  466. * \brief ParamPackSplit, used for network forwarding.
  467. * Split a single large param into several small tensors, use copy stategy
  468. * either.
  469. */
  470. class ParamPackSplit: public ParamPackConcatSplitBase {
  471. DEF_OPR_IMPL(ParamPackSplit, ParamPackConcatSplitBase, 2, 1);
  472. public:
  473. /*
  474. * \param[in] src: input TensorND, live on cpu or gpu
  475. * \param[in] table: with size `2 * srcs.nr_total_elems()`.
  476. * table[addr] corresponding to outer_idx,
  477. * table[addr+srcs.nr_total_elems()] corresponding to
  478. * inner_idx of dsts.
  479. * \param[out] dsts: TensorND on cpu. dsts[i] corresponding to the address
  480. * of i-th Tensor
  481. */
  482. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in table,
  483. _megdnn_tensor_out dsts, _megdnn_workspace workspace) = 0;
  484. virtual size_t get_workspace_in_bytes(const TensorShape& src,
  485. const TensorShape& table,
  486. const TensorShapeArray& dsts) = 0;
  487. };
  488. /**
  489. * \brief base class for Tile and Repeat
  490. */
  491. class TileRepeatBase: public OperatorBase {
  492. public:
  493. TileRepeatBase(Handle *handle): OperatorBase(handle) {}
  494. struct Param {
  495. TensorShape times;
  496. };
  497. Param &param() { return m_param; }
  498. const Param &param() const { return m_param; }
  499. protected:
  500. void check_layout_fwd(const TensorLayout &src,
  501. const TensorLayout &dst);
  502. void deduce_layout_fwd(const TensorLayout &src,
  503. TensorLayout &dst);
  504. /**
  505. * Assuming src/dst/times are already simplified on entrance.
  506. */
  507. size_t get_workspace_in_bytes_fwd(const TensorShape &src,
  508. const TensorShape &dst,
  509. const TensorShape &times,
  510. DType dtype);
  511. Param m_param;
  512. };
  513. class TileBase: public TileRepeatBase {
  514. public:
  515. TileBase(Handle *handle): TileRepeatBase(handle) {}
  516. protected:
  517. void simplify_shape(const TensorShape &src,
  518. const TensorShape &dst,
  519. const TensorShape &times,
  520. TensorShape &src2,
  521. TensorShape &dst2,
  522. TensorShape &times2);
  523. /**
  524. * This is a helper function that would facilitate other backends'
  525. * implementation.
  526. */
  527. size_t get_workspace_in_bytes_fwd(const TensorLayout &src,
  528. const TensorLayout &dst);
  529. };
  530. class TileForward: public TileBase {
  531. DEF_OPR_IMPL(TileForward, TileBase, 1, 1);
  532. public:
  533. /**
  534. * \brief Tile src times to get dst.
  535. * \param[in] src input tensor
  536. * \param[out] dst output tensor
  537. * \param[out] workspace temporary workspace
  538. *
  539. * src and dst must be contiguous.
  540. * dst.shape should be {src.shape[0]*param().times[0],
  541. * src.shape[1]*param().times[1], ...}
  542. *
  543. * \see http://docs.scipy.org/doc/numpy/reference/generated/numpy.tile.html
  544. *
  545. * Difference between Tile and Repeat:
  546. * Tiling `abc' twice yields `abcabc', whereas repeating `abc' twice
  547. * yields `aabbcc'.
  548. */
  549. virtual void exec(_megdnn_tensor_in src,
  550. _megdnn_tensor_out dst,
  551. _megdnn_workspace workspace) = 0;
  552. void deduce_layout(const TensorLayout &src,
  553. TensorLayout &dst);
  554. virtual size_t get_workspace_in_bytes(const TensorLayout &src,
  555. const TensorLayout &dst) = 0;
  556. protected:
  557. void check_exec(const TensorLayout &src, const TensorLayout &dst,
  558. size_t workspace_in_bytes);
  559. };
  560. using Tile = TileForward;
  561. class TileBackward: public TileBase {
  562. DEF_OPR_IMPL(TileBackward, TileBase, 1, 1);
  563. public:
  564. /**
  565. * \param[in] diff the backpropagated gradient wrt. dst
  566. * \param[out] grad the backpropagated gradient wrt. src
  567. * \param[out] workspace temporary workspace
  568. */
  569. virtual void exec(_megdnn_tensor_in diff,
  570. _megdnn_tensor_out grad,
  571. _megdnn_workspace workspace) = 0;
  572. virtual size_t get_workspace_in_bytes(const TensorLayout &diff,
  573. const TensorLayout &grad) = 0;
  574. protected:
  575. void check_exec(const TensorLayout &diff, const TensorLayout &grad,
  576. size_t workspace_in_bytes);
  577. };
  578. class RepeatBase: public TileRepeatBase {
  579. public:
  580. RepeatBase(Handle *handle): TileRepeatBase(handle) {}
  581. protected:
  582. void simplify_shape(const TensorShape &src,
  583. const TensorShape &dst,
  584. const TensorShape &times,
  585. TensorShape &src2,
  586. TensorShape &dst2,
  587. TensorShape &times2);
  588. /**
  589. * This is a helper function that would facilitate other backends'
  590. * implementation.
  591. */
  592. size_t get_workspace_in_bytes_fwd(const TensorLayout &src,
  593. const TensorLayout &dst);
  594. };
  595. class RepeatForward: public RepeatBase {
  596. DEF_OPR_IMPL(RepeatForward, RepeatBase, 1, 1);
  597. public:
  598. /**
  599. * \brief Repeat src times to get dst.
  600. * \param[in] src input tensor
  601. * \param[out] dst output tensor
  602. * \param[out] workspace temporary workspace
  603. *
  604. * src and dst must be contiguous.
  605. * dst.shape should be {src.shape[0]*param().times[0],
  606. * src.shape[1]*param().times[1], ...}
  607. *
  608. * \see http://docs.scipy.org/doc/numpy/reference/generated/numpy.repeat.html
  609. * \see TileForward
  610. */
  611. virtual void exec(_megdnn_tensor_in src,
  612. _megdnn_tensor_out dst,
  613. _megdnn_workspace workspace) = 0;
  614. void deduce_layout(const TensorLayout &src,
  615. TensorLayout &dst);
  616. virtual size_t get_workspace_in_bytes(const TensorLayout &src,
  617. const TensorLayout &dst) = 0;
  618. protected:
  619. void check_exec(const TensorLayout &src,
  620. const TensorLayout &dst,
  621. size_t workspace_in_bytes);
  622. };
  623. using Repeat = RepeatForward;
  624. class RepeatBackward: public RepeatBase {
  625. DEF_OPR_IMPL(RepeatBackward, RepeatBase, 1, 1);
  626. public:
  627. /**
  628. * \param[in] diff the backpropagated gradient wrt. dst
  629. * \param[out] grad the backpropagated gradient wrt. src
  630. * \param[out] workspace temporary workspace
  631. */
  632. virtual void exec(_megdnn_tensor_in diff,
  633. _megdnn_tensor_out grad,
  634. _megdnn_workspace workspace) = 0;
  635. virtual size_t get_workspace_in_bytes(const TensorLayout &diff,
  636. const TensorLayout &grad) = 0;
  637. protected:
  638. void check_exec(const TensorLayout &diff,
  639. const TensorLayout &grad,
  640. size_t workspace_in_bytes);
  641. };
  642. class ArgsortForward: public OperatorBase {
  643. DEF_OPR_IMPL(ArgsortForward, OperatorBase, 1, 2);
  644. DEF_OPR_PARAM(Argsort);
  645. public:
  646. using Order = Param::Order;
  647. /**
  648. * \param[in] src (m, n)
  649. * \param[out] dst (m, n)
  650. * \param[out] indices (m, n)
  651. *
  652. * src, dst and indices should be contiguous.
  653. * Performing m independent sorting on m arrays of length n.
  654. * Sorting arrays and storing the resulting array in `dst',
  655. * and the corresponding indices in `indices'.
  656. *
  657. * Indices range from 0 to n-1.
  658. *
  659. * Note that indices is a TensorND of type int.
  660. */
  661. virtual void exec(_megdnn_tensor_in src,
  662. _megdnn_tensor_out dst,
  663. _megdnn_tensor_out indices,
  664. _megdnn_workspace workspace) = 0;
  665. void deduce_layout(const TensorLayout &src,
  666. TensorLayout &dst,
  667. TensorLayout &indices);
  668. virtual size_t get_workspace_in_bytes(const TensorLayout &src,
  669. const TensorLayout &dst,
  670. const TensorLayout &indices) = 0;
  671. protected:
  672. void check_exec(const TensorLayout &src,
  673. const TensorLayout &dst,
  674. const TensorLayout &indices,
  675. size_t workspace_in_bytes);
  676. };
  677. using Argsort = ArgsortForward;
  678. /*!
  679. * \brief backward opr for Argsort
  680. *
  681. * Note: the name is kept for backward compatibility. This opr is actually a
  682. * batched value setter. It is used for gradient computing of Argsort and TopK.
  683. */
  684. class ArgsortBackward : public OperatorBase {
  685. DEF_OPR_IMPL(ArgsortBackward, OperatorBase, 2, 1);
  686. DEF_OPR_PARAM(Empty);
  687. public:
  688. /**
  689. * \param[in] diff (m, k) the backpropagated gradient wrt. dst
  690. * \param[in] indices (m, k) the `indices' parameter in
  691. * ArgsortForward::exec
  692. * \param[out] grad (m, n) the backpropagated gradient wrt. src
  693. *
  694. * Constraint: n >= k. Untouched values would be initialized as zero.
  695. */
  696. virtual void exec(_megdnn_tensor_in diff, _megdnn_tensor_in indices,
  697. _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
  698. virtual size_t get_workspace_in_bytes(const TensorLayout& diff,
  699. const TensorLayout& indices,
  700. const TensorLayout& grad) = 0;
  701. protected:
  702. void check_exec(const TensorLayout& diff, const TensorLayout& indices,
  703. const TensorLayout& grad, size_t workspace_in_bytes);
  704. };
  705. class TopK : public OperatorBase {
  706. DEF_OPR_IMPL(TopK, OperatorBase, 1, 2);
  707. DEF_OPR_PARAM(TopK);
  708. protected:
  709. //! impl exec; inputs have been validated
  710. virtual void do_exec(int k, _megdnn_tensor_in data,
  711. _megdnn_tensor_out values, int32_t* indices,
  712. _megdnn_workspace workspace) = 0;
  713. public:
  714. /*!
  715. * \param[in] k if positive, compute the smallest top-k values; otherwise
  716. * compute the largest top-k values
  717. * \param[in] data (m, n) input data, where top-k is computed on the
  718. * second axis. The second dimension must be contiguous, and the first
  719. * dimension can have arbitrary stride.
  720. * \param[out] values (m, ) or (m, k) output values; its shape depends
  721. * on mode
  722. * \param[out] indices () or (m, ) or (m, k) output values; its shape
  723. * depends on mode
  724. */
  725. void exec(int k, _megdnn_tensor_in data, _megdnn_tensor_out values,
  726. _megdnn_tensor_out indices, _megdnn_workspace workspace);
  727. virtual size_t get_workspace_in_bytes(int k, const TensorLayout& data,
  728. const TensorLayout& values,
  729. const TensorLayout& indices) = 0;
  730. void deduce_layout(int k, const TensorLayout& data, TensorLayout& values,
  731. TensorLayout& indices);
  732. };
  733. /*!
  734. * \brief convert dtype of *src* to match dtype of *dst*; *src* may have
  735. * arbitrary layout and *dst* must be contiguous.
  736. */
  737. class TypeCvtForward: public OperatorBase {
  738. DEF_OPR_PARAM(Empty);
  739. DEF_OPR_IMPL(TypeCvtForward, OperatorBase, 1, 1);
  740. public:
  741. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) = 0;
  742. protected:
  743. void check_exec(const TensorLayout &src, const TensorLayout &dst);
  744. };
  745. using TypeCvt = TypeCvtForward;
  746. class IndexingRemapBase: public OperatorBase {
  747. public:
  748. using Param = param::IndexingRemap;
  749. IndexingRemapBase(Handle *handle): OperatorBase(handle) {}
  750. Param &param() { return m_param; }
  751. const Param &param() const { return m_param; }
  752. protected:
  753. Param m_param;
  754. void check_layout_fwd(const TensorLayout &src,
  755. const TensorLayout &map,
  756. const TensorLayout &dst);
  757. };
  758. class IndexingRemapForward: public IndexingRemapBase {
  759. DEF_OPR_IMPL(IndexingRemapForward, IndexingRemapBase, 2, 1);
  760. public:
  761. /**
  762. * \param[in] src input tensor
  763. * \param[in] map input map
  764. * \param[out] dst output tensor
  765. *
  766. * Suppose:
  767. * the shape of src is \f$(s_0, s_1, ..., s_{m-1}\f$;
  768. * the shape of dst is \f$(d_0, d_1, ..., d_{n-1})\f$;
  769. * then:
  770. * the shape of map must be \f$(d_0, d_1, ..., d_{n-1}, m)\f$.
  771. *
  772. * The last dimension of map indicates the src indices for the
  773. * corresponding dst entry.
  774. *
  775. * src and dst can be non-contiguous in a non-overlapping manner.
  776. */
  777. virtual void exec(_megdnn_tensor_in src,
  778. _megdnn_tensor_in map,
  779. _megdnn_tensor_out dst,
  780. _megdnn_workspace workspace) = 0;
  781. void deduce_layout(const TensorLayout &src,
  782. const TensorLayout &map,
  783. TensorLayout &dst);
  784. virtual size_t get_workspace_in_bytes(const TensorLayout &src,
  785. const TensorLayout &map,
  786. const TensorLayout &dst) = 0;
  787. protected:
  788. void check_exec(const TensorLayout &src,
  789. const TensorLayout &map,
  790. const TensorLayout &dst,
  791. size_t workspace_in_bytes);
  792. };
  793. using IndexingRemap = IndexingRemapForward;
  794. // The using directives preserve backward compatibility.
  795. using TensorRemapForward = IndexingRemap;
  796. using TensorRemap = TensorRemapForward;
  797. class IndexingRemapBackward: public IndexingRemapBase {
  798. DEF_OPR_IMPL(IndexingRemapBackward, IndexingRemapBase, 2, 1);
  799. public:
  800. /**
  801. * \param[in] diff the backpropagated gradient wrt. dst
  802. * \param[in] map the `map' parameter in IndexingRemapForward::exec
  803. * \param[out] grad the backpropagated gradient wrt. src
  804. */
  805. virtual void exec(_megdnn_tensor_in diff,
  806. _megdnn_tensor_in map,
  807. _megdnn_tensor_out grad,
  808. _megdnn_workspace workspace) = 0;
  809. virtual size_t get_workspace_in_bytes(const TensorLayout &diff,
  810. const TensorLayout &map,
  811. const TensorLayout &grad) = 0;
  812. protected:
  813. void check_exec(const TensorLayout &diff,
  814. const TensorLayout &map,
  815. const TensorLayout &grad,
  816. size_t workspace_in_bytes);
  817. };
  818. // The using directives preserve backward compatibility.
  819. using TensorRemapBackward = IndexingRemapBackward;
  820. class Linspace: public OperatorBase {
  821. DEF_OPR_IMPL(Linspace, OperatorBase, 0, 1);
  822. DEF_OPR_PARAM(LinspaceFull);
  823. public:
  824. /**
  825. * \param[out] dst must be 1d.
  826. *
  827. * \see http://docs.scipy.org/doc/numpy/reference/generated/numpy.linspace.html
  828. */
  829. virtual void exec(_megdnn_tensor_out dst,
  830. _megdnn_workspace workspace) = 0;
  831. virtual size_t get_workspace_in_bytes(const TensorLayout &dst) = 0;
  832. protected:
  833. void check_exec(const TensorLayout &dst, size_t workspace_in_bytes);
  834. };
  835. class Eye: public OperatorBase {
  836. DEF_OPR_IMPL(Eye, OperatorBase, 0, 1);
  837. DEF_OPR_PARAM(Eye);
  838. public:
  839. /**
  840. * \see http://docs.scipy.org/doc/numpy/reference/generated/numpy.eye.html
  841. */
  842. virtual void exec(_megdnn_tensor_out dst,
  843. _megdnn_workspace workspace) = 0;
  844. virtual size_t get_workspace_in_bytes(const TensorLayout &dst) = 0;
  845. protected:
  846. void check_exec(const TensorLayout &dst, size_t workspace_in_bytes);
  847. };
  848. class IndexingOneHotBase: public OperatorBase {
  849. DEF_OPR_IMPL_CTOR(IndexingOneHotBase, OperatorBase);
  850. DEF_OPR_PARAM(Axis);
  851. protected:
  852. void deduce_layout_fwd(const TensorLayout &src,
  853. const TensorLayout &index,
  854. TensorLayout &dst);
  855. void check_layout_fwd(const TensorLayout &src,
  856. const TensorLayout &index,
  857. const TensorLayout &dst);
  858. };
  859. /*!
  860. * \brief Indexing for one-hot encoding
  861. *
  862. * Given src, axis and index,
  863. * for all valid (n-1)-dimensional subscript tuples i iterating through index:
  864. * dst[i[0], ..., i[axis-1], 0, i[axis], ..., i[n-2]] =
  865. * inp[i[0], ..., i[axis-1], index[i], i[axis], ..., i[n-2]]
  866. *
  867. * \param[in] src n-dimensional input data
  868. * \param[in] index (n-1)-dimensional index, must be int
  869. * \param[out] dst n-dimensional output data
  870. */
  871. class IndexingOneHotForward: public IndexingOneHotBase {
  872. DEF_OPR_IMPL(IndexingOneHotForward, IndexingOneHotBase, 2, 1);
  873. public:
  874. void deduce_layout(const TensorLayout &src,
  875. const TensorLayout &index, TensorLayout &dst) {
  876. deduce_layout_fwd(src, index, dst);
  877. }
  878. virtual void exec(_megdnn_tensor_in src,
  879. _megdnn_tensor_in index,
  880. _megdnn_tensor_out dst,
  881. _megdnn_workspace workspace) = 0;
  882. virtual size_t get_workspace_in_bytes(const TensorLayout &src,
  883. const TensorLayout &index,
  884. const TensorLayout &dst) = 0;
  885. protected:
  886. void check_exec(const TensorLayout &src,
  887. const TensorLayout &index, const TensorLayout &dst,
  888. size_t workspace_in_bytes);
  889. };
  890. using IndexingOneHot = IndexingOneHotForward;
  891. /*!
  892. * \brief set-subtensor corresponding to IndexingOneHotForward
  893. *
  894. * \param[in,out] data n-dimensional input and output data, whose sub part
  895. * corresponding to *index* would be replaced by *sub*
  896. * \param[in] index (n-1)-dimensional index, must be int
  897. * \param[in] sub n-dimensional sub tensor to be filled in *data*
  898. */
  899. class IndexingSetOneHotForward: public IndexingOneHotBase {
  900. DEF_OPR_IMPL(IndexingSetOneHotForward, IndexingOneHotBase, -1, 1);
  901. public:
  902. virtual void exec(_megdnn_tensor_inout data, _megdnn_tensor_in index,
  903. _megdnn_tensor_in sub, _megdnn_workspace workspace) = 0;
  904. virtual size_t get_workspace_in_bytes(const TensorLayout &data,
  905. const TensorLayout &index,
  906. const TensorLayout &sub) = 0;
  907. protected:
  908. void check_exec(const TensorLayout &data,
  909. const TensorLayout &index, const TensorLayout &sub,
  910. size_t workspace_in_bytes);
  911. };
  912. using IndexingSetOneHot = IndexingSetOneHotForward;
  913. /*!
  914. * \brief base class for indexing on multiple axes using vector indices
  915. *
  916. * Note that the indexing axes are required to be sorted in ascending order
  917. */
  918. class IndexingMultiAxisVecBase: public OperatorBase {
  919. DEF_OPR_IMPL_CTOR(IndexingMultiAxisVecBase, OperatorBase);
  920. DEF_OPR_PARAM(Empty);
  921. public:
  922. struct AxisIndexer {
  923. size_t axis;
  924. TensorND vec;
  925. };
  926. struct AxisIndexerLayoutOnly {
  927. size_t axis;
  928. TensorLayout layout;
  929. };
  930. using IndexDesc = std::vector<AxisIndexer>;
  931. using IndexDescLayoutOnly = std::vector<AxisIndexerLayoutOnly>;
  932. /*!
  933. * \brief convert IndexDesc to IndexDescLayoutOnly
  934. */
  935. static IndexDescLayoutOnly extract_index_layout(const IndexDesc &index);
  936. /*!
  937. * \brief get the axes on src that are not used in index
  938. * \param[out] out output buffer; suggested size is
  939. * TensorLayout::MAX_NDIM
  940. * \return number of elements written to *out*
  941. */
  942. static size_t get_nonindex_axes(size_t src_ndim, const IndexDesc &index,
  943. size_t *out);
  944. /*!
  945. * \brief get contiguous-collapsed layout for indexing on value
  946. * \param idx_axis indexer axis on value (i.e. ExecInfo::idx_axis)
  947. * \return a tensor layout and an axis to iterate over *value* and also
  948. * access *data*; stride of layout on that axis would be zero, and
  949. * strides on other axes correspond to the strides in *data*
  950. */
  951. static std::pair<TensorLayout, size_t> get_value_iter_optimized_layout(
  952. const TensorLayout &data, const TensorLayout &value,
  953. const IndexDesc &index, size_t idx_axis);
  954. //! helper info for kernel implementation
  955. struct ExecInfo {
  956. //! axis in value used by indexer
  957. size_t idx_axis;
  958. ptrdiff_t value_stride;
  959. void* error_tracker;
  960. megcore::AsyncErrorInfo* error_info;
  961. };
  962. protected:
  963. /*!
  964. * \return axis on dst used by indexer (i.e. ExecInfo::idx_axis)
  965. */
  966. static size_t deduce_layout_fwd(
  967. const TensorLayout &data,
  968. const IndexDescLayoutOnly &index,
  969. TensorLayout &dst);
  970. static ExecInfo check_exec_noworkspace(
  971. const TensorLayout &data, const TensorLayout &value,
  972. const IndexDesc &index, IndexDescLayoutOnly &index_layout);
  973. };
  974. /*!
  975. * \brief compute indexing result, like numpy advanced indexing
  976. *
  977. * src can have arbitrary layout, but dst must be dim1-contig
  978. */
  979. class IndexingMultiAxisVec: public IndexingMultiAxisVecBase {
  980. DEF_OPR_IMPL(IndexingMultiAxisVec, IndexingMultiAxisVecBase, 0, 1);
  981. public:
  982. virtual void exec(_megdnn_tensor_in src,
  983. const IndexDesc &index,
  984. _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0;
  985. /*!
  986. * \brief get workspace size based on output shape and indexing axes
  987. */
  988. size_t get_workspace_in_bytes(
  989. const TensorShape &dst,
  990. const size_t *axes, size_t nr_axes);
  991. static void deduce_layout(
  992. const TensorLayout &data,
  993. const IndexDescLayoutOnly &index,
  994. TensorLayout &dst) {
  995. deduce_layout_fwd(data, index, dst);
  996. }
  997. protected:
  998. virtual size_t get_workspace_in_bytes(size_t dst_idx_size) = 0;
  999. ExecInfo check_exec(
  1000. const TensorLayout &src,
  1001. const IndexDesc &index,
  1002. const TensorLayout &dst,
  1003. size_t workspace_in_bytes);
  1004. };
  1005. /*!
  1006. * \brief base class for modifying data by given index
  1007. *
  1008. * data can have arbitrary layout, but value must be dim1-contig
  1009. */
  1010. class IndexingModifyMultiAxisVecBase: public IndexingMultiAxisVecBase {
  1011. DEF_OPR_IMPL_CTOR(IndexingModifyMultiAxisVecBase, IndexingMultiAxisVecBase);
  1012. public:
  1013. virtual void exec(
  1014. _megdnn_tensor_inout data, _megdnn_tensor_in value,
  1015. const IndexDesc &index,
  1016. _megdnn_workspace workspace) = 0;
  1017. /*!
  1018. * \brief get workspace size based on shape of value input and indexing
  1019. * axes
  1020. */
  1021. size_t get_workspace_in_bytes(
  1022. const TensorShape &value,
  1023. const size_t *axes, size_t nr_axes);
  1024. protected:
  1025. ExecInfo check_exec(
  1026. const TensorLayout &data, const TensorLayout &value,
  1027. const IndexDesc &index,
  1028. size_t workspace_in_bytes);
  1029. virtual size_t get_workspace_in_bytes(size_t value_idx_size) = 0;
  1030. };
  1031. //! set value to indexed locations; index values must be non-overlapping
  1032. class IndexingSetMultiAxisVec: public IndexingModifyMultiAxisVecBase {
  1033. DEF_OPR_IMPL(IndexingSetMultiAxisVec,
  1034. IndexingModifyMultiAxisVecBase, 0, 0);
  1035. };
  1036. //! add value to indexed locations; index values must be non-overlapping
  1037. class IndexingIncrMultiAxisVec: public IndexingModifyMultiAxisVecBase {
  1038. DEF_OPR_IMPL(IndexingIncrMultiAxisVec,
  1039. IndexingModifyMultiAxisVecBase, 0, 0);
  1040. };
  1041. class MeshBase : public OperatorBase {
  1042. DEF_OPR_PARAM(Empty);
  1043. DEF_OPR_IMPL_CTOR(MeshBase, OperatorBase);
  1044. public:
  1045. using AxisIndexer = IndexingMultiAxisVecBase::AxisIndexer;
  1046. using IndexDesc = IndexingMultiAxisVecBase::IndexDesc;
  1047. using AxisIndexerLayoutOnly =
  1048. IndexingMultiAxisVecBase::AxisIndexerLayoutOnly;
  1049. using IndexDescLayoutOnly = IndexingMultiAxisVecBase::IndexDescLayoutOnly;
  1050. size_t get_workspace_in_bytes(const TensorShape&, const size_t*, size_t) {
  1051. return 0;
  1052. }
  1053. protected:
  1054. virtual void check_exec(const TensorLayout& origin,
  1055. const TensorLayout& indexed, const IndexDesc& desc);
  1056. };
  1057. class NormalMeshBase : public MeshBase {
  1058. DEF_OPR_IMPL(NormalMeshBase, MeshBase, 0, 0);
  1059. protected:
  1060. virtual void check_exec(const TensorLayout& origin,
  1061. const TensorLayout& indexed,
  1062. const IndexDesc& desc) override final;
  1063. };
  1064. class NormalMeshModifyBase : public NormalMeshBase {
  1065. DEF_OPR_IMPL_CTOR(NormalMeshModifyBase, NormalMeshBase);
  1066. public:
  1067. virtual void exec(_megdnn_tensor_inout data, _megdnn_tensor_in value,
  1068. const IndexDesc& desc, _megdnn_workspace workspace) = 0;
  1069. };
  1070. class BatchedMeshBase : public MeshBase {
  1071. DEF_OPR_IMPL_CTOR(BatchedMeshBase, MeshBase);
  1072. protected:
  1073. virtual void check_exec(const TensorLayout& origin,
  1074. const TensorLayout& indexed,
  1075. const IndexDesc& desc) override final;
  1076. };
  1077. class BatchedMeshModifyBase : public BatchedMeshBase {
  1078. DEF_OPR_IMPL_CTOR(BatchedMeshModifyBase, BatchedMeshBase);
  1079. public:
  1080. virtual void exec(_megdnn_tensor_inout data, _megdnn_tensor_in value,
  1081. const IndexDesc& desc, _megdnn_workspace workspace) = 0;
  1082. };
  1083. class MeshIndexing : public NormalMeshBase {
  1084. DEF_OPR_IMPL(MeshIndexing, NormalMeshBase, 0, 0);
  1085. public:
  1086. virtual void exec(_megdnn_tensor_in src, const IndexDesc& desc,
  1087. _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0;
  1088. static void deduce_layout(const TensorLayout& inp,
  1089. const IndexDescLayoutOnly& layouts,
  1090. TensorLayout& out_layout);
  1091. };
  1092. class IncrMeshIndexing : public NormalMeshModifyBase {
  1093. DEF_OPR_IMPL(IncrMeshIndexing, NormalMeshModifyBase, 0, 0);
  1094. };
  1095. class SetMeshIndexing : public NormalMeshModifyBase {
  1096. DEF_OPR_IMPL(SetMeshIndexing, NormalMeshModifyBase, 0, 0);
  1097. };
  1098. class BatchedMeshIndexing : public BatchedMeshBase {
  1099. DEF_OPR_IMPL(BatchedMeshIndexing, BatchedMeshBase, 0, 0);
  1100. public:
  1101. virtual void exec(_megdnn_tensor_in src, const IndexDesc& desc,
  1102. _megdnn_tensor_out dst,
  1103. _megdnn_workspace workspace) = 0;
  1104. static void deduce_layout(const TensorLayout& inp,
  1105. const IndexDescLayoutOnly& layouts,
  1106. TensorLayout& out_layout);
  1107. };
  1108. class BatchedIncrMeshIndexing : public BatchedMeshModifyBase {
  1109. DEF_OPR_IMPL(BatchedIncrMeshIndexing, BatchedMeshModifyBase, 0, 0);
  1110. };
  1111. class BatchedSetMeshIndexing : public BatchedMeshModifyBase {
  1112. DEF_OPR_IMPL(BatchedSetMeshIndexing, BatchedMeshModifyBase, 0, 0);
  1113. };
  1114. class RelayoutFormat : public OperatorBase {
  1115. DEF_OPR_PARAM(RelayoutFormat);
  1116. DEF_OPR_IMPL(RelayoutFormat, OperatorBase, 1, 1);
  1117. public:
  1118. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
  1119. _megdnn_workspace workspace) = 0;
  1120. void deduce_layout(const TensorLayout& src, TensorLayout& dst);
  1121. void deduce_format(TensorFormat src, TensorFormat& dst);
  1122. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  1123. const TensorLayout& dst) = 0;
  1124. protected:
  1125. void deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst);
  1126. void check_layout_fwd(const TensorLayout& src, const TensorLayout& dst);
  1127. void check_exec(const TensorLayout& src, const TensorLayout& dst,
  1128. size_t workspace_in_bytes);
  1129. void deduce_exec_layout(const TensorLayout& src, const TensorLayout& dst,
  1130. TensorLayout& exec_src, TensorLayout& exec_dst);
  1131. };
  1132. } // namespace megdnn
  1133. #include "megdnn/internal/opr_header_epilogue.h"
  1134. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台