You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

general.h 45 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242
  1. /**
  2. * \file dnn/include/megdnn/oprs/general.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #pragma once
  12. #include "megdnn/internal/opr_header_prologue.h"
  13. #include "megdnn/thin/small_vector.h"
  14. namespace megdnn {
  15. /*!
  16. * \brief standard element-wise operator
  17. *
  18. * Inputs must have same dtype, and their shapes must broadcastable into a final
  19. * shape. They can have arbitrary layouts, but non-contiguous and non-broadcast
  20. * layouts may harm performance seriously.
  21. *
  22. * Output dtype is the same as input dtype (note that even for compare oprs this
  23. * is true, e.g. float == float returns value of float). Output layout must be
  24. * contiguous.
  25. */
  26. class ElemwiseForward: public OperatorBase {
  27. DEF_OPR_PARAM(Elemwise);
  28. DEF_OPR_IMPL(ElemwiseForward, OperatorBase, -1, 1);
  29. public:
  30. using Mode = Param::Mode;
  31. //! information about a mode
  32. struct ModeTrait {
  33. uint32_t arity; //!< number of inputs needed
  34. bool commutable; //!< whether arity == 2 and inputs commutable
  35. bool allow_int; //!< whether int inputs allowed
  36. bool allow_float; //!< whether float inputs allowed
  37. bool allow_bool; //!< whether bool inputs allowed
  38. const char* name; //!< name of the mode
  39. ModeTrait():
  40. arity(0), commutable(0), allow_int(0), allow_float(0), allow_bool(0),
  41. name(NULL)
  42. {}
  43. //! get trait from a mode; this function is thread safe
  44. static const ModeTrait& from_mode(Mode mode);
  45. };
  46. //! get trait of current mode
  47. const ModeTrait& mode_trait() const {
  48. return ModeTrait::from_mode(m_param.mode);
  49. }
  50. /**
  51. * \param[in] src input tensor
  52. * \param[out] dst output tensor
  53. *
  54. * src and dst should have the same shape;
  55. * layouts should be contiguous;
  56. * the underlying data pointer can point to the same memory region for
  57. * src and dst.
  58. */
  59. virtual void exec(_megdnn_in const TensorNDArray &src,
  60. _megdnn_tensor_out dst) = 0;
  61. //! deduce output shape (do not check whether arity matches)
  62. static void deduce_shape(
  63. const TensorShapeArray &src,
  64. TensorShape &dst);
  65. static void deduce_format(const TensorFormatArray& src,
  66. TensorFormat& dst);
  67. //! deduce output layout
  68. void deduce_layout(const TensorLayoutArray &src,
  69. TensorLayout &dst);
  70. protected:
  71. //! throw exception if incorrect layout; broadcast input shape to
  72. //! output shape
  73. void check_layout_and_broadcast(
  74. const TensorLayoutPtrArray &src, const TensorLayout &dst);
  75. private:
  76. void check_dtype(DType dtype);
  77. };
  78. using Elemwise = ElemwiseForward;
  79. /*!
  80. * \brief compute ``x**a`` where ``a`` is a constant from the Param
  81. *
  82. * This opr is usually not directly accessible by the end user and it is created
  83. * by mgb optimizer, aiming to work around numerical stability issues with pow.
  84. * For example ``powf(x, 2.f)`` with ``x < 0`` in fast math mode may return NaN.
  85. *
  86. * Like elemwise, this opr supports arbitrary strides. But it should only be
  87. * used with monotone strides. Input and output should have the same
  88. * float-category dtype.
  89. */
  90. class PowC : public OperatorBase {
  91. DEF_OPR_PARAM(PowC);
  92. DEF_OPR_IMPL(PowC, OperatorBase, 1, 1);
  93. public:
  94. void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst);
  95. //! compatible API for mgb; workspace is not used
  96. void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
  97. _megdnn_workspace) {
  98. return exec(src, dst);
  99. }
  100. size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&) {
  101. // the impls should require no workspace; this can be later changed to a
  102. // virtual function if this situation changes
  103. return 0;
  104. }
  105. void deduce_layout(const TensorLayout& src, TensorLayout& dst) {
  106. dst.dtype = src.dtype;
  107. dst.init_contiguous_stride(src);
  108. }
  109. protected:
  110. /*!
  111. * Perform the computing where layouts have been verified.
  112. *
  113. * \p src can have arbitrary layout, and \p dst is contiguous. They have the
  114. * same shape and dtype.
  115. *
  116. * The implementation should not access param(). It should check \p exp_f
  117. * and \p exp_i for the exponent value. Exactly one of them would be
  118. * non-null.
  119. *
  120. * Note: \p exp_f and \p exp_i must be dereferenced before dispatching any
  121. * kernel. They are allocated on the caller's stack.
  122. */
  123. virtual void do_exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
  124. const float* exp_f, const int* exp_i) = 0;
  125. };
  126. /*!
  127. * \brief modify a tensor inplace by adding another tensor to it
  128. *
  129. * dst and delta can have arbitrary layout but must have the same shape.
  130. */
  131. class AddUpdateForward: public OperatorBase {
  132. DEF_OPR_PARAM(AddUpdate);
  133. DEF_OPR_IMPL(AddUpdateForward, OperatorBase, -1, 1);
  134. public:
  135. virtual void exec(
  136. _megdnn_tensor_inout dst, _megdnn_tensor_in delta) = 0;
  137. protected:
  138. void check_exec(const TensorLayout &dst, const TensorLayout &delta);
  139. };
  140. using AddUpdate = AddUpdateForward;
  141. class ReduceForward: public OperatorBase {
  142. DEF_OPR_PARAM(Reduce);
  143. DEF_OPR_IMPL(ReduceForward, OperatorBase, 1, 1);
  144. public:
  145. using Mode = Param::Mode;
  146. using DataType = Param::DataType;
  147. /**
  148. * \param[in] src input tensor
  149. * \param[out] dst output tensor
  150. *
  151. * src and dst should be contiguous.
  152. * src and dst should be of the same shape for all dimensions except
  153. * param().axis.
  154. * the param().axis-th dimension shape for dst should be one.
  155. */
  156. virtual void exec(_megdnn_tensor_in src,
  157. _megdnn_tensor_out dst,
  158. _megdnn_workspace workspace) = 0;
  159. void deduce_layout(const TensorLayout &src, TensorLayout &dst);
  160. virtual size_t get_workspace_in_bytes(const TensorLayout &src,
  161. const TensorLayout &dst) = 0;
  162. protected:
  163. void check_exec(const TensorLayout &src, const TensorLayout &dst,
  164. size_t workspace_in_bytes);
  165. };
  166. using Reduce = ReduceForward;
  167. class CumsumForward: public OperatorBase {
  168. DEF_OPR_PARAM(Cumsum);
  169. DEF_OPR_IMPL(CumsumForward, OperatorBase, 1, 1);
  170. public:
  171. /**
  172. * \param[in] src input tensor
  173. * \param[out] dst output tensor
  174. *
  175. * src and dst should be contiguous.
  176. * src and dst should have the same shape.
  177. *
  178. * The exclusive flag specifies whether the current element it taken
  179. * into account when calculating results.
  180. *
  181. * The reverse flag specifies whether cumsum is forward (
  182. * from 0 to n) or backward (from n downto 0).
  183. *
  184. * Example:
  185. * exclusive && reverse:
  186. * dst_i = src_{i+1} + src_{i+2} + ... + src_{n-1}
  187. * exclusive && !reverse
  188. * dst_i = src_0 + src_1 + ... + src_{i-1}
  189. * !exclusive && reverse:
  190. * dst_i = src_i + src_{i+1} + ... + src_{n-1}
  191. * !exclusive && !reverse:
  192. * dst_i = src_0 + src_1 + ... + src_i
  193. */
  194. virtual void exec(_megdnn_tensor_in src,
  195. _megdnn_tensor_out dst,
  196. _megdnn_workspace workspace) = 0;
  197. void deduce_layout(const TensorLayout &src, TensorLayout &dst);
  198. virtual size_t get_workspace_in_bytes(const TensorLayout &src,
  199. const TensorLayout &dst) = 0;
  200. protected:
  201. void check_exec(const TensorLayout &src, const TensorLayout &dst,
  202. size_t workspace_in_bytes);
  203. };
  204. using Cumsum = CumsumForward;
  205. // mxx can be max or min
  206. class ArgmxxBase: public OperatorBase {
  207. DEF_OPR_IMPL_CTOR(ArgmxxBase, OperatorBase);
  208. DEF_OPR_PARAM(Axis);
  209. protected:
  210. void check_layout_fwd(const TensorLayout &src,
  211. const TensorLayout &dst);
  212. };
  213. class ArgmaxForward: public ArgmxxBase {
  214. DEF_OPR_IMPL(ArgmaxForward, ArgmxxBase, 1, 1);
  215. public:
  216. /**
  217. * \param[in] src input tensor
  218. * \param[out] dst output tensor containing the argmax indices
  219. *
  220. * src and dst should be contiguous.
  221. * src and dst should be of the same shape for all dimensions except
  222. * param().axis.
  223. * the param().axis-th dimension shape for dst should be one.
  224. */
  225. virtual void exec(_megdnn_tensor_in src,
  226. _megdnn_tensor_out dst,
  227. _megdnn_workspace workspace) = 0;
  228. void deduce_layout(const TensorLayout &src,
  229. TensorLayout &dst);
  230. virtual size_t get_workspace_in_bytes(const TensorLayout &src,
  231. const TensorLayout &dst) = 0;
  232. protected:
  233. void check_exec(const TensorLayout &src,
  234. const TensorLayout &dst,
  235. size_t workspace_in_bytes);
  236. };
  237. using Argmax = ArgmaxForward;
  238. class ArgminForward: public ArgmxxBase {
  239. DEF_OPR_IMPL(ArgminForward, ArgmxxBase, 1, 1);
  240. public:
  241. /**
  242. * \param[in] src input tensor
  243. * \param[out] dst output tensor containing the argmax indices
  244. *
  245. * src and dst should be contiguous.
  246. * src and dst should be of the same shape for all dimensions except
  247. * param().axis.
  248. * the param().axis-th dimension shape for dst should be one.
  249. */
  250. virtual void exec(_megdnn_tensor_in src,
  251. _megdnn_tensor_out dst,
  252. _megdnn_workspace workspace) = 0;
  253. void deduce_layout(const TensorLayout &src,
  254. TensorLayout &dst);
  255. virtual size_t get_workspace_in_bytes(const TensorLayout &src,
  256. const TensorLayout &dst) = 0;
  257. protected:
  258. void check_exec(const TensorLayout &src,
  259. const TensorLayout &dst,
  260. size_t workspace_in_bytes);
  261. };
  262. using Argmin = ArgminForward;
  263. /*!
  264. * \brief take values from input according to given condition
  265. *
  266. * Output two tensors:
  267. * 1. values copied from *data*, with same dtype as *data*
  268. * 2. selected indices with dtype int32; note that it is 1-dimensional and
  269. * based on the flatten input.
  270. *
  271. * Require data and mask to have the same shape and both be contiguous.
  272. */
  273. class CondTake : public OperatorBase {
  274. DEF_OPR_IMPL(CondTake, OperatorBase, 2, 2);
  275. DEF_OPR_PARAM(CondTake);
  276. public:
  277. using Output = std::array<TensorND, 2>;
  278. using OutputDType = std::array<DType, 2>;
  279. OutputDType infer_dtype(DType data, DType mask);
  280. virtual size_t get_workspace_in_bytes(const TensorLayout& data) = 0;
  281. virtual Output exec(_megdnn_tensor_in data, _megdnn_tensor_in mask,
  282. _megdnn_workspace workspace,
  283. DynOutMallocPolicyCall malloc_policy) = 0;
  284. protected:
  285. //! check input layouts and get flattened size
  286. size_t check_exec_get_size(const TensorLayout& data,
  287. const TensorLayout& mask,
  288. size_t workspace_in_bytes);
  289. };
  290. class TransposeForward: public OperatorBase {
  291. DEF_OPR_IMPL(TransposeForward, OperatorBase, 1, 1);
  292. DEF_OPR_PARAM(Empty);
  293. public:
  294. /**
  295. * \param[in] src (m, n) stride[0] >= n && stride[1] == 1
  296. * \param[out] dst (n, m) stride[0] >= m && stride[1] == 1
  297. */
  298. virtual void exec(_megdnn_tensor_in src,
  299. _megdnn_tensor_out dst,
  300. _megdnn_workspace workspace) = 0;
  301. void deduce_layout(const TensorLayout &src, TensorLayout &dst);
  302. virtual size_t get_workspace_in_bytes(const TensorLayout &src,
  303. const TensorLayout &dst) = 0;
  304. protected:
  305. void check_exec(const TensorLayout &src, const TensorLayout &dst,
  306. size_t workspace_in_bytes);
  307. };
  308. using Transpose = TransposeForward;
  309. /**
  310. * Change a tensor to another layout that has the same dtype and total number of
  311. * elements, and non-overlapping stride.
  312. *
  313. * ON CPU:
  314. * This operator is optimized for some cases(e.g. both dst and last dim of src
  315. * are contiguous)
  316. *
  317. * ON CUDA:
  318. * More contiguous the input/output layouts, higher performance. There is also
  319. * special optimization for broadcast case.
  320. */
  321. class RelayoutForward: public OperatorBase {
  322. DEF_OPR_IMPL(RelayoutForward, OperatorBase, 1, 1);
  323. DEF_OPR_PARAM(Empty);
  324. public:
  325. /*!
  326. * \brief execute relayout opr
  327. *
  328. * This operator should be placed on the same computing device of *dst*.
  329. *
  330. * \param src_handle handle of input tensor; for CUDA d2d copy, the
  331. * src handle can be on a different GPU for copy tensor with
  332. * non-contig dims <= 2
  333. */
  334. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
  335. Handle *src_handle = nullptr) = 0;
  336. protected:
  337. //! check layout and collapse contiguous
  338. void check_layout_and_canonize(
  339. TensorLayout &src, TensorLayout &dst);
  340. };
  341. using Relayout = RelayoutForward;
  342. /**
  343. * \brief Base class for Concat and Split operators
  344. */
  345. class ConcatSplitBase: public OperatorBase {
  346. public:
  347. using Param = param::Axis;
  348. ConcatSplitBase(Handle *handle);
  349. const Param &param() const { return m_param; }
  350. Param &param() { return m_param; }
  351. protected:
  352. void check_layout_common(const TensorLayoutArray &srcs,
  353. const TensorLayout &dst);
  354. Param m_param;
  355. /**
  356. * \brief a helper function
  357. *
  358. * A = shape[0] * shape[1] * ... * shape[axis-1]
  359. * B = {srcs[0].shape[axis], srcs[1].shape[axis], ...}
  360. * C = shape[axis+1] * shape[axis+2] * ... * shape[ndim-1]
  361. */
  362. void get_ABC(const TensorShapeArray &srcs,
  363. size_t &A,
  364. size_t *B,
  365. size_t &C);
  366. thin_function<TensorLayout(const TensorND &tensor)> m_get_layout;
  367. thin_function<TensorShape(const TensorLayout &layout)> m_get_shape;
  368. };
  369. class ConcatForward: public ConcatSplitBase {
  370. DEF_OPR_IMPL(ConcatForward, ConcatSplitBase, 1, 1);
  371. public:
  372. /**
  373. * \param[in] srcs a vector containing all inputs to be concatenated
  374. * \param[out] dst the output tensor.
  375. *
  376. * All tensors in srcs and dst should be contiguous.
  377. * All tensors should have the same shape for all axes except
  378. * param().axis.
  379. * For the param().axis-th axis, the axis shape for dst should be the
  380. * sum of corresponding axis shapes for all srcs.
  381. */
  382. virtual void exec(_megdnn_in const TensorNDArray &srcs,
  383. _megdnn_tensor_out dst,
  384. _megdnn_workspace workspace) = 0;
  385. void deduce_layout(const TensorLayoutArray &srcs,
  386. TensorLayout &dst);
  387. virtual size_t get_workspace_in_bytes(
  388. const TensorLayoutArray &srcs,
  389. const TensorLayout &dst) = 0;
  390. protected:
  391. void check_exec(const TensorLayoutArray &srcs,
  392. const TensorLayout &dst,
  393. size_t workspace_in_bytes);
  394. };
  395. using Concat = ConcatForward;
  396. class SplitForward: public ConcatSplitBase {
  397. DEF_OPR_IMPL(SplitForward, ConcatSplitBase, 1, 1);
  398. public:
  399. /**
  400. * \param[in] src input tensor
  401. * \param[out] dsts a vector containing all splitted result
  402. *
  403. * All tensors in src and dsts should be contiguous.
  404. * All tensors should have the same shape for all axes except
  405. * param().axis.
  406. * For the param().axis-th axis, the axis shape for src should be the
  407. * sum of corresponding axis shapes for all dsts.
  408. */
  409. virtual void exec(_megdnn_tensor_in src,
  410. const TensorNDArray &dsts,
  411. _megdnn_workspace workspace) = 0;
  412. virtual size_t get_workspace_in_bytes(const TensorLayout &src,
  413. const TensorLayoutArray &dsts) = 0;
  414. protected:
  415. void check_exec(const TensorLayout &src,
  416. const TensorLayoutArray &dsts,
  417. size_t workspace_in_bytes);
  418. };
  419. using Split = SplitForward;
  420. /**
  421. * \brief Base class for ParamPackConcat and ParamPackSplit Operators.
  422. *
  423. * ParamPack oprs act like Concat and Split, but they also are optimized for a
  424. * large number of inputs and can handle alignment requirements. Axis is also
  425. * not supported.
  426. *
  427. * The offsets can be generated by gen_offsets().
  428. */
  429. class ParamPackConcatSplitBase : public OperatorBase {
  430. protected:
  431. void check_exec(const TensorLayout& concated, const TensorLayout& offsets,
  432. const TensorLayout& parts);
  433. public:
  434. using Param = megdnn::param::Empty;
  435. ParamPackConcatSplitBase(Handle* handle) : OperatorBase(handle) {}
  436. //! generate offsets to be used with ParamPackConcat and ParamPackSplit
  437. static std::vector<dt_int32> gen_offsets(const TensorShapeArray& shapes,
  438. size_t alignment,
  439. size_t dtype_size);
  440. };
  441. /**
  442. * \brief ParamPackConcat, used for calculating gradient of ParamPackSplit
  443. * Combine multiple gradient tensors into a single large tensor, use copy
  444. * strategy due to AddUpdate or other dynamic situation.
  445. */
  446. class ParamPackConcat: public ParamPackConcatSplitBase {
  447. DEF_OPR_IMPL(ParamPackConcat, ParamPackConcatSplitBase, 2, 1);
  448. public:
  449. /*
  450. * \param[in] srcs: TensorND on cpu. srcs[i] corresponding to the
  451. * address of i-th Tensor.
  452. * \param[in] offsets: with size `2 * srcs.shape[0]`.
  453. * offsets[i * 2] and offsets[i * 2 + 1] means
  454. * the begin and the end of srcs[i]'s offsets in dst
  455. * \param[out] dst: output TensorND, live on cpu or gpu
  456. */
  457. virtual void exec(_megdnn_tensor_in srcs, _megdnn_tensor_in offsets,
  458. _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0;
  459. virtual size_t get_workspace_in_bytes(const TensorShapeArray& srcs,
  460. const TensorShape& offsets,
  461. const TensorShape& dst) = 0;
  462. };
  463. /**
  464. * \brief base class for Tile and Repeat
  465. */
  466. class TileRepeatBase: public OperatorBase {
  467. public:
  468. TileRepeatBase(Handle *handle): OperatorBase(handle) {}
  469. struct Param {
  470. TensorShape times;
  471. };
  472. Param &param() { return m_param; }
  473. const Param &param() const { return m_param; }
  474. protected:
  475. void check_layout_fwd(const TensorLayout &src,
  476. const TensorLayout &dst);
  477. void deduce_layout_fwd(const TensorLayout &src,
  478. TensorLayout &dst);
  479. /**
  480. * Assuming src/dst/times are already simplified on entrance.
  481. */
  482. size_t get_workspace_in_bytes_fwd(const TensorShape &src,
  483. const TensorShape &dst,
  484. const TensorShape &times,
  485. DType dtype);
  486. Param m_param;
  487. };
  488. class TileBase: public TileRepeatBase {
  489. public:
  490. TileBase(Handle *handle): TileRepeatBase(handle) {}
  491. protected:
  492. void simplify_shape(const TensorShape &src,
  493. const TensorShape &dst,
  494. const TensorShape &times,
  495. TensorShape &src2,
  496. TensorShape &dst2,
  497. TensorShape &times2);
  498. /**
  499. * This is a helper function that would facilitate other backends'
  500. * implementation.
  501. */
  502. size_t get_workspace_in_bytes_fwd(const TensorLayout &src,
  503. const TensorLayout &dst);
  504. };
  505. class TileForward: public TileBase {
  506. DEF_OPR_IMPL(TileForward, TileBase, 1, 1);
  507. public:
  508. /**
  509. * \brief Tile src times to get dst.
  510. * \param[in] src input tensor
  511. * \param[out] dst output tensor
  512. * \param[out] workspace temporary workspace
  513. *
  514. * src and dst must be contiguous.
  515. * dst.shape should be {src.shape[0]*param().times[0],
  516. * src.shape[1]*param().times[1], ...}
  517. *
  518. * \see http://docs.scipy.org/doc/numpy/reference/generated/numpy.tile.html
  519. *
  520. * Difference between Tile and Repeat:
  521. * Tiling `abc' twice yields `abcabc', whereas repeating `abc' twice
  522. * yields `aabbcc'.
  523. */
  524. virtual void exec(_megdnn_tensor_in src,
  525. _megdnn_tensor_out dst,
  526. _megdnn_workspace workspace) = 0;
  527. void deduce_layout(const TensorLayout &src,
  528. TensorLayout &dst);
  529. virtual size_t get_workspace_in_bytes(const TensorLayout &src,
  530. const TensorLayout &dst) = 0;
  531. protected:
  532. void check_exec(const TensorLayout &src, const TensorLayout &dst,
  533. size_t workspace_in_bytes);
  534. };
  535. using Tile = TileForward;
  536. class TileBackward: public TileBase {
  537. DEF_OPR_IMPL(TileBackward, TileBase, 1, 1);
  538. public:
  539. /**
  540. * \param[in] diff the backpropagated gradient wrt. dst
  541. * \param[out] grad the backpropagated gradient wrt. src
  542. * \param[out] workspace temporary workspace
  543. */
  544. virtual void exec(_megdnn_tensor_in diff,
  545. _megdnn_tensor_out grad,
  546. _megdnn_workspace workspace) = 0;
  547. virtual size_t get_workspace_in_bytes(const TensorLayout &diff,
  548. const TensorLayout &grad) = 0;
  549. protected:
  550. void check_exec(const TensorLayout &diff, const TensorLayout &grad,
  551. size_t workspace_in_bytes);
  552. };
  553. class RepeatBase: public TileRepeatBase {
  554. public:
  555. RepeatBase(Handle *handle): TileRepeatBase(handle) {}
  556. protected:
  557. void simplify_shape(const TensorShape &src,
  558. const TensorShape &dst,
  559. const TensorShape &times,
  560. TensorShape &src2,
  561. TensorShape &dst2,
  562. TensorShape &times2);
  563. /**
  564. * This is a helper function that would facilitate other backends'
  565. * implementation.
  566. */
  567. size_t get_workspace_in_bytes_fwd(const TensorLayout &src,
  568. const TensorLayout &dst);
  569. };
  570. class RepeatForward: public RepeatBase {
  571. DEF_OPR_IMPL(RepeatForward, RepeatBase, 1, 1);
  572. public:
  573. /**
  574. * \brief Repeat src times to get dst.
  575. * \param[in] src input tensor
  576. * \param[out] dst output tensor
  577. * \param[out] workspace temporary workspace
  578. *
  579. * src and dst must be contiguous.
  580. * dst.shape should be {src.shape[0]*param().times[0],
  581. * src.shape[1]*param().times[1], ...}
  582. *
  583. * \see http://docs.scipy.org/doc/numpy/reference/generated/numpy.repeat.html
  584. * \see TileForward
  585. */
  586. virtual void exec(_megdnn_tensor_in src,
  587. _megdnn_tensor_out dst,
  588. _megdnn_workspace workspace) = 0;
  589. void deduce_layout(const TensorLayout &src,
  590. TensorLayout &dst);
  591. virtual size_t get_workspace_in_bytes(const TensorLayout &src,
  592. const TensorLayout &dst) = 0;
  593. protected:
  594. void check_exec(const TensorLayout &src,
  595. const TensorLayout &dst,
  596. size_t workspace_in_bytes);
  597. };
  598. using Repeat = RepeatForward;
  599. class RepeatBackward: public RepeatBase {
  600. DEF_OPR_IMPL(RepeatBackward, RepeatBase, 1, 1);
  601. public:
  602. /**
  603. * \param[in] diff the backpropagated gradient wrt. dst
  604. * \param[out] grad the backpropagated gradient wrt. src
  605. * \param[out] workspace temporary workspace
  606. */
  607. virtual void exec(_megdnn_tensor_in diff,
  608. _megdnn_tensor_out grad,
  609. _megdnn_workspace workspace) = 0;
  610. virtual size_t get_workspace_in_bytes(const TensorLayout &diff,
  611. const TensorLayout &grad) = 0;
  612. protected:
  613. void check_exec(const TensorLayout &diff,
  614. const TensorLayout &grad,
  615. size_t workspace_in_bytes);
  616. };
  617. class ArgsortForward: public OperatorBase {
  618. DEF_OPR_IMPL(ArgsortForward, OperatorBase, 1, 2);
  619. DEF_OPR_PARAM(Argsort);
  620. public:
  621. using Order = Param::Order;
  622. /**
  623. * \param[in] src (m, n)
  624. * \param[out] dst (m, n)
  625. * \param[out] indices (m, n)
  626. *
  627. * src, dst and indices should be contiguous.
  628. * Performing m independent sorting on m arrays of length n.
  629. * Sorting arrays and storing the resulting array in `dst',
  630. * and the corresponding indices in `indices'.
  631. *
  632. * Indices range from 0 to n-1.
  633. *
  634. * Note that indices is a TensorND of type int.
  635. */
  636. virtual void exec(_megdnn_tensor_in src,
  637. _megdnn_tensor_out dst,
  638. _megdnn_tensor_out indices,
  639. _megdnn_workspace workspace) = 0;
  640. void deduce_layout(const TensorLayout &src,
  641. TensorLayout &dst,
  642. TensorLayout &indices);
  643. virtual size_t get_workspace_in_bytes(const TensorLayout &src,
  644. const TensorLayout &dst,
  645. const TensorLayout &indices) = 0;
  646. protected:
  647. void check_exec(const TensorLayout &src,
  648. const TensorLayout &dst,
  649. const TensorLayout &indices,
  650. size_t workspace_in_bytes);
  651. };
  652. using Argsort = ArgsortForward;
  653. /*!
  654. * \brief backward opr for Argsort
  655. *
  656. * Note: the name is kept for backward compatibility. This opr is actually a
  657. * batched value setter. It is used for gradient computing of Argsort and TopK.
  658. */
  659. class ArgsortBackward : public OperatorBase {
  660. DEF_OPR_IMPL(ArgsortBackward, OperatorBase, 2, 1);
  661. DEF_OPR_PARAM(Empty);
  662. public:
  663. /**
  664. * \param[in] diff (m, k) the backpropagated gradient wrt. dst
  665. * \param[in] indices (m, k) the `indices' parameter in
  666. * ArgsortForward::exec
  667. * \param[out] grad (m, n) the backpropagated gradient wrt. src
  668. *
  669. * Constraint: n >= k. Untouched values would be initialized as zero.
  670. */
  671. virtual void exec(_megdnn_tensor_in diff, _megdnn_tensor_in indices,
  672. _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
  673. virtual size_t get_workspace_in_bytes(const TensorLayout& diff,
  674. const TensorLayout& indices,
  675. const TensorLayout& grad) = 0;
  676. protected:
  677. void check_exec(const TensorLayout& diff, const TensorLayout& indices,
  678. const TensorLayout& grad, size_t workspace_in_bytes);
  679. };
  680. class TopK : public OperatorBase {
  681. DEF_OPR_IMPL(TopK, OperatorBase, 1, 2);
  682. DEF_OPR_PARAM(TopK);
  683. protected:
  684. //! impl exec; inputs have been validated
  685. virtual void do_exec(int k, _megdnn_tensor_in data,
  686. _megdnn_tensor_out values, int32_t* indices,
  687. _megdnn_workspace workspace) = 0;
  688. public:
  689. /*!
  690. * \param[in] k if positive, compute the smallest top-k values; otherwise
  691. * compute the largest top-k values
  692. * \param[in] data (m, n) input data, where top-k is computed on the
  693. * second axis. The second dimension must be contiguous, and the first
  694. * dimension can have arbitrary stride.
  695. * \param[out] values (m, ) or (m, k) output values; its shape depends
  696. * on mode
  697. * \param[out] indices () or (m, ) or (m, k) output values; its shape
  698. * depends on mode
  699. */
  700. void exec(int k, _megdnn_tensor_in data, _megdnn_tensor_out values,
  701. _megdnn_tensor_out indices, _megdnn_workspace workspace);
  702. virtual size_t get_workspace_in_bytes(int k, const TensorLayout& data,
  703. const TensorLayout& values,
  704. const TensorLayout& indices) = 0;
  705. void deduce_layout(int k, const TensorLayout& data, TensorLayout& values,
  706. TensorLayout& indices);
  707. };
  708. /*!
  709. * \brief convert dtype of *src* to match dtype of *dst*; *src* may have
  710. * arbitrary layout and *dst* must be contiguous.
  711. */
  712. class TypeCvtForward: public OperatorBase {
  713. DEF_OPR_PARAM(Empty);
  714. DEF_OPR_IMPL(TypeCvtForward, OperatorBase, 1, 1);
  715. public:
  716. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) = 0;
  717. protected:
  718. void check_exec(const TensorLayout &src, const TensorLayout &dst);
  719. };
  720. using TypeCvt = TypeCvtForward;
  721. class IndexingRemapBase: public OperatorBase {
  722. public:
  723. using Param = param::IndexingRemap;
  724. IndexingRemapBase(Handle *handle): OperatorBase(handle) {}
  725. Param &param() { return m_param; }
  726. const Param &param() const { return m_param; }
  727. protected:
  728. Param m_param;
  729. void check_layout_fwd(const TensorLayout &src,
  730. const TensorLayout &map,
  731. const TensorLayout &dst);
  732. };
  733. class IndexingRemapForward: public IndexingRemapBase {
  734. DEF_OPR_IMPL(IndexingRemapForward, IndexingRemapBase, 2, 1);
  735. public:
  736. /**
  737. * \param[in] src input tensor
  738. * \param[in] map input map
  739. * \param[out] dst output tensor
  740. *
  741. * Suppose:
  742. * the shape of src is \f$(s_0, s_1, ..., s_{m-1}\f$;
  743. * the shape of dst is \f$(d_0, d_1, ..., d_{n-1})\f$;
  744. * then:
  745. * the shape of map must be \f$(d_0, d_1, ..., d_{n-1}, m)\f$.
  746. *
  747. * The last dimension of map indicates the src indices for the
  748. * corresponding dst entry.
  749. *
  750. * src and dst can be non-contiguous in a non-overlapping manner.
  751. */
  752. virtual void exec(_megdnn_tensor_in src,
  753. _megdnn_tensor_in map,
  754. _megdnn_tensor_out dst,
  755. _megdnn_workspace workspace) = 0;
  756. void deduce_layout(const TensorLayout &src,
  757. const TensorLayout &map,
  758. TensorLayout &dst);
  759. virtual size_t get_workspace_in_bytes(const TensorLayout &src,
  760. const TensorLayout &map,
  761. const TensorLayout &dst) = 0;
  762. protected:
  763. void check_exec(const TensorLayout &src,
  764. const TensorLayout &map,
  765. const TensorLayout &dst,
  766. size_t workspace_in_bytes);
  767. };
  768. using IndexingRemap = IndexingRemapForward;
  769. // The using directives preserve backward compatibility.
  770. using TensorRemapForward = IndexingRemap;
  771. using TensorRemap = TensorRemapForward;
  772. class IndexingRemapBackward: public IndexingRemapBase {
  773. DEF_OPR_IMPL(IndexingRemapBackward, IndexingRemapBase, 2, 1);
  774. public:
  775. /**
  776. * \param[in] diff the backpropagated gradient wrt. dst
  777. * \param[in] map the `map' parameter in IndexingRemapForward::exec
  778. * \param[out] grad the backpropagated gradient wrt. src
  779. */
  780. virtual void exec(_megdnn_tensor_in diff,
  781. _megdnn_tensor_in map,
  782. _megdnn_tensor_out grad,
  783. _megdnn_workspace workspace) = 0;
  784. virtual size_t get_workspace_in_bytes(const TensorLayout &diff,
  785. const TensorLayout &map,
  786. const TensorLayout &grad) = 0;
  787. protected:
  788. void check_exec(const TensorLayout &diff,
  789. const TensorLayout &map,
  790. const TensorLayout &grad,
  791. size_t workspace_in_bytes);
  792. };
  793. // The using directives preserve backward compatibility.
  794. using TensorRemapBackward = IndexingRemapBackward;
  795. class Linspace: public OperatorBase {
  796. DEF_OPR_IMPL(Linspace, OperatorBase, 0, 1);
  797. DEF_OPR_PARAM(LinspaceFull);
  798. public:
  799. /**
  800. * \param[out] dst must be 1d.
  801. *
  802. * \see http://docs.scipy.org/doc/numpy/reference/generated/numpy.linspace.html
  803. */
  804. virtual void exec(_megdnn_tensor_out dst,
  805. _megdnn_workspace workspace) = 0;
  806. virtual size_t get_workspace_in_bytes(const TensorLayout &dst) = 0;
  807. protected:
  808. void check_exec(const TensorLayout &dst, size_t workspace_in_bytes);
  809. };
  810. class Eye: public OperatorBase {
  811. DEF_OPR_IMPL(Eye, OperatorBase, 0, 1);
  812. DEF_OPR_PARAM(Eye);
  813. public:
  814. /**
  815. * \see http://docs.scipy.org/doc/numpy/reference/generated/numpy.eye.html
  816. */
  817. virtual void exec(_megdnn_tensor_out dst,
  818. _megdnn_workspace workspace) = 0;
  819. virtual size_t get_workspace_in_bytes(const TensorLayout &dst) = 0;
  820. protected:
  821. void check_exec(const TensorLayout &dst, size_t workspace_in_bytes);
  822. };
  823. class IndexingOneHotBase: public OperatorBase {
  824. DEF_OPR_IMPL_CTOR(IndexingOneHotBase, OperatorBase);
  825. DEF_OPR_PARAM(Axis);
  826. protected:
  827. void deduce_layout_fwd(const TensorLayout &src,
  828. const TensorLayout &index,
  829. TensorLayout &dst);
  830. void check_layout_fwd(const TensorLayout &src,
  831. const TensorLayout &index,
  832. const TensorLayout &dst);
  833. };
  834. /*!
  835. * \brief Indexing for one-hot encoding
  836. *
  837. * Given src, axis and index,
  838. * for all valid (n-1)-dimensional subscript tuples i iterating through index:
  839. * dst[i[0], ..., i[axis-1], 0, i[axis], ..., i[n-2]] =
  840. * inp[i[0], ..., i[axis-1], index[i], i[axis], ..., i[n-2]]
  841. *
  842. * \param[in] src n-dimensional input data
  843. * \param[in] index (n-1)-dimensional index, must be int
  844. * \param[out] dst n-dimensional output data
  845. */
  846. class IndexingOneHotForward: public IndexingOneHotBase {
  847. DEF_OPR_IMPL(IndexingOneHotForward, IndexingOneHotBase, 2, 1);
  848. public:
  849. void deduce_layout(const TensorLayout &src,
  850. const TensorLayout &index, TensorLayout &dst) {
  851. deduce_layout_fwd(src, index, dst);
  852. }
  853. virtual void exec(_megdnn_tensor_in src,
  854. _megdnn_tensor_in index,
  855. _megdnn_tensor_out dst,
  856. _megdnn_workspace workspace) = 0;
  857. virtual size_t get_workspace_in_bytes(const TensorLayout &src,
  858. const TensorLayout &index,
  859. const TensorLayout &dst) = 0;
  860. protected:
  861. void check_exec(const TensorLayout &src,
  862. const TensorLayout &index, const TensorLayout &dst,
  863. size_t workspace_in_bytes);
  864. };
  865. using IndexingOneHot = IndexingOneHotForward;
  866. /*!
  867. * \brief set-subtensor corresponding to IndexingOneHotForward
  868. *
  869. * \param[in,out] data n-dimensional input and output data, whose sub part
  870. * corresponding to *index* would be replaced by *sub*
  871. * \param[in] index (n-1)-dimensional index, must be int
  872. * \param[in] sub n-dimensional sub tensor to be filled in *data*
  873. */
  874. class IndexingSetOneHotForward: public IndexingOneHotBase {
  875. DEF_OPR_IMPL(IndexingSetOneHotForward, IndexingOneHotBase, -1, 1);
  876. public:
  877. virtual void exec(_megdnn_tensor_inout data, _megdnn_tensor_in index,
  878. _megdnn_tensor_in sub, _megdnn_workspace workspace) = 0;
  879. virtual size_t get_workspace_in_bytes(const TensorLayout &data,
  880. const TensorLayout &index,
  881. const TensorLayout &sub) = 0;
  882. protected:
  883. void check_exec(const TensorLayout &data,
  884. const TensorLayout &index, const TensorLayout &sub,
  885. size_t workspace_in_bytes);
  886. };
  887. using IndexingSetOneHot = IndexingSetOneHotForward;
  888. /*!
  889. * \brief base class for indexing on multiple axes using vector indices
  890. *
  891. * Note that the indexing axes are required to be sorted in ascending order
  892. */
  893. class IndexingMultiAxisVecBase: public OperatorBase {
  894. DEF_OPR_IMPL_CTOR(IndexingMultiAxisVecBase, OperatorBase);
  895. DEF_OPR_PARAM(Empty);
  896. public:
  897. struct AxisIndexer {
  898. size_t axis;
  899. TensorND vec;
  900. };
  901. struct AxisIndexerLayoutOnly {
  902. size_t axis;
  903. TensorLayout layout;
  904. };
  905. using IndexDesc = std::vector<AxisIndexer>;
  906. using IndexDescLayoutOnly = std::vector<AxisIndexerLayoutOnly>;
  907. /*!
  908. * \brief convert IndexDesc to IndexDescLayoutOnly
  909. */
  910. static IndexDescLayoutOnly extract_index_layout(const IndexDesc &index);
  911. /*!
  912. * \brief get the axes on src that are not used in index
  913. * \param[out] out output buffer; suggested size is
  914. * TensorLayout::MAX_NDIM
  915. * \return number of elements written to *out*
  916. */
  917. static size_t get_nonindex_axes(size_t src_ndim, const IndexDesc &index,
  918. size_t *out);
  919. /*!
  920. * \brief get contiguous-collapsed layout for indexing on value
  921. * \param idx_axis indexer axis on value (i.e. ExecInfo::idx_axis)
  922. * \return a tensor layout and an axis to iterate over *value* and also
  923. * access *data*; stride of layout on that axis would be zero, and
  924. * strides on other axes correspond to the strides in *data*
  925. */
  926. static std::pair<TensorLayout, size_t> get_value_iter_optimized_layout(
  927. const TensorLayout &data, const TensorLayout &value,
  928. const IndexDesc &index, size_t idx_axis);
  929. //! helper info for kernel implementation
  930. struct ExecInfo {
  931. //! axis in value used by indexer
  932. size_t idx_axis;
  933. ptrdiff_t value_stride;
  934. void* error_tracker;
  935. megcore::AsyncErrorInfo* error_info;
  936. };
  937. protected:
  938. /*!
  939. * \return axis on dst used by indexer (i.e. ExecInfo::idx_axis)
  940. */
  941. static size_t deduce_layout_fwd(
  942. const TensorLayout &data,
  943. const IndexDescLayoutOnly &index,
  944. TensorLayout &dst);
  945. static ExecInfo check_exec_noworkspace(
  946. const TensorLayout &data, const TensorLayout &value,
  947. const IndexDesc &index, IndexDescLayoutOnly &index_layout);
  948. };
  949. /*!
  950. * \brief compute indexing result, like numpy advanced indexing
  951. *
  952. * src can have arbitrary layout, but dst must be dim1-contig
  953. */
  954. class IndexingMultiAxisVec: public IndexingMultiAxisVecBase {
  955. DEF_OPR_IMPL(IndexingMultiAxisVec, IndexingMultiAxisVecBase, 0, 1);
  956. public:
  957. virtual void exec(_megdnn_tensor_in src,
  958. const IndexDesc &index,
  959. _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0;
  960. /*!
  961. * \brief get workspace size based on output shape and indexing axes
  962. */
  963. size_t get_workspace_in_bytes(
  964. const TensorShape &dst,
  965. const size_t *axes, size_t nr_axes);
  966. static void deduce_layout(
  967. const TensorLayout &data,
  968. const IndexDescLayoutOnly &index,
  969. TensorLayout &dst) {
  970. deduce_layout_fwd(data, index, dst);
  971. }
  972. protected:
  973. virtual size_t get_workspace_in_bytes(size_t dst_idx_size) = 0;
  974. ExecInfo check_exec(
  975. const TensorLayout &src,
  976. const IndexDesc &index,
  977. const TensorLayout &dst,
  978. size_t workspace_in_bytes);
  979. };
  980. /*!
  981. * \brief base class for modifying data by given index
  982. *
  983. * data can have arbitrary layout, but value must be dim1-contig
  984. */
  985. class IndexingModifyMultiAxisVecBase: public IndexingMultiAxisVecBase {
  986. DEF_OPR_IMPL_CTOR(IndexingModifyMultiAxisVecBase, IndexingMultiAxisVecBase);
  987. public:
  988. virtual void exec(
  989. _megdnn_tensor_inout data, _megdnn_tensor_in value,
  990. const IndexDesc &index,
  991. _megdnn_workspace workspace) = 0;
  992. /*!
  993. * \brief get workspace size based on shape of value input and indexing
  994. * axes
  995. */
  996. size_t get_workspace_in_bytes(
  997. const TensorShape &value,
  998. const size_t *axes, size_t nr_axes);
  999. protected:
  1000. ExecInfo check_exec(
  1001. const TensorLayout &data, const TensorLayout &value,
  1002. const IndexDesc &index,
  1003. size_t workspace_in_bytes);
  1004. virtual size_t get_workspace_in_bytes(size_t value_idx_size) = 0;
  1005. };
  1006. //! set value to indexed locations; index values must be non-overlapping
  1007. class IndexingSetMultiAxisVec: public IndexingModifyMultiAxisVecBase {
  1008. DEF_OPR_IMPL(IndexingSetMultiAxisVec,
  1009. IndexingModifyMultiAxisVecBase, 0, 0);
  1010. };
  1011. //! add value to indexed locations; index values must be non-overlapping
  1012. class IndexingIncrMultiAxisVec: public IndexingModifyMultiAxisVecBase {
  1013. DEF_OPR_IMPL(IndexingIncrMultiAxisVec,
  1014. IndexingModifyMultiAxisVecBase, 0, 0);
  1015. };
  1016. class MeshBase : public OperatorBase {
  1017. DEF_OPR_PARAM(Empty);
  1018. DEF_OPR_IMPL_CTOR(MeshBase, OperatorBase);
  1019. public:
  1020. using AxisIndexer = IndexingMultiAxisVecBase::AxisIndexer;
  1021. using IndexDesc = IndexingMultiAxisVecBase::IndexDesc;
  1022. using AxisIndexerLayoutOnly =
  1023. IndexingMultiAxisVecBase::AxisIndexerLayoutOnly;
  1024. using IndexDescLayoutOnly = IndexingMultiAxisVecBase::IndexDescLayoutOnly;
  1025. size_t get_workspace_in_bytes(const TensorShape&, const size_t*, size_t) {
  1026. return 0;
  1027. }
  1028. protected:
  1029. virtual void check_exec(const TensorLayout& origin,
  1030. const TensorLayout& indexed, const IndexDesc& desc);
  1031. };
  1032. class NormalMeshBase : public MeshBase {
  1033. DEF_OPR_IMPL(NormalMeshBase, MeshBase, 0, 0);
  1034. protected:
  1035. virtual void check_exec(const TensorLayout& origin,
  1036. const TensorLayout& indexed,
  1037. const IndexDesc& desc) override final;
  1038. };
  1039. class NormalMeshModifyBase : public NormalMeshBase {
  1040. DEF_OPR_IMPL_CTOR(NormalMeshModifyBase, NormalMeshBase);
  1041. public:
  1042. virtual void exec(_megdnn_tensor_inout data, _megdnn_tensor_in value,
  1043. const IndexDesc& desc, _megdnn_workspace workspace) = 0;
  1044. };
  1045. class BatchedMeshBase : public MeshBase {
  1046. DEF_OPR_IMPL_CTOR(BatchedMeshBase, MeshBase);
  1047. protected:
  1048. virtual void check_exec(const TensorLayout& origin,
  1049. const TensorLayout& indexed,
  1050. const IndexDesc& desc) override final;
  1051. };
  1052. class BatchedMeshModifyBase : public BatchedMeshBase {
  1053. DEF_OPR_IMPL_CTOR(BatchedMeshModifyBase, BatchedMeshBase);
  1054. public:
  1055. virtual void exec(_megdnn_tensor_inout data, _megdnn_tensor_in value,
  1056. const IndexDesc& desc, _megdnn_workspace workspace) = 0;
  1057. };
  1058. class MeshIndexing : public NormalMeshBase {
  1059. DEF_OPR_IMPL(MeshIndexing, NormalMeshBase, 0, 0);
  1060. public:
  1061. virtual void exec(_megdnn_tensor_in src, const IndexDesc& desc,
  1062. _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0;
  1063. static void deduce_layout(const TensorLayout& inp,
  1064. const IndexDescLayoutOnly& layouts,
  1065. TensorLayout& out_layout);
  1066. };
  1067. class IncrMeshIndexing : public NormalMeshModifyBase {
  1068. DEF_OPR_IMPL(IncrMeshIndexing, NormalMeshModifyBase, 0, 0);
  1069. };
  1070. class SetMeshIndexing : public NormalMeshModifyBase {
  1071. DEF_OPR_IMPL(SetMeshIndexing, NormalMeshModifyBase, 0, 0);
  1072. };
  1073. class BatchedMeshIndexing : public BatchedMeshBase {
  1074. DEF_OPR_IMPL(BatchedMeshIndexing, BatchedMeshBase, 0, 0);
  1075. public:
  1076. virtual void exec(_megdnn_tensor_in src, const IndexDesc& desc,
  1077. _megdnn_tensor_out dst,
  1078. _megdnn_workspace workspace) = 0;
  1079. static void deduce_layout(const TensorLayout& inp,
  1080. const IndexDescLayoutOnly& layouts,
  1081. TensorLayout& out_layout);
  1082. };
  1083. class BatchedIncrMeshIndexing : public BatchedMeshModifyBase {
  1084. DEF_OPR_IMPL(BatchedIncrMeshIndexing, BatchedMeshModifyBase, 0, 0);
  1085. };
  1086. class BatchedSetMeshIndexing : public BatchedMeshModifyBase {
  1087. DEF_OPR_IMPL(BatchedSetMeshIndexing, BatchedMeshModifyBase, 0, 0);
  1088. };
  1089. class RelayoutFormat : public OperatorBase {
  1090. DEF_OPR_PARAM(RelayoutFormat);
  1091. DEF_OPR_IMPL(RelayoutFormat, OperatorBase, 1, 1);
  1092. public:
  1093. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
  1094. _megdnn_workspace workspace) = 0;
  1095. void deduce_layout(const TensorLayout& src, TensorLayout& dst);
  1096. void deduce_format(TensorFormat src, TensorFormat& dst);
  1097. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  1098. const TensorLayout& dst) = 0;
  1099. protected:
  1100. void deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst);
  1101. void check_layout_fwd(const TensorLayout& src, const TensorLayout& dst);
  1102. void check_exec(const TensorLayout& src, const TensorLayout& dst,
  1103. size_t workspace_in_bytes);
  1104. void deduce_exec_layout(const TensorLayout& src, const TensorLayout& dst,
  1105. TensorLayout& exec_src, TensorLayout& exec_dst);
  1106. };
  1107. } // namespace megdnn
  1108. #include "megdnn/internal/opr_header_epilogue.h"
  1109. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台