You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

general.h 45 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241
  1. /**
  2. * \file dnn/include/megdnn/oprs/general.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #pragma once
  12. #include "megdnn/internal/opr_header_prologue.h"
  13. #include "megdnn/thin/small_vector.h"
  14. namespace megdnn {
  15. /*!
  16. * \brief standard element-wise operator
  17. *
  18. * Inputs must have same dtype, and their shapes must broadcastable into a final
  19. * shape. They can have arbitrary layouts, but non-contiguous and non-broadcast
  20. * layouts may harm performance seriously.
  21. *
  22. * Output dtype is the same as input dtype (note that even for compare oprs this
  23. * is true, e.g. float == float returns value of float). Output layout must be
  24. * contiguous.
  25. */
  26. class ElemwiseForward: public OperatorBase {
  27. DEF_OPR_PARAM(Elemwise);
  28. DEF_OPR_IMPL(ElemwiseForward, OperatorBase, -1, 1);
  29. public:
  30. using Mode = Param::Mode;
  31. //! information about a mode
  32. struct ModeTrait {
  33. uint32_t arity; //!< number of inputs needed
  34. bool commutable; //!< whether arity == 2 and inputs commutable
  35. bool allow_int; //!< whether int inputs allowed
  36. bool allow_float; //!< whether float inputs allowed
  37. const char* name; //!< name of the mode
  38. ModeTrait():
  39. arity(0), commutable(0), allow_int(0), allow_float(0),
  40. name(NULL)
  41. {}
  42. //! get trait from a mode; this function is thread safe
  43. static const ModeTrait& from_mode(Mode mode);
  44. };
  45. //! get trait of current mode
  46. const ModeTrait& mode_trait() const {
  47. return ModeTrait::from_mode(m_param.mode);
  48. }
  49. /**
  50. * \param[in] src input tensor
  51. * \param[out] dst output tensor
  52. *
  53. * src and dst should have the same shape;
  54. * layouts should be contiguous;
  55. * the underlying data pointer can point to the same memory region for
  56. * src and dst.
  57. */
  58. virtual void exec(_megdnn_in const TensorNDArray &src,
  59. _megdnn_tensor_out dst) = 0;
  60. //! deduce output shape (do not check whether arity matches)
  61. static void deduce_shape(
  62. const TensorShapeArray &src,
  63. TensorShape &dst);
  64. static void deduce_format(const TensorFormatArray& src,
  65. TensorFormat& dst);
  66. //! deduce output layout
  67. void deduce_layout(const TensorLayoutArray &src,
  68. TensorLayout &dst);
  69. protected:
  70. //! throw exception if incorrect layout; broadcast input shape to
  71. //! output shape
  72. void check_layout_and_broadcast(
  73. const TensorLayoutPtrArray &src, const TensorLayout &dst);
  74. private:
  75. void check_dtype(DType dtype);
  76. };
  77. using Elemwise = ElemwiseForward;
  78. /*!
  79. * \brief compute ``x**a`` where ``a`` is a constant from the Param
  80. *
  81. * This opr is usually not directly accessible by the end user and it is created
  82. * by mgb optimizer, aiming to work around numerical stability issues with pow.
  83. * For example ``powf(x, 2.f)`` with ``x < 0`` in fast math mode may return NaN.
  84. *
  85. * Like elemwise, this opr supports arbitrary strides. But it should only be
  86. * used with monotone strides. Input and output should have the same
  87. * float-category dtype.
  88. */
  89. class PowC : public OperatorBase {
  90. DEF_OPR_PARAM(PowC);
  91. DEF_OPR_IMPL(PowC, OperatorBase, 1, 1);
  92. public:
  93. void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst);
  94. //! compatible API for mgb; workspace is not used
  95. void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
  96. _megdnn_workspace) {
  97. return exec(src, dst);
  98. }
  99. size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&) {
  100. // the impls should require no workspace; this can be later changed to a
  101. // virtual function if this situation changes
  102. return 0;
  103. }
  104. void deduce_layout(const TensorLayout& src, TensorLayout& dst) {
  105. dst.dtype = src.dtype;
  106. dst.init_contiguous_stride(src);
  107. }
  108. protected:
  109. /*!
  110. * Perform the computing where layouts have been verified.
  111. *
  112. * \p src can have arbitrary layout, and \p dst is contiguous. They have the
  113. * same shape and dtype.
  114. *
  115. * The implementation should not access param(). It should check \p exp_f
  116. * and \p exp_i for the exponent value. Exactly one of them would be
  117. * non-null.
  118. *
  119. * Note: \p exp_f and \p exp_i must be dereferenced before dispatching any
  120. * kernel. They are allocated on the caller's stack.
  121. */
  122. virtual void do_exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
  123. const float* exp_f, const int* exp_i) = 0;
  124. };
  125. /*!
  126. * \brief modify a tensor inplace by adding another tensor to it
  127. *
  128. * dst and delta can have arbitrary layout but must have the same shape.
  129. */
  130. class AddUpdateForward: public OperatorBase {
  131. DEF_OPR_PARAM(AddUpdate);
  132. DEF_OPR_IMPL(AddUpdateForward, OperatorBase, -1, 1);
  133. public:
  134. virtual void exec(
  135. _megdnn_tensor_inout dst, _megdnn_tensor_in delta) = 0;
  136. protected:
  137. void check_exec(const TensorLayout &dst, const TensorLayout &delta);
  138. };
  139. using AddUpdate = AddUpdateForward;
  140. class ReduceForward: public OperatorBase {
  141. DEF_OPR_PARAM(Reduce);
  142. DEF_OPR_IMPL(ReduceForward, OperatorBase, 1, 1);
  143. public:
  144. using Mode = Param::Mode;
  145. using DataType = Param::DataType;
  146. /**
  147. * \param[in] src input tensor
  148. * \param[out] dst output tensor
  149. *
  150. * src and dst should be contiguous.
  151. * src and dst should be of the same shape for all dimensions except
  152. * param().axis.
  153. * the param().axis-th dimension shape for dst should be one.
  154. */
  155. virtual void exec(_megdnn_tensor_in src,
  156. _megdnn_tensor_out dst,
  157. _megdnn_workspace workspace) = 0;
  158. void deduce_layout(const TensorLayout &src, TensorLayout &dst);
  159. virtual size_t get_workspace_in_bytes(const TensorLayout &src,
  160. const TensorLayout &dst) = 0;
  161. protected:
  162. void check_exec(const TensorLayout &src, const TensorLayout &dst,
  163. size_t workspace_in_bytes);
  164. };
  165. using Reduce = ReduceForward;
  166. class CumsumForward: public OperatorBase {
  167. DEF_OPR_PARAM(Cumsum);
  168. DEF_OPR_IMPL(CumsumForward, OperatorBase, 1, 1);
  169. public:
  170. /**
  171. * \param[in] src input tensor
  172. * \param[out] dst output tensor
  173. *
  174. * src and dst should be contiguous.
  175. * src and dst should have the same shape.
  176. *
  177. * The exclusive flag specifies whether the current element it taken
  178. * into account when calculating results.
  179. *
  180. * The reverse flag specifies whether cumsum is forward (
  181. * from 0 to n) or backward (from n downto 0).
  182. *
  183. * Example:
  184. * exclusive && reverse:
  185. * dst_i = src_{i+1} + src_{i+2} + ... + src_{n-1}
  186. * exclusive && !reverse
  187. * dst_i = src_0 + src_1 + ... + src_{i-1}
  188. * !exclusive && reverse:
  189. * dst_i = src_i + src_{i+1} + ... + src_{n-1}
  190. * !exclusive && !reverse:
  191. * dst_i = src_0 + src_1 + ... + src_i
  192. */
  193. virtual void exec(_megdnn_tensor_in src,
  194. _megdnn_tensor_out dst,
  195. _megdnn_workspace workspace) = 0;
  196. void deduce_layout(const TensorLayout &src, TensorLayout &dst);
  197. virtual size_t get_workspace_in_bytes(const TensorLayout &src,
  198. const TensorLayout &dst) = 0;
  199. protected:
  200. void check_exec(const TensorLayout &src, const TensorLayout &dst,
  201. size_t workspace_in_bytes);
  202. };
  203. using Cumsum = CumsumForward;
  204. // mxx can be max or min
  205. class ArgmxxBase: public OperatorBase {
  206. DEF_OPR_IMPL_CTOR(ArgmxxBase, OperatorBase);
  207. DEF_OPR_PARAM(Axis);
  208. protected:
  209. void check_layout_fwd(const TensorLayout &src,
  210. const TensorLayout &dst);
  211. };
  212. class ArgmaxForward: public ArgmxxBase {
  213. DEF_OPR_IMPL(ArgmaxForward, ArgmxxBase, 1, 1);
  214. public:
  215. /**
  216. * \param[in] src input tensor
  217. * \param[out] dst output tensor containing the argmax indices
  218. *
  219. * src and dst should be contiguous.
  220. * src and dst should be of the same shape for all dimensions except
  221. * param().axis.
  222. * the param().axis-th dimension shape for dst should be one.
  223. */
  224. virtual void exec(_megdnn_tensor_in src,
  225. _megdnn_tensor_out dst,
  226. _megdnn_workspace workspace) = 0;
  227. void deduce_layout(const TensorLayout &src,
  228. TensorLayout &dst);
  229. virtual size_t get_workspace_in_bytes(const TensorLayout &src,
  230. const TensorLayout &dst) = 0;
  231. protected:
  232. void check_exec(const TensorLayout &src,
  233. const TensorLayout &dst,
  234. size_t workspace_in_bytes);
  235. };
  236. using Argmax = ArgmaxForward;
  237. class ArgminForward: public ArgmxxBase {
  238. DEF_OPR_IMPL(ArgminForward, ArgmxxBase, 1, 1);
  239. public:
  240. /**
  241. * \param[in] src input tensor
  242. * \param[out] dst output tensor containing the argmax indices
  243. *
  244. * src and dst should be contiguous.
  245. * src and dst should be of the same shape for all dimensions except
  246. * param().axis.
  247. * the param().axis-th dimension shape for dst should be one.
  248. */
  249. virtual void exec(_megdnn_tensor_in src,
  250. _megdnn_tensor_out dst,
  251. _megdnn_workspace workspace) = 0;
  252. void deduce_layout(const TensorLayout &src,
  253. TensorLayout &dst);
  254. virtual size_t get_workspace_in_bytes(const TensorLayout &src,
  255. const TensorLayout &dst) = 0;
  256. protected:
  257. void check_exec(const TensorLayout &src,
  258. const TensorLayout &dst,
  259. size_t workspace_in_bytes);
  260. };
  261. using Argmin = ArgminForward;
  262. /*!
  263. * \brief take values from input according to given condition
  264. *
  265. * Output two tensors:
  266. * 1. values copied from *data*, with same dtype as *data*
  267. * 2. selected indices with dtype int32; note that it is 1-dimensional and
  268. * based on the flatten input.
  269. *
  270. * Require data and mask to have the same shape and both be contiguous.
  271. */
  272. class CondTake : public OperatorBase {
  273. DEF_OPR_IMPL(CondTake, OperatorBase, 2, 2);
  274. DEF_OPR_PARAM(CondTake);
  275. public:
  276. using Output = std::array<TensorND, 2>;
  277. using OutputDType = std::array<DType, 2>;
  278. OutputDType infer_dtype(DType data, DType mask);
  279. virtual size_t get_workspace_in_bytes(const TensorLayout& data) = 0;
  280. virtual Output exec(_megdnn_tensor_in data, _megdnn_tensor_in mask,
  281. _megdnn_workspace workspace,
  282. DynOutMallocPolicyCall malloc_policy) = 0;
  283. protected:
  284. //! check input layouts and get flattened size
  285. size_t check_exec_get_size(const TensorLayout& data,
  286. const TensorLayout& mask,
  287. size_t workspace_in_bytes);
  288. };
  289. class TransposeForward: public OperatorBase {
  290. DEF_OPR_IMPL(TransposeForward, OperatorBase, 1, 1);
  291. DEF_OPR_PARAM(Empty);
  292. public:
  293. /**
  294. * \param[in] src (m, n) stride[0] >= n && stride[1] == 1
  295. * \param[out] dst (n, m) stride[0] >= m && stride[1] == 1
  296. */
  297. virtual void exec(_megdnn_tensor_in src,
  298. _megdnn_tensor_out dst,
  299. _megdnn_workspace workspace) = 0;
  300. void deduce_layout(const TensorLayout &src, TensorLayout &dst);
  301. virtual size_t get_workspace_in_bytes(const TensorLayout &src,
  302. const TensorLayout &dst) = 0;
  303. protected:
  304. void check_exec(const TensorLayout &src, const TensorLayout &dst,
  305. size_t workspace_in_bytes);
  306. };
  307. using Transpose = TransposeForward;
  308. /**
  309. * Change a tensor to another layout that has the same dtype and total number of
  310. * elements, and non-overlapping stride.
  311. *
  312. * ON CPU:
  313. * This operator is optimized for some cases(e.g. both dst and last dim of src
  314. * are contiguous)
  315. *
  316. * ON CUDA:
  317. * More contiguous the input/output layouts, higher performance. There is also
  318. * special optimization for broadcast case.
  319. */
  320. class RelayoutForward: public OperatorBase {
  321. DEF_OPR_IMPL(RelayoutForward, OperatorBase, 1, 1);
  322. DEF_OPR_PARAM(Empty);
  323. public:
  324. /*!
  325. * \brief execute relayout opr
  326. *
  327. * This operator should be placed on the same computing device of *dst*.
  328. *
  329. * \param src_handle handle of input tensor; for CUDA d2d copy, the
  330. * src handle can be on a different GPU for copy tensor with
  331. * non-contig dims <= 2
  332. */
  333. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
  334. Handle *src_handle = nullptr) = 0;
  335. protected:
  336. //! check layout and collapse contiguous
  337. void check_layout_and_canonize(
  338. TensorLayout &src, TensorLayout &dst);
  339. };
  340. using Relayout = RelayoutForward;
  341. /**
  342. * \brief Base class for Concat and Split operators
  343. */
  344. class ConcatSplitBase: public OperatorBase {
  345. public:
  346. using Param = param::Axis;
  347. ConcatSplitBase(Handle *handle);
  348. const Param &param() const { return m_param; }
  349. Param &param() { return m_param; }
  350. protected:
  351. void check_layout_common(const TensorLayoutArray &srcs,
  352. const TensorLayout &dst);
  353. Param m_param;
  354. /**
  355. * \brief a helper function
  356. *
  357. * A = shape[0] * shape[1] * ... * shape[axis-1]
  358. * B = {srcs[0].shape[axis], srcs[1].shape[axis], ...}
  359. * C = shape[axis+1] * shape[axis+2] * ... * shape[ndim-1]
  360. */
  361. void get_ABC(const TensorShapeArray &srcs,
  362. size_t &A,
  363. size_t *B,
  364. size_t &C);
  365. thin_function<TensorLayout(const TensorND &tensor)> m_get_layout;
  366. thin_function<TensorShape(const TensorLayout &layout)> m_get_shape;
  367. };
  368. class ConcatForward: public ConcatSplitBase {
  369. DEF_OPR_IMPL(ConcatForward, ConcatSplitBase, 1, 1);
  370. public:
  371. /**
  372. * \param[in] srcs a vector containing all inputs to be concatenated
  373. * \param[out] dst the output tensor.
  374. *
  375. * All tensors in srcs and dst should be contiguous.
  376. * All tensors should have the same shape for all axes except
  377. * param().axis.
  378. * For the param().axis-th axis, the axis shape for dst should be the
  379. * sum of corresponding axis shapes for all srcs.
  380. */
  381. virtual void exec(_megdnn_in const TensorNDArray &srcs,
  382. _megdnn_tensor_out dst,
  383. _megdnn_workspace workspace) = 0;
  384. void deduce_layout(const TensorLayoutArray &srcs,
  385. TensorLayout &dst);
  386. virtual size_t get_workspace_in_bytes(
  387. const TensorLayoutArray &srcs,
  388. const TensorLayout &dst) = 0;
  389. protected:
  390. void check_exec(const TensorLayoutArray &srcs,
  391. const TensorLayout &dst,
  392. size_t workspace_in_bytes);
  393. };
  394. using Concat = ConcatForward;
  395. class SplitForward: public ConcatSplitBase {
  396. DEF_OPR_IMPL(SplitForward, ConcatSplitBase, 1, 1);
  397. public:
  398. /**
  399. * \param[in] src input tensor
  400. * \param[out] dsts a vector containing all splitted result
  401. *
  402. * All tensors in src and dsts should be contiguous.
  403. * All tensors should have the same shape for all axes except
  404. * param().axis.
  405. * For the param().axis-th axis, the axis shape for src should be the
  406. * sum of corresponding axis shapes for all dsts.
  407. */
  408. virtual void exec(_megdnn_tensor_in src,
  409. const TensorNDArray &dsts,
  410. _megdnn_workspace workspace) = 0;
  411. virtual size_t get_workspace_in_bytes(const TensorLayout &src,
  412. const TensorLayoutArray &dsts) = 0;
  413. protected:
  414. void check_exec(const TensorLayout &src,
  415. const TensorLayoutArray &dsts,
  416. size_t workspace_in_bytes);
  417. };
  418. using Split = SplitForward;
  419. /**
  420. * \brief Base class for ParamPackConcat and ParamPackSplit Operators.
  421. *
  422. * ParamPack oprs act like Concat and Split, but they also are optimized for a
  423. * large number of inputs and can handle alignment requirements. Axis is also
  424. * not supported.
  425. *
  426. * The offsets can be generated by gen_offsets().
  427. */
  428. class ParamPackConcatSplitBase : public OperatorBase {
  429. protected:
  430. void check_exec(const TensorLayout& concated, const TensorLayout& offsets,
  431. const TensorLayout& parts);
  432. public:
  433. using Param = megdnn::param::Empty;
  434. ParamPackConcatSplitBase(Handle* handle) : OperatorBase(handle) {}
  435. //! generate offsets to be used with ParamPackConcat and ParamPackSplit
  436. static std::vector<dt_int32> gen_offsets(const TensorShapeArray& shapes,
  437. size_t alignment,
  438. size_t dtype_size);
  439. };
  440. /**
  441. * \brief ParamPackConcat, used for calculating gradient of ParamPackSplit
  442. * Combine multiple gradient tensors into a single large tensor, use copy
  443. * strategy due to AddUpdate or other dynamic situation.
  444. */
  445. class ParamPackConcat: public ParamPackConcatSplitBase {
  446. DEF_OPR_IMPL(ParamPackConcat, ParamPackConcatSplitBase, 2, 1);
  447. public:
  448. /*
  449. * \param[in] srcs: TensorND on cpu. srcs[i] corresponding to the
  450. * address of i-th Tensor.
  451. * \param[in] offsets: with size `2 * srcs.shape[0]`.
  452. * offsets[i * 2] and offsets[i * 2 + 1] means
  453. * the begin and the end of srcs[i]'s offsets in dst
  454. * \param[out] dst: output TensorND, live on cpu or gpu
  455. */
  456. virtual void exec(_megdnn_tensor_in srcs, _megdnn_tensor_in offsets,
  457. _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0;
  458. virtual size_t get_workspace_in_bytes(const TensorShapeArray& srcs,
  459. const TensorShape& offsets,
  460. const TensorShape& dst) = 0;
  461. };
  462. /**
  463. * \brief base class for Tile and Repeat
  464. */
  465. class TileRepeatBase: public OperatorBase {
  466. public:
  467. TileRepeatBase(Handle *handle): OperatorBase(handle) {}
  468. struct Param {
  469. TensorShape times;
  470. };
  471. Param &param() { return m_param; }
  472. const Param &param() const { return m_param; }
  473. protected:
  474. void check_layout_fwd(const TensorLayout &src,
  475. const TensorLayout &dst);
  476. void deduce_layout_fwd(const TensorLayout &src,
  477. TensorLayout &dst);
  478. /**
  479. * Assuming src/dst/times are already simplified on entrance.
  480. */
  481. size_t get_workspace_in_bytes_fwd(const TensorShape &src,
  482. const TensorShape &dst,
  483. const TensorShape &times,
  484. DType dtype);
  485. Param m_param;
  486. };
  487. class TileBase: public TileRepeatBase {
  488. public:
  489. TileBase(Handle *handle): TileRepeatBase(handle) {}
  490. protected:
  491. void simplify_shape(const TensorShape &src,
  492. const TensorShape &dst,
  493. const TensorShape &times,
  494. TensorShape &src2,
  495. TensorShape &dst2,
  496. TensorShape &times2);
  497. /**
  498. * This is a helper function that would facilitate other backends'
  499. * implementation.
  500. */
  501. size_t get_workspace_in_bytes_fwd(const TensorLayout &src,
  502. const TensorLayout &dst);
  503. };
  504. class TileForward: public TileBase {
  505. DEF_OPR_IMPL(TileForward, TileBase, 1, 1);
  506. public:
  507. /**
  508. * \brief Tile src times to get dst.
  509. * \param[in] src input tensor
  510. * \param[out] dst output tensor
  511. * \param[out] workspace temporary workspace
  512. *
  513. * src and dst must be contiguous.
  514. * dst.shape should be {src.shape[0]*param().times[0],
  515. * src.shape[1]*param().times[1], ...}
  516. *
  517. * \see http://docs.scipy.org/doc/numpy/reference/generated/numpy.tile.html
  518. *
  519. * Difference between Tile and Repeat:
  520. * Tiling `abc' twice yields `abcabc', whereas repeating `abc' twice
  521. * yields `aabbcc'.
  522. */
  523. virtual void exec(_megdnn_tensor_in src,
  524. _megdnn_tensor_out dst,
  525. _megdnn_workspace workspace) = 0;
  526. void deduce_layout(const TensorLayout &src,
  527. TensorLayout &dst);
  528. virtual size_t get_workspace_in_bytes(const TensorLayout &src,
  529. const TensorLayout &dst) = 0;
  530. protected:
  531. void check_exec(const TensorLayout &src, const TensorLayout &dst,
  532. size_t workspace_in_bytes);
  533. };
  534. using Tile = TileForward;
  535. class TileBackward: public TileBase {
  536. DEF_OPR_IMPL(TileBackward, TileBase, 1, 1);
  537. public:
  538. /**
  539. * \param[in] diff the backpropagated gradient wrt. dst
  540. * \param[out] grad the backpropagated gradient wrt. src
  541. * \param[out] workspace temporary workspace
  542. */
  543. virtual void exec(_megdnn_tensor_in diff,
  544. _megdnn_tensor_out grad,
  545. _megdnn_workspace workspace) = 0;
  546. virtual size_t get_workspace_in_bytes(const TensorLayout &diff,
  547. const TensorLayout &grad) = 0;
  548. protected:
  549. void check_exec(const TensorLayout &diff, const TensorLayout &grad,
  550. size_t workspace_in_bytes);
  551. };
  552. class RepeatBase: public TileRepeatBase {
  553. public:
  554. RepeatBase(Handle *handle): TileRepeatBase(handle) {}
  555. protected:
  556. void simplify_shape(const TensorShape &src,
  557. const TensorShape &dst,
  558. const TensorShape &times,
  559. TensorShape &src2,
  560. TensorShape &dst2,
  561. TensorShape &times2);
  562. /**
  563. * This is a helper function that would facilitate other backends'
  564. * implementation.
  565. */
  566. size_t get_workspace_in_bytes_fwd(const TensorLayout &src,
  567. const TensorLayout &dst);
  568. };
  569. class RepeatForward: public RepeatBase {
  570. DEF_OPR_IMPL(RepeatForward, RepeatBase, 1, 1);
  571. public:
  572. /**
  573. * \brief Repeat src times to get dst.
  574. * \param[in] src input tensor
  575. * \param[out] dst output tensor
  576. * \param[out] workspace temporary workspace
  577. *
  578. * src and dst must be contiguous.
  579. * dst.shape should be {src.shape[0]*param().times[0],
  580. * src.shape[1]*param().times[1], ...}
  581. *
  582. * \see http://docs.scipy.org/doc/numpy/reference/generated/numpy.repeat.html
  583. * \see TileForward
  584. */
  585. virtual void exec(_megdnn_tensor_in src,
  586. _megdnn_tensor_out dst,
  587. _megdnn_workspace workspace) = 0;
  588. void deduce_layout(const TensorLayout &src,
  589. TensorLayout &dst);
  590. virtual size_t get_workspace_in_bytes(const TensorLayout &src,
  591. const TensorLayout &dst) = 0;
  592. protected:
  593. void check_exec(const TensorLayout &src,
  594. const TensorLayout &dst,
  595. size_t workspace_in_bytes);
  596. };
  597. using Repeat = RepeatForward;
  598. class RepeatBackward: public RepeatBase {
  599. DEF_OPR_IMPL(RepeatBackward, RepeatBase, 1, 1);
  600. public:
  601. /**
  602. * \param[in] diff the backpropagated gradient wrt. dst
  603. * \param[out] grad the backpropagated gradient wrt. src
  604. * \param[out] workspace temporary workspace
  605. */
  606. virtual void exec(_megdnn_tensor_in diff,
  607. _megdnn_tensor_out grad,
  608. _megdnn_workspace workspace) = 0;
  609. virtual size_t get_workspace_in_bytes(const TensorLayout &diff,
  610. const TensorLayout &grad) = 0;
  611. protected:
  612. void check_exec(const TensorLayout &diff,
  613. const TensorLayout &grad,
  614. size_t workspace_in_bytes);
  615. };
  616. class ArgsortForward: public OperatorBase {
  617. DEF_OPR_IMPL(ArgsortForward, OperatorBase, 1, 2);
  618. DEF_OPR_PARAM(Argsort);
  619. public:
  620. using Order = Param::Order;
  621. /**
  622. * \param[in] src (m, n)
  623. * \param[out] dst (m, n)
  624. * \param[out] indices (m, n)
  625. *
  626. * src, dst and indices should be contiguous.
  627. * Performing m independent sorting on m arrays of length n.
  628. * Sorting arrays and storing the resulting array in `dst',
  629. * and the corresponding indices in `indices'.
  630. *
  631. * Indices range from 0 to n-1.
  632. *
  633. * Note that indices is a TensorND of type int.
  634. */
  635. virtual void exec(_megdnn_tensor_in src,
  636. _megdnn_tensor_out dst,
  637. _megdnn_tensor_out indices,
  638. _megdnn_workspace workspace) = 0;
  639. void deduce_layout(const TensorLayout &src,
  640. TensorLayout &dst,
  641. TensorLayout &indices);
  642. virtual size_t get_workspace_in_bytes(const TensorLayout &src,
  643. const TensorLayout &dst,
  644. const TensorLayout &indices) = 0;
  645. protected:
  646. void check_exec(const TensorLayout &src,
  647. const TensorLayout &dst,
  648. const TensorLayout &indices,
  649. size_t workspace_in_bytes);
  650. };
  651. using Argsort = ArgsortForward;
  652. /*!
  653. * \brief backward opr for Argsort
  654. *
  655. * Note: the name is kept for backward compatibility. This opr is actually a
  656. * batched value setter. It is used for gradient computing of Argsort and TopK.
  657. */
  658. class ArgsortBackward : public OperatorBase {
  659. DEF_OPR_IMPL(ArgsortBackward, OperatorBase, 2, 1);
  660. DEF_OPR_PARAM(Empty);
  661. public:
  662. /**
  663. * \param[in] diff (m, k) the backpropagated gradient wrt. dst
  664. * \param[in] indices (m, k) the `indices' parameter in
  665. * ArgsortForward::exec
  666. * \param[out] grad (m, n) the backpropagated gradient wrt. src
  667. *
  668. * Constraint: n >= k. Untouched values would be initialized as zero.
  669. */
  670. virtual void exec(_megdnn_tensor_in diff, _megdnn_tensor_in indices,
  671. _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
  672. virtual size_t get_workspace_in_bytes(const TensorLayout& diff,
  673. const TensorLayout& indices,
  674. const TensorLayout& grad) = 0;
  675. protected:
  676. void check_exec(const TensorLayout& diff, const TensorLayout& indices,
  677. const TensorLayout& grad, size_t workspace_in_bytes);
  678. };
  679. class TopK : public OperatorBase {
  680. DEF_OPR_IMPL(TopK, OperatorBase, 1, 2);
  681. DEF_OPR_PARAM(TopK);
  682. protected:
  683. //! impl exec; inputs have been validated
  684. virtual void do_exec(int k, _megdnn_tensor_in data,
  685. _megdnn_tensor_out values, int32_t* indices,
  686. _megdnn_workspace workspace) = 0;
  687. public:
  688. /*!
  689. * \param[in] k if positive, compute the smallest top-k values; otherwise
  690. * compute the largest top-k values
  691. * \param[in] data (m, n) input data, where top-k is computed on the
  692. * second axis. The second dimension must be contiguous, and the first
  693. * dimension can have arbitrary stride.
  694. * \param[out] values (m, ) or (m, k) output values; its shape depends
  695. * on mode
  696. * \param[out] indices () or (m, ) or (m, k) output values; its shape
  697. * depends on mode
  698. */
  699. void exec(int k, _megdnn_tensor_in data, _megdnn_tensor_out values,
  700. _megdnn_tensor_out indices, _megdnn_workspace workspace);
  701. virtual size_t get_workspace_in_bytes(int k, const TensorLayout& data,
  702. const TensorLayout& values,
  703. const TensorLayout& indices) = 0;
  704. void deduce_layout(int k, const TensorLayout& data, TensorLayout& values,
  705. TensorLayout& indices);
  706. };
  707. /*!
  708. * \brief convert dtype of *src* to match dtype of *dst*; *src* may have
  709. * arbitrary layout and *dst* must be contiguous.
  710. */
  711. class TypeCvtForward: public OperatorBase {
  712. DEF_OPR_PARAM(Empty);
  713. DEF_OPR_IMPL(TypeCvtForward, OperatorBase, 1, 1);
  714. public:
  715. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) = 0;
  716. protected:
  717. void check_exec(const TensorLayout &src, const TensorLayout &dst);
  718. };
  719. using TypeCvt = TypeCvtForward;
  720. class IndexingRemapBase: public OperatorBase {
  721. public:
  722. using Param = param::IndexingRemap;
  723. IndexingRemapBase(Handle *handle): OperatorBase(handle) {}
  724. Param &param() { return m_param; }
  725. const Param &param() const { return m_param; }
  726. protected:
  727. Param m_param;
  728. void check_layout_fwd(const TensorLayout &src,
  729. const TensorLayout &map,
  730. const TensorLayout &dst);
  731. };
  732. class IndexingRemapForward: public IndexingRemapBase {
  733. DEF_OPR_IMPL(IndexingRemapForward, IndexingRemapBase, 2, 1);
  734. public:
  735. /**
  736. * \param[in] src input tensor
  737. * \param[in] map input map
  738. * \param[out] dst output tensor
  739. *
  740. * Suppose:
  741. * the shape of src is \f$(s_0, s_1, ..., s_{m-1}\f$;
  742. * the shape of dst is \f$(d_0, d_1, ..., d_{n-1})\f$;
  743. * then:
  744. * the shape of map must be \f$(d_0, d_1, ..., d_{n-1}, m)\f$.
  745. *
  746. * The last dimension of map indicates the src indices for the
  747. * corresponding dst entry.
  748. *
  749. * src and dst can be non-contiguous in a non-overlapping manner.
  750. */
  751. virtual void exec(_megdnn_tensor_in src,
  752. _megdnn_tensor_in map,
  753. _megdnn_tensor_out dst,
  754. _megdnn_workspace workspace) = 0;
  755. void deduce_layout(const TensorLayout &src,
  756. const TensorLayout &map,
  757. TensorLayout &dst);
  758. virtual size_t get_workspace_in_bytes(const TensorLayout &src,
  759. const TensorLayout &map,
  760. const TensorLayout &dst) = 0;
  761. protected:
  762. void check_exec(const TensorLayout &src,
  763. const TensorLayout &map,
  764. const TensorLayout &dst,
  765. size_t workspace_in_bytes);
  766. };
  767. using IndexingRemap = IndexingRemapForward;
  768. // The using directives preserve backward compatibility.
  769. using TensorRemapForward = IndexingRemap;
  770. using TensorRemap = TensorRemapForward;
  771. class IndexingRemapBackward: public IndexingRemapBase {
  772. DEF_OPR_IMPL(IndexingRemapBackward, IndexingRemapBase, 2, 1);
  773. public:
  774. /**
  775. * \param[in] diff the backpropagated gradient wrt. dst
  776. * \param[in] map the `map' parameter in IndexingRemapForward::exec
  777. * \param[out] grad the backpropagated gradient wrt. src
  778. */
  779. virtual void exec(_megdnn_tensor_in diff,
  780. _megdnn_tensor_in map,
  781. _megdnn_tensor_out grad,
  782. _megdnn_workspace workspace) = 0;
  783. virtual size_t get_workspace_in_bytes(const TensorLayout &diff,
  784. const TensorLayout &map,
  785. const TensorLayout &grad) = 0;
  786. protected:
  787. void check_exec(const TensorLayout &diff,
  788. const TensorLayout &map,
  789. const TensorLayout &grad,
  790. size_t workspace_in_bytes);
  791. };
  792. // The using directives preserve backward compatibility.
  793. using TensorRemapBackward = IndexingRemapBackward;
  794. class Linspace: public OperatorBase {
  795. DEF_OPR_IMPL(Linspace, OperatorBase, 0, 1);
  796. DEF_OPR_PARAM(LinspaceFull);
  797. public:
  798. /**
  799. * \param[out] dst must be 1d.
  800. *
  801. * \see http://docs.scipy.org/doc/numpy/reference/generated/numpy.linspace.html
  802. */
  803. virtual void exec(_megdnn_tensor_out dst,
  804. _megdnn_workspace workspace) = 0;
  805. virtual size_t get_workspace_in_bytes(const TensorLayout &dst) = 0;
  806. protected:
  807. void check_exec(const TensorLayout &dst, size_t workspace_in_bytes);
  808. };
  809. class Eye: public OperatorBase {
  810. DEF_OPR_IMPL(Eye, OperatorBase, 0, 1);
  811. DEF_OPR_PARAM(Eye);
  812. public:
  813. /**
  814. * \see http://docs.scipy.org/doc/numpy/reference/generated/numpy.eye.html
  815. */
  816. virtual void exec(_megdnn_tensor_out dst,
  817. _megdnn_workspace workspace) = 0;
  818. virtual size_t get_workspace_in_bytes(const TensorLayout &dst) = 0;
  819. protected:
  820. void check_exec(const TensorLayout &dst, size_t workspace_in_bytes);
  821. };
  822. class IndexingOneHotBase: public OperatorBase {
  823. DEF_OPR_IMPL_CTOR(IndexingOneHotBase, OperatorBase);
  824. DEF_OPR_PARAM(Axis);
  825. protected:
  826. void deduce_layout_fwd(const TensorLayout &src,
  827. const TensorLayout &index,
  828. TensorLayout &dst);
  829. void check_layout_fwd(const TensorLayout &src,
  830. const TensorLayout &index,
  831. const TensorLayout &dst);
  832. };
  833. /*!
  834. * \brief Indexing for one-hot encoding
  835. *
  836. * Given src, axis and index,
  837. * for all valid (n-1)-dimensional subscript tuples i iterating through index:
  838. * dst[i[0], ..., i[axis-1], 0, i[axis], ..., i[n-2]] =
  839. * inp[i[0], ..., i[axis-1], index[i], i[axis], ..., i[n-2]]
  840. *
  841. * \param[in] src n-dimensional input data
  842. * \param[in] index (n-1)-dimensional index, must be int
  843. * \param[out] dst n-dimensional output data
  844. */
  845. class IndexingOneHotForward: public IndexingOneHotBase {
  846. DEF_OPR_IMPL(IndexingOneHotForward, IndexingOneHotBase, 2, 1);
  847. public:
  848. void deduce_layout(const TensorLayout &src,
  849. const TensorLayout &index, TensorLayout &dst) {
  850. deduce_layout_fwd(src, index, dst);
  851. }
  852. virtual void exec(_megdnn_tensor_in src,
  853. _megdnn_tensor_in index,
  854. _megdnn_tensor_out dst,
  855. _megdnn_workspace workspace) = 0;
  856. virtual size_t get_workspace_in_bytes(const TensorLayout &src,
  857. const TensorLayout &index,
  858. const TensorLayout &dst) = 0;
  859. protected:
  860. void check_exec(const TensorLayout &src,
  861. const TensorLayout &index, const TensorLayout &dst,
  862. size_t workspace_in_bytes);
  863. };
  864. using IndexingOneHot = IndexingOneHotForward;
  865. /*!
  866. * \brief set-subtensor corresponding to IndexingOneHotForward
  867. *
  868. * \param[in,out] data n-dimensional input and output data, whose sub part
  869. * corresponding to *index* would be replaced by *sub*
  870. * \param[in] index (n-1)-dimensional index, must be int
  871. * \param[in] sub n-dimensional sub tensor to be filled in *data*
  872. */
  873. class IndexingSetOneHotForward: public IndexingOneHotBase {
  874. DEF_OPR_IMPL(IndexingSetOneHotForward, IndexingOneHotBase, -1, 1);
  875. public:
  876. virtual void exec(_megdnn_tensor_inout data, _megdnn_tensor_in index,
  877. _megdnn_tensor_in sub, _megdnn_workspace workspace) = 0;
  878. virtual size_t get_workspace_in_bytes(const TensorLayout &data,
  879. const TensorLayout &index,
  880. const TensorLayout &sub) = 0;
  881. protected:
  882. void check_exec(const TensorLayout &data,
  883. const TensorLayout &index, const TensorLayout &sub,
  884. size_t workspace_in_bytes);
  885. };
  886. using IndexingSetOneHot = IndexingSetOneHotForward;
  887. /*!
  888. * \brief base class for indexing on multiple axes using vector indices
  889. *
  890. * Note that the indexing axes are required to be sorted in ascending order
  891. */
  892. class IndexingMultiAxisVecBase: public OperatorBase {
  893. DEF_OPR_IMPL_CTOR(IndexingMultiAxisVecBase, OperatorBase);
  894. DEF_OPR_PARAM(Empty);
  895. public:
  896. struct AxisIndexer {
  897. size_t axis;
  898. TensorND vec;
  899. };
  900. struct AxisIndexerLayoutOnly {
  901. size_t axis;
  902. TensorLayout layout;
  903. };
  904. using IndexDesc = std::vector<AxisIndexer>;
  905. using IndexDescLayoutOnly = std::vector<AxisIndexerLayoutOnly>;
  906. /*!
  907. * \brief convert IndexDesc to IndexDescLayoutOnly
  908. */
  909. static IndexDescLayoutOnly extract_index_layout(const IndexDesc &index);
  910. /*!
  911. * \brief get the axes on src that are not used in index
  912. * \param[out] out output buffer; suggested size is
  913. * TensorLayout::MAX_NDIM
  914. * \return number of elements written to *out*
  915. */
  916. static size_t get_nonindex_axes(size_t src_ndim, const IndexDesc &index,
  917. size_t *out);
  918. /*!
  919. * \brief get contiguous-collapsed layout for indexing on value
  920. * \param idx_axis indexer axis on value (i.e. ExecInfo::idx_axis)
  921. * \return a tensor layout and an axis to iterate over *value* and also
  922. * access *data*; stride of layout on that axis would be zero, and
  923. * strides on other axes correspond to the strides in *data*
  924. */
  925. static std::pair<TensorLayout, size_t> get_value_iter_optimized_layout(
  926. const TensorLayout &data, const TensorLayout &value,
  927. const IndexDesc &index, size_t idx_axis);
  928. //! helper info for kernel implementation
  929. struct ExecInfo {
  930. //! axis in value used by indexer
  931. size_t idx_axis;
  932. ptrdiff_t value_stride;
  933. void* error_tracker;
  934. megcore::AsyncErrorInfo* error_info;
  935. };
  936. protected:
  937. /*!
  938. * \return axis on dst used by indexer (i.e. ExecInfo::idx_axis)
  939. */
  940. static size_t deduce_layout_fwd(
  941. const TensorLayout &data,
  942. const IndexDescLayoutOnly &index,
  943. TensorLayout &dst);
  944. static ExecInfo check_exec_noworkspace(
  945. const TensorLayout &data, const TensorLayout &value,
  946. const IndexDesc &index, IndexDescLayoutOnly &index_layout);
  947. };
  948. /*!
  949. * \brief compute indexing result, like numpy advanced indexing
  950. *
  951. * src can have arbitrary layout, but dst must be dim1-contig
  952. */
  953. class IndexingMultiAxisVec: public IndexingMultiAxisVecBase {
  954. DEF_OPR_IMPL(IndexingMultiAxisVec, IndexingMultiAxisVecBase, 0, 1);
  955. public:
  956. virtual void exec(_megdnn_tensor_in src,
  957. const IndexDesc &index,
  958. _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0;
  959. /*!
  960. * \brief get workspace size based on output shape and indexing axes
  961. */
  962. size_t get_workspace_in_bytes(
  963. const TensorShape &dst,
  964. const size_t *axes, size_t nr_axes);
  965. static void deduce_layout(
  966. const TensorLayout &data,
  967. const IndexDescLayoutOnly &index,
  968. TensorLayout &dst) {
  969. deduce_layout_fwd(data, index, dst);
  970. }
  971. protected:
  972. virtual size_t get_workspace_in_bytes(size_t dst_idx_size) = 0;
  973. ExecInfo check_exec(
  974. const TensorLayout &src,
  975. const IndexDesc &index,
  976. const TensorLayout &dst,
  977. size_t workspace_in_bytes);
  978. };
  979. /*!
  980. * \brief base class for modifying data by given index
  981. *
  982. * data can have arbitrary layout, but value must be dim1-contig
  983. */
  984. class IndexingModifyMultiAxisVecBase: public IndexingMultiAxisVecBase {
  985. DEF_OPR_IMPL_CTOR(IndexingModifyMultiAxisVecBase, IndexingMultiAxisVecBase);
  986. public:
  987. virtual void exec(
  988. _megdnn_tensor_inout data, _megdnn_tensor_in value,
  989. const IndexDesc &index,
  990. _megdnn_workspace workspace) = 0;
  991. /*!
  992. * \brief get workspace size based on shape of value input and indexing
  993. * axes
  994. */
  995. size_t get_workspace_in_bytes(
  996. const TensorShape &value,
  997. const size_t *axes, size_t nr_axes);
  998. protected:
  999. ExecInfo check_exec(
  1000. const TensorLayout &data, const TensorLayout &value,
  1001. const IndexDesc &index,
  1002. size_t workspace_in_bytes);
  1003. virtual size_t get_workspace_in_bytes(size_t value_idx_size) = 0;
  1004. };
  1005. //! set value to indexed locations; index values must be non-overlapping
  1006. class IndexingSetMultiAxisVec: public IndexingModifyMultiAxisVecBase {
  1007. DEF_OPR_IMPL(IndexingSetMultiAxisVec,
  1008. IndexingModifyMultiAxisVecBase, 0, 0);
  1009. };
  1010. //! add value to indexed locations; index values must be non-overlapping
  1011. class IndexingIncrMultiAxisVec: public IndexingModifyMultiAxisVecBase {
  1012. DEF_OPR_IMPL(IndexingIncrMultiAxisVec,
  1013. IndexingModifyMultiAxisVecBase, 0, 0);
  1014. };
  1015. class MeshBase : public OperatorBase {
  1016. DEF_OPR_PARAM(Empty);
  1017. DEF_OPR_IMPL_CTOR(MeshBase, OperatorBase);
  1018. public:
  1019. using AxisIndexer = IndexingMultiAxisVecBase::AxisIndexer;
  1020. using IndexDesc = IndexingMultiAxisVecBase::IndexDesc;
  1021. using AxisIndexerLayoutOnly =
  1022. IndexingMultiAxisVecBase::AxisIndexerLayoutOnly;
  1023. using IndexDescLayoutOnly = IndexingMultiAxisVecBase::IndexDescLayoutOnly;
  1024. size_t get_workspace_in_bytes(const TensorShape&, const size_t*, size_t) {
  1025. return 0;
  1026. }
  1027. protected:
  1028. virtual void check_exec(const TensorLayout& origin,
  1029. const TensorLayout& indexed, const IndexDesc& desc);
  1030. };
  1031. class NormalMeshBase : public MeshBase {
  1032. DEF_OPR_IMPL(NormalMeshBase, MeshBase, 0, 0);
  1033. protected:
  1034. virtual void check_exec(const TensorLayout& origin,
  1035. const TensorLayout& indexed,
  1036. const IndexDesc& desc) override final;
  1037. };
  1038. class NormalMeshModifyBase : public NormalMeshBase {
  1039. DEF_OPR_IMPL_CTOR(NormalMeshModifyBase, NormalMeshBase);
  1040. public:
  1041. virtual void exec(_megdnn_tensor_inout data, _megdnn_tensor_in value,
  1042. const IndexDesc& desc, _megdnn_workspace workspace) = 0;
  1043. };
  1044. class BatchedMeshBase : public MeshBase {
  1045. DEF_OPR_IMPL_CTOR(BatchedMeshBase, MeshBase);
  1046. protected:
  1047. virtual void check_exec(const TensorLayout& origin,
  1048. const TensorLayout& indexed,
  1049. const IndexDesc& desc) override final;
  1050. };
  1051. class BatchedMeshModifyBase : public BatchedMeshBase {
  1052. DEF_OPR_IMPL_CTOR(BatchedMeshModifyBase, BatchedMeshBase);
  1053. public:
  1054. virtual void exec(_megdnn_tensor_inout data, _megdnn_tensor_in value,
  1055. const IndexDesc& desc, _megdnn_workspace workspace) = 0;
  1056. };
  1057. class MeshIndexing : public NormalMeshBase {
  1058. DEF_OPR_IMPL(MeshIndexing, NormalMeshBase, 0, 0);
  1059. public:
  1060. virtual void exec(_megdnn_tensor_in src, const IndexDesc& desc,
  1061. _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0;
  1062. static void deduce_layout(const TensorLayout& inp,
  1063. const IndexDescLayoutOnly& layouts,
  1064. TensorLayout& out_layout);
  1065. };
  1066. class IncrMeshIndexing : public NormalMeshModifyBase {
  1067. DEF_OPR_IMPL(IncrMeshIndexing, NormalMeshModifyBase, 0, 0);
  1068. };
  1069. class SetMeshIndexing : public NormalMeshModifyBase {
  1070. DEF_OPR_IMPL(SetMeshIndexing, NormalMeshModifyBase, 0, 0);
  1071. };
  1072. class BatchedMeshIndexing : public BatchedMeshBase {
  1073. DEF_OPR_IMPL(BatchedMeshIndexing, BatchedMeshBase, 0, 0);
  1074. public:
  1075. virtual void exec(_megdnn_tensor_in src, const IndexDesc& desc,
  1076. _megdnn_tensor_out dst,
  1077. _megdnn_workspace workspace) = 0;
  1078. static void deduce_layout(const TensorLayout& inp,
  1079. const IndexDescLayoutOnly& layouts,
  1080. TensorLayout& out_layout);
  1081. };
  1082. class BatchedIncrMeshIndexing : public BatchedMeshModifyBase {
  1083. DEF_OPR_IMPL(BatchedIncrMeshIndexing, BatchedMeshModifyBase, 0, 0);
  1084. };
  1085. class BatchedSetMeshIndexing : public BatchedMeshModifyBase {
  1086. DEF_OPR_IMPL(BatchedSetMeshIndexing, BatchedMeshModifyBase, 0, 0);
  1087. };
  1088. class RelayoutFormat : public OperatorBase {
  1089. DEF_OPR_PARAM(RelayoutFormat);
  1090. DEF_OPR_IMPL(RelayoutFormat, OperatorBase, 1, 1);
  1091. public:
  1092. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
  1093. _megdnn_workspace workspace) = 0;
  1094. void deduce_layout(const TensorLayout& src, TensorLayout& dst);
  1095. void deduce_format(TensorFormat src, TensorFormat& dst);
  1096. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  1097. const TensorLayout& dst) = 0;
  1098. protected:
  1099. void deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst);
  1100. void check_layout_fwd(const TensorLayout& src, const TensorLayout& dst);
  1101. void check_exec(const TensorLayout& src, const TensorLayout& dst,
  1102. size_t workspace_in_bytes);
  1103. void deduce_exec_layout(const TensorLayout& src, const TensorLayout& dst,
  1104. TensorLayout& exec_src, TensorLayout& exec_dst);
  1105. };
  1106. } // namespace megdnn
  1107. #include "megdnn/internal/opr_header_epilogue.h"
  1108. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台