You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

opdef.h.inl 95 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914
  1. // clang-format off
  2. class AdaptivePooling : public OpDefImplBase<AdaptivePooling> {
  3. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  4. public:
  5. using Mode = ::megdnn::param::AdaptivePooling::Mode;
  6. using Format = ::megdnn::param::AdaptivePooling::Format;
  7. Mode mode = ::megdnn::param::AdaptivePooling::Mode::MAX;
  8. Format format = ::megdnn::param::AdaptivePooling::Format::NCHW;
  9. std::vector<int32_t> shape;
  10. AdaptivePooling() = default;
  11. AdaptivePooling(Mode mode_, Format format_, std::vector<int32_t> shape_, std::string scope_ = {}): mode(mode_), format(format_), shape(shape_) { set_scope(scope_); }
  12. AdaptivePooling(::megdnn::param::AdaptivePooling packed_param_0, std::vector<int32_t> shape_): mode(packed_param_0.mode), format(packed_param_0.format), shape(shape_) {}
  13. ::megdnn::param::AdaptivePooling param() const {
  14. return {mode, format};
  15. }
  16. };
  17. class AddAxis : public OpDefImplBase<AddAxis> {
  18. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  19. public:
  20. std::vector<int32_t> axis;
  21. AddAxis() = default;
  22. AddAxis(std::vector<int32_t> axis_, std::string scope_ = {}): axis(axis_) { set_scope(scope_); }
  23. };
  24. class Argmax : public OpDefImplBase<Argmax> {
  25. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  26. public:
  27. int32_t axis = 0;
  28. Argmax() = default;
  29. Argmax(int32_t axis_, std::string scope_ = {}): axis(axis_) { set_scope(scope_); }
  30. Argmax(::megdnn::param::Axis packed_param_0): axis(packed_param_0.axis) {}
  31. ::megdnn::param::Axis param() const {
  32. return {axis};
  33. }
  34. };
  35. class Argmin : public OpDefImplBase<Argmin> {
  36. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  37. public:
  38. int32_t axis = 0;
  39. Argmin() = default;
  40. Argmin(int32_t axis_, std::string scope_ = {}): axis(axis_) { set_scope(scope_); }
  41. Argmin(::megdnn::param::Axis packed_param_0): axis(packed_param_0.axis) {}
  42. ::megdnn::param::Axis param() const {
  43. return {axis};
  44. }
  45. };
  46. class Argsort : public OpDefImplBase<Argsort> {
  47. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  48. public:
  49. using Order = ::megdnn::param::Argsort::Order;
  50. Order order = ::megdnn::param::Argsort::Order::ASCENDING;
  51. Argsort() = default;
  52. Argsort(Order order_, std::string scope_ = {}): order(order_) { set_scope(scope_); }
  53. Argsort(::megdnn::param::Argsort packed_param_0): order(packed_param_0.order) {}
  54. ::megdnn::param::Argsort param() const {
  55. return {order};
  56. }
  57. };
  58. class AssertEqual : public OpDefImplBase<AssertEqual> {
  59. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  60. public:
  61. float maxerr = 0.0001;
  62. bool verbose = false;
  63. AssertEqual() = default;
  64. AssertEqual(float maxerr_, bool verbose_, std::string scope_ = {}): maxerr(maxerr_), verbose(verbose_) { set_scope(scope_); }
  65. AssertEqual(::megdnn::param::AssertEqual packed_param_0): maxerr(packed_param_0.maxerr), verbose(packed_param_0.verbose) {}
  66. ::megdnn::param::AssertEqual param() const {
  67. return {maxerr, verbose};
  68. }
  69. };
  70. class AtlasRuntime : public OpDefImplBase<AtlasRuntime> {
  71. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  72. public:
  73. std::string buf;
  74. size_t buf_size;
  75. AtlasRuntime() = default;
  76. AtlasRuntime(std::string buf_, size_t buf_size_, std::string scope_ = {}): buf(buf_), buf_size(buf_size_) { set_scope(scope_); }
  77. };
  78. class Barrier : public OpDefImplBase<Barrier> {
  79. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  80. public:
  81. ::mgb::CompNode comp_node;
  82. uint32_t nr_outputs;
  83. Barrier() = default;
  84. Barrier(::mgb::CompNode comp_node_, uint32_t nr_outputs_, std::string scope_ = {}): comp_node(comp_node_), nr_outputs(nr_outputs_) { set_scope(scope_); }
  85. };
  86. class BatchConvBias : public OpDefImplBase<BatchConvBias> {
  87. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  88. public:
  89. using NonlineMode = ::megdnn::param::BatchConvBias::NonlineMode;
  90. using Mode = ::megdnn::param::BatchConvBias::Mode;
  91. using Sparse = ::megdnn::param::BatchConvBias::Sparse;
  92. using Format = ::megdnn::param::BatchConvBias::Format;
  93. using ComputeMode = ::megdnn::param::BatchConvBias::ComputeMode;
  94. using Strategy = ::megdnn::param::ExecutionPolicy::Strategy;
  95. NonlineMode nonlineMode = ::megdnn::param::BatchConvBias::NonlineMode::IDENTITY;
  96. Mode mode = ::megdnn::param::BatchConvBias::Mode::CROSS_CORRELATION;
  97. uint32_t pad_h = 0;
  98. uint32_t pad_w = 0;
  99. uint32_t stride_h = 1;
  100. uint32_t stride_w = 1;
  101. uint32_t dilate_h = 1;
  102. uint32_t dilate_w = 1;
  103. Sparse sparse = ::megdnn::param::BatchConvBias::Sparse::DENSE;
  104. Format format = ::megdnn::param::BatchConvBias::Format::NCHW;
  105. ComputeMode compute_mode = ::megdnn::param::BatchConvBias::ComputeMode::DEFAULT;
  106. Strategy strategy = static_cast<::megdnn::param::ExecutionPolicy::Strategy>(1);
  107. uint64_t workspace_limit = 18446744073709551615ull;
  108. ::megdnn::DType dtype;
  109. BatchConvBias() = default;
  110. BatchConvBias(NonlineMode nonlineMode_, Mode mode_, uint32_t pad_h_, uint32_t pad_w_, uint32_t stride_h_, uint32_t stride_w_, uint32_t dilate_h_, uint32_t dilate_w_, Sparse sparse_, Format format_, ComputeMode compute_mode_, Strategy strategy_, uint64_t workspace_limit_, ::megdnn::DType dtype_, std::string scope_ = {}): nonlineMode(nonlineMode_), mode(mode_), pad_h(pad_h_), pad_w(pad_w_), stride_h(stride_h_), stride_w(stride_w_), dilate_h(dilate_h_), dilate_w(dilate_w_), sparse(sparse_), format(format_), compute_mode(compute_mode_), strategy(strategy_), workspace_limit(workspace_limit_), dtype(dtype_) {
  111. set_scope(scope_);
  112. mgb_assert(static_cast<uint32_t>(strategy) <= uint32_t(8));
  113. }
  114. BatchConvBias(::megdnn::param::BatchConvBias packed_param_0, ::megdnn::param::ExecutionPolicy packed_param_1, ::megdnn::DType dtype_): nonlineMode(packed_param_0.nonlineMode), mode(packed_param_0.mode), pad_h(packed_param_0.pad_h), pad_w(packed_param_0.pad_w), stride_h(packed_param_0.stride_h), stride_w(packed_param_0.stride_w), dilate_h(packed_param_0.dilate_h), dilate_w(packed_param_0.dilate_w), sparse(packed_param_0.sparse), format(packed_param_0.format), compute_mode(packed_param_0.compute_mode), strategy(packed_param_1.strategy), workspace_limit(packed_param_1.workspace_limit), dtype(dtype_) {
  115. mgb_assert(static_cast<uint32_t>(strategy) <= uint32_t(8));
  116. }
  117. ::megdnn::param::BatchConvBias param() const {
  118. return {nonlineMode, mode, pad_h, pad_w, stride_h, stride_w, dilate_h, dilate_w, sparse, format, compute_mode};
  119. }
  120. ::megdnn::param::ExecutionPolicy policy() const {
  121. return {strategy, workspace_limit};
  122. }
  123. };
  124. class BatchNorm : public OpDefImplBase<BatchNorm> {
  125. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  126. public:
  127. using ParamDim = ::megdnn::param::BN::ParamDim;
  128. using FwdMode = ::megdnn::param::BN::FwdMode;
  129. ParamDim param_dim = ::megdnn::param::BN::ParamDim::DIM_11HW;
  130. FwdMode fwd_mode = ::megdnn::param::BN::FwdMode::TRAINING;
  131. double epsilon = 1e-4f;
  132. double avg_factor = 1.f;
  133. float scale = 1.f;
  134. float bias = 0.f;
  135. BatchNorm() = default;
  136. BatchNorm(ParamDim param_dim_, FwdMode fwd_mode_, double epsilon_, double avg_factor_, float scale_, float bias_, std::string scope_ = {}): param_dim(param_dim_), fwd_mode(fwd_mode_), epsilon(epsilon_), avg_factor(avg_factor_), scale(scale_), bias(bias_) { set_scope(scope_); }
  137. BatchNorm(::megdnn::param::BN packed_param_0): param_dim(packed_param_0.param_dim), fwd_mode(packed_param_0.fwd_mode), epsilon(packed_param_0.epsilon), avg_factor(packed_param_0.avg_factor), scale(packed_param_0.scale), bias(packed_param_0.bias) {}
  138. ::megdnn::param::BN param() const {
  139. return {param_dim, fwd_mode, epsilon, avg_factor, scale, bias};
  140. }
  141. };
  142. class BatchNormBackward : public OpDefImplBase<BatchNormBackward> {
  143. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  144. public:
  145. using ParamDim = ::megdnn::param::BN::ParamDim;
  146. using FwdMode = ::megdnn::param::BN::FwdMode;
  147. ParamDim param_dim = ::megdnn::param::BN::ParamDim::DIM_11HW;
  148. FwdMode fwd_mode = ::megdnn::param::BN::FwdMode::TRAINING;
  149. double epsilon = 1e-4f;
  150. double avg_factor = 1.f;
  151. float scale = 1.f;
  152. float bias = 0.f;
  153. BatchNormBackward() = default;
  154. BatchNormBackward(ParamDim param_dim_, FwdMode fwd_mode_, double epsilon_, double avg_factor_, float scale_, float bias_, std::string scope_ = {}): param_dim(param_dim_), fwd_mode(fwd_mode_), epsilon(epsilon_), avg_factor(avg_factor_), scale(scale_), bias(bias_) { set_scope(scope_); }
  155. BatchNormBackward(::megdnn::param::BN packed_param_0): param_dim(packed_param_0.param_dim), fwd_mode(packed_param_0.fwd_mode), epsilon(packed_param_0.epsilon), avg_factor(packed_param_0.avg_factor), scale(packed_param_0.scale), bias(packed_param_0.bias) {}
  156. ::megdnn::param::BN param() const {
  157. return {param_dim, fwd_mode, epsilon, avg_factor, scale, bias};
  158. }
  159. };
  160. class BatchedIncrMeshIndexing : public OpDefImplBase<BatchedIncrMeshIndexing> {
  161. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  162. public:
  163. std::vector<std::tuple<int8_t, bool, bool, bool, bool>> items;
  164. BatchedIncrMeshIndexing() = default;
  165. BatchedIncrMeshIndexing(std::vector<std::tuple<int8_t, bool, bool, bool, bool>> items_, std::string scope_ = {}): items(items_) { set_scope(scope_); }
  166. };
  167. class BatchedMatrixMul : public OpDefImplBase<BatchedMatrixMul> {
  168. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  169. public:
  170. using ComputeMode = ::megdnn::param::MatrixMul::ComputeMode;
  171. using Format = ::megdnn::param::MatrixMul::Format;
  172. using Strategy = ::megdnn::param::ExecutionPolicy::Strategy;
  173. bool transposeA = false;
  174. bool transposeB = false;
  175. ComputeMode compute_mode = ::megdnn::param::MatrixMul::ComputeMode::DEFAULT;
  176. Format format = ::megdnn::param::MatrixMul::Format::DEFAULT;
  177. Strategy strategy = static_cast<::megdnn::param::ExecutionPolicy::Strategy>(1);
  178. uint64_t workspace_limit = 18446744073709551615ull;
  179. uint32_t dimA;
  180. uint32_t dimB;
  181. BatchedMatrixMul() = default;
  182. BatchedMatrixMul(bool transposeA_, bool transposeB_, ComputeMode compute_mode_, Format format_, Strategy strategy_, uint64_t workspace_limit_, uint32_t dimA_, uint32_t dimB_, std::string scope_ = {}): transposeA(transposeA_), transposeB(transposeB_), compute_mode(compute_mode_), format(format_), strategy(strategy_), workspace_limit(workspace_limit_), dimA(dimA_), dimB(dimB_) {
  183. set_scope(scope_);
  184. mgb_assert(static_cast<uint32_t>(strategy) <= uint32_t(8));
  185. }
  186. BatchedMatrixMul(::megdnn::param::MatrixMul packed_param_0, ::megdnn::param::ExecutionPolicy packed_param_1, uint32_t dimA_, uint32_t dimB_): transposeA(packed_param_0.transposeA), transposeB(packed_param_0.transposeB), compute_mode(packed_param_0.compute_mode), format(packed_param_0.format), strategy(packed_param_1.strategy), workspace_limit(packed_param_1.workspace_limit), dimA(dimA_), dimB(dimB_) {
  187. mgb_assert(static_cast<uint32_t>(strategy) <= uint32_t(8));
  188. }
  189. ::megdnn::param::MatrixMul param() const {
  190. return {transposeA, transposeB, compute_mode, format};
  191. }
  192. ::megdnn::param::ExecutionPolicy policy() const {
  193. return {strategy, workspace_limit};
  194. }
  195. };
  196. class BatchedMeshIndexing : public OpDefImplBase<BatchedMeshIndexing> {
  197. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  198. public:
  199. std::vector<std::tuple<int8_t, bool, bool, bool, bool>> items;
  200. BatchedMeshIndexing() = default;
  201. BatchedMeshIndexing(std::vector<std::tuple<int8_t, bool, bool, bool, bool>> items_, std::string scope_ = {}): items(items_) { set_scope(scope_); }
  202. };
  203. class BatchedSetMeshIndexing : public OpDefImplBase<BatchedSetMeshIndexing> {
  204. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  205. public:
  206. std::vector<std::tuple<int8_t, bool, bool, bool, bool>> items;
  207. BatchedSetMeshIndexing() = default;
  208. BatchedSetMeshIndexing(std::vector<std::tuple<int8_t, bool, bool, bool, bool>> items_, std::string scope_ = {}): items(items_) { set_scope(scope_); }
  209. };
  210. class BetaRNG : public OpDefImplBase<BetaRNG> {
  211. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  212. public:
  213. uint64_t seed = 0;
  214. size_t handle;
  215. BetaRNG() = default;
  216. BetaRNG(uint64_t seed_, size_t handle_, std::string scope_ = {}): seed(seed_), handle(handle_) { set_scope(scope_); }
  217. BetaRNG(::megdnn::param::BetaRNG packed_param_0, size_t handle_): seed(packed_param_0.seed), handle(handle_) {}
  218. ::megdnn::param::BetaRNG param() const {
  219. return {seed};
  220. }
  221. };
  222. class Borrow : public OpDefImplBase<Borrow> {
  223. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  224. public:
  225. ::mgb::CompNode comp_node;
  226. Borrow() = default;
  227. Borrow(::mgb::CompNode comp_node_, std::string scope_ = {}): comp_node(comp_node_) { set_scope(scope_); }
  228. };
  229. class Broadcast : public OpDefImplBase<Broadcast> {
  230. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  231. public:
  232. std::vector<int32_t> shape;
  233. Broadcast() = default;
  234. Broadcast(std::vector<int32_t> shape_, std::string scope_ = {}): shape(shape_) { set_scope(scope_); }
  235. Broadcast(::megdnn::param::Empty, std::vector<int32_t> shape_): shape(shape_) {}
  236. ::megdnn::param::Empty param() const {
  237. return {};
  238. }
  239. };
  240. class CambriconRuntime : public OpDefImplBase<CambriconRuntime> {
  241. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  242. public:
  243. std::string buf;
  244. size_t buf_size;
  245. std::string symbol;
  246. bool tensor_dim_mutable;
  247. CambriconRuntime() = default;
  248. CambriconRuntime(std::string buf_, size_t buf_size_, std::string symbol_, bool tensor_dim_mutable_, std::string scope_ = {}): buf(buf_), buf_size(buf_size_), symbol(symbol_), tensor_dim_mutable(tensor_dim_mutable_) { set_scope(scope_); }
  249. };
  250. class CheckNonFinite : public OpDefImplBase<CheckNonFinite> {
  251. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  252. public:
  253. float scale = 1.0;
  254. CheckNonFinite() = default;
  255. CheckNonFinite(float scale_, std::string scope_ = {}): scale(scale_) { set_scope(scope_); }
  256. CheckNonFinite(::megdnn::param::CheckNonFinite packed_param_0): scale(packed_param_0.scale) {}
  257. ::megdnn::param::CheckNonFinite param() const {
  258. return {scale};
  259. }
  260. };
  261. class CollectiveComm : public OpDefImplBase<CollectiveComm> {
  262. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  263. public:
  264. using Mode = ::megdnn::param::CollectiveComm::Mode;
  265. Mode mode = ::megdnn::param::CollectiveComm::Mode::REDUCE_SUM;
  266. std::string key;
  267. uint32_t nr_devices;
  268. uint32_t rank;
  269. bool is_root;
  270. bool local_grad;
  271. std::string addr;
  272. uint32_t port;
  273. ::megdnn::DType dtype;
  274. std::string backend;
  275. std::string comp_node;
  276. CollectiveComm() = default;
  277. CollectiveComm(Mode mode_, std::string key_, uint32_t nr_devices_, uint32_t rank_, bool is_root_, bool local_grad_, std::string addr_, uint32_t port_, ::megdnn::DType dtype_, std::string backend_, std::string comp_node_, std::string scope_ = {}): mode(mode_), key(key_), nr_devices(nr_devices_), rank(rank_), is_root(is_root_), local_grad(local_grad_), addr(addr_), port(port_), dtype(dtype_), backend(backend_), comp_node(comp_node_) { set_scope(scope_); }
  278. CollectiveComm(::megdnn::param::CollectiveComm packed_param_0, std::string key_, uint32_t nr_devices_, uint32_t rank_, bool is_root_, bool local_grad_, std::string addr_, uint32_t port_, ::megdnn::DType dtype_, std::string backend_, std::string comp_node_): mode(packed_param_0.mode), key(key_), nr_devices(nr_devices_), rank(rank_), is_root(is_root_), local_grad(local_grad_), addr(addr_), port(port_), dtype(dtype_), backend(backend_), comp_node(comp_node_) {}
  279. ::megdnn::param::CollectiveComm param() const {
  280. return {mode};
  281. }
  282. };
  283. class Concat : public OpDefImplBase<Concat> {
  284. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  285. public:
  286. int32_t axis = 0;
  287. ::mgb::CompNode comp_node;
  288. Concat() = default;
  289. Concat(int32_t axis_, ::mgb::CompNode comp_node_, std::string scope_ = {}): axis(axis_), comp_node(comp_node_) { set_scope(scope_); }
  290. Concat(::megdnn::param::Axis packed_param_0, ::mgb::CompNode comp_node_): axis(packed_param_0.axis), comp_node(comp_node_) {}
  291. ::megdnn::param::Axis param() const {
  292. return {axis};
  293. }
  294. };
  295. class CondTake : public OpDefImplBase<CondTake> {
  296. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  297. public:
  298. CondTake() = default;
  299. };
  300. class ConvBias : public OpDefImplBase<ConvBias> {
  301. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  302. public:
  303. using NonlineMode = ::megdnn::param::ConvBias::NonlineMode;
  304. using Mode = ::megdnn::param::ConvBias::Mode;
  305. using Sparse = ::megdnn::param::ConvBias::Sparse;
  306. using Format = ::megdnn::param::ConvBias::Format;
  307. using ComputeMode = ::megdnn::param::ConvBias::ComputeMode;
  308. using Strategy = ::megdnn::param::ExecutionPolicy::Strategy;
  309. NonlineMode nonlineMode = ::megdnn::param::ConvBias::NonlineMode::IDENTITY;
  310. Mode mode = ::megdnn::param::ConvBias::Mode::CROSS_CORRELATION;
  311. Sparse sparse = ::megdnn::param::ConvBias::Sparse::DENSE;
  312. Format format = ::megdnn::param::ConvBias::Format::NCHW;
  313. uint32_t pad_h = 0;
  314. uint32_t pad_w = 0;
  315. uint32_t stride_h = 1;
  316. uint32_t stride_w = 1;
  317. uint32_t dilate_h = 1;
  318. uint32_t dilate_w = 1;
  319. ComputeMode compute_mode = ::megdnn::param::ConvBias::ComputeMode::DEFAULT;
  320. Strategy strategy = static_cast<::megdnn::param::ExecutionPolicy::Strategy>(1);
  321. uint64_t workspace_limit = 18446744073709551615ull;
  322. ::megdnn::DType dtype;
  323. ConvBias() = default;
  324. ConvBias(NonlineMode nonlineMode_, Mode mode_, Sparse sparse_, Format format_, uint32_t pad_h_, uint32_t pad_w_, uint32_t stride_h_, uint32_t stride_w_, uint32_t dilate_h_, uint32_t dilate_w_, ComputeMode compute_mode_, Strategy strategy_, uint64_t workspace_limit_, ::megdnn::DType dtype_, std::string scope_ = {}): nonlineMode(nonlineMode_), mode(mode_), sparse(sparse_), format(format_), pad_h(pad_h_), pad_w(pad_w_), stride_h(stride_h_), stride_w(stride_w_), dilate_h(dilate_h_), dilate_w(dilate_w_), compute_mode(compute_mode_), strategy(strategy_), workspace_limit(workspace_limit_), dtype(dtype_) {
  325. set_scope(scope_);
  326. mgb_assert(static_cast<uint32_t>(strategy) <= uint32_t(8));
  327. }
  328. ConvBias(::megdnn::param::ConvBias packed_param_0, ::megdnn::param::ExecutionPolicy packed_param_1, ::megdnn::DType dtype_): nonlineMode(packed_param_0.nonlineMode), mode(packed_param_0.mode), sparse(packed_param_0.sparse), format(packed_param_0.format), pad_h(packed_param_0.pad_h), pad_w(packed_param_0.pad_w), stride_h(packed_param_0.stride_h), stride_w(packed_param_0.stride_w), dilate_h(packed_param_0.dilate_h), dilate_w(packed_param_0.dilate_w), compute_mode(packed_param_0.compute_mode), strategy(packed_param_1.strategy), workspace_limit(packed_param_1.workspace_limit), dtype(dtype_) {
  329. mgb_assert(static_cast<uint32_t>(strategy) <= uint32_t(8));
  330. }
  331. ::megdnn::param::ConvBias param() const {
  332. return {nonlineMode, mode, sparse, format, pad_h, pad_w, stride_h, stride_w, dilate_h, dilate_w, compute_mode};
  333. }
  334. ::megdnn::param::ExecutionPolicy policy() const {
  335. return {strategy, workspace_limit};
  336. }
  337. };
  338. class Convolution : public OpDefImplBase<Convolution> {
  339. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  340. public:
  341. using Mode = ::megdnn::param::Convolution::Mode;
  342. using Sparse = ::megdnn::param::Convolution::Sparse;
  343. using Format = ::megdnn::param::Convolution::Format;
  344. using ComputeMode = ::megdnn::param::Convolution::ComputeMode;
  345. using Strategy = ::megdnn::param::ExecutionPolicy::Strategy;
  346. Mode mode = ::megdnn::param::Convolution::Mode::CROSS_CORRELATION;
  347. uint32_t pad_h = 0;
  348. uint32_t pad_w = 0;
  349. uint32_t stride_h = 1;
  350. uint32_t stride_w = 1;
  351. uint32_t dilate_h = 1;
  352. uint32_t dilate_w = 1;
  353. Sparse sparse = ::megdnn::param::Convolution::Sparse::DENSE;
  354. Format format = ::megdnn::param::Convolution::Format::NCHW;
  355. ComputeMode compute_mode = ::megdnn::param::Convolution::ComputeMode::DEFAULT;
  356. Strategy strategy = static_cast<::megdnn::param::ExecutionPolicy::Strategy>(1);
  357. uint64_t workspace_limit = 18446744073709551615ull;
  358. Convolution() = default;
  359. Convolution(Mode mode_, uint32_t pad_h_, uint32_t pad_w_, uint32_t stride_h_, uint32_t stride_w_, uint32_t dilate_h_, uint32_t dilate_w_, Sparse sparse_, Format format_, ComputeMode compute_mode_, Strategy strategy_, uint64_t workspace_limit_, std::string scope_ = {}): mode(mode_), pad_h(pad_h_), pad_w(pad_w_), stride_h(stride_h_), stride_w(stride_w_), dilate_h(dilate_h_), dilate_w(dilate_w_), sparse(sparse_), format(format_), compute_mode(compute_mode_), strategy(strategy_), workspace_limit(workspace_limit_) {
  360. set_scope(scope_);
  361. mgb_assert(static_cast<uint32_t>(strategy) <= uint32_t(8));
  362. }
  363. Convolution(::megdnn::param::Convolution packed_param_0, ::megdnn::param::ExecutionPolicy packed_param_1): mode(packed_param_0.mode), pad_h(packed_param_0.pad_h), pad_w(packed_param_0.pad_w), stride_h(packed_param_0.stride_h), stride_w(packed_param_0.stride_w), dilate_h(packed_param_0.dilate_h), dilate_w(packed_param_0.dilate_w), sparse(packed_param_0.sparse), format(packed_param_0.format), compute_mode(packed_param_0.compute_mode), strategy(packed_param_1.strategy), workspace_limit(packed_param_1.workspace_limit) {
  364. mgb_assert(static_cast<uint32_t>(strategy) <= uint32_t(8));
  365. }
  366. ::megdnn::param::Convolution param() const {
  367. return {mode, pad_h, pad_w, stride_h, stride_w, dilate_h, dilate_w, sparse, format, compute_mode};
  368. }
  369. ::megdnn::param::ExecutionPolicy policy() const {
  370. return {strategy, workspace_limit};
  371. }
  372. };
  373. class Convolution3D : public OpDefImplBase<Convolution3D> {
  374. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  375. public:
  376. using Mode = ::megdnn::param::Convolution3D::Mode;
  377. using Sparse = ::megdnn::param::Convolution3D::Sparse;
  378. using DataType = ::megdnn::param::Convolution3D::DataType;
  379. using Format = ::megdnn::param::Convolution3D::Format;
  380. using Strategy = ::megdnn::param::ExecutionPolicy::Strategy;
  381. Mode mode = ::megdnn::param::Convolution3D::Mode::CROSS_CORRELATION;
  382. uint32_t pad_d = 0;
  383. uint32_t pad_h = 0;
  384. uint32_t pad_w = 0;
  385. uint32_t stride_d = 1;
  386. uint32_t stride_h = 1;
  387. uint32_t stride_w = 1;
  388. uint32_t dilate_d = 1;
  389. uint32_t dilate_h = 1;
  390. uint32_t dilate_w = 1;
  391. Sparse sparse = ::megdnn::param::Convolution3D::Sparse::DENSE;
  392. DataType data_type = ::megdnn::param::Convolution3D::DataType::FLOAT;
  393. Format format = ::megdnn::param::Convolution3D::Format::NCDHW;
  394. Strategy strategy = static_cast<::megdnn::param::ExecutionPolicy::Strategy>(1);
  395. uint64_t workspace_limit = 18446744073709551615ull;
  396. Convolution3D() = default;
  397. Convolution3D(Mode mode_, uint32_t pad_d_, uint32_t pad_h_, uint32_t pad_w_, uint32_t stride_d_, uint32_t stride_h_, uint32_t stride_w_, uint32_t dilate_d_, uint32_t dilate_h_, uint32_t dilate_w_, Sparse sparse_, DataType data_type_, Format format_, Strategy strategy_, uint64_t workspace_limit_, std::string scope_ = {}): mode(mode_), pad_d(pad_d_), pad_h(pad_h_), pad_w(pad_w_), stride_d(stride_d_), stride_h(stride_h_), stride_w(stride_w_), dilate_d(dilate_d_), dilate_h(dilate_h_), dilate_w(dilate_w_), sparse(sparse_), data_type(data_type_), format(format_), strategy(strategy_), workspace_limit(workspace_limit_) {
  398. set_scope(scope_);
  399. mgb_assert(static_cast<uint32_t>(strategy) <= uint32_t(8));
  400. }
  401. Convolution3D(::megdnn::param::Convolution3D packed_param_0, ::megdnn::param::ExecutionPolicy packed_param_1): mode(packed_param_0.mode), pad_d(packed_param_0.pad_d), pad_h(packed_param_0.pad_h), pad_w(packed_param_0.pad_w), stride_d(packed_param_0.stride_d), stride_h(packed_param_0.stride_h), stride_w(packed_param_0.stride_w), dilate_d(packed_param_0.dilate_d), dilate_h(packed_param_0.dilate_h), dilate_w(packed_param_0.dilate_w), sparse(packed_param_0.sparse), data_type(packed_param_0.data_type), format(packed_param_0.format), strategy(packed_param_1.strategy), workspace_limit(packed_param_1.workspace_limit) {
  402. mgb_assert(static_cast<uint32_t>(strategy) <= uint32_t(8));
  403. }
  404. ::megdnn::param::Convolution3D param() const {
  405. return {mode, pad_d, pad_h, pad_w, stride_d, stride_h, stride_w, dilate_d, dilate_h, dilate_w, sparse, data_type, format};
  406. }
  407. ::megdnn::param::ExecutionPolicy policy() const {
  408. return {strategy, workspace_limit};
  409. }
  410. };
  411. class Convolution3DBackwardData : public OpDefImplBase<Convolution3DBackwardData> {
  412. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  413. public:
  414. using Mode = ::megdnn::param::Convolution3D::Mode;
  415. using Sparse = ::megdnn::param::Convolution3D::Sparse;
  416. using DataType = ::megdnn::param::Convolution3D::DataType;
  417. using Format = ::megdnn::param::Convolution3D::Format;
  418. using Strategy = ::megdnn::param::ExecutionPolicy::Strategy;
  419. Mode mode = ::megdnn::param::Convolution3D::Mode::CROSS_CORRELATION;
  420. uint32_t pad_d = 0;
  421. uint32_t pad_h = 0;
  422. uint32_t pad_w = 0;
  423. uint32_t stride_d = 1;
  424. uint32_t stride_h = 1;
  425. uint32_t stride_w = 1;
  426. uint32_t dilate_d = 1;
  427. uint32_t dilate_h = 1;
  428. uint32_t dilate_w = 1;
  429. Sparse sparse = ::megdnn::param::Convolution3D::Sparse::DENSE;
  430. DataType data_type = ::megdnn::param::Convolution3D::DataType::FLOAT;
  431. Format format = ::megdnn::param::Convolution3D::Format::NCDHW;
  432. Strategy strategy = static_cast<::megdnn::param::ExecutionPolicy::Strategy>(1);
  433. uint64_t workspace_limit = 18446744073709551615ull;
  434. Convolution3DBackwardData() = default;
  435. Convolution3DBackwardData(Mode mode_, uint32_t pad_d_, uint32_t pad_h_, uint32_t pad_w_, uint32_t stride_d_, uint32_t stride_h_, uint32_t stride_w_, uint32_t dilate_d_, uint32_t dilate_h_, uint32_t dilate_w_, Sparse sparse_, DataType data_type_, Format format_, Strategy strategy_, uint64_t workspace_limit_, std::string scope_ = {}): mode(mode_), pad_d(pad_d_), pad_h(pad_h_), pad_w(pad_w_), stride_d(stride_d_), stride_h(stride_h_), stride_w(stride_w_), dilate_d(dilate_d_), dilate_h(dilate_h_), dilate_w(dilate_w_), sparse(sparse_), data_type(data_type_), format(format_), strategy(strategy_), workspace_limit(workspace_limit_) {
  436. set_scope(scope_);
  437. mgb_assert(static_cast<uint32_t>(strategy) <= uint32_t(8));
  438. }
  439. Convolution3DBackwardData(::megdnn::param::Convolution3D packed_param_0, ::megdnn::param::ExecutionPolicy packed_param_1): mode(packed_param_0.mode), pad_d(packed_param_0.pad_d), pad_h(packed_param_0.pad_h), pad_w(packed_param_0.pad_w), stride_d(packed_param_0.stride_d), stride_h(packed_param_0.stride_h), stride_w(packed_param_0.stride_w), dilate_d(packed_param_0.dilate_d), dilate_h(packed_param_0.dilate_h), dilate_w(packed_param_0.dilate_w), sparse(packed_param_0.sparse), data_type(packed_param_0.data_type), format(packed_param_0.format), strategy(packed_param_1.strategy), workspace_limit(packed_param_1.workspace_limit) {
  440. mgb_assert(static_cast<uint32_t>(strategy) <= uint32_t(8));
  441. }
  442. ::megdnn::param::Convolution3D param() const {
  443. return {mode, pad_d, pad_h, pad_w, stride_d, stride_h, stride_w, dilate_d, dilate_h, dilate_w, sparse, data_type, format};
  444. }
  445. ::megdnn::param::ExecutionPolicy policy() const {
  446. return {strategy, workspace_limit};
  447. }
  448. };
  449. class ConvolutionBackwardData : public OpDefImplBase<ConvolutionBackwardData> {
  450. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  451. public:
  452. using Mode = ::megdnn::param::Convolution::Mode;
  453. using Sparse = ::megdnn::param::Convolution::Sparse;
  454. using Format = ::megdnn::param::Convolution::Format;
  455. using ComputeMode = ::megdnn::param::Convolution::ComputeMode;
  456. using Strategy = ::megdnn::param::ExecutionPolicy::Strategy;
  457. Mode mode = ::megdnn::param::Convolution::Mode::CROSS_CORRELATION;
  458. uint32_t pad_h = 0;
  459. uint32_t pad_w = 0;
  460. uint32_t stride_h = 1;
  461. uint32_t stride_w = 1;
  462. uint32_t dilate_h = 1;
  463. uint32_t dilate_w = 1;
  464. Sparse sparse = ::megdnn::param::Convolution::Sparse::DENSE;
  465. Format format = ::megdnn::param::Convolution::Format::NCHW;
  466. ComputeMode compute_mode = ::megdnn::param::Convolution::ComputeMode::DEFAULT;
  467. Strategy strategy = static_cast<::megdnn::param::ExecutionPolicy::Strategy>(1);
  468. uint64_t workspace_limit = 18446744073709551615ull;
  469. ::megdnn::DType dtype;
  470. ConvolutionBackwardData() = default;
  471. ConvolutionBackwardData(Mode mode_, uint32_t pad_h_, uint32_t pad_w_, uint32_t stride_h_, uint32_t stride_w_, uint32_t dilate_h_, uint32_t dilate_w_, Sparse sparse_, Format format_, ComputeMode compute_mode_, Strategy strategy_, uint64_t workspace_limit_, ::megdnn::DType dtype_, std::string scope_ = {}): mode(mode_), pad_h(pad_h_), pad_w(pad_w_), stride_h(stride_h_), stride_w(stride_w_), dilate_h(dilate_h_), dilate_w(dilate_w_), sparse(sparse_), format(format_), compute_mode(compute_mode_), strategy(strategy_), workspace_limit(workspace_limit_), dtype(dtype_) {
  472. set_scope(scope_);
  473. mgb_assert(static_cast<uint32_t>(strategy) <= uint32_t(8));
  474. }
  475. ConvolutionBackwardData(::megdnn::param::Convolution packed_param_0, ::megdnn::param::ExecutionPolicy packed_param_1, ::megdnn::DType dtype_): mode(packed_param_0.mode), pad_h(packed_param_0.pad_h), pad_w(packed_param_0.pad_w), stride_h(packed_param_0.stride_h), stride_w(packed_param_0.stride_w), dilate_h(packed_param_0.dilate_h), dilate_w(packed_param_0.dilate_w), sparse(packed_param_0.sparse), format(packed_param_0.format), compute_mode(packed_param_0.compute_mode), strategy(packed_param_1.strategy), workspace_limit(packed_param_1.workspace_limit), dtype(dtype_) {
  476. mgb_assert(static_cast<uint32_t>(strategy) <= uint32_t(8));
  477. }
  478. ::megdnn::param::Convolution param() const {
  479. return {mode, pad_h, pad_w, stride_h, stride_w, dilate_h, dilate_w, sparse, format, compute_mode};
  480. }
  481. ::megdnn::param::ExecutionPolicy policy() const {
  482. return {strategy, workspace_limit};
  483. }
  484. };
  485. class Copy : public OpDefImplBase<Copy> {
  486. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  487. public:
  488. ::mgb::CompNode comp_node;
  489. Copy() = default;
  490. Copy(::mgb::CompNode comp_node_, std::string scope_ = {}): comp_node(comp_node_) { set_scope(scope_); }
  491. };
  492. class Correlation : public OpDefImplBase<Correlation> {
  493. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  494. public:
  495. using Format = ::megdnn::param::Correlation::Format;
  496. Format format = ::megdnn::param::Correlation::Format::NCHW;
  497. uint32_t kernel_size = 1;
  498. uint32_t max_displacement = 1;
  499. uint32_t stride1 = 1;
  500. uint32_t stride2 = 1;
  501. uint32_t pad_size = 0;
  502. bool is_multiply = true;
  503. Correlation() = default;
  504. Correlation(Format format_, uint32_t kernel_size_, uint32_t max_displacement_, uint32_t stride1_, uint32_t stride2_, uint32_t pad_size_, bool is_multiply_, std::string scope_ = {}): format(format_), kernel_size(kernel_size_), max_displacement(max_displacement_), stride1(stride1_), stride2(stride2_), pad_size(pad_size_), is_multiply(is_multiply_) { set_scope(scope_); }
  505. Correlation(::megdnn::param::Correlation packed_param_0): format(packed_param_0.format), kernel_size(packed_param_0.kernel_size), max_displacement(packed_param_0.max_displacement), stride1(packed_param_0.stride1), stride2(packed_param_0.stride2), pad_size(packed_param_0.pad_size), is_multiply(packed_param_0.is_multiply) {}
  506. ::megdnn::param::Correlation param() const {
  507. return {format, kernel_size, max_displacement, stride1, stride2, pad_size, is_multiply};
  508. }
  509. };
  510. class Cumsum : public OpDefImplBase<Cumsum> {
  511. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  512. public:
  513. int32_t axis = 2147483647;
  514. bool exclusive = true;
  515. bool reverse = false;
  516. Cumsum() = default;
  517. Cumsum(int32_t axis_, bool exclusive_, bool reverse_, std::string scope_ = {}): axis(axis_), exclusive(exclusive_), reverse(reverse_) { set_scope(scope_); }
  518. Cumsum(::megdnn::param::Cumsum packed_param_0): axis(packed_param_0.axis), exclusive(packed_param_0.exclusive), reverse(packed_param_0.reverse) {}
  519. ::megdnn::param::Cumsum param() const {
  520. return {axis, exclusive, reverse};
  521. }
  522. };
  523. class CvtColor : public OpDefImplBase<CvtColor> {
  524. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  525. public:
  526. using Mode = ::megdnn::param::CvtColor::Mode;
  527. Mode mode = ::megdnn::param::CvtColor::Mode::RGB2GRAY;
  528. CvtColor() = default;
  529. CvtColor(Mode mode_, std::string scope_ = {}): mode(mode_) { set_scope(scope_); }
  530. CvtColor(::megdnn::param::CvtColor packed_param_0): mode(packed_param_0.mode) {}
  531. ::megdnn::param::CvtColor param() const {
  532. return {mode};
  533. }
  534. };
  535. class DeformableConv : public OpDefImplBase<DeformableConv> {
  536. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  537. public:
  538. using Mode = ::megdnn::param::Convolution::Mode;
  539. using Sparse = ::megdnn::param::Convolution::Sparse;
  540. using Format = ::megdnn::param::Convolution::Format;
  541. using ComputeMode = ::megdnn::param::Convolution::ComputeMode;
  542. using Strategy = ::megdnn::param::ExecutionPolicy::Strategy;
  543. Mode mode = ::megdnn::param::Convolution::Mode::CROSS_CORRELATION;
  544. uint32_t pad_h = 0;
  545. uint32_t pad_w = 0;
  546. uint32_t stride_h = 1;
  547. uint32_t stride_w = 1;
  548. uint32_t dilate_h = 1;
  549. uint32_t dilate_w = 1;
  550. Sparse sparse = ::megdnn::param::Convolution::Sparse::DENSE;
  551. Format format = ::megdnn::param::Convolution::Format::NCHW;
  552. ComputeMode compute_mode = ::megdnn::param::Convolution::ComputeMode::DEFAULT;
  553. Strategy strategy = static_cast<::megdnn::param::ExecutionPolicy::Strategy>(1);
  554. uint64_t workspace_limit = 18446744073709551615ull;
  555. DeformableConv() = default;
  556. DeformableConv(Mode mode_, uint32_t pad_h_, uint32_t pad_w_, uint32_t stride_h_, uint32_t stride_w_, uint32_t dilate_h_, uint32_t dilate_w_, Sparse sparse_, Format format_, ComputeMode compute_mode_, Strategy strategy_, uint64_t workspace_limit_, std::string scope_ = {}): mode(mode_), pad_h(pad_h_), pad_w(pad_w_), stride_h(stride_h_), stride_w(stride_w_), dilate_h(dilate_h_), dilate_w(dilate_w_), sparse(sparse_), format(format_), compute_mode(compute_mode_), strategy(strategy_), workspace_limit(workspace_limit_) {
  557. set_scope(scope_);
  558. mgb_assert(static_cast<uint32_t>(strategy) <= uint32_t(8));
  559. }
  560. DeformableConv(::megdnn::param::Convolution packed_param_0, ::megdnn::param::ExecutionPolicy packed_param_1): mode(packed_param_0.mode), pad_h(packed_param_0.pad_h), pad_w(packed_param_0.pad_w), stride_h(packed_param_0.stride_h), stride_w(packed_param_0.stride_w), dilate_h(packed_param_0.dilate_h), dilate_w(packed_param_0.dilate_w), sparse(packed_param_0.sparse), format(packed_param_0.format), compute_mode(packed_param_0.compute_mode), strategy(packed_param_1.strategy), workspace_limit(packed_param_1.workspace_limit) {
  561. mgb_assert(static_cast<uint32_t>(strategy) <= uint32_t(8));
  562. }
  563. ::megdnn::param::Convolution param() const {
  564. return {mode, pad_h, pad_w, stride_h, stride_w, dilate_h, dilate_w, sparse, format, compute_mode};
  565. }
  566. ::megdnn::param::ExecutionPolicy policy() const {
  567. return {strategy, workspace_limit};
  568. }
  569. };
  570. class DeformablePSROIPooling : public OpDefImplBase<DeformablePSROIPooling> {
  571. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  572. public:
  573. bool no_trans = true;
  574. float spatial_scale = 1;
  575. float trans_std = 1;
  576. uint32_t pooled_h = 1;
  577. uint32_t pooled_w = 1;
  578. uint32_t part_size = 1;
  579. uint32_t sample_per_part = 1;
  580. DeformablePSROIPooling() = default;
  581. DeformablePSROIPooling(bool no_trans_, float spatial_scale_, float trans_std_, uint32_t pooled_h_, uint32_t pooled_w_, uint32_t part_size_, uint32_t sample_per_part_, std::string scope_ = {}): no_trans(no_trans_), spatial_scale(spatial_scale_), trans_std(trans_std_), pooled_h(pooled_h_), pooled_w(pooled_w_), part_size(part_size_), sample_per_part(sample_per_part_) { set_scope(scope_); }
  582. DeformablePSROIPooling(::megdnn::param::DeformablePSROIPooling packed_param_0): no_trans(packed_param_0.no_trans), spatial_scale(packed_param_0.spatial_scale), trans_std(packed_param_0.trans_std), pooled_h(packed_param_0.pooled_h), pooled_w(packed_param_0.pooled_w), part_size(packed_param_0.part_size), sample_per_part(packed_param_0.sample_per_part) {}
  583. ::megdnn::param::DeformablePSROIPooling param() const {
  584. return {no_trans, spatial_scale, trans_std, pooled_h, pooled_w, part_size, sample_per_part};
  585. }
  586. };
  587. class Diag : public OpDefImplBase<Diag> {
  588. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  589. public:
  590. int32_t k = 0;
  591. Diag() = default;
  592. Diag(int32_t k_, std::string scope_ = {}): k(k_) { set_scope(scope_); }
  593. Diag(::megdnn::param::Diag packed_param_0): k(packed_param_0.k) {}
  594. ::megdnn::param::Diag param() const {
  595. return {k};
  596. }
  597. };
  598. class Dimshuffle : public OpDefImplBase<Dimshuffle> {
  599. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  600. public:
  601. std::vector<int32_t> pattern;
  602. Dimshuffle() = default;
  603. Dimshuffle(std::vector<int32_t> pattern_, std::string scope_ = {}): pattern(pattern_) { set_scope(scope_); }
  604. };
  605. class Dot : public OpDefImplBase<Dot> {
  606. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  607. public:
  608. Dot() = default;
  609. Dot(::megdnn::param::Empty) {}
  610. ::megdnn::param::Empty param() const {
  611. return {};
  612. }
  613. };
  614. class Dropout : public OpDefImplBase<Dropout> {
  615. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  616. public:
  617. float drop_prob = 0;
  618. uint64_t seed = 0;
  619. size_t handle;
  620. Dropout() = default;
  621. Dropout(float drop_prob_, uint64_t seed_, size_t handle_, std::string scope_ = {}): drop_prob(drop_prob_), seed(seed_), handle(handle_) { set_scope(scope_); }
  622. Dropout(::megdnn::param::Dropout packed_param_0, size_t handle_): drop_prob(packed_param_0.drop_prob), seed(packed_param_0.seed), handle(handle_) {}
  623. ::megdnn::param::Dropout param() const {
  624. return {drop_prob, seed};
  625. }
  626. };
  627. class Elemwise : public OpDefImplBase<Elemwise> {
  628. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  629. public:
  630. using Mode = ::megdnn::param::Elemwise::Mode;
  631. Mode mode = ::megdnn::param::Elemwise::Mode::RELU;
  632. Elemwise() = default;
  633. Elemwise(Mode mode_, std::string scope_ = {}): mode(mode_) { set_scope(scope_); }
  634. Elemwise(::megdnn::param::Elemwise packed_param_0): mode(packed_param_0.mode) {}
  635. ::megdnn::param::Elemwise param() const {
  636. return {mode};
  637. }
  638. };
  639. template <>
  640. struct ToStringTrait<Elemwise::Mode> {
  641. std::string operator()(Elemwise::Mode e) const {
  642. switch (e) {
  643. case Elemwise::Mode::RELU: return "RELU";
  644. case Elemwise::Mode::ABS: return "ABS";
  645. case Elemwise::Mode::ACOS: return "ACOS";
  646. case Elemwise::Mode::ASIN: return "ASIN";
  647. case Elemwise::Mode::CEIL: return "CEIL";
  648. case Elemwise::Mode::COS: return "COS";
  649. case Elemwise::Mode::EXP: return "EXP";
  650. case Elemwise::Mode::EXPM1: return "EXPM1";
  651. case Elemwise::Mode::FLOOR: return "FLOOR";
  652. case Elemwise::Mode::LOG: return "LOG";
  653. case Elemwise::Mode::LOG1P: return "LOG1P";
  654. case Elemwise::Mode::NEGATE: return "NEGATE";
  655. case Elemwise::Mode::SIGMOID: return "SIGMOID";
  656. case Elemwise::Mode::SIN: return "SIN";
  657. case Elemwise::Mode::TANH: return "TANH";
  658. case Elemwise::Mode::ABS_GRAD: return "ABS_GRAD";
  659. case Elemwise::Mode::ADD: return "ADD";
  660. case Elemwise::Mode::FLOOR_DIV: return "FLOOR_DIV";
  661. case Elemwise::Mode::MAX: return "MAX";
  662. case Elemwise::Mode::MIN: return "MIN";
  663. case Elemwise::Mode::MOD: return "MOD";
  664. case Elemwise::Mode::MUL: return "MUL";
  665. case Elemwise::Mode::POW: return "POW";
  666. case Elemwise::Mode::SIGMOID_GRAD: return "SIGMOID_GRAD";
  667. case Elemwise::Mode::SUB: return "SUB";
  668. case Elemwise::Mode::SWITCH_GT0: return "SWITCH_GT0";
  669. case Elemwise::Mode::TANH_GRAD: return "TANH_GRAD";
  670. case Elemwise::Mode::TRUE_DIV: return "TRUE_DIV";
  671. case Elemwise::Mode::LOG_SUM_EXP: return "LOG_SUM_EXP";
  672. case Elemwise::Mode::LT: return "LT";
  673. case Elemwise::Mode::LEQ: return "LEQ";
  674. case Elemwise::Mode::EQ: return "EQ";
  675. case Elemwise::Mode::SHL: return "SHL";
  676. case Elemwise::Mode::SHR: return "SHR";
  677. case Elemwise::Mode::COND_LEQ_MOV: return "COND_LEQ_MOV";
  678. case Elemwise::Mode::FUSE_MUL_ADD3: return "FUSE_MUL_ADD3";
  679. case Elemwise::Mode::FUSE_MUL_ADD4: return "FUSE_MUL_ADD4";
  680. case Elemwise::Mode::FUSE_ADD_RELU: return "FUSE_ADD_RELU";
  681. case Elemwise::Mode::FUSE_ADD_SIGMOID: return "FUSE_ADD_SIGMOID";
  682. case Elemwise::Mode::FUSE_ADD_TANH: return "FUSE_ADD_TANH";
  683. case Elemwise::Mode::FAST_TANH: return "FAST_TANH";
  684. case Elemwise::Mode::FAST_TANH_GRAD: return "FAST_TANH_GRAD";
  685. case Elemwise::Mode::ROUND: return "ROUND";
  686. case Elemwise::Mode::RMULH: return "RMULH";
  687. case Elemwise::Mode::ATAN2: return "ATAN2";
  688. case Elemwise::Mode::ERF: return "ERF";
  689. case Elemwise::Mode::ERFINV: return "ERFINV";
  690. case Elemwise::Mode::ERFC: return "ERFC";
  691. case Elemwise::Mode::ERFCINV: return "ERFCINV";
  692. case Elemwise::Mode::H_SWISH: return "H_SWISH";
  693. case Elemwise::Mode::H_SWISH_GRAD: return "H_SWISH_GRAD";
  694. case Elemwise::Mode::FUSE_ADD_H_SWISH: return "FUSE_ADD_H_SWISH";
  695. case Elemwise::Mode::NOT: return "NOT";
  696. case Elemwise::Mode::AND: return "AND";
  697. case Elemwise::Mode::OR: return "OR";
  698. case Elemwise::Mode::XOR: return "XOR";
  699. case Elemwise::Mode::SILU: return "SILU";
  700. case Elemwise::Mode::SILU_GRAD: return "SILU_GRAD";
  701. case Elemwise::Mode::GELU: return "GELU";
  702. case Elemwise::Mode::GELU_GRAD: return "GELU_GRAD";
  703. case Elemwise::Mode::COND_LT_MOV: return "COND_LT_MOV";
  704. case Elemwise::Mode::NEQ: return "NEQ";
  705. case Elemwise::Mode::ISNAN: return "ISNAN";
  706. case Elemwise::Mode::ISINF: return "ISINF";
  707. default:
  708. return "Elemwise::Mode::Unknown";
  709. }
  710. }
  711. };
  712. class ElemwiseMultiType : public OpDefImplBase<ElemwiseMultiType> {
  713. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  714. public:
  715. using Mode = ::megdnn::param::ElemwiseMultiType::Mode;
  716. Mode mode = ::megdnn::param::ElemwiseMultiType::Mode::FUSE_MUL_ADD3_INT16x32x32x32;
  717. ::megdnn::DType dtype;
  718. ElemwiseMultiType() = default;
  719. ElemwiseMultiType(Mode mode_, ::megdnn::DType dtype_, std::string scope_ = {}): mode(mode_), dtype(dtype_) { set_scope(scope_); }
  720. ElemwiseMultiType(::megdnn::param::ElemwiseMultiType packed_param_0, ::megdnn::DType dtype_): mode(packed_param_0.mode), dtype(dtype_) {}
  721. ::megdnn::param::ElemwiseMultiType param() const {
  722. return {mode};
  723. }
  724. };
  725. template <>
  726. struct ToStringTrait<ElemwiseMultiType::Mode> {
  727. std::string operator()(ElemwiseMultiType::Mode e) const {
  728. switch (e) {
  729. case ElemwiseMultiType::Mode::FUSE_MUL_ADD3_INT16x32x32x32: return "FUSE_MUL_ADD3_INT16x32x32x32";
  730. case ElemwiseMultiType::Mode::FUSE_MUL_ADD3_IXxF32xF32xI8: return "FUSE_MUL_ADD3_IXxF32xF32xI8";
  731. case ElemwiseMultiType::Mode::ROUND_SHR_SATURATE_IXxI8xI8: return "ROUND_SHR_SATURATE_IXxI8xI8";
  732. case ElemwiseMultiType::Mode::FUSE_ADD_RMULH_ROUND_SHR_SATURATE_INT16x16x16x8: return "FUSE_ADD_RMULH_ROUND_SHR_SATURATE_INT16x16x16x8";
  733. case ElemwiseMultiType::Mode::FUSE_ADD_RMULH_ROUND_SHR_SATURATE_INT32x32x32x8: return "FUSE_ADD_RMULH_ROUND_SHR_SATURATE_INT32x32x32x8";
  734. case ElemwiseMultiType::Mode::ROUND_SHR_SATURATE_IXxI8xI16: return "ROUND_SHR_SATURATE_IXxI8xI16";
  735. case ElemwiseMultiType::Mode::QADD: return "QADD";
  736. case ElemwiseMultiType::Mode::QFUSE_ADD_RELU: return "QFUSE_ADD_RELU";
  737. case ElemwiseMultiType::Mode::QMUL: return "QMUL";
  738. case ElemwiseMultiType::Mode::QMIN: return "QMIN";
  739. case ElemwiseMultiType::Mode::QMAX: return "QMAX";
  740. case ElemwiseMultiType::Mode::QSUB: return "QSUB";
  741. case ElemwiseMultiType::Mode::QTRUE_DIV: return "QTRUE_DIV";
  742. case ElemwiseMultiType::Mode::QFUSE_ADD_SIGMOID: return "QFUSE_ADD_SIGMOID";
  743. case ElemwiseMultiType::Mode::QFUSE_ADD_TANH: return "QFUSE_ADD_TANH";
  744. case ElemwiseMultiType::Mode::QRELU: return "QRELU";
  745. case ElemwiseMultiType::Mode::QABS: return "QABS";
  746. case ElemwiseMultiType::Mode::QSIGMOID: return "QSIGMOID";
  747. case ElemwiseMultiType::Mode::QEXP: return "QEXP";
  748. case ElemwiseMultiType::Mode::QTANH: return "QTANH";
  749. case ElemwiseMultiType::Mode::QFUSE_MUL_ADD3: return "QFUSE_MUL_ADD3";
  750. case ElemwiseMultiType::Mode::QFAST_TANH: return "QFAST_TANH";
  751. case ElemwiseMultiType::Mode::QNEGATE: return "QNEGATE";
  752. case ElemwiseMultiType::Mode::QACOS: return "QACOS";
  753. case ElemwiseMultiType::Mode::QASIN: return "QASIN";
  754. case ElemwiseMultiType::Mode::QCEIL: return "QCEIL";
  755. case ElemwiseMultiType::Mode::QCOS: return "QCOS";
  756. case ElemwiseMultiType::Mode::QEXPM1: return "QEXPM1";
  757. case ElemwiseMultiType::Mode::QFLOOR: return "QFLOOR";
  758. case ElemwiseMultiType::Mode::QLOG: return "QLOG";
  759. case ElemwiseMultiType::Mode::QLOG1P: return "QLOG1P";
  760. case ElemwiseMultiType::Mode::QSIN: return "QSIN";
  761. case ElemwiseMultiType::Mode::QROUND: return "QROUND";
  762. case ElemwiseMultiType::Mode::QERF: return "QERF";
  763. case ElemwiseMultiType::Mode::QERFINV: return "QERFINV";
  764. case ElemwiseMultiType::Mode::QERFC: return "QERFC";
  765. case ElemwiseMultiType::Mode::QERFCINV: return "QERFCINV";
  766. case ElemwiseMultiType::Mode::QABS_GRAD: return "QABS_GRAD";
  767. case ElemwiseMultiType::Mode::QFLOOR_DIV: return "QFLOOR_DIV";
  768. case ElemwiseMultiType::Mode::QMOD: return "QMOD";
  769. case ElemwiseMultiType::Mode::QSIGMOID_GRAD: return "QSIGMOID_GRAD";
  770. case ElemwiseMultiType::Mode::QSWITCH_GT0: return "QSWITCH_GT0";
  771. case ElemwiseMultiType::Mode::QTANH_GRAD: return "QTANH_GRAD";
  772. case ElemwiseMultiType::Mode::QLT: return "QLT";
  773. case ElemwiseMultiType::Mode::QLEQ: return "QLEQ";
  774. case ElemwiseMultiType::Mode::QEQ: return "QEQ";
  775. case ElemwiseMultiType::Mode::QPOW: return "QPOW";
  776. case ElemwiseMultiType::Mode::QLOG_SUM_EXP: return "QLOG_SUM_EXP";
  777. case ElemwiseMultiType::Mode::QFAST_TANH_GRAD: return "QFAST_TANH_GRAD";
  778. case ElemwiseMultiType::Mode::QATAN2: return "QATAN2";
  779. case ElemwiseMultiType::Mode::QCOND_LEQ_MOV: return "QCOND_LEQ_MOV";
  780. case ElemwiseMultiType::Mode::QH_SWISH: return "QH_SWISH";
  781. case ElemwiseMultiType::Mode::QFUSE_ADD_H_SWISH: return "QFUSE_ADD_H_SWISH";
  782. case ElemwiseMultiType::Mode::QH_SWISH_GRAD: return "QH_SWISH_GRAD";
  783. case ElemwiseMultiType::Mode::FUSE_MUL_ADD3_INT16xF32xF32xF32: return "FUSE_MUL_ADD3_INT16xF32xF32xF32";
  784. case ElemwiseMultiType::Mode::MUL_INT16xF32xF32: return "MUL_INT16xF32xF32";
  785. case ElemwiseMultiType::Mode::FUSE_MUL_ADD3_UINT8xF32xF32xF32: return "FUSE_MUL_ADD3_UINT8xF32xF32xF32";
  786. case ElemwiseMultiType::Mode::QCOND_LT_MOV: return "QCOND_LT_MOV";
  787. case ElemwiseMultiType::Mode::EQ: return "EQ";
  788. case ElemwiseMultiType::Mode::NEQ: return "NEQ";
  789. case ElemwiseMultiType::Mode::LT: return "LT";
  790. case ElemwiseMultiType::Mode::LEQ: return "LEQ";
  791. case ElemwiseMultiType::Mode::ISNAN: return "ISNAN";
  792. case ElemwiseMultiType::Mode::ISINF: return "ISINF";
  793. default:
  794. return "ElemwiseMultiType::Mode::Unknown";
  795. }
  796. }
  797. };
  798. class ExternOpr : public OpDefImplBase<ExternOpr> {
  799. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  800. public:
  801. std::vector<std::vector<size_t>> output_shapes;
  802. std::string name;
  803. std::string data;
  804. size_t data_len;
  805. std::vector<::megdnn::DType> output_dtypes;
  806. ExternOpr() = default;
  807. ExternOpr(std::vector<std::vector<size_t>> output_shapes_, std::string name_, std::string data_, size_t data_len_, std::vector<::megdnn::DType> output_dtypes_, std::string scope_ = {}): output_shapes(output_shapes_), name(name_), data(data_), data_len(data_len_), output_dtypes(output_dtypes_) { set_scope(scope_); }
  808. };
  809. class Eye : public OpDefImplBase<Eye> {
  810. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  811. public:
  812. int32_t k = 0;
  813. ::megdnn::DType dtype = megdnn::DType::from_enum(megdnn::DTypeEnum::Float32);
  814. ::mgb::CompNode comp_node;
  815. Eye() = default;
  816. Eye(int32_t k_, ::megdnn::DType dtype_, ::mgb::CompNode comp_node_, std::string scope_ = {}): k(k_), dtype(dtype_), comp_node(comp_node_) { set_scope(scope_); }
  817. };
  818. class FakeQuant : public OpDefImplBase<FakeQuant> {
  819. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  820. public:
  821. int32_t qmin = -2147483648;
  822. int32_t qmax = 2147483647;
  823. FakeQuant() = default;
  824. FakeQuant(int32_t qmin_, int32_t qmax_, std::string scope_ = {}): qmin(qmin_), qmax(qmax_) { set_scope(scope_); }
  825. FakeQuant(::megdnn::param::FakeQuant packed_param_0): qmin(packed_param_0.qmin), qmax(packed_param_0.qmax) {}
  826. ::megdnn::param::FakeQuant param() const {
  827. return {qmin, qmax};
  828. }
  829. };
  830. class FastpathCopy : public OpDefImplBase<FastpathCopy> {
  831. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  832. public:
  833. FastpathCopy() = default;
  834. };
  835. class GammaRNG : public OpDefImplBase<GammaRNG> {
  836. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  837. public:
  838. uint64_t seed = 0;
  839. size_t handle;
  840. GammaRNG() = default;
  841. GammaRNG(uint64_t seed_, size_t handle_, std::string scope_ = {}): seed(seed_), handle(handle_) { set_scope(scope_); }
  842. GammaRNG(::megdnn::param::GammaRNG packed_param_0, size_t handle_): seed(packed_param_0.seed), handle(handle_) {}
  843. ::megdnn::param::GammaRNG param() const {
  844. return {seed};
  845. }
  846. };
  847. class GaussianRNG : public OpDefImplBase<GaussianRNG> {
  848. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  849. public:
  850. uint64_t seed = 0;
  851. float mean = 0;
  852. float std = 1;
  853. ::megdnn::DType dtype = megdnn::DType::from_enum(megdnn::DTypeEnum::Float32);
  854. size_t handle;
  855. GaussianRNG() = default;
  856. GaussianRNG(uint64_t seed_, float mean_, float std_, ::megdnn::DType dtype_, size_t handle_, std::string scope_ = {}): seed(seed_), mean(mean_), std(std_), dtype(dtype_), handle(handle_) { set_scope(scope_); }
  857. };
  858. class GetVarShape : public OpDefImplBase<GetVarShape> {
  859. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  860. public:
  861. int32_t axis = ::megdnn::param::OptionalAxisV1::INVALID_AXIS;
  862. GetVarShape() = default;
  863. GetVarShape(int32_t axis_, std::string scope_ = {}): axis(axis_) { set_scope(scope_); }
  864. GetVarShape(::megdnn::param::OptionalAxisV1 packed_param_0): axis(packed_param_0.axis) {}
  865. ::megdnn::param::OptionalAxisV1 param() const {
  866. return {axis};
  867. }
  868. };
  869. class GroupLocal : public OpDefImplBase<GroupLocal> {
  870. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  871. public:
  872. using Mode = ::megdnn::param::Convolution::Mode;
  873. using Sparse = ::megdnn::param::Convolution::Sparse;
  874. using Format = ::megdnn::param::Convolution::Format;
  875. using ComputeMode = ::megdnn::param::Convolution::ComputeMode;
  876. Mode mode = ::megdnn::param::Convolution::Mode::CROSS_CORRELATION;
  877. uint32_t pad_h = 0;
  878. uint32_t pad_w = 0;
  879. uint32_t stride_h = 1;
  880. uint32_t stride_w = 1;
  881. uint32_t dilate_h = 1;
  882. uint32_t dilate_w = 1;
  883. Sparse sparse = ::megdnn::param::Convolution::Sparse::DENSE;
  884. Format format = ::megdnn::param::Convolution::Format::NCHW;
  885. ComputeMode compute_mode = ::megdnn::param::Convolution::ComputeMode::DEFAULT;
  886. GroupLocal() = default;
  887. GroupLocal(Mode mode_, uint32_t pad_h_, uint32_t pad_w_, uint32_t stride_h_, uint32_t stride_w_, uint32_t dilate_h_, uint32_t dilate_w_, Sparse sparse_, Format format_, ComputeMode compute_mode_, std::string scope_ = {}): mode(mode_), pad_h(pad_h_), pad_w(pad_w_), stride_h(stride_h_), stride_w(stride_w_), dilate_h(dilate_h_), dilate_w(dilate_w_), sparse(sparse_), format(format_), compute_mode(compute_mode_) { set_scope(scope_); }
  888. GroupLocal(::megdnn::param::Convolution packed_param_0): mode(packed_param_0.mode), pad_h(packed_param_0.pad_h), pad_w(packed_param_0.pad_w), stride_h(packed_param_0.stride_h), stride_w(packed_param_0.stride_w), dilate_h(packed_param_0.dilate_h), dilate_w(packed_param_0.dilate_w), sparse(packed_param_0.sparse), format(packed_param_0.format), compute_mode(packed_param_0.compute_mode) {}
  889. ::megdnn::param::Convolution param() const {
  890. return {mode, pad_h, pad_w, stride_h, stride_w, dilate_h, dilate_w, sparse, format, compute_mode};
  891. }
  892. };
  893. class GroupNorm : public OpDefImplBase<GroupNorm> {
  894. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  895. public:
  896. using Format = ::megdnn::param::GroupNorm::Format;
  897. bool affine = true;
  898. float eps = 1e-5f;
  899. uint32_t group = 1;
  900. Format format = ::megdnn::param::GroupNorm::Format::NCHW;
  901. GroupNorm() = default;
  902. GroupNorm(bool affine_, float eps_, uint32_t group_, Format format_, std::string scope_ = {}): affine(affine_), eps(eps_), group(group_), format(format_) { set_scope(scope_); }
  903. GroupNorm(::megdnn::param::GroupNorm packed_param_0): affine(packed_param_0.affine), eps(packed_param_0.eps), group(packed_param_0.group), format(packed_param_0.format) {}
  904. ::megdnn::param::GroupNorm param() const {
  905. return {affine, eps, group, format};
  906. }
  907. };
  908. class Identity : public OpDefImplBase<Identity> {
  909. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  910. public:
  911. Identity() = default;
  912. };
  913. class Images2Neibs : public OpDefImplBase<Images2Neibs> {
  914. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  915. public:
  916. uint32_t pad_h = 0;
  917. uint32_t pad_w = 0;
  918. uint32_t stride_h = 1;
  919. uint32_t stride_w = 1;
  920. uint32_t dilate_h = 1;
  921. uint32_t dilate_w = 1;
  922. uint32_t window_h = 3;
  923. uint32_t window_w = 3;
  924. Images2Neibs() = default;
  925. Images2Neibs(uint32_t pad_h_, uint32_t pad_w_, uint32_t stride_h_, uint32_t stride_w_, uint32_t dilate_h_, uint32_t dilate_w_, uint32_t window_h_, uint32_t window_w_, std::string scope_ = {}): pad_h(pad_h_), pad_w(pad_w_), stride_h(stride_h_), stride_w(stride_w_), dilate_h(dilate_h_), dilate_w(dilate_w_), window_h(window_h_), window_w(window_w_) { set_scope(scope_); }
  926. Images2Neibs(::megdnn::param::Images2Neibs packed_param_0): pad_h(packed_param_0.pad_h), pad_w(packed_param_0.pad_w), stride_h(packed_param_0.stride_h), stride_w(packed_param_0.stride_w), dilate_h(packed_param_0.dilate_h), dilate_w(packed_param_0.dilate_w), window_h(packed_param_0.window_h), window_w(packed_param_0.window_w) {}
  927. ::megdnn::param::Images2Neibs param() const {
  928. return {pad_h, pad_w, stride_h, stride_w, dilate_h, dilate_w, window_h, window_w};
  929. }
  930. };
  931. class IncrMeshIndexing : public OpDefImplBase<IncrMeshIndexing> {
  932. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  933. public:
  934. std::vector<std::tuple<int8_t, bool, bool, bool, bool>> items;
  935. IncrMeshIndexing() = default;
  936. IncrMeshIndexing(std::vector<std::tuple<int8_t, bool, bool, bool, bool>> items_, std::string scope_ = {}): items(items_) { set_scope(scope_); }
  937. };
  938. class IncrSubtensor : public OpDefImplBase<IncrSubtensor> {
  939. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  940. public:
  941. std::vector<std::tuple<int8_t, bool, bool, bool, bool>> items;
  942. IncrSubtensor() = default;
  943. IncrSubtensor(std::vector<std::tuple<int8_t, bool, bool, bool, bool>> items_, std::string scope_ = {}): items(items_) { set_scope(scope_); }
  944. };
  945. class IndexingIncrMultiAxisVec : public OpDefImplBase<IndexingIncrMultiAxisVec> {
  946. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  947. public:
  948. std::vector<std::tuple<int8_t, bool, bool, bool, bool>> items;
  949. IndexingIncrMultiAxisVec() = default;
  950. IndexingIncrMultiAxisVec(std::vector<std::tuple<int8_t, bool, bool, bool, bool>> items_, std::string scope_ = {}): items(items_) { set_scope(scope_); }
  951. };
  952. class IndexingMultiAxisVec : public OpDefImplBase<IndexingMultiAxisVec> {
  953. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  954. public:
  955. std::vector<std::tuple<int8_t, bool, bool, bool, bool>> items;
  956. IndexingMultiAxisVec() = default;
  957. IndexingMultiAxisVec(std::vector<std::tuple<int8_t, bool, bool, bool, bool>> items_, std::string scope_ = {}): items(items_) { set_scope(scope_); }
  958. };
  959. class IndexingOneHot : public OpDefImplBase<IndexingOneHot> {
  960. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  961. public:
  962. int32_t axis = 0;
  963. int32_t ndim;
  964. IndexingOneHot() = default;
  965. IndexingOneHot(int32_t axis_, int32_t ndim_, std::string scope_ = {}): axis(axis_), ndim(ndim_) { set_scope(scope_); }
  966. IndexingOneHot(::megdnn::param::Axis packed_param_0, int32_t ndim_): axis(packed_param_0.axis), ndim(ndim_) {}
  967. ::megdnn::param::Axis param() const {
  968. return {axis};
  969. }
  970. };
  971. class IndexingSetMultiAxisVec : public OpDefImplBase<IndexingSetMultiAxisVec> {
  972. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  973. public:
  974. std::vector<std::tuple<int8_t, bool, bool, bool, bool>> items;
  975. IndexingSetMultiAxisVec() = default;
  976. IndexingSetMultiAxisVec(std::vector<std::tuple<int8_t, bool, bool, bool, bool>> items_, std::string scope_ = {}): items(items_) { set_scope(scope_); }
  977. };
  978. class IndexingSetOneHot : public OpDefImplBase<IndexingSetOneHot> {
  979. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  980. public:
  981. int32_t axis = 0;
  982. int32_t ndim;
  983. IndexingSetOneHot() = default;
  984. IndexingSetOneHot(int32_t axis_, int32_t ndim_, std::string scope_ = {}): axis(axis_), ndim(ndim_) { set_scope(scope_); }
  985. IndexingSetOneHot(::megdnn::param::Axis packed_param_0, int32_t ndim_): axis(packed_param_0.axis), ndim(ndim_) {}
  986. ::megdnn::param::Axis param() const {
  987. return {axis};
  988. }
  989. };
  990. class InplaceAdd : public OpDefImplBase<InplaceAdd> {
  991. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  992. public:
  993. InplaceAdd() = default;
  994. InplaceAdd(::megdnn::param::Empty) {}
  995. ::megdnn::param::Empty param() const {
  996. return {};
  997. }
  998. };
  999. class LAMBUpdate : public OpDefImplBase<LAMBUpdate> {
  1000. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  1001. public:
  1002. float beta_1 = 1.f;
  1003. float beta_2 = 1.f;
  1004. float step = 1.f;
  1005. float lr = 1.f;
  1006. float weight_decay = 1.f;
  1007. float eps = 1.f;
  1008. bool bias_correction = true;
  1009. bool always_adapt = false;
  1010. LAMBUpdate() = default;
  1011. LAMBUpdate(float beta_1_, float beta_2_, float step_, float lr_, float weight_decay_, float eps_, bool bias_correction_, bool always_adapt_, std::string scope_ = {}): beta_1(beta_1_), beta_2(beta_2_), step(step_), lr(lr_), weight_decay(weight_decay_), eps(eps_), bias_correction(bias_correction_), always_adapt(always_adapt_) { set_scope(scope_); }
  1012. LAMBUpdate(::megdnn::param::LAMBUpdate packed_param_0): beta_1(packed_param_0.beta_1), beta_2(packed_param_0.beta_2), step(packed_param_0.step), lr(packed_param_0.lr), weight_decay(packed_param_0.weight_decay), eps(packed_param_0.eps), bias_correction(packed_param_0.bias_correction), always_adapt(packed_param_0.always_adapt) {}
  1013. ::megdnn::param::LAMBUpdate param() const {
  1014. return {beta_1, beta_2, step, lr, weight_decay, eps, bias_correction, always_adapt};
  1015. }
  1016. };
  1017. class LRN : public OpDefImplBase<LRN> {
  1018. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  1019. public:
  1020. uint32_t n = 5;
  1021. float k = 2.f;
  1022. float alpha = 1e-4f;
  1023. float beta = 0.75f;
  1024. LRN() = default;
  1025. LRN(uint32_t n_, float k_, float alpha_, float beta_, std::string scope_ = {}): n(n_), k(k_), alpha(alpha_), beta(beta_) { set_scope(scope_); }
  1026. LRN(::megdnn::param::LRN packed_param_0): n(packed_param_0.n), k(packed_param_0.k), alpha(packed_param_0.alpha), beta(packed_param_0.beta) {}
  1027. ::megdnn::param::LRN param() const {
  1028. return {n, k, alpha, beta};
  1029. }
  1030. };
  1031. class LSQ : public OpDefImplBase<LSQ> {
  1032. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  1033. public:
  1034. int32_t qmin = -2147483648;
  1035. int32_t qmax = 2147483647;
  1036. LSQ() = default;
  1037. LSQ(int32_t qmin_, int32_t qmax_, std::string scope_ = {}): qmin(qmin_), qmax(qmax_) { set_scope(scope_); }
  1038. LSQ(::megdnn::param::LSQ packed_param_0): qmin(packed_param_0.qmin), qmax(packed_param_0.qmax) {}
  1039. ::megdnn::param::LSQ param() const {
  1040. return {qmin, qmax};
  1041. }
  1042. };
  1043. class LSTM : public OpDefImplBase<LSTM> {
  1044. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  1045. public:
  1046. using FwdMode = ::megdnn::param::LSTM::FwdMode;
  1047. uint32_t num_layers = 1;
  1048. bool bidirectional = false;
  1049. bool bias = true;
  1050. uint32_t hidden_size = 128;
  1051. uint32_t proj_size = 0;
  1052. float dropout = 0.f;
  1053. FwdMode fwd_mode = ::megdnn::param::LSTM::FwdMode::TRAINING;
  1054. LSTM() = default;
  1055. LSTM(uint32_t num_layers_, bool bidirectional_, bool bias_, uint32_t hidden_size_, uint32_t proj_size_, float dropout_, FwdMode fwd_mode_, std::string scope_ = {}): num_layers(num_layers_), bidirectional(bidirectional_), bias(bias_), hidden_size(hidden_size_), proj_size(proj_size_), dropout(dropout_), fwd_mode(fwd_mode_) { set_scope(scope_); }
  1056. LSTM(::megdnn::param::LSTM packed_param_0): num_layers(packed_param_0.num_layers), bidirectional(packed_param_0.bidirectional), bias(packed_param_0.bias), hidden_size(packed_param_0.hidden_size), proj_size(packed_param_0.proj_size), dropout(packed_param_0.dropout), fwd_mode(packed_param_0.fwd_mode) {}
  1057. ::megdnn::param::LSTM param() const {
  1058. return {num_layers, bidirectional, bias, hidden_size, proj_size, dropout, fwd_mode};
  1059. }
  1060. };
  1061. class LSTMCell : public OpDefImplBase<LSTMCell> {
  1062. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  1063. public:
  1064. LSTMCell() = default;
  1065. LSTMCell(::megdnn::param::Empty) {}
  1066. ::megdnn::param::Empty param() const {
  1067. return {};
  1068. }
  1069. };
  1070. class LayerNorm : public OpDefImplBase<LayerNorm> {
  1071. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  1072. public:
  1073. bool affine = true;
  1074. float eps = 1e-5f;
  1075. uint64_t normalized_dim = 1;
  1076. uint64_t normalized_size = 1;
  1077. LayerNorm() = default;
  1078. LayerNorm(bool affine_, float eps_, uint64_t normalized_dim_, uint64_t normalized_size_, std::string scope_ = {}): affine(affine_), eps(eps_), normalized_dim(normalized_dim_), normalized_size(normalized_size_) { set_scope(scope_); }
  1079. LayerNorm(::megdnn::param::LayerNorm packed_param_0): affine(packed_param_0.affine), eps(packed_param_0.eps), normalized_dim(packed_param_0.normalized_dim), normalized_size(packed_param_0.normalized_size) {}
  1080. ::megdnn::param::LayerNorm param() const {
  1081. return {affine, eps, normalized_dim, normalized_size};
  1082. }
  1083. };
  1084. class Linspace : public OpDefImplBase<Linspace> {
  1085. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  1086. public:
  1087. bool endpoint = true;
  1088. ::mgb::CompNode comp_node;
  1089. Linspace() = default;
  1090. Linspace(bool endpoint_, ::mgb::CompNode comp_node_, std::string scope_ = {}): endpoint(endpoint_), comp_node(comp_node_) { set_scope(scope_); }
  1091. Linspace(::megdnn::param::Linspace packed_param_0, ::mgb::CompNode comp_node_): endpoint(packed_param_0.endpoint), comp_node(comp_node_) {}
  1092. ::megdnn::param::Linspace param() const {
  1093. return {endpoint};
  1094. }
  1095. };
  1096. class MagicMindRuntime : public OpDefImplBase<MagicMindRuntime> {
  1097. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  1098. public:
  1099. std::string buf;
  1100. size_t buf_size;
  1101. MagicMindRuntime() = default;
  1102. MagicMindRuntime(std::string buf_, size_t buf_size_, std::string scope_ = {}): buf(buf_), buf_size(buf_size_) { set_scope(scope_); }
  1103. };
  1104. class MatrixInverse : public OpDefImplBase<MatrixInverse> {
  1105. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  1106. public:
  1107. MatrixInverse() = default;
  1108. MatrixInverse(::megdnn::param::Empty) {}
  1109. ::megdnn::param::Empty param() const {
  1110. return {};
  1111. }
  1112. };
  1113. class MatrixMul : public OpDefImplBase<MatrixMul> {
  1114. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  1115. public:
  1116. using ComputeMode = ::megdnn::param::MatrixMul::ComputeMode;
  1117. using Format = ::megdnn::param::MatrixMul::Format;
  1118. using Strategy = ::megdnn::param::ExecutionPolicy::Strategy;
  1119. bool transposeA = false;
  1120. bool transposeB = false;
  1121. ComputeMode compute_mode = ::megdnn::param::MatrixMul::ComputeMode::DEFAULT;
  1122. Format format = ::megdnn::param::MatrixMul::Format::DEFAULT;
  1123. Strategy strategy = static_cast<::megdnn::param::ExecutionPolicy::Strategy>(1);
  1124. uint64_t workspace_limit = 18446744073709551615ull;
  1125. uint32_t dimA;
  1126. uint32_t dimB;
  1127. MatrixMul() = default;
  1128. MatrixMul(bool transposeA_, bool transposeB_, ComputeMode compute_mode_, Format format_, Strategy strategy_, uint64_t workspace_limit_, uint32_t dimA_, uint32_t dimB_, std::string scope_ = {}): transposeA(transposeA_), transposeB(transposeB_), compute_mode(compute_mode_), format(format_), strategy(strategy_), workspace_limit(workspace_limit_), dimA(dimA_), dimB(dimB_) {
  1129. set_scope(scope_);
  1130. mgb_assert(static_cast<uint32_t>(strategy) <= uint32_t(8));
  1131. }
  1132. MatrixMul(::megdnn::param::MatrixMul packed_param_0, ::megdnn::param::ExecutionPolicy packed_param_1, uint32_t dimA_, uint32_t dimB_): transposeA(packed_param_0.transposeA), transposeB(packed_param_0.transposeB), compute_mode(packed_param_0.compute_mode), format(packed_param_0.format), strategy(packed_param_1.strategy), workspace_limit(packed_param_1.workspace_limit), dimA(dimA_), dimB(dimB_) {
  1133. mgb_assert(static_cast<uint32_t>(strategy) <= uint32_t(8));
  1134. }
  1135. ::megdnn::param::MatrixMul param() const {
  1136. return {transposeA, transposeB, compute_mode, format};
  1137. }
  1138. ::megdnn::param::ExecutionPolicy policy() const {
  1139. return {strategy, workspace_limit};
  1140. }
  1141. };
  1142. class MeshGrid : public OpDefImplBase<MeshGrid> {
  1143. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  1144. public:
  1145. std::string indexing;
  1146. MeshGrid() = default;
  1147. MeshGrid(std::string indexing_, std::string scope_ = {}): indexing(indexing_) { set_scope(scope_); }
  1148. };
  1149. class MeshIndexing : public OpDefImplBase<MeshIndexing> {
  1150. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  1151. public:
  1152. std::vector<std::tuple<int8_t, bool, bool, bool, bool>> items;
  1153. MeshIndexing() = default;
  1154. MeshIndexing(std::vector<std::tuple<int8_t, bool, bool, bool, bool>> items_, std::string scope_ = {}): items(items_) { set_scope(scope_); }
  1155. };
  1156. class NMSKeep : public OpDefImplBase<NMSKeep> {
  1157. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  1158. public:
  1159. float iou_thresh;
  1160. uint32_t max_output;
  1161. NMSKeep() = default;
  1162. NMSKeep(float iou_thresh_, uint32_t max_output_, std::string scope_ = {}): iou_thresh(iou_thresh_), max_output(max_output_) { set_scope(scope_); }
  1163. };
  1164. class NvOf : public OpDefImplBase<NvOf> {
  1165. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  1166. public:
  1167. uint32_t precision = 1;
  1168. NvOf() = default;
  1169. NvOf(uint32_t precision_, std::string scope_ = {}): precision(precision_) { set_scope(scope_); }
  1170. NvOf(::megdnn::param::NvOf packed_param_0): precision(packed_param_0.precision) {}
  1171. ::megdnn::param::NvOf param() const {
  1172. return {precision};
  1173. }
  1174. };
  1175. class Padding : public OpDefImplBase<Padding> {
  1176. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  1177. public:
  1178. using PaddingMode = ::megdnn::param::Padding::PaddingMode;
  1179. uint32_t front_offset_dim0 = 0;
  1180. uint32_t front_offset_dim1 = 0;
  1181. uint32_t front_offset_dim2 = 0;
  1182. uint32_t front_offset_dim3 = 0;
  1183. uint32_t front_offset_dim4 = 0;
  1184. uint32_t front_offset_dim5 = 0;
  1185. uint32_t front_offset_dim6 = 0;
  1186. uint32_t back_offset_dim0 = 0;
  1187. uint32_t back_offset_dim1 = 0;
  1188. uint32_t back_offset_dim2 = 0;
  1189. uint32_t back_offset_dim3 = 0;
  1190. uint32_t back_offset_dim4 = 0;
  1191. uint32_t back_offset_dim5 = 0;
  1192. uint32_t back_offset_dim6 = 0;
  1193. float padding_val = 0;
  1194. PaddingMode padding_mode = ::megdnn::param::Padding::PaddingMode::CONSTANT;
  1195. Padding() = default;
  1196. Padding(uint32_t front_offset_dim0_, uint32_t front_offset_dim1_, uint32_t front_offset_dim2_, uint32_t front_offset_dim3_, uint32_t front_offset_dim4_, uint32_t front_offset_dim5_, uint32_t front_offset_dim6_, uint32_t back_offset_dim0_, uint32_t back_offset_dim1_, uint32_t back_offset_dim2_, uint32_t back_offset_dim3_, uint32_t back_offset_dim4_, uint32_t back_offset_dim5_, uint32_t back_offset_dim6_, float padding_val_, PaddingMode padding_mode_, std::string scope_ = {}): front_offset_dim0(front_offset_dim0_), front_offset_dim1(front_offset_dim1_), front_offset_dim2(front_offset_dim2_), front_offset_dim3(front_offset_dim3_), front_offset_dim4(front_offset_dim4_), front_offset_dim5(front_offset_dim5_), front_offset_dim6(front_offset_dim6_), back_offset_dim0(back_offset_dim0_), back_offset_dim1(back_offset_dim1_), back_offset_dim2(back_offset_dim2_), back_offset_dim3(back_offset_dim3_), back_offset_dim4(back_offset_dim4_), back_offset_dim5(back_offset_dim5_), back_offset_dim6(back_offset_dim6_), padding_val(padding_val_), padding_mode(padding_mode_) { set_scope(scope_); }
  1197. Padding(::megdnn::param::Padding packed_param_0): front_offset_dim0(packed_param_0.front_offset_dim0), front_offset_dim1(packed_param_0.front_offset_dim1), front_offset_dim2(packed_param_0.front_offset_dim2), front_offset_dim3(packed_param_0.front_offset_dim3), front_offset_dim4(packed_param_0.front_offset_dim4), front_offset_dim5(packed_param_0.front_offset_dim5), front_offset_dim6(packed_param_0.front_offset_dim6), back_offset_dim0(packed_param_0.back_offset_dim0), back_offset_dim1(packed_param_0.back_offset_dim1), back_offset_dim2(packed_param_0.back_offset_dim2), back_offset_dim3(packed_param_0.back_offset_dim3), back_offset_dim4(packed_param_0.back_offset_dim4), back_offset_dim5(packed_param_0.back_offset_dim5), back_offset_dim6(packed_param_0.back_offset_dim6), padding_val(packed_param_0.padding_val), padding_mode(packed_param_0.padding_mode) {}
  1198. ::megdnn::param::Padding param() const {
  1199. return {front_offset_dim0, front_offset_dim1, front_offset_dim2, front_offset_dim3, front_offset_dim4, front_offset_dim5, front_offset_dim6, back_offset_dim0, back_offset_dim1, back_offset_dim2, back_offset_dim3, back_offset_dim4, back_offset_dim5, back_offset_dim6, padding_val, padding_mode};
  1200. }
  1201. };
  1202. class ParamPackConcat : public OpDefImplBase<ParamPackConcat> {
  1203. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  1204. public:
  1205. std::vector<int32_t> offsets;
  1206. ParamPackConcat() = default;
  1207. ParamPackConcat(std::vector<int32_t> offsets_, std::string scope_ = {}): offsets(offsets_) { set_scope(scope_); }
  1208. };
  1209. class ParamPackSplit : public OpDefImplBase<ParamPackSplit> {
  1210. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  1211. public:
  1212. std::vector<int32_t> offsets;
  1213. std::vector<std::vector<size_t>> shapes;
  1214. ParamPackSplit() = default;
  1215. ParamPackSplit(std::vector<int32_t> offsets_, std::vector<std::vector<size_t>> shapes_, std::string scope_ = {}): offsets(offsets_), shapes(shapes_) { set_scope(scope_); }
  1216. };
  1217. class PermutationRNG : public OpDefImplBase<PermutationRNG> {
  1218. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  1219. public:
  1220. uint64_t seed = 0;
  1221. ::megdnn::DType dtype = megdnn::DType::from_enum(megdnn::DTypeEnum::Int32);
  1222. size_t handle;
  1223. PermutationRNG() = default;
  1224. PermutationRNG(uint64_t seed_, ::megdnn::DType dtype_, size_t handle_, std::string scope_ = {}): seed(seed_), dtype(dtype_), handle(handle_) { set_scope(scope_); }
  1225. };
  1226. class PixelShuffle : public OpDefImplBase<PixelShuffle> {
  1227. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  1228. public:
  1229. int32_t factor;
  1230. PixelShuffle() = default;
  1231. PixelShuffle(int32_t factor_, std::string scope_ = {}): factor(factor_) { set_scope(scope_); }
  1232. };
  1233. class PixelShuffleBackward : public OpDefImplBase<PixelShuffleBackward> {
  1234. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  1235. public:
  1236. int32_t factor;
  1237. PixelShuffleBackward() = default;
  1238. PixelShuffleBackward(int32_t factor_, std::string scope_ = {}): factor(factor_) { set_scope(scope_); }
  1239. };
  1240. class PoissonRNG : public OpDefImplBase<PoissonRNG> {
  1241. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  1242. public:
  1243. uint64_t seed = 0;
  1244. size_t handle;
  1245. PoissonRNG() = default;
  1246. PoissonRNG(uint64_t seed_, size_t handle_, std::string scope_ = {}): seed(seed_), handle(handle_) { set_scope(scope_); }
  1247. PoissonRNG(::megdnn::param::PoissonRNG packed_param_0, size_t handle_): seed(packed_param_0.seed), handle(handle_) {}
  1248. ::megdnn::param::PoissonRNG param() const {
  1249. return {seed};
  1250. }
  1251. };
  1252. class Pooling : public OpDefImplBase<Pooling> {
  1253. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  1254. public:
  1255. using Mode = ::megdnn::param::Pooling::Mode;
  1256. using Format = ::megdnn::param::Pooling::Format;
  1257. using Strategy = ::megdnn::param::ExecutionPolicy::Strategy;
  1258. Mode mode = ::megdnn::param::Pooling::Mode::MAX;
  1259. uint32_t pad_h = 0;
  1260. uint32_t pad_w = 0;
  1261. uint32_t stride_h = 2;
  1262. uint32_t stride_w = 2;
  1263. uint32_t window_h = 2;
  1264. uint32_t window_w = 2;
  1265. Format format = ::megdnn::param::Pooling::Format::NCHW;
  1266. Strategy strategy = static_cast<::megdnn::param::ExecutionPolicy::Strategy>(1);
  1267. uint64_t workspace_limit = 18446744073709551615ull;
  1268. Pooling() = default;
  1269. Pooling(Mode mode_, uint32_t pad_h_, uint32_t pad_w_, uint32_t stride_h_, uint32_t stride_w_, uint32_t window_h_, uint32_t window_w_, Format format_, Strategy strategy_, uint64_t workspace_limit_, std::string scope_ = {}): mode(mode_), pad_h(pad_h_), pad_w(pad_w_), stride_h(stride_h_), stride_w(stride_w_), window_h(window_h_), window_w(window_w_), format(format_), strategy(strategy_), workspace_limit(workspace_limit_) {
  1270. set_scope(scope_);
  1271. mgb_assert(static_cast<uint32_t>(strategy) <= uint32_t(8));
  1272. }
  1273. Pooling(::megdnn::param::Pooling packed_param_0, ::megdnn::param::ExecutionPolicy packed_param_1): mode(packed_param_0.mode), pad_h(packed_param_0.pad_h), pad_w(packed_param_0.pad_w), stride_h(packed_param_0.stride_h), stride_w(packed_param_0.stride_w), window_h(packed_param_0.window_h), window_w(packed_param_0.window_w), format(packed_param_0.format), strategy(packed_param_1.strategy), workspace_limit(packed_param_1.workspace_limit) {
  1274. mgb_assert(static_cast<uint32_t>(strategy) <= uint32_t(8));
  1275. }
  1276. ::megdnn::param::Pooling param() const {
  1277. return {mode, pad_h, pad_w, stride_h, stride_w, window_h, window_w, format};
  1278. }
  1279. ::megdnn::param::ExecutionPolicy policy() const {
  1280. return {strategy, workspace_limit};
  1281. }
  1282. };
  1283. class RNN : public OpDefImplBase<RNN> {
  1284. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  1285. public:
  1286. using NonlineMode = ::megdnn::param::RNN::NonlineMode;
  1287. using FwdMode = ::megdnn::param::RNN::FwdMode;
  1288. uint32_t num_layers = 1;
  1289. bool bidirectional = false;
  1290. bool bias = true;
  1291. uint32_t hidden_size = 128;
  1292. float dropout = 0.f;
  1293. NonlineMode nonlineMode = ::megdnn::param::RNN::NonlineMode::IDENTITY;
  1294. FwdMode fwd_mode = ::megdnn::param::RNN::FwdMode::TRAINING;
  1295. RNN() = default;
  1296. RNN(uint32_t num_layers_, bool bidirectional_, bool bias_, uint32_t hidden_size_, float dropout_, NonlineMode nonlineMode_, FwdMode fwd_mode_, std::string scope_ = {}): num_layers(num_layers_), bidirectional(bidirectional_), bias(bias_), hidden_size(hidden_size_), dropout(dropout_), nonlineMode(nonlineMode_), fwd_mode(fwd_mode_) { set_scope(scope_); }
  1297. RNN(::megdnn::param::RNN packed_param_0): num_layers(packed_param_0.num_layers), bidirectional(packed_param_0.bidirectional), bias(packed_param_0.bias), hidden_size(packed_param_0.hidden_size), dropout(packed_param_0.dropout), nonlineMode(packed_param_0.nonlineMode), fwd_mode(packed_param_0.fwd_mode) {}
  1298. ::megdnn::param::RNN param() const {
  1299. return {num_layers, bidirectional, bias, hidden_size, dropout, nonlineMode, fwd_mode};
  1300. }
  1301. };
  1302. class RNNCell : public OpDefImplBase<RNNCell> {
  1303. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  1304. public:
  1305. using NonlineMode = ::megdnn::param::RNNCell::NonlineMode;
  1306. NonlineMode nonlineMode = ::megdnn::param::RNNCell::NonlineMode::IDENTITY;
  1307. RNNCell() = default;
  1308. RNNCell(NonlineMode nonlineMode_, std::string scope_ = {}): nonlineMode(nonlineMode_) { set_scope(scope_); }
  1309. RNNCell(::megdnn::param::RNNCell packed_param_0): nonlineMode(packed_param_0.nonlineMode) {}
  1310. ::megdnn::param::RNNCell param() const {
  1311. return {nonlineMode};
  1312. }
  1313. };
  1314. class ROIAlign : public OpDefImplBase<ROIAlign> {
  1315. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  1316. public:
  1317. using Mode = ::megdnn::param::ROIAlign::Mode;
  1318. using Format = ::megdnn::param::ROIAlign::Format;
  1319. Mode mode = ::megdnn::param::ROIAlign::Mode::MAX;
  1320. Format format = ::megdnn::param::ROIAlign::Format::NCHW;
  1321. float spatial_scale = 1.0;
  1322. float offset = 0.0;
  1323. uint32_t pooled_height = 1;
  1324. uint32_t pooled_width = 1;
  1325. uint32_t sample_height = 2;
  1326. uint32_t sample_width = 2;
  1327. ROIAlign() = default;
  1328. ROIAlign(Mode mode_, Format format_, float spatial_scale_, float offset_, uint32_t pooled_height_, uint32_t pooled_width_, uint32_t sample_height_, uint32_t sample_width_, std::string scope_ = {}): mode(mode_), format(format_), spatial_scale(spatial_scale_), offset(offset_), pooled_height(pooled_height_), pooled_width(pooled_width_), sample_height(sample_height_), sample_width(sample_width_) { set_scope(scope_); }
  1329. ROIAlign(::megdnn::param::ROIAlign packed_param_0): mode(packed_param_0.mode), format(packed_param_0.format), spatial_scale(packed_param_0.spatial_scale), offset(packed_param_0.offset), pooled_height(packed_param_0.pooled_height), pooled_width(packed_param_0.pooled_width), sample_height(packed_param_0.sample_height), sample_width(packed_param_0.sample_width) {}
  1330. ::megdnn::param::ROIAlign param() const {
  1331. return {mode, format, spatial_scale, offset, pooled_height, pooled_width, sample_height, sample_width};
  1332. }
  1333. };
  1334. class ROIPooling : public OpDefImplBase<ROIPooling> {
  1335. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  1336. public:
  1337. using Mode = ::megdnn::param::ROIPooling::Mode;
  1338. Mode mode = ::megdnn::param::ROIPooling::Mode::MAX;
  1339. float scale = 1.f;
  1340. ROIPooling() = default;
  1341. ROIPooling(Mode mode_, float scale_, std::string scope_ = {}): mode(mode_), scale(scale_) { set_scope(scope_); }
  1342. ROIPooling(::megdnn::param::ROIPooling packed_param_0): mode(packed_param_0.mode), scale(packed_param_0.scale) {}
  1343. ::megdnn::param::ROIPooling param() const {
  1344. return {mode, scale};
  1345. }
  1346. };
  1347. class Reduce : public OpDefImplBase<Reduce> {
  1348. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  1349. public:
  1350. using Mode = ::megdnn::param::Reduce::Mode;
  1351. using DataType = ::megdnn::param::Reduce::DataType;
  1352. Mode mode = ::megdnn::param::Reduce::Mode::SUM;
  1353. int32_t axis = 2147483647;
  1354. DataType data_type = ::megdnn::param::Reduce::DataType::DEFAULT;
  1355. bool keepdim = true;
  1356. Reduce() = default;
  1357. Reduce(Mode mode_, int32_t axis_, DataType data_type_, bool keepdim_, std::string scope_ = {}): mode(mode_), axis(axis_), data_type(data_type_), keepdim(keepdim_) { set_scope(scope_); }
  1358. Reduce(::megdnn::param::Reduce packed_param_0, bool keepdim_): mode(packed_param_0.mode), axis(packed_param_0.axis), data_type(packed_param_0.data_type), keepdim(keepdim_) {}
  1359. ::megdnn::param::Reduce param() const {
  1360. return {mode, axis, data_type};
  1361. }
  1362. };
  1363. class RegionRestrictedConvolution : public OpDefImplBase<RegionRestrictedConvolution> {
  1364. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  1365. public:
  1366. using Mode = ::megdnn::param::Convolution::Mode;
  1367. using Sparse = ::megdnn::param::Convolution::Sparse;
  1368. using Format = ::megdnn::param::Convolution::Format;
  1369. using ComputeMode = ::megdnn::param::Convolution::ComputeMode;
  1370. Mode mode = ::megdnn::param::Convolution::Mode::CROSS_CORRELATION;
  1371. uint32_t pad_h = 0;
  1372. uint32_t pad_w = 0;
  1373. uint32_t stride_h = 1;
  1374. uint32_t stride_w = 1;
  1375. uint32_t dilate_h = 1;
  1376. uint32_t dilate_w = 1;
  1377. Sparse sparse = ::megdnn::param::Convolution::Sparse::DENSE;
  1378. Format format = ::megdnn::param::Convolution::Format::NCHW;
  1379. ComputeMode compute_mode = ::megdnn::param::Convolution::ComputeMode::DEFAULT;
  1380. RegionRestrictedConvolution() = default;
  1381. RegionRestrictedConvolution(Mode mode_, uint32_t pad_h_, uint32_t pad_w_, uint32_t stride_h_, uint32_t stride_w_, uint32_t dilate_h_, uint32_t dilate_w_, Sparse sparse_, Format format_, ComputeMode compute_mode_, std::string scope_ = {}): mode(mode_), pad_h(pad_h_), pad_w(pad_w_), stride_h(stride_h_), stride_w(stride_w_), dilate_h(dilate_h_), dilate_w(dilate_w_), sparse(sparse_), format(format_), compute_mode(compute_mode_) { set_scope(scope_); }
  1382. RegionRestrictedConvolution(::megdnn::param::Convolution packed_param_0): mode(packed_param_0.mode), pad_h(packed_param_0.pad_h), pad_w(packed_param_0.pad_w), stride_h(packed_param_0.stride_h), stride_w(packed_param_0.stride_w), dilate_h(packed_param_0.dilate_h), dilate_w(packed_param_0.dilate_w), sparse(packed_param_0.sparse), format(packed_param_0.format), compute_mode(packed_param_0.compute_mode) {}
  1383. ::megdnn::param::Convolution param() const {
  1384. return {mode, pad_h, pad_w, stride_h, stride_w, dilate_h, dilate_w, sparse, format, compute_mode};
  1385. }
  1386. };
  1387. class RegionRestrictedConvolutionBackwardData : public OpDefImplBase<RegionRestrictedConvolutionBackwardData> {
  1388. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  1389. public:
  1390. using Mode = ::megdnn::param::Convolution::Mode;
  1391. using Sparse = ::megdnn::param::Convolution::Sparse;
  1392. using Format = ::megdnn::param::Convolution::Format;
  1393. using ComputeMode = ::megdnn::param::Convolution::ComputeMode;
  1394. Mode mode = ::megdnn::param::Convolution::Mode::CROSS_CORRELATION;
  1395. uint32_t pad_h = 0;
  1396. uint32_t pad_w = 0;
  1397. uint32_t stride_h = 1;
  1398. uint32_t stride_w = 1;
  1399. uint32_t dilate_h = 1;
  1400. uint32_t dilate_w = 1;
  1401. Sparse sparse = ::megdnn::param::Convolution::Sparse::DENSE;
  1402. Format format = ::megdnn::param::Convolution::Format::NCHW;
  1403. ComputeMode compute_mode = ::megdnn::param::Convolution::ComputeMode::DEFAULT;
  1404. RegionRestrictedConvolutionBackwardData() = default;
  1405. RegionRestrictedConvolutionBackwardData(Mode mode_, uint32_t pad_h_, uint32_t pad_w_, uint32_t stride_h_, uint32_t stride_w_, uint32_t dilate_h_, uint32_t dilate_w_, Sparse sparse_, Format format_, ComputeMode compute_mode_, std::string scope_ = {}): mode(mode_), pad_h(pad_h_), pad_w(pad_w_), stride_h(stride_h_), stride_w(stride_w_), dilate_h(dilate_h_), dilate_w(dilate_w_), sparse(sparse_), format(format_), compute_mode(compute_mode_) { set_scope(scope_); }
  1406. RegionRestrictedConvolutionBackwardData(::megdnn::param::Convolution packed_param_0): mode(packed_param_0.mode), pad_h(packed_param_0.pad_h), pad_w(packed_param_0.pad_w), stride_h(packed_param_0.stride_h), stride_w(packed_param_0.stride_w), dilate_h(packed_param_0.dilate_h), dilate_w(packed_param_0.dilate_w), sparse(packed_param_0.sparse), format(packed_param_0.format), compute_mode(packed_param_0.compute_mode) {}
  1407. ::megdnn::param::Convolution param() const {
  1408. return {mode, pad_h, pad_w, stride_h, stride_w, dilate_h, dilate_w, sparse, format, compute_mode};
  1409. }
  1410. };
  1411. class Remap : public OpDefImplBase<Remap> {
  1412. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  1413. public:
  1414. using InterpolationMode = ::megdnn::param::Remap::InterpolationMode;
  1415. using BorderMode = ::megdnn::param::Remap::BorderMode;
  1416. using Format = ::megdnn::param::Remap::Format;
  1417. InterpolationMode imode = ::megdnn::param::Remap::InterpolationMode::LINEAR;
  1418. BorderMode border_type = ::megdnn::param::Remap::BorderMode::REPLICATE;
  1419. Format format = ::megdnn::param::Remap::Format::NHWC;
  1420. float scalar = 0.f;
  1421. Remap() = default;
  1422. Remap(InterpolationMode imode_, BorderMode border_type_, Format format_, float scalar_, std::string scope_ = {}): imode(imode_), border_type(border_type_), format(format_), scalar(scalar_) { set_scope(scope_); }
  1423. Remap(::megdnn::param::Remap packed_param_0): imode(packed_param_0.imode), border_type(packed_param_0.border_type), format(packed_param_0.format), scalar(packed_param_0.scalar) {}
  1424. ::megdnn::param::Remap param() const {
  1425. return {imode, border_type, format, scalar};
  1426. }
  1427. };
  1428. class RemoteRecv : public OpDefImplBase<RemoteRecv> {
  1429. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  1430. public:
  1431. std::string key;
  1432. std::string addr;
  1433. uint32_t port;
  1434. uint32_t rank_from;
  1435. ::mgb::CompNode cn;
  1436. std::vector<int32_t> shape;
  1437. ::megdnn::DType dtype;
  1438. std::string backend;
  1439. RemoteRecv() = default;
  1440. RemoteRecv(std::string key_, std::string addr_, uint32_t port_, uint32_t rank_from_, ::mgb::CompNode cn_, std::vector<int32_t> shape_, ::megdnn::DType dtype_, std::string backend_, std::string scope_ = {}): key(key_), addr(addr_), port(port_), rank_from(rank_from_), cn(cn_), shape(shape_), dtype(dtype_), backend(backend_) { set_scope(scope_); }
  1441. };
  1442. class RemoteSend : public OpDefImplBase<RemoteSend> {
  1443. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  1444. public:
  1445. std::string key;
  1446. std::string addr;
  1447. uint32_t port;
  1448. uint32_t rank_to;
  1449. std::string backend;
  1450. RemoteSend() = default;
  1451. RemoteSend(std::string key_, std::string addr_, uint32_t port_, uint32_t rank_to_, std::string backend_, std::string scope_ = {}): key(key_), addr(addr_), port(port_), rank_to(rank_to_), backend(backend_) { set_scope(scope_); }
  1452. };
  1453. class RemoveAxis : public OpDefImplBase<RemoveAxis> {
  1454. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  1455. public:
  1456. std::vector<int32_t> axis;
  1457. RemoveAxis() = default;
  1458. RemoveAxis(std::vector<int32_t> axis_, std::string scope_ = {}): axis(axis_) { set_scope(scope_); }
  1459. };
  1460. class Reshape : public OpDefImplBase<Reshape> {
  1461. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  1462. public:
  1463. int32_t axis = ::megdnn::param::OptionalAxisV1::INVALID_AXIS;
  1464. std::vector<int32_t> shape;
  1465. Reshape() = default;
  1466. Reshape(int32_t axis_, std::vector<int32_t> shape_, std::string scope_ = {}): axis(axis_), shape(shape_) { set_scope(scope_); }
  1467. Reshape(::megdnn::param::OptionalAxisV1 packed_param_0, std::vector<int32_t> shape_): axis(packed_param_0.axis), shape(shape_) {}
  1468. ::megdnn::param::OptionalAxisV1 param() const {
  1469. return {axis};
  1470. }
  1471. };
  1472. class Resize : public OpDefImplBase<Resize> {
  1473. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  1474. public:
  1475. using InterpolationMode = ::megdnn::param::Resize::InterpolationMode;
  1476. using Format = ::megdnn::param::Resize::Format;
  1477. InterpolationMode imode = ::megdnn::param::Resize::InterpolationMode::LINEAR;
  1478. Format format = ::megdnn::param::Resize::Format::NHWC;
  1479. Resize() = default;
  1480. Resize(InterpolationMode imode_, Format format_, std::string scope_ = {}): imode(imode_), format(format_) { set_scope(scope_); }
  1481. Resize(::megdnn::param::Resize packed_param_0): imode(packed_param_0.imode), format(packed_param_0.format) {}
  1482. ::megdnn::param::Resize param() const {
  1483. return {imode, format};
  1484. }
  1485. };
  1486. class SVD : public OpDefImplBase<SVD> {
  1487. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  1488. public:
  1489. bool full_matrices = false;
  1490. bool compute_uv = true;
  1491. SVD() = default;
  1492. SVD(bool full_matrices_, bool compute_uv_, std::string scope_ = {}): full_matrices(full_matrices_), compute_uv(compute_uv_) { set_scope(scope_); }
  1493. SVD(::megdnn::param::SVD packed_param_0): full_matrices(packed_param_0.full_matrices), compute_uv(packed_param_0.compute_uv) {}
  1494. ::megdnn::param::SVD param() const {
  1495. return {full_matrices, compute_uv};
  1496. }
  1497. };
  1498. class SetMeshIndexing : public OpDefImplBase<SetMeshIndexing> {
  1499. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  1500. public:
  1501. std::vector<std::tuple<int8_t, bool, bool, bool, bool>> items;
  1502. SetMeshIndexing() = default;
  1503. SetMeshIndexing(std::vector<std::tuple<int8_t, bool, bool, bool, bool>> items_, std::string scope_ = {}): items(items_) { set_scope(scope_); }
  1504. };
  1505. class SetSubtensor : public OpDefImplBase<SetSubtensor> {
  1506. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  1507. public:
  1508. std::vector<std::tuple<int8_t, bool, bool, bool, bool>> items;
  1509. SetSubtensor() = default;
  1510. SetSubtensor(std::vector<std::tuple<int8_t, bool, bool, bool, bool>> items_, std::string scope_ = {}): items(items_) { set_scope(scope_); }
  1511. };
  1512. class ShuffleRNG : public OpDefImplBase<ShuffleRNG> {
  1513. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  1514. public:
  1515. uint64_t seed = 0;
  1516. size_t handle;
  1517. ShuffleRNG() = default;
  1518. ShuffleRNG(uint64_t seed_, size_t handle_, std::string scope_ = {}): seed(seed_), handle(handle_) { set_scope(scope_); }
  1519. ShuffleRNG(::megdnn::param::ShuffleRNG packed_param_0, size_t handle_): seed(packed_param_0.seed), handle(handle_) {}
  1520. ::megdnn::param::ShuffleRNG param() const {
  1521. return {seed};
  1522. }
  1523. };
  1524. class SlidingWindowTranspose : public OpDefImplBase<SlidingWindowTranspose> {
  1525. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  1526. public:
  1527. uint32_t out_h = 0;
  1528. uint32_t out_w = 0;
  1529. uint32_t pad_h = 0;
  1530. uint32_t pad_w = 0;
  1531. uint32_t stride_h = 1;
  1532. uint32_t stride_w = 1;
  1533. uint32_t dilate_h = 1;
  1534. uint32_t dilate_w = 1;
  1535. uint32_t window_h = 3;
  1536. uint32_t window_w = 3;
  1537. SlidingWindowTranspose() = default;
  1538. SlidingWindowTranspose(uint32_t out_h_, uint32_t out_w_, uint32_t pad_h_, uint32_t pad_w_, uint32_t stride_h_, uint32_t stride_w_, uint32_t dilate_h_, uint32_t dilate_w_, uint32_t window_h_, uint32_t window_w_, std::string scope_ = {}): out_h(out_h_), out_w(out_w_), pad_h(pad_h_), pad_w(pad_w_), stride_h(stride_h_), stride_w(stride_w_), dilate_h(dilate_h_), dilate_w(dilate_w_), window_h(window_h_), window_w(window_w_) { set_scope(scope_); }
  1539. SlidingWindowTranspose(::megdnn::param::SlidingWindowTranspose packed_param_0): out_h(packed_param_0.out_h), out_w(packed_param_0.out_w), pad_h(packed_param_0.pad_h), pad_w(packed_param_0.pad_w), stride_h(packed_param_0.stride_h), stride_w(packed_param_0.stride_w), dilate_h(packed_param_0.dilate_h), dilate_w(packed_param_0.dilate_w), window_h(packed_param_0.window_h), window_w(packed_param_0.window_w) {}
  1540. ::megdnn::param::SlidingWindowTranspose param() const {
  1541. return {out_h, out_w, pad_h, pad_w, stride_h, stride_w, dilate_h, dilate_w, window_h, window_w};
  1542. }
  1543. };
  1544. class Softmax : public OpDefImplBase<Softmax> {
  1545. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  1546. public:
  1547. int32_t axis = -1;
  1548. Softmax() = default;
  1549. Softmax(int32_t axis_, std::string scope_ = {}): axis(axis_) { set_scope(scope_); }
  1550. Softmax(::megdnn::param::Softmax packed_param_0): axis(packed_param_0.axis) {}
  1551. ::megdnn::param::Softmax param() const {
  1552. return {axis};
  1553. }
  1554. };
  1555. class Split : public OpDefImplBase<Split> {
  1556. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  1557. public:
  1558. int32_t axis;
  1559. int32_t nsections;
  1560. Split() = default;
  1561. Split(int32_t axis_, int32_t nsections_, std::string scope_ = {}): axis(axis_), nsections(nsections_) { set_scope(scope_); }
  1562. Split(::megdnn::param::Empty, int32_t axis_, int32_t nsections_): axis(axis_), nsections(nsections_) {}
  1563. ::megdnn::param::Empty param() const {
  1564. return {};
  1565. }
  1566. };
  1567. class Subtensor : public OpDefImplBase<Subtensor> {
  1568. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  1569. public:
  1570. std::vector<std::tuple<int8_t, bool, bool, bool, bool>> items;
  1571. Subtensor() = default;
  1572. Subtensor(std::vector<std::tuple<int8_t, bool, bool, bool, bool>> items_, std::string scope_ = {}): items(items_) { set_scope(scope_); }
  1573. };
  1574. class TQT : public OpDefImplBase<TQT> {
  1575. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  1576. public:
  1577. int32_t qmin = -2147483648;
  1578. int32_t qmax = 2147483647;
  1579. TQT() = default;
  1580. TQT(int32_t qmin_, int32_t qmax_, std::string scope_ = {}): qmin(qmin_), qmax(qmax_) { set_scope(scope_); }
  1581. TQT(::megdnn::param::TQT packed_param_0): qmin(packed_param_0.qmin), qmax(packed_param_0.qmax) {}
  1582. ::megdnn::param::TQT param() const {
  1583. return {qmin, qmax};
  1584. }
  1585. };
  1586. class TensorRTRuntime : public OpDefImplBase<TensorRTRuntime> {
  1587. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  1588. public:
  1589. std::string buf;
  1590. size_t buf_size;
  1591. TensorRTRuntime() = default;
  1592. TensorRTRuntime(std::string buf_, size_t buf_size_, std::string scope_ = {}): buf(buf_), buf_size(buf_size_) { set_scope(scope_); }
  1593. };
  1594. class TopK : public OpDefImplBase<TopK> {
  1595. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  1596. public:
  1597. using Mode = ::megdnn::param::TopK::Mode;
  1598. Mode mode = ::megdnn::param::TopK::Mode::KTH_ONLY;
  1599. TopK() = default;
  1600. TopK(Mode mode_, std::string scope_ = {}): mode(mode_) { set_scope(scope_); }
  1601. TopK(::megdnn::param::TopK packed_param_0): mode(packed_param_0.mode) {}
  1602. ::megdnn::param::TopK param() const {
  1603. return {mode};
  1604. }
  1605. };
  1606. class TypeCvt : public OpDefImplBase<TypeCvt> {
  1607. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  1608. public:
  1609. ::megdnn::DType dtype;
  1610. TypeCvt() = default;
  1611. TypeCvt(::megdnn::DType dtype_, std::string scope_ = {}): dtype(dtype_) { set_scope(scope_); }
  1612. };
  1613. class UniformRNG : public OpDefImplBase<UniformRNG> {
  1614. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  1615. public:
  1616. uint64_t seed = 0;
  1617. ::megdnn::DType dtype = megdnn::DType::from_enum(megdnn::DTypeEnum::Float32);
  1618. size_t handle;
  1619. UniformRNG() = default;
  1620. UniformRNG(uint64_t seed_, ::megdnn::DType dtype_, size_t handle_, std::string scope_ = {}): seed(seed_), dtype(dtype_), handle(handle_) { set_scope(scope_); }
  1621. };
  1622. class WarpAffine : public OpDefImplBase<WarpAffine> {
  1623. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  1624. public:
  1625. using InterpolationMode = ::megdnn::param::WarpAffine::InterpolationMode;
  1626. using BorderMode = ::megdnn::param::WarpAffine::BorderMode;
  1627. using Format = ::megdnn::param::WarpAffine::Format;
  1628. InterpolationMode imode = ::megdnn::param::WarpAffine::InterpolationMode::LINEAR;
  1629. BorderMode border_mode = ::megdnn::param::WarpAffine::BorderMode::REPLICATE;
  1630. float border_val = .0f;
  1631. Format format = ::megdnn::param::WarpAffine::Format::NHWC;
  1632. WarpAffine() = default;
  1633. WarpAffine(InterpolationMode imode_, BorderMode border_mode_, float border_val_, Format format_, std::string scope_ = {}): imode(imode_), border_mode(border_mode_), border_val(border_val_), format(format_) { set_scope(scope_); }
  1634. WarpAffine(::megdnn::param::WarpAffine packed_param_0): imode(packed_param_0.imode), border_mode(packed_param_0.border_mode), border_val(packed_param_0.border_val), format(packed_param_0.format) {}
  1635. ::megdnn::param::WarpAffine param() const {
  1636. return {imode, border_mode, border_val, format};
  1637. }
  1638. };
  1639. class WarpPerspective : public OpDefImplBase<WarpPerspective> {
  1640. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  1641. public:
  1642. using InterpolationMode = ::megdnn::param::WarpPerspective::InterpolationMode;
  1643. using BorderMode = ::megdnn::param::WarpPerspective::BorderMode;
  1644. using Format = ::megdnn::param::WarpPerspective::Format;
  1645. InterpolationMode imode = ::megdnn::param::WarpPerspective::InterpolationMode::LINEAR;
  1646. BorderMode bmode = ::megdnn::param::WarpPerspective::BorderMode::REPLICATE;
  1647. Format format = ::megdnn::param::WarpPerspective::Format::NCHW;
  1648. float border_val = .0f;
  1649. WarpPerspective() = default;
  1650. WarpPerspective(InterpolationMode imode_, BorderMode bmode_, Format format_, float border_val_, std::string scope_ = {}): imode(imode_), bmode(bmode_), format(format_), border_val(border_val_) { set_scope(scope_); }
  1651. WarpPerspective(::megdnn::param::WarpPerspective packed_param_0): imode(packed_param_0.imode), bmode(packed_param_0.bmode), format(packed_param_0.format), border_val(packed_param_0.border_val) {}
  1652. ::megdnn::param::WarpPerspective param() const {
  1653. return {imode, bmode, format, border_val};
  1654. }
  1655. };
  1656. class WarpPerspectiveBackwardData : public OpDefImplBase<WarpPerspectiveBackwardData> {
  1657. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  1658. public:
  1659. using InterpolationMode = ::megdnn::param::WarpPerspective::InterpolationMode;
  1660. using BorderMode = ::megdnn::param::WarpPerspective::BorderMode;
  1661. using Format = ::megdnn::param::WarpPerspective::Format;
  1662. InterpolationMode imode = ::megdnn::param::WarpPerspective::InterpolationMode::LINEAR;
  1663. BorderMode bmode = ::megdnn::param::WarpPerspective::BorderMode::REPLICATE;
  1664. Format format = ::megdnn::param::WarpPerspective::Format::NCHW;
  1665. float border_val = .0f;
  1666. WarpPerspectiveBackwardData() = default;
  1667. WarpPerspectiveBackwardData(InterpolationMode imode_, BorderMode bmode_, Format format_, float border_val_, std::string scope_ = {}): imode(imode_), bmode(bmode_), format(format_), border_val(border_val_) { set_scope(scope_); }
  1668. WarpPerspectiveBackwardData(::megdnn::param::WarpPerspective packed_param_0): imode(packed_param_0.imode), bmode(packed_param_0.bmode), format(packed_param_0.format), border_val(packed_param_0.border_val) {}
  1669. ::megdnn::param::WarpPerspective param() const {
  1670. return {imode, bmode, format, border_val};
  1671. }
  1672. };
  1673. class WarpPerspectiveBackwardMat : public OpDefImplBase<WarpPerspectiveBackwardMat> {
  1674. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  1675. public:
  1676. using InterpolationMode = ::megdnn::param::WarpPerspective::InterpolationMode;
  1677. using BorderMode = ::megdnn::param::WarpPerspective::BorderMode;
  1678. using Format = ::megdnn::param::WarpPerspective::Format;
  1679. InterpolationMode imode = ::megdnn::param::WarpPerspective::InterpolationMode::LINEAR;
  1680. BorderMode bmode = ::megdnn::param::WarpPerspective::BorderMode::REPLICATE;
  1681. Format format = ::megdnn::param::WarpPerspective::Format::NCHW;
  1682. float border_val = .0f;
  1683. WarpPerspectiveBackwardMat() = default;
  1684. WarpPerspectiveBackwardMat(InterpolationMode imode_, BorderMode bmode_, Format format_, float border_val_, std::string scope_ = {}): imode(imode_), bmode(bmode_), format(format_), border_val(border_val_) { set_scope(scope_); }
  1685. WarpPerspectiveBackwardMat(::megdnn::param::WarpPerspective packed_param_0): imode(packed_param_0.imode), bmode(packed_param_0.bmode), format(packed_param_0.format), border_val(packed_param_0.border_val) {}
  1686. ::megdnn::param::WarpPerspective param() const {
  1687. return {imode, bmode, format, border_val};
  1688. }
  1689. };
  1690. // clang-format on