You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

opdef.h.inl 90 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836
  1. // clang-format off
  2. class AdaptivePooling : public OpDefImplBase<AdaptivePooling> {
  3. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  4. public:
  5. using Mode = ::megdnn::param::AdaptivePooling::Mode;
  6. using Format = ::megdnn::param::AdaptivePooling::Format;
  7. Mode mode = ::megdnn::param::AdaptivePooling::Mode::MAX;
  8. Format format = ::megdnn::param::AdaptivePooling::Format::NCHW;
  9. std::vector<int32_t> shape;
  10. AdaptivePooling() = default;
  11. AdaptivePooling(Mode mode_, Format format_, std::vector<int32_t> shape_, std::string scope_ = {}): mode(mode_), format(format_), shape(shape_) { set_scope(scope_); }
  12. AdaptivePooling(::megdnn::param::AdaptivePooling packed_param_0, std::vector<int32_t> shape_): mode(packed_param_0.mode), format(packed_param_0.format), shape(shape_) {}
  13. ::megdnn::param::AdaptivePooling param() const {
  14. return {mode, format};
  15. }
  16. };
  17. class AddAxis : public OpDefImplBase<AddAxis> {
  18. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  19. public:
  20. std::vector<int32_t> axis;
  21. AddAxis() = default;
  22. AddAxis(std::vector<int32_t> axis_, std::string scope_ = {}): axis(axis_) { set_scope(scope_); }
  23. };
  24. class Argmax : public OpDefImplBase<Argmax> {
  25. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  26. public:
  27. int32_t axis = 0;
  28. Argmax() = default;
  29. Argmax(int32_t axis_, std::string scope_ = {}): axis(axis_) { set_scope(scope_); }
  30. Argmax(::megdnn::param::Axis packed_param_0): axis(packed_param_0.axis) {}
  31. ::megdnn::param::Axis param() const {
  32. return {axis};
  33. }
  34. };
  35. class Argmin : public OpDefImplBase<Argmin> {
  36. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  37. public:
  38. int32_t axis = 0;
  39. Argmin() = default;
  40. Argmin(int32_t axis_, std::string scope_ = {}): axis(axis_) { set_scope(scope_); }
  41. Argmin(::megdnn::param::Axis packed_param_0): axis(packed_param_0.axis) {}
  42. ::megdnn::param::Axis param() const {
  43. return {axis};
  44. }
  45. };
  46. class Argsort : public OpDefImplBase<Argsort> {
  47. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  48. public:
  49. using Order = ::megdnn::param::Argsort::Order;
  50. Order order = ::megdnn::param::Argsort::Order::ASCENDING;
  51. Argsort() = default;
  52. Argsort(Order order_, std::string scope_ = {}): order(order_) { set_scope(scope_); }
  53. Argsort(::megdnn::param::Argsort packed_param_0): order(packed_param_0.order) {}
  54. ::megdnn::param::Argsort param() const {
  55. return {order};
  56. }
  57. };
  58. class AssertEqual : public OpDefImplBase<AssertEqual> {
  59. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  60. public:
  61. float maxerr = 0.0001;
  62. bool verbose = false;
  63. AssertEqual() = default;
  64. AssertEqual(float maxerr_, bool verbose_, std::string scope_ = {}): maxerr(maxerr_), verbose(verbose_) { set_scope(scope_); }
  65. AssertEqual(::megdnn::param::AssertEqual packed_param_0): maxerr(packed_param_0.maxerr), verbose(packed_param_0.verbose) {}
  66. ::megdnn::param::AssertEqual param() const {
  67. return {maxerr, verbose};
  68. }
  69. };
  70. class AtlasRuntime : public OpDefImplBase<AtlasRuntime> {
  71. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  72. public:
  73. std::string buf;
  74. size_t buf_size;
  75. AtlasRuntime() = default;
  76. AtlasRuntime(std::string buf_, size_t buf_size_, std::string scope_ = {}): buf(buf_), buf_size(buf_size_) { set_scope(scope_); }
  77. };
  78. class Barrier : public OpDefImplBase<Barrier> {
  79. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  80. public:
  81. ::mgb::CompNode comp_node;
  82. uint32_t nr_outputs;
  83. Barrier() = default;
  84. Barrier(::mgb::CompNode comp_node_, uint32_t nr_outputs_, std::string scope_ = {}): comp_node(comp_node_), nr_outputs(nr_outputs_) { set_scope(scope_); }
  85. };
  86. class BatchConvBias : public OpDefImplBase<BatchConvBias> {
  87. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  88. public:
  89. using NonlineMode = ::megdnn::param::BatchConvBias::NonlineMode;
  90. using Mode = ::megdnn::param::BatchConvBias::Mode;
  91. using Sparse = ::megdnn::param::BatchConvBias::Sparse;
  92. using Format = ::megdnn::param::BatchConvBias::Format;
  93. using ComputeMode = ::megdnn::param::BatchConvBias::ComputeMode;
  94. using Strategy = ::megdnn::param::ExecutionPolicy::Strategy;
  95. NonlineMode nonlineMode = ::megdnn::param::BatchConvBias::NonlineMode::IDENTITY;
  96. Mode mode = ::megdnn::param::BatchConvBias::Mode::CROSS_CORRELATION;
  97. uint32_t pad_h = 0;
  98. uint32_t pad_w = 0;
  99. uint32_t stride_h = 1;
  100. uint32_t stride_w = 1;
  101. uint32_t dilate_h = 1;
  102. uint32_t dilate_w = 1;
  103. Sparse sparse = ::megdnn::param::BatchConvBias::Sparse::DENSE;
  104. Format format = ::megdnn::param::BatchConvBias::Format::NCHW;
  105. ComputeMode compute_mode = ::megdnn::param::BatchConvBias::ComputeMode::DEFAULT;
  106. Strategy strategy = static_cast<::megdnn::param::ExecutionPolicy::Strategy>(1);
  107. uint64_t workspace_limit = 18446744073709551615ull;
  108. ::megdnn::DType dtype;
  109. BatchConvBias() = default;
  110. BatchConvBias(NonlineMode nonlineMode_, Mode mode_, uint32_t pad_h_, uint32_t pad_w_, uint32_t stride_h_, uint32_t stride_w_, uint32_t dilate_h_, uint32_t dilate_w_, Sparse sparse_, Format format_, ComputeMode compute_mode_, Strategy strategy_, uint64_t workspace_limit_, ::megdnn::DType dtype_, std::string scope_ = {}): nonlineMode(nonlineMode_), mode(mode_), pad_h(pad_h_), pad_w(pad_w_), stride_h(stride_h_), stride_w(stride_w_), dilate_h(dilate_h_), dilate_w(dilate_w_), sparse(sparse_), format(format_), compute_mode(compute_mode_), strategy(strategy_), workspace_limit(workspace_limit_), dtype(dtype_) {
  111. set_scope(scope_);
  112. mgb_assert(static_cast<uint32_t>(strategy) <= uint32_t(8));
  113. }
  114. BatchConvBias(::megdnn::param::BatchConvBias packed_param_0, ::megdnn::param::ExecutionPolicy packed_param_1, ::megdnn::DType dtype_): nonlineMode(packed_param_0.nonlineMode), mode(packed_param_0.mode), pad_h(packed_param_0.pad_h), pad_w(packed_param_0.pad_w), stride_h(packed_param_0.stride_h), stride_w(packed_param_0.stride_w), dilate_h(packed_param_0.dilate_h), dilate_w(packed_param_0.dilate_w), sparse(packed_param_0.sparse), format(packed_param_0.format), compute_mode(packed_param_0.compute_mode), strategy(packed_param_1.strategy), workspace_limit(packed_param_1.workspace_limit), dtype(dtype_) {
  115. mgb_assert(static_cast<uint32_t>(strategy) <= uint32_t(8));
  116. }
  117. ::megdnn::param::BatchConvBias param() const {
  118. return {nonlineMode, mode, pad_h, pad_w, stride_h, stride_w, dilate_h, dilate_w, sparse, format, compute_mode};
  119. }
  120. ::megdnn::param::ExecutionPolicy policy() const {
  121. return {strategy, workspace_limit};
  122. }
  123. };
  124. class BatchNorm : public OpDefImplBase<BatchNorm> {
  125. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  126. public:
  127. using ParamDim = ::megdnn::param::BN::ParamDim;
  128. using FwdMode = ::megdnn::param::BN::FwdMode;
  129. ParamDim param_dim = ::megdnn::param::BN::ParamDim::DIM_11HW;
  130. FwdMode fwd_mode = ::megdnn::param::BN::FwdMode::TRAINING;
  131. double epsilon = 1e-4f;
  132. double avg_factor = 1.f;
  133. float scale = 1.f;
  134. float bias = 0.f;
  135. BatchNorm() = default;
  136. BatchNorm(ParamDim param_dim_, FwdMode fwd_mode_, double epsilon_, double avg_factor_, float scale_, float bias_, std::string scope_ = {}): param_dim(param_dim_), fwd_mode(fwd_mode_), epsilon(epsilon_), avg_factor(avg_factor_), scale(scale_), bias(bias_) { set_scope(scope_); }
  137. BatchNorm(::megdnn::param::BN packed_param_0): param_dim(packed_param_0.param_dim), fwd_mode(packed_param_0.fwd_mode), epsilon(packed_param_0.epsilon), avg_factor(packed_param_0.avg_factor), scale(packed_param_0.scale), bias(packed_param_0.bias) {}
  138. ::megdnn::param::BN param() const {
  139. return {param_dim, fwd_mode, epsilon, avg_factor, scale, bias};
  140. }
  141. };
  142. class BatchNormBackward : public OpDefImplBase<BatchNormBackward> {
  143. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  144. public:
  145. using ParamDim = ::megdnn::param::BN::ParamDim;
  146. using FwdMode = ::megdnn::param::BN::FwdMode;
  147. ParamDim param_dim = ::megdnn::param::BN::ParamDim::DIM_11HW;
  148. FwdMode fwd_mode = ::megdnn::param::BN::FwdMode::TRAINING;
  149. double epsilon = 1e-4f;
  150. double avg_factor = 1.f;
  151. float scale = 1.f;
  152. float bias = 0.f;
  153. BatchNormBackward() = default;
  154. BatchNormBackward(ParamDim param_dim_, FwdMode fwd_mode_, double epsilon_, double avg_factor_, float scale_, float bias_, std::string scope_ = {}): param_dim(param_dim_), fwd_mode(fwd_mode_), epsilon(epsilon_), avg_factor(avg_factor_), scale(scale_), bias(bias_) { set_scope(scope_); }
  155. BatchNormBackward(::megdnn::param::BN packed_param_0): param_dim(packed_param_0.param_dim), fwd_mode(packed_param_0.fwd_mode), epsilon(packed_param_0.epsilon), avg_factor(packed_param_0.avg_factor), scale(packed_param_0.scale), bias(packed_param_0.bias) {}
  156. ::megdnn::param::BN param() const {
  157. return {param_dim, fwd_mode, epsilon, avg_factor, scale, bias};
  158. }
  159. };
  160. class BatchedIncrMeshIndexing : public OpDefImplBase<BatchedIncrMeshIndexing> {
  161. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  162. public:
  163. std::vector<std::tuple<int8_t, bool, bool, bool, bool>> items;
  164. BatchedIncrMeshIndexing() = default;
  165. BatchedIncrMeshIndexing(std::vector<std::tuple<int8_t, bool, bool, bool, bool>> items_, std::string scope_ = {}): items(items_) { set_scope(scope_); }
  166. };
  167. class BatchedMatrixMul : public OpDefImplBase<BatchedMatrixMul> {
  168. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  169. public:
  170. using ComputeMode = ::megdnn::param::MatrixMul::ComputeMode;
  171. using Format = ::megdnn::param::MatrixMul::Format;
  172. using Strategy = ::megdnn::param::ExecutionPolicy::Strategy;
  173. bool transposeA = false;
  174. bool transposeB = false;
  175. ComputeMode compute_mode = ::megdnn::param::MatrixMul::ComputeMode::DEFAULT;
  176. Format format = ::megdnn::param::MatrixMul::Format::DEFAULT;
  177. Strategy strategy = static_cast<::megdnn::param::ExecutionPolicy::Strategy>(1);
  178. uint64_t workspace_limit = 18446744073709551615ull;
  179. uint32_t dimA;
  180. uint32_t dimB;
  181. BatchedMatrixMul() = default;
  182. BatchedMatrixMul(bool transposeA_, bool transposeB_, ComputeMode compute_mode_, Format format_, Strategy strategy_, uint64_t workspace_limit_, uint32_t dimA_, uint32_t dimB_, std::string scope_ = {}): transposeA(transposeA_), transposeB(transposeB_), compute_mode(compute_mode_), format(format_), strategy(strategy_), workspace_limit(workspace_limit_), dimA(dimA_), dimB(dimB_) {
  183. set_scope(scope_);
  184. mgb_assert(static_cast<uint32_t>(strategy) <= uint32_t(8));
  185. }
  186. BatchedMatrixMul(::megdnn::param::MatrixMul packed_param_0, ::megdnn::param::ExecutionPolicy packed_param_1, uint32_t dimA_, uint32_t dimB_): transposeA(packed_param_0.transposeA), transposeB(packed_param_0.transposeB), compute_mode(packed_param_0.compute_mode), format(packed_param_0.format), strategy(packed_param_1.strategy), workspace_limit(packed_param_1.workspace_limit), dimA(dimA_), dimB(dimB_) {
  187. mgb_assert(static_cast<uint32_t>(strategy) <= uint32_t(8));
  188. }
  189. ::megdnn::param::MatrixMul param() const {
  190. return {transposeA, transposeB, compute_mode, format};
  191. }
  192. ::megdnn::param::ExecutionPolicy policy() const {
  193. return {strategy, workspace_limit};
  194. }
  195. };
  196. class BatchedMeshIndexing : public OpDefImplBase<BatchedMeshIndexing> {
  197. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  198. public:
  199. std::vector<std::tuple<int8_t, bool, bool, bool, bool>> items;
  200. BatchedMeshIndexing() = default;
  201. BatchedMeshIndexing(std::vector<std::tuple<int8_t, bool, bool, bool, bool>> items_, std::string scope_ = {}): items(items_) { set_scope(scope_); }
  202. };
  203. class BatchedSetMeshIndexing : public OpDefImplBase<BatchedSetMeshIndexing> {
  204. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  205. public:
  206. std::vector<std::tuple<int8_t, bool, bool, bool, bool>> items;
  207. BatchedSetMeshIndexing() = default;
  208. BatchedSetMeshIndexing(std::vector<std::tuple<int8_t, bool, bool, bool, bool>> items_, std::string scope_ = {}): items(items_) { set_scope(scope_); }
  209. };
  210. class BetaRNG : public OpDefImplBase<BetaRNG> {
  211. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  212. public:
  213. uint64_t seed = 0;
  214. size_t handle;
  215. BetaRNG() = default;
  216. BetaRNG(uint64_t seed_, size_t handle_, std::string scope_ = {}): seed(seed_), handle(handle_) { set_scope(scope_); }
  217. BetaRNG(::megdnn::param::BetaRNG packed_param_0, size_t handle_): seed(packed_param_0.seed), handle(handle_) {}
  218. ::megdnn::param::BetaRNG param() const {
  219. return {seed};
  220. }
  221. };
  222. class Borrow : public OpDefImplBase<Borrow> {
  223. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  224. public:
  225. ::mgb::CompNode comp_node;
  226. Borrow() = default;
  227. Borrow(::mgb::CompNode comp_node_, std::string scope_ = {}): comp_node(comp_node_) { set_scope(scope_); }
  228. };
  229. class Broadcast : public OpDefImplBase<Broadcast> {
  230. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  231. public:
  232. std::vector<int32_t> shape;
  233. Broadcast() = default;
  234. Broadcast(std::vector<int32_t> shape_, std::string scope_ = {}): shape(shape_) { set_scope(scope_); }
  235. Broadcast(::megdnn::param::Empty, std::vector<int32_t> shape_): shape(shape_) {}
  236. ::megdnn::param::Empty param() const {
  237. return {};
  238. }
  239. };
  240. class CambriconRuntime : public OpDefImplBase<CambriconRuntime> {
  241. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  242. public:
  243. std::string buf;
  244. size_t buf_size;
  245. std::string symbol;
  246. bool tensor_dim_mutable;
  247. CambriconRuntime() = default;
  248. CambriconRuntime(std::string buf_, size_t buf_size_, std::string symbol_, bool tensor_dim_mutable_, std::string scope_ = {}): buf(buf_), buf_size(buf_size_), symbol(symbol_), tensor_dim_mutable(tensor_dim_mutable_) { set_scope(scope_); }
  249. };
  250. class CheckNonFinite : public OpDefImplBase<CheckNonFinite> {
  251. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  252. public:
  253. float scale = 1.0;
  254. CheckNonFinite() = default;
  255. CheckNonFinite(float scale_, std::string scope_ = {}): scale(scale_) { set_scope(scope_); }
  256. CheckNonFinite(::megdnn::param::CheckNonFinite packed_param_0): scale(packed_param_0.scale) {}
  257. ::megdnn::param::CheckNonFinite param() const {
  258. return {scale};
  259. }
  260. };
  261. class CollectiveComm : public OpDefImplBase<CollectiveComm> {
  262. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  263. public:
  264. using Mode = ::megdnn::param::CollectiveComm::Mode;
  265. Mode mode = ::megdnn::param::CollectiveComm::Mode::REDUCE_SUM;
  266. std::string key;
  267. uint32_t nr_devices;
  268. uint32_t rank;
  269. bool is_root;
  270. bool local_grad;
  271. std::string addr;
  272. uint32_t port;
  273. ::megdnn::DType dtype;
  274. std::string backend;
  275. std::string comp_node;
  276. CollectiveComm() = default;
  277. CollectiveComm(Mode mode_, std::string key_, uint32_t nr_devices_, uint32_t rank_, bool is_root_, bool local_grad_, std::string addr_, uint32_t port_, ::megdnn::DType dtype_, std::string backend_, std::string comp_node_, std::string scope_ = {}): mode(mode_), key(key_), nr_devices(nr_devices_), rank(rank_), is_root(is_root_), local_grad(local_grad_), addr(addr_), port(port_), dtype(dtype_), backend(backend_), comp_node(comp_node_) { set_scope(scope_); }
  278. CollectiveComm(::megdnn::param::CollectiveComm packed_param_0, std::string key_, uint32_t nr_devices_, uint32_t rank_, bool is_root_, bool local_grad_, std::string addr_, uint32_t port_, ::megdnn::DType dtype_, std::string backend_, std::string comp_node_): mode(packed_param_0.mode), key(key_), nr_devices(nr_devices_), rank(rank_), is_root(is_root_), local_grad(local_grad_), addr(addr_), port(port_), dtype(dtype_), backend(backend_), comp_node(comp_node_) {}
  279. ::megdnn::param::CollectiveComm param() const {
  280. return {mode};
  281. }
  282. };
  283. class Concat : public OpDefImplBase<Concat> {
  284. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  285. public:
  286. int32_t axis = 0;
  287. ::mgb::CompNode comp_node;
  288. Concat() = default;
  289. Concat(int32_t axis_, ::mgb::CompNode comp_node_, std::string scope_ = {}): axis(axis_), comp_node(comp_node_) { set_scope(scope_); }
  290. Concat(::megdnn::param::Axis packed_param_0, ::mgb::CompNode comp_node_): axis(packed_param_0.axis), comp_node(comp_node_) {}
  291. ::megdnn::param::Axis param() const {
  292. return {axis};
  293. }
  294. };
  295. class CondTake : public OpDefImplBase<CondTake> {
  296. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  297. public:
  298. CondTake() = default;
  299. };
  300. class ConvBias : public OpDefImplBase<ConvBias> {
  301. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  302. public:
  303. using NonlineMode = ::megdnn::param::ConvBias::NonlineMode;
  304. using Mode = ::megdnn::param::ConvBias::Mode;
  305. using Sparse = ::megdnn::param::ConvBias::Sparse;
  306. using Format = ::megdnn::param::ConvBias::Format;
  307. using ComputeMode = ::megdnn::param::ConvBias::ComputeMode;
  308. using Strategy = ::megdnn::param::ExecutionPolicy::Strategy;
  309. NonlineMode nonlineMode = ::megdnn::param::ConvBias::NonlineMode::IDENTITY;
  310. Mode mode = ::megdnn::param::ConvBias::Mode::CROSS_CORRELATION;
  311. Sparse sparse = ::megdnn::param::ConvBias::Sparse::DENSE;
  312. Format format = ::megdnn::param::ConvBias::Format::NCHW;
  313. uint32_t pad_h = 0;
  314. uint32_t pad_w = 0;
  315. uint32_t stride_h = 1;
  316. uint32_t stride_w = 1;
  317. uint32_t dilate_h = 1;
  318. uint32_t dilate_w = 1;
  319. ComputeMode compute_mode = ::megdnn::param::ConvBias::ComputeMode::DEFAULT;
  320. Strategy strategy = static_cast<::megdnn::param::ExecutionPolicy::Strategy>(1);
  321. uint64_t workspace_limit = 18446744073709551615ull;
  322. ::megdnn::DType dtype;
  323. ConvBias() = default;
  324. ConvBias(NonlineMode nonlineMode_, Mode mode_, Sparse sparse_, Format format_, uint32_t pad_h_, uint32_t pad_w_, uint32_t stride_h_, uint32_t stride_w_, uint32_t dilate_h_, uint32_t dilate_w_, ComputeMode compute_mode_, Strategy strategy_, uint64_t workspace_limit_, ::megdnn::DType dtype_, std::string scope_ = {}): nonlineMode(nonlineMode_), mode(mode_), sparse(sparse_), format(format_), pad_h(pad_h_), pad_w(pad_w_), stride_h(stride_h_), stride_w(stride_w_), dilate_h(dilate_h_), dilate_w(dilate_w_), compute_mode(compute_mode_), strategy(strategy_), workspace_limit(workspace_limit_), dtype(dtype_) {
  325. set_scope(scope_);
  326. mgb_assert(static_cast<uint32_t>(strategy) <= uint32_t(8));
  327. }
  328. ConvBias(::megdnn::param::ConvBias packed_param_0, ::megdnn::param::ExecutionPolicy packed_param_1, ::megdnn::DType dtype_): nonlineMode(packed_param_0.nonlineMode), mode(packed_param_0.mode), sparse(packed_param_0.sparse), format(packed_param_0.format), pad_h(packed_param_0.pad_h), pad_w(packed_param_0.pad_w), stride_h(packed_param_0.stride_h), stride_w(packed_param_0.stride_w), dilate_h(packed_param_0.dilate_h), dilate_w(packed_param_0.dilate_w), compute_mode(packed_param_0.compute_mode), strategy(packed_param_1.strategy), workspace_limit(packed_param_1.workspace_limit), dtype(dtype_) {
  329. mgb_assert(static_cast<uint32_t>(strategy) <= uint32_t(8));
  330. }
  331. ::megdnn::param::ConvBias param() const {
  332. return {nonlineMode, mode, sparse, format, pad_h, pad_w, stride_h, stride_w, dilate_h, dilate_w, compute_mode};
  333. }
  334. ::megdnn::param::ExecutionPolicy policy() const {
  335. return {strategy, workspace_limit};
  336. }
  337. };
  338. class Convolution : public OpDefImplBase<Convolution> {
  339. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  340. public:
  341. using Mode = ::megdnn::param::Convolution::Mode;
  342. using Sparse = ::megdnn::param::Convolution::Sparse;
  343. using Format = ::megdnn::param::Convolution::Format;
  344. using ComputeMode = ::megdnn::param::Convolution::ComputeMode;
  345. using Strategy = ::megdnn::param::ExecutionPolicy::Strategy;
  346. Mode mode = ::megdnn::param::Convolution::Mode::CROSS_CORRELATION;
  347. uint32_t pad_h = 0;
  348. uint32_t pad_w = 0;
  349. uint32_t stride_h = 1;
  350. uint32_t stride_w = 1;
  351. uint32_t dilate_h = 1;
  352. uint32_t dilate_w = 1;
  353. Sparse sparse = ::megdnn::param::Convolution::Sparse::DENSE;
  354. Format format = ::megdnn::param::Convolution::Format::NCHW;
  355. ComputeMode compute_mode = ::megdnn::param::Convolution::ComputeMode::DEFAULT;
  356. Strategy strategy = static_cast<::megdnn::param::ExecutionPolicy::Strategy>(1);
  357. uint64_t workspace_limit = 18446744073709551615ull;
  358. Convolution() = default;
  359. Convolution(Mode mode_, uint32_t pad_h_, uint32_t pad_w_, uint32_t stride_h_, uint32_t stride_w_, uint32_t dilate_h_, uint32_t dilate_w_, Sparse sparse_, Format format_, ComputeMode compute_mode_, Strategy strategy_, uint64_t workspace_limit_, std::string scope_ = {}): mode(mode_), pad_h(pad_h_), pad_w(pad_w_), stride_h(stride_h_), stride_w(stride_w_), dilate_h(dilate_h_), dilate_w(dilate_w_), sparse(sparse_), format(format_), compute_mode(compute_mode_), strategy(strategy_), workspace_limit(workspace_limit_) {
  360. set_scope(scope_);
  361. mgb_assert(static_cast<uint32_t>(strategy) <= uint32_t(8));
  362. }
  363. Convolution(::megdnn::param::Convolution packed_param_0, ::megdnn::param::ExecutionPolicy packed_param_1): mode(packed_param_0.mode), pad_h(packed_param_0.pad_h), pad_w(packed_param_0.pad_w), stride_h(packed_param_0.stride_h), stride_w(packed_param_0.stride_w), dilate_h(packed_param_0.dilate_h), dilate_w(packed_param_0.dilate_w), sparse(packed_param_0.sparse), format(packed_param_0.format), compute_mode(packed_param_0.compute_mode), strategy(packed_param_1.strategy), workspace_limit(packed_param_1.workspace_limit) {
  364. mgb_assert(static_cast<uint32_t>(strategy) <= uint32_t(8));
  365. }
  366. ::megdnn::param::Convolution param() const {
  367. return {mode, pad_h, pad_w, stride_h, stride_w, dilate_h, dilate_w, sparse, format, compute_mode};
  368. }
  369. ::megdnn::param::ExecutionPolicy policy() const {
  370. return {strategy, workspace_limit};
  371. }
  372. };
  373. class Convolution3D : public OpDefImplBase<Convolution3D> {
  374. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  375. public:
  376. using Mode = ::megdnn::param::Convolution3D::Mode;
  377. using Sparse = ::megdnn::param::Convolution3D::Sparse;
  378. using DataType = ::megdnn::param::Convolution3D::DataType;
  379. using Format = ::megdnn::param::Convolution3D::Format;
  380. using Strategy = ::megdnn::param::ExecutionPolicy::Strategy;
  381. Mode mode = ::megdnn::param::Convolution3D::Mode::CROSS_CORRELATION;
  382. uint32_t pad_d = 0;
  383. uint32_t pad_h = 0;
  384. uint32_t pad_w = 0;
  385. uint32_t stride_d = 1;
  386. uint32_t stride_h = 1;
  387. uint32_t stride_w = 1;
  388. uint32_t dilate_d = 1;
  389. uint32_t dilate_h = 1;
  390. uint32_t dilate_w = 1;
  391. Sparse sparse = ::megdnn::param::Convolution3D::Sparse::DENSE;
  392. DataType data_type = ::megdnn::param::Convolution3D::DataType::FLOAT;
  393. Format format = ::megdnn::param::Convolution3D::Format::NCDHW;
  394. Strategy strategy = static_cast<::megdnn::param::ExecutionPolicy::Strategy>(1);
  395. uint64_t workspace_limit = 18446744073709551615ull;
  396. Convolution3D() = default;
  397. Convolution3D(Mode mode_, uint32_t pad_d_, uint32_t pad_h_, uint32_t pad_w_, uint32_t stride_d_, uint32_t stride_h_, uint32_t stride_w_, uint32_t dilate_d_, uint32_t dilate_h_, uint32_t dilate_w_, Sparse sparse_, DataType data_type_, Format format_, Strategy strategy_, uint64_t workspace_limit_, std::string scope_ = {}): mode(mode_), pad_d(pad_d_), pad_h(pad_h_), pad_w(pad_w_), stride_d(stride_d_), stride_h(stride_h_), stride_w(stride_w_), dilate_d(dilate_d_), dilate_h(dilate_h_), dilate_w(dilate_w_), sparse(sparse_), data_type(data_type_), format(format_), strategy(strategy_), workspace_limit(workspace_limit_) {
  398. set_scope(scope_);
  399. mgb_assert(static_cast<uint32_t>(strategy) <= uint32_t(8));
  400. }
  401. Convolution3D(::megdnn::param::Convolution3D packed_param_0, ::megdnn::param::ExecutionPolicy packed_param_1): mode(packed_param_0.mode), pad_d(packed_param_0.pad_d), pad_h(packed_param_0.pad_h), pad_w(packed_param_0.pad_w), stride_d(packed_param_0.stride_d), stride_h(packed_param_0.stride_h), stride_w(packed_param_0.stride_w), dilate_d(packed_param_0.dilate_d), dilate_h(packed_param_0.dilate_h), dilate_w(packed_param_0.dilate_w), sparse(packed_param_0.sparse), data_type(packed_param_0.data_type), format(packed_param_0.format), strategy(packed_param_1.strategy), workspace_limit(packed_param_1.workspace_limit) {
  402. mgb_assert(static_cast<uint32_t>(strategy) <= uint32_t(8));
  403. }
  404. ::megdnn::param::Convolution3D param() const {
  405. return {mode, pad_d, pad_h, pad_w, stride_d, stride_h, stride_w, dilate_d, dilate_h, dilate_w, sparse, data_type, format};
  406. }
  407. ::megdnn::param::ExecutionPolicy policy() const {
  408. return {strategy, workspace_limit};
  409. }
  410. };
  411. class Convolution3DBackwardData : public OpDefImplBase<Convolution3DBackwardData> {
  412. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  413. public:
  414. using Mode = ::megdnn::param::Convolution3D::Mode;
  415. using Sparse = ::megdnn::param::Convolution3D::Sparse;
  416. using DataType = ::megdnn::param::Convolution3D::DataType;
  417. using Format = ::megdnn::param::Convolution3D::Format;
  418. using Strategy = ::megdnn::param::ExecutionPolicy::Strategy;
  419. Mode mode = ::megdnn::param::Convolution3D::Mode::CROSS_CORRELATION;
  420. uint32_t pad_d = 0;
  421. uint32_t pad_h = 0;
  422. uint32_t pad_w = 0;
  423. uint32_t stride_d = 1;
  424. uint32_t stride_h = 1;
  425. uint32_t stride_w = 1;
  426. uint32_t dilate_d = 1;
  427. uint32_t dilate_h = 1;
  428. uint32_t dilate_w = 1;
  429. Sparse sparse = ::megdnn::param::Convolution3D::Sparse::DENSE;
  430. DataType data_type = ::megdnn::param::Convolution3D::DataType::FLOAT;
  431. Format format = ::megdnn::param::Convolution3D::Format::NCDHW;
  432. Strategy strategy = static_cast<::megdnn::param::ExecutionPolicy::Strategy>(1);
  433. uint64_t workspace_limit = 18446744073709551615ull;
  434. Convolution3DBackwardData() = default;
  435. Convolution3DBackwardData(Mode mode_, uint32_t pad_d_, uint32_t pad_h_, uint32_t pad_w_, uint32_t stride_d_, uint32_t stride_h_, uint32_t stride_w_, uint32_t dilate_d_, uint32_t dilate_h_, uint32_t dilate_w_, Sparse sparse_, DataType data_type_, Format format_, Strategy strategy_, uint64_t workspace_limit_, std::string scope_ = {}): mode(mode_), pad_d(pad_d_), pad_h(pad_h_), pad_w(pad_w_), stride_d(stride_d_), stride_h(stride_h_), stride_w(stride_w_), dilate_d(dilate_d_), dilate_h(dilate_h_), dilate_w(dilate_w_), sparse(sparse_), data_type(data_type_), format(format_), strategy(strategy_), workspace_limit(workspace_limit_) {
  436. set_scope(scope_);
  437. mgb_assert(static_cast<uint32_t>(strategy) <= uint32_t(8));
  438. }
  439. Convolution3DBackwardData(::megdnn::param::Convolution3D packed_param_0, ::megdnn::param::ExecutionPolicy packed_param_1): mode(packed_param_0.mode), pad_d(packed_param_0.pad_d), pad_h(packed_param_0.pad_h), pad_w(packed_param_0.pad_w), stride_d(packed_param_0.stride_d), stride_h(packed_param_0.stride_h), stride_w(packed_param_0.stride_w), dilate_d(packed_param_0.dilate_d), dilate_h(packed_param_0.dilate_h), dilate_w(packed_param_0.dilate_w), sparse(packed_param_0.sparse), data_type(packed_param_0.data_type), format(packed_param_0.format), strategy(packed_param_1.strategy), workspace_limit(packed_param_1.workspace_limit) {
  440. mgb_assert(static_cast<uint32_t>(strategy) <= uint32_t(8));
  441. }
  442. ::megdnn::param::Convolution3D param() const {
  443. return {mode, pad_d, pad_h, pad_w, stride_d, stride_h, stride_w, dilate_d, dilate_h, dilate_w, sparse, data_type, format};
  444. }
  445. ::megdnn::param::ExecutionPolicy policy() const {
  446. return {strategy, workspace_limit};
  447. }
  448. };
  449. class ConvolutionBackwardData : public OpDefImplBase<ConvolutionBackwardData> {
  450. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  451. public:
  452. using Mode = ::megdnn::param::Convolution::Mode;
  453. using Sparse = ::megdnn::param::Convolution::Sparse;
  454. using Format = ::megdnn::param::Convolution::Format;
  455. using ComputeMode = ::megdnn::param::Convolution::ComputeMode;
  456. using Strategy = ::megdnn::param::ExecutionPolicy::Strategy;
  457. Mode mode = ::megdnn::param::Convolution::Mode::CROSS_CORRELATION;
  458. uint32_t pad_h = 0;
  459. uint32_t pad_w = 0;
  460. uint32_t stride_h = 1;
  461. uint32_t stride_w = 1;
  462. uint32_t dilate_h = 1;
  463. uint32_t dilate_w = 1;
  464. Sparse sparse = ::megdnn::param::Convolution::Sparse::DENSE;
  465. Format format = ::megdnn::param::Convolution::Format::NCHW;
  466. ComputeMode compute_mode = ::megdnn::param::Convolution::ComputeMode::DEFAULT;
  467. Strategy strategy = static_cast<::megdnn::param::ExecutionPolicy::Strategy>(1);
  468. uint64_t workspace_limit = 18446744073709551615ull;
  469. ::megdnn::DType dtype;
  470. ConvolutionBackwardData() = default;
  471. ConvolutionBackwardData(Mode mode_, uint32_t pad_h_, uint32_t pad_w_, uint32_t stride_h_, uint32_t stride_w_, uint32_t dilate_h_, uint32_t dilate_w_, Sparse sparse_, Format format_, ComputeMode compute_mode_, Strategy strategy_, uint64_t workspace_limit_, ::megdnn::DType dtype_, std::string scope_ = {}): mode(mode_), pad_h(pad_h_), pad_w(pad_w_), stride_h(stride_h_), stride_w(stride_w_), dilate_h(dilate_h_), dilate_w(dilate_w_), sparse(sparse_), format(format_), compute_mode(compute_mode_), strategy(strategy_), workspace_limit(workspace_limit_), dtype(dtype_) {
  472. set_scope(scope_);
  473. mgb_assert(static_cast<uint32_t>(strategy) <= uint32_t(8));
  474. }
  475. ConvolutionBackwardData(::megdnn::param::Convolution packed_param_0, ::megdnn::param::ExecutionPolicy packed_param_1, ::megdnn::DType dtype_): mode(packed_param_0.mode), pad_h(packed_param_0.pad_h), pad_w(packed_param_0.pad_w), stride_h(packed_param_0.stride_h), stride_w(packed_param_0.stride_w), dilate_h(packed_param_0.dilate_h), dilate_w(packed_param_0.dilate_w), sparse(packed_param_0.sparse), format(packed_param_0.format), compute_mode(packed_param_0.compute_mode), strategy(packed_param_1.strategy), workspace_limit(packed_param_1.workspace_limit), dtype(dtype_) {
  476. mgb_assert(static_cast<uint32_t>(strategy) <= uint32_t(8));
  477. }
  478. ::megdnn::param::Convolution param() const {
  479. return {mode, pad_h, pad_w, stride_h, stride_w, dilate_h, dilate_w, sparse, format, compute_mode};
  480. }
  481. ::megdnn::param::ExecutionPolicy policy() const {
  482. return {strategy, workspace_limit};
  483. }
  484. };
  485. class Copy : public OpDefImplBase<Copy> {
  486. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  487. public:
  488. ::mgb::CompNode comp_node;
  489. Copy() = default;
  490. Copy(::mgb::CompNode comp_node_, std::string scope_ = {}): comp_node(comp_node_) { set_scope(scope_); }
  491. };
  492. class Correlation : public OpDefImplBase<Correlation> {
  493. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  494. public:
  495. using Format = ::megdnn::param::Correlation::Format;
  496. Format format = ::megdnn::param::Correlation::Format::NCHW;
  497. uint32_t kernel_size = 1;
  498. uint32_t max_displacement = 1;
  499. uint32_t stride1 = 1;
  500. uint32_t stride2 = 1;
  501. uint32_t pad_size = 0;
  502. bool is_multiply = true;
  503. Correlation() = default;
  504. Correlation(Format format_, uint32_t kernel_size_, uint32_t max_displacement_, uint32_t stride1_, uint32_t stride2_, uint32_t pad_size_, bool is_multiply_, std::string scope_ = {}): format(format_), kernel_size(kernel_size_), max_displacement(max_displacement_), stride1(stride1_), stride2(stride2_), pad_size(pad_size_), is_multiply(is_multiply_) { set_scope(scope_); }
  505. Correlation(::megdnn::param::Correlation packed_param_0): format(packed_param_0.format), kernel_size(packed_param_0.kernel_size), max_displacement(packed_param_0.max_displacement), stride1(packed_param_0.stride1), stride2(packed_param_0.stride2), pad_size(packed_param_0.pad_size), is_multiply(packed_param_0.is_multiply) {}
  506. ::megdnn::param::Correlation param() const {
  507. return {format, kernel_size, max_displacement, stride1, stride2, pad_size, is_multiply};
  508. }
  509. };
  510. class Cumprod : public OpDefImplBase<Cumprod> {
  511. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  512. public:
  513. int32_t axis = 2147483647;
  514. bool exclusive = true;
  515. bool reverse = false;
  516. Cumprod() = default;
  517. Cumprod(int32_t axis_, bool exclusive_, bool reverse_, std::string scope_ = {}): axis(axis_), exclusive(exclusive_), reverse(reverse_) { set_scope(scope_); }
  518. Cumprod(::megdnn::param::Cumprod packed_param_0): axis(packed_param_0.axis), exclusive(packed_param_0.exclusive), reverse(packed_param_0.reverse) {}
  519. ::megdnn::param::Cumprod param() const {
  520. return {axis, exclusive, reverse};
  521. }
  522. };
  523. class Cumsum : public OpDefImplBase<Cumsum> {
  524. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  525. public:
  526. int32_t axis = 2147483647;
  527. bool exclusive = true;
  528. bool reverse = false;
  529. Cumsum() = default;
  530. Cumsum(int32_t axis_, bool exclusive_, bool reverse_, std::string scope_ = {}): axis(axis_), exclusive(exclusive_), reverse(reverse_) { set_scope(scope_); }
  531. Cumsum(::megdnn::param::Cumsum packed_param_0): axis(packed_param_0.axis), exclusive(packed_param_0.exclusive), reverse(packed_param_0.reverse) {}
  532. ::megdnn::param::Cumsum param() const {
  533. return {axis, exclusive, reverse};
  534. }
  535. };
  536. class CvtColor : public OpDefImplBase<CvtColor> {
  537. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  538. public:
  539. using Mode = ::megdnn::param::CvtColor::Mode;
  540. Mode mode = ::megdnn::param::CvtColor::Mode::RGB2GRAY;
  541. CvtColor() = default;
  542. CvtColor(Mode mode_, std::string scope_ = {}): mode(mode_) { set_scope(scope_); }
  543. CvtColor(::megdnn::param::CvtColor packed_param_0): mode(packed_param_0.mode) {}
  544. ::megdnn::param::CvtColor param() const {
  545. return {mode};
  546. }
  547. };
  548. class DeformableConv : public OpDefImplBase<DeformableConv> {
  549. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  550. public:
  551. using Mode = ::megdnn::param::Convolution::Mode;
  552. using Sparse = ::megdnn::param::Convolution::Sparse;
  553. using Format = ::megdnn::param::Convolution::Format;
  554. using ComputeMode = ::megdnn::param::Convolution::ComputeMode;
  555. using Strategy = ::megdnn::param::ExecutionPolicy::Strategy;
  556. Mode mode = ::megdnn::param::Convolution::Mode::CROSS_CORRELATION;
  557. uint32_t pad_h = 0;
  558. uint32_t pad_w = 0;
  559. uint32_t stride_h = 1;
  560. uint32_t stride_w = 1;
  561. uint32_t dilate_h = 1;
  562. uint32_t dilate_w = 1;
  563. Sparse sparse = ::megdnn::param::Convolution::Sparse::DENSE;
  564. Format format = ::megdnn::param::Convolution::Format::NCHW;
  565. ComputeMode compute_mode = ::megdnn::param::Convolution::ComputeMode::DEFAULT;
  566. Strategy strategy = static_cast<::megdnn::param::ExecutionPolicy::Strategy>(1);
  567. uint64_t workspace_limit = 18446744073709551615ull;
  568. DeformableConv() = default;
  569. DeformableConv(Mode mode_, uint32_t pad_h_, uint32_t pad_w_, uint32_t stride_h_, uint32_t stride_w_, uint32_t dilate_h_, uint32_t dilate_w_, Sparse sparse_, Format format_, ComputeMode compute_mode_, Strategy strategy_, uint64_t workspace_limit_, std::string scope_ = {}): mode(mode_), pad_h(pad_h_), pad_w(pad_w_), stride_h(stride_h_), stride_w(stride_w_), dilate_h(dilate_h_), dilate_w(dilate_w_), sparse(sparse_), format(format_), compute_mode(compute_mode_), strategy(strategy_), workspace_limit(workspace_limit_) {
  570. set_scope(scope_);
  571. mgb_assert(static_cast<uint32_t>(strategy) <= uint32_t(8));
  572. }
  573. DeformableConv(::megdnn::param::Convolution packed_param_0, ::megdnn::param::ExecutionPolicy packed_param_1): mode(packed_param_0.mode), pad_h(packed_param_0.pad_h), pad_w(packed_param_0.pad_w), stride_h(packed_param_0.stride_h), stride_w(packed_param_0.stride_w), dilate_h(packed_param_0.dilate_h), dilate_w(packed_param_0.dilate_w), sparse(packed_param_0.sparse), format(packed_param_0.format), compute_mode(packed_param_0.compute_mode), strategy(packed_param_1.strategy), workspace_limit(packed_param_1.workspace_limit) {
  574. mgb_assert(static_cast<uint32_t>(strategy) <= uint32_t(8));
  575. }
  576. ::megdnn::param::Convolution param() const {
  577. return {mode, pad_h, pad_w, stride_h, stride_w, dilate_h, dilate_w, sparse, format, compute_mode};
  578. }
  579. ::megdnn::param::ExecutionPolicy policy() const {
  580. return {strategy, workspace_limit};
  581. }
  582. };
  583. class DeformablePSROIPooling : public OpDefImplBase<DeformablePSROIPooling> {
  584. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  585. public:
  586. bool no_trans = true;
  587. float spatial_scale = 1;
  588. float trans_std = 1;
  589. uint32_t pooled_h = 1;
  590. uint32_t pooled_w = 1;
  591. uint32_t part_size = 1;
  592. uint32_t sample_per_part = 1;
  593. DeformablePSROIPooling() = default;
  594. DeformablePSROIPooling(bool no_trans_, float spatial_scale_, float trans_std_, uint32_t pooled_h_, uint32_t pooled_w_, uint32_t part_size_, uint32_t sample_per_part_, std::string scope_ = {}): no_trans(no_trans_), spatial_scale(spatial_scale_), trans_std(trans_std_), pooled_h(pooled_h_), pooled_w(pooled_w_), part_size(part_size_), sample_per_part(sample_per_part_) { set_scope(scope_); }
  595. DeformablePSROIPooling(::megdnn::param::DeformablePSROIPooling packed_param_0): no_trans(packed_param_0.no_trans), spatial_scale(packed_param_0.spatial_scale), trans_std(packed_param_0.trans_std), pooled_h(packed_param_0.pooled_h), pooled_w(packed_param_0.pooled_w), part_size(packed_param_0.part_size), sample_per_part(packed_param_0.sample_per_part) {}
  596. ::megdnn::param::DeformablePSROIPooling param() const {
  597. return {no_trans, spatial_scale, trans_std, pooled_h, pooled_w, part_size, sample_per_part};
  598. }
  599. };
  600. class Diag : public OpDefImplBase<Diag> {
  601. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  602. public:
  603. int32_t k = 0;
  604. Diag() = default;
  605. Diag(int32_t k_, std::string scope_ = {}): k(k_) { set_scope(scope_); }
  606. Diag(::megdnn::param::Diag packed_param_0): k(packed_param_0.k) {}
  607. ::megdnn::param::Diag param() const {
  608. return {k};
  609. }
  610. };
  611. class Dimshuffle : public OpDefImplBase<Dimshuffle> {
  612. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  613. public:
  614. std::vector<int32_t> pattern;
  615. Dimshuffle() = default;
  616. Dimshuffle(std::vector<int32_t> pattern_, std::string scope_ = {}): pattern(pattern_) { set_scope(scope_); }
  617. };
  618. class Dot : public OpDefImplBase<Dot> {
  619. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  620. public:
  621. Dot() = default;
  622. Dot(::megdnn::param::Empty) {}
  623. ::megdnn::param::Empty param() const {
  624. return {};
  625. }
  626. };
  627. class Dropout : public OpDefImplBase<Dropout> {
  628. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  629. public:
  630. float drop_prob = 0;
  631. uint64_t seed = 0;
  632. size_t handle;
  633. Dropout() = default;
  634. Dropout(float drop_prob_, uint64_t seed_, size_t handle_, std::string scope_ = {}): drop_prob(drop_prob_), seed(seed_), handle(handle_) { set_scope(scope_); }
  635. Dropout(::megdnn::param::Dropout packed_param_0, size_t handle_): drop_prob(packed_param_0.drop_prob), seed(packed_param_0.seed), handle(handle_) {}
  636. ::megdnn::param::Dropout param() const {
  637. return {drop_prob, seed};
  638. }
  639. };
  640. class Elemwise : public OpDefImplBase<Elemwise> {
  641. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  642. public:
  643. using Mode = ::megdnn::param::Elemwise::Mode;
  644. Mode mode = ::megdnn::param::Elemwise::Mode::RELU;
  645. Elemwise() = default;
  646. Elemwise(Mode mode_, std::string scope_ = {}): mode(mode_) { set_scope(scope_); }
  647. Elemwise(::megdnn::param::Elemwise packed_param_0): mode(packed_param_0.mode) {}
  648. ::megdnn::param::Elemwise param() const {
  649. return {mode};
  650. }
  651. };
  652. template <>
  653. struct ToStringTrait<Elemwise::Mode> {
  654. std::string operator()(Elemwise::Mode e) const {
  655. switch (e) {
  656. case Elemwise::Mode::RELU: return "RELU";
  657. case Elemwise::Mode::ABS: return "ABS";
  658. case Elemwise::Mode::ACOS: return "ACOS";
  659. case Elemwise::Mode::ASIN: return "ASIN";
  660. case Elemwise::Mode::CEIL: return "CEIL";
  661. case Elemwise::Mode::COS: return "COS";
  662. case Elemwise::Mode::EXP: return "EXP";
  663. case Elemwise::Mode::EXPM1: return "EXPM1";
  664. case Elemwise::Mode::FLOOR: return "FLOOR";
  665. case Elemwise::Mode::LOG: return "LOG";
  666. case Elemwise::Mode::LOG1P: return "LOG1P";
  667. case Elemwise::Mode::NEGATE: return "NEGATE";
  668. case Elemwise::Mode::SIGMOID: return "SIGMOID";
  669. case Elemwise::Mode::SIN: return "SIN";
  670. case Elemwise::Mode::TANH: return "TANH";
  671. case Elemwise::Mode::ABS_GRAD: return "ABS_GRAD";
  672. case Elemwise::Mode::ADD: return "ADD";
  673. case Elemwise::Mode::FLOOR_DIV: return "FLOOR_DIV";
  674. case Elemwise::Mode::MAX: return "MAX";
  675. case Elemwise::Mode::MIN: return "MIN";
  676. case Elemwise::Mode::MOD: return "MOD";
  677. case Elemwise::Mode::MUL: return "MUL";
  678. case Elemwise::Mode::POW: return "POW";
  679. case Elemwise::Mode::SIGMOID_GRAD: return "SIGMOID_GRAD";
  680. case Elemwise::Mode::SUB: return "SUB";
  681. case Elemwise::Mode::SWITCH_GT0: return "SWITCH_GT0";
  682. case Elemwise::Mode::TANH_GRAD: return "TANH_GRAD";
  683. case Elemwise::Mode::TRUE_DIV: return "TRUE_DIV";
  684. case Elemwise::Mode::LOG_SUM_EXP: return "LOG_SUM_EXP";
  685. case Elemwise::Mode::LT: return "LT";
  686. case Elemwise::Mode::LEQ: return "LEQ";
  687. case Elemwise::Mode::EQ: return "EQ";
  688. case Elemwise::Mode::SHL: return "SHL";
  689. case Elemwise::Mode::SHR: return "SHR";
  690. case Elemwise::Mode::COND_LEQ_MOV: return "COND_LEQ_MOV";
  691. case Elemwise::Mode::FUSE_MUL_ADD3: return "FUSE_MUL_ADD3";
  692. case Elemwise::Mode::FUSE_MUL_ADD4: return "FUSE_MUL_ADD4";
  693. case Elemwise::Mode::FUSE_ADD_RELU: return "FUSE_ADD_RELU";
  694. case Elemwise::Mode::FUSE_ADD_SIGMOID: return "FUSE_ADD_SIGMOID";
  695. case Elemwise::Mode::FUSE_ADD_TANH: return "FUSE_ADD_TANH";
  696. case Elemwise::Mode::FAST_TANH: return "FAST_TANH";
  697. case Elemwise::Mode::FAST_TANH_GRAD: return "FAST_TANH_GRAD";
  698. case Elemwise::Mode::ROUND: return "ROUND";
  699. case Elemwise::Mode::RMULH: return "RMULH";
  700. case Elemwise::Mode::ATAN2: return "ATAN2";
  701. case Elemwise::Mode::ERF: return "ERF";
  702. case Elemwise::Mode::ERFINV: return "ERFINV";
  703. case Elemwise::Mode::ERFC: return "ERFC";
  704. case Elemwise::Mode::ERFCINV: return "ERFCINV";
  705. case Elemwise::Mode::H_SWISH: return "H_SWISH";
  706. case Elemwise::Mode::H_SWISH_GRAD: return "H_SWISH_GRAD";
  707. case Elemwise::Mode::FUSE_ADD_H_SWISH: return "FUSE_ADD_H_SWISH";
  708. case Elemwise::Mode::NOT: return "NOT";
  709. case Elemwise::Mode::AND: return "AND";
  710. case Elemwise::Mode::OR: return "OR";
  711. case Elemwise::Mode::XOR: return "XOR";
  712. case Elemwise::Mode::SILU: return "SILU";
  713. case Elemwise::Mode::SILU_GRAD: return "SILU_GRAD";
  714. case Elemwise::Mode::GELU: return "GELU";
  715. case Elemwise::Mode::GELU_GRAD: return "GELU_GRAD";
  716. case Elemwise::Mode::COND_LT_MOV: return "COND_LT_MOV";
  717. case Elemwise::Mode::SINH: return "SINH";
  718. case Elemwise::Mode::COSH: return "COSH";
  719. case Elemwise::Mode::ASINH: return "ASINH";
  720. case Elemwise::Mode::ACOSH: return "ACOSH";
  721. case Elemwise::Mode::ATANH: return "ATANH";
  722. case Elemwise::Mode::TAN: return "TAN";
  723. case Elemwise::Mode::ASINH_GRAD: return "ASINH_GRAD";
  724. case Elemwise::Mode::ACOSH_GRAD: return "ACOSH_GRAD";
  725. case Elemwise::Mode::ATANH_GRAD: return "ATANH_GRAD";
  726. case Elemwise::Mode::PRELU: return "PRELU";
  727. case Elemwise::Mode::CLIP: return "CLIP";
  728. case Elemwise::Mode::PRELU_GRAD: return "PRELU_GRAD";
  729. case Elemwise::Mode::SOFTPLUS: return "SOFTPLUS";
  730. case Elemwise::Mode::SOFTPLUS_GRAD: return "SOFTPLUS_GRAD";
  731. case Elemwise::Mode::RELU6: return "RELU6";
  732. case Elemwise::Mode::RELU6_GRAD: return "RELU6_GRAD";
  733. case Elemwise::Mode::HSIGMOID: return "HSIGMOID";
  734. case Elemwise::Mode::HSIGMOID_GRAD: return "HSIGMOID_GRAD";
  735. case Elemwise::Mode::LOGSIGMOID: return "LOGSIGMOID";
  736. case Elemwise::Mode::SQRT: return "SQRT";
  737. case Elemwise::Mode::SQUARE: return "SQUARE";
  738. case Elemwise::Mode::SIGN: return "SIGN";
  739. case Elemwise::Mode::SAFE_DIV: return "SAFE_DIV";
  740. case Elemwise::Mode::NEQ: return "NEQ";
  741. case Elemwise::Mode::ISNAN: return "ISNAN";
  742. case Elemwise::Mode::ISINF: return "ISINF";
  743. default:
  744. return "Elemwise::Mode::Unknown";
  745. }
  746. }
  747. };
  748. class ElemwiseMultiType : public OpDefImplBase<ElemwiseMultiType> {
  749. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  750. public:
  751. using Mode = ::megdnn::param::ElemwiseMultiType::Mode;
  752. Mode mode = ::megdnn::param::ElemwiseMultiType::Mode::FUSE_MUL_ADD3_INT16x32x32x32;
  753. ::megdnn::DType dtype;
  754. ElemwiseMultiType() = default;
  755. ElemwiseMultiType(Mode mode_, ::megdnn::DType dtype_, std::string scope_ = {}): mode(mode_), dtype(dtype_) { set_scope(scope_); }
  756. ElemwiseMultiType(::megdnn::param::ElemwiseMultiType packed_param_0, ::megdnn::DType dtype_): mode(packed_param_0.mode), dtype(dtype_) {}
  757. ::megdnn::param::ElemwiseMultiType param() const {
  758. return {mode};
  759. }
  760. };
  761. template <>
  762. struct ToStringTrait<ElemwiseMultiType::Mode> {
  763. std::string operator()(ElemwiseMultiType::Mode e) const {
  764. switch (e) {
  765. case ElemwiseMultiType::Mode::FUSE_MUL_ADD3_INT16x32x32x32: return "FUSE_MUL_ADD3_INT16x32x32x32";
  766. case ElemwiseMultiType::Mode::FUSE_MUL_ADD3_IXxF32xF32xI8: return "FUSE_MUL_ADD3_IXxF32xF32xI8";
  767. case ElemwiseMultiType::Mode::ROUND_SHR_SATURATE_IXxI8xI8: return "ROUND_SHR_SATURATE_IXxI8xI8";
  768. case ElemwiseMultiType::Mode::FUSE_ADD_RMULH_ROUND_SHR_SATURATE_INT16x16x16x8: return "FUSE_ADD_RMULH_ROUND_SHR_SATURATE_INT16x16x16x8";
  769. case ElemwiseMultiType::Mode::FUSE_ADD_RMULH_ROUND_SHR_SATURATE_INT32x32x32x8: return "FUSE_ADD_RMULH_ROUND_SHR_SATURATE_INT32x32x32x8";
  770. case ElemwiseMultiType::Mode::ROUND_SHR_SATURATE_IXxI8xI16: return "ROUND_SHR_SATURATE_IXxI8xI16";
  771. case ElemwiseMultiType::Mode::QADD: return "QADD";
  772. case ElemwiseMultiType::Mode::QFUSE_ADD_RELU: return "QFUSE_ADD_RELU";
  773. case ElemwiseMultiType::Mode::QMUL: return "QMUL";
  774. case ElemwiseMultiType::Mode::QMIN: return "QMIN";
  775. case ElemwiseMultiType::Mode::QMAX: return "QMAX";
  776. case ElemwiseMultiType::Mode::QSUB: return "QSUB";
  777. case ElemwiseMultiType::Mode::QTRUE_DIV: return "QTRUE_DIV";
  778. case ElemwiseMultiType::Mode::QFUSE_ADD_SIGMOID: return "QFUSE_ADD_SIGMOID";
  779. case ElemwiseMultiType::Mode::QFUSE_ADD_TANH: return "QFUSE_ADD_TANH";
  780. case ElemwiseMultiType::Mode::QRELU: return "QRELU";
  781. case ElemwiseMultiType::Mode::QABS: return "QABS";
  782. case ElemwiseMultiType::Mode::QSIGMOID: return "QSIGMOID";
  783. case ElemwiseMultiType::Mode::QEXP: return "QEXP";
  784. case ElemwiseMultiType::Mode::QTANH: return "QTANH";
  785. case ElemwiseMultiType::Mode::QFUSE_MUL_ADD3: return "QFUSE_MUL_ADD3";
  786. case ElemwiseMultiType::Mode::QFAST_TANH: return "QFAST_TANH";
  787. case ElemwiseMultiType::Mode::QNEGATE: return "QNEGATE";
  788. case ElemwiseMultiType::Mode::QACOS: return "QACOS";
  789. case ElemwiseMultiType::Mode::QASIN: return "QASIN";
  790. case ElemwiseMultiType::Mode::QCEIL: return "QCEIL";
  791. case ElemwiseMultiType::Mode::QCOS: return "QCOS";
  792. case ElemwiseMultiType::Mode::QEXPM1: return "QEXPM1";
  793. case ElemwiseMultiType::Mode::QFLOOR: return "QFLOOR";
  794. case ElemwiseMultiType::Mode::QLOG: return "QLOG";
  795. case ElemwiseMultiType::Mode::QLOG1P: return "QLOG1P";
  796. case ElemwiseMultiType::Mode::QSIN: return "QSIN";
  797. case ElemwiseMultiType::Mode::QROUND: return "QROUND";
  798. case ElemwiseMultiType::Mode::QERF: return "QERF";
  799. case ElemwiseMultiType::Mode::QERFINV: return "QERFINV";
  800. case ElemwiseMultiType::Mode::QERFC: return "QERFC";
  801. case ElemwiseMultiType::Mode::QERFCINV: return "QERFCINV";
  802. case ElemwiseMultiType::Mode::QABS_GRAD: return "QABS_GRAD";
  803. case ElemwiseMultiType::Mode::QFLOOR_DIV: return "QFLOOR_DIV";
  804. case ElemwiseMultiType::Mode::QMOD: return "QMOD";
  805. case ElemwiseMultiType::Mode::QSIGMOID_GRAD: return "QSIGMOID_GRAD";
  806. case ElemwiseMultiType::Mode::QSWITCH_GT0: return "QSWITCH_GT0";
  807. case ElemwiseMultiType::Mode::QTANH_GRAD: return "QTANH_GRAD";
  808. case ElemwiseMultiType::Mode::QLT: return "QLT";
  809. case ElemwiseMultiType::Mode::QLEQ: return "QLEQ";
  810. case ElemwiseMultiType::Mode::QEQ: return "QEQ";
  811. case ElemwiseMultiType::Mode::QPOW: return "QPOW";
  812. case ElemwiseMultiType::Mode::QLOG_SUM_EXP: return "QLOG_SUM_EXP";
  813. case ElemwiseMultiType::Mode::QFAST_TANH_GRAD: return "QFAST_TANH_GRAD";
  814. case ElemwiseMultiType::Mode::QATAN2: return "QATAN2";
  815. case ElemwiseMultiType::Mode::QCOND_LEQ_MOV: return "QCOND_LEQ_MOV";
  816. case ElemwiseMultiType::Mode::QH_SWISH: return "QH_SWISH";
  817. case ElemwiseMultiType::Mode::QFUSE_ADD_H_SWISH: return "QFUSE_ADD_H_SWISH";
  818. case ElemwiseMultiType::Mode::QH_SWISH_GRAD: return "QH_SWISH_GRAD";
  819. case ElemwiseMultiType::Mode::FUSE_MUL_ADD3_INT16xF32xF32xF32: return "FUSE_MUL_ADD3_INT16xF32xF32xF32";
  820. case ElemwiseMultiType::Mode::MUL_INT16xF32xF32: return "MUL_INT16xF32xF32";
  821. case ElemwiseMultiType::Mode::FUSE_MUL_ADD3_UINT8xF32xF32xF32: return "FUSE_MUL_ADD3_UINT8xF32xF32xF32";
  822. case ElemwiseMultiType::Mode::QCOND_LT_MOV: return "QCOND_LT_MOV";
  823. case ElemwiseMultiType::Mode::EQ: return "EQ";
  824. case ElemwiseMultiType::Mode::NEQ: return "NEQ";
  825. case ElemwiseMultiType::Mode::LT: return "LT";
  826. case ElemwiseMultiType::Mode::LEQ: return "LEQ";
  827. case ElemwiseMultiType::Mode::ISNAN: return "ISNAN";
  828. case ElemwiseMultiType::Mode::ISINF: return "ISINF";
  829. default:
  830. return "ElemwiseMultiType::Mode::Unknown";
  831. }
  832. }
  833. };
  834. class ExternOpr : public OpDefImplBase<ExternOpr> {
  835. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  836. public:
  837. std::vector<std::vector<size_t>> output_shapes;
  838. std::string name;
  839. std::string data;
  840. size_t data_len;
  841. std::vector<::megdnn::DType> output_dtypes;
  842. ExternOpr() = default;
  843. ExternOpr(std::vector<std::vector<size_t>> output_shapes_, std::string name_, std::string data_, size_t data_len_, std::vector<::megdnn::DType> output_dtypes_, std::string scope_ = {}): output_shapes(output_shapes_), name(name_), data(data_), data_len(data_len_), output_dtypes(output_dtypes_) { set_scope(scope_); }
  844. };
  845. class Eye : public OpDefImplBase<Eye> {
  846. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  847. public:
  848. int32_t k = 0;
  849. ::megdnn::DType dtype = megdnn::DType::from_enum(megdnn::DTypeEnum::Float32);
  850. ::mgb::CompNode comp_node;
  851. Eye() = default;
  852. Eye(int32_t k_, ::megdnn::DType dtype_, ::mgb::CompNode comp_node_, std::string scope_ = {}): k(k_), dtype(dtype_), comp_node(comp_node_) { set_scope(scope_); }
  853. };
  854. class FakeQuant : public OpDefImplBase<FakeQuant> {
  855. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  856. public:
  857. int32_t qmin = -2147483648;
  858. int32_t qmax = 2147483647;
  859. FakeQuant() = default;
  860. FakeQuant(int32_t qmin_, int32_t qmax_, std::string scope_ = {}): qmin(qmin_), qmax(qmax_) { set_scope(scope_); }
  861. FakeQuant(::megdnn::param::FakeQuant packed_param_0): qmin(packed_param_0.qmin), qmax(packed_param_0.qmax) {}
  862. ::megdnn::param::FakeQuant param() const {
  863. return {qmin, qmax};
  864. }
  865. };
  866. class FastpathCopy : public OpDefImplBase<FastpathCopy> {
  867. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  868. public:
  869. FastpathCopy() = default;
  870. };
  871. class GammaRNG : public OpDefImplBase<GammaRNG> {
  872. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  873. public:
  874. uint64_t seed = 0;
  875. size_t handle;
  876. GammaRNG() = default;
  877. GammaRNG(uint64_t seed_, size_t handle_, std::string scope_ = {}): seed(seed_), handle(handle_) { set_scope(scope_); }
  878. GammaRNG(::megdnn::param::GammaRNG packed_param_0, size_t handle_): seed(packed_param_0.seed), handle(handle_) {}
  879. ::megdnn::param::GammaRNG param() const {
  880. return {seed};
  881. }
  882. };
  883. class GaussianRNG : public OpDefImplBase<GaussianRNG> {
  884. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  885. public:
  886. uint64_t seed = 0;
  887. float mean = 0;
  888. float std = 1;
  889. ::megdnn::DType dtype = megdnn::DType::from_enum(megdnn::DTypeEnum::Float32);
  890. size_t handle;
  891. GaussianRNG() = default;
  892. GaussianRNG(uint64_t seed_, float mean_, float std_, ::megdnn::DType dtype_, size_t handle_, std::string scope_ = {}): seed(seed_), mean(mean_), std(std_), dtype(dtype_), handle(handle_) { set_scope(scope_); }
  893. };
  894. class GetVarShape : public OpDefImplBase<GetVarShape> {
  895. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  896. public:
  897. int32_t axis = ::megdnn::param::OptionalAxisV1::INVALID_AXIS;
  898. GetVarShape() = default;
  899. GetVarShape(int32_t axis_, std::string scope_ = {}): axis(axis_) { set_scope(scope_); }
  900. GetVarShape(::megdnn::param::OptionalAxisV1 packed_param_0): axis(packed_param_0.axis) {}
  901. ::megdnn::param::OptionalAxisV1 param() const {
  902. return {axis};
  903. }
  904. };
  905. class GroupLocal : public OpDefImplBase<GroupLocal> {
  906. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  907. public:
  908. using Mode = ::megdnn::param::Convolution::Mode;
  909. using Sparse = ::megdnn::param::Convolution::Sparse;
  910. using Format = ::megdnn::param::Convolution::Format;
  911. using ComputeMode = ::megdnn::param::Convolution::ComputeMode;
  912. Mode mode = ::megdnn::param::Convolution::Mode::CROSS_CORRELATION;
  913. uint32_t pad_h = 0;
  914. uint32_t pad_w = 0;
  915. uint32_t stride_h = 1;
  916. uint32_t stride_w = 1;
  917. uint32_t dilate_h = 1;
  918. uint32_t dilate_w = 1;
  919. Sparse sparse = ::megdnn::param::Convolution::Sparse::DENSE;
  920. Format format = ::megdnn::param::Convolution::Format::NCHW;
  921. ComputeMode compute_mode = ::megdnn::param::Convolution::ComputeMode::DEFAULT;
  922. GroupLocal() = default;
  923. GroupLocal(Mode mode_, uint32_t pad_h_, uint32_t pad_w_, uint32_t stride_h_, uint32_t stride_w_, uint32_t dilate_h_, uint32_t dilate_w_, Sparse sparse_, Format format_, ComputeMode compute_mode_, std::string scope_ = {}): mode(mode_), pad_h(pad_h_), pad_w(pad_w_), stride_h(stride_h_), stride_w(stride_w_), dilate_h(dilate_h_), dilate_w(dilate_w_), sparse(sparse_), format(format_), compute_mode(compute_mode_) { set_scope(scope_); }
  924. GroupLocal(::megdnn::param::Convolution packed_param_0): mode(packed_param_0.mode), pad_h(packed_param_0.pad_h), pad_w(packed_param_0.pad_w), stride_h(packed_param_0.stride_h), stride_w(packed_param_0.stride_w), dilate_h(packed_param_0.dilate_h), dilate_w(packed_param_0.dilate_w), sparse(packed_param_0.sparse), format(packed_param_0.format), compute_mode(packed_param_0.compute_mode) {}
  925. ::megdnn::param::Convolution param() const {
  926. return {mode, pad_h, pad_w, stride_h, stride_w, dilate_h, dilate_w, sparse, format, compute_mode};
  927. }
  928. };
  929. class Identity : public OpDefImplBase<Identity> {
  930. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  931. public:
  932. Identity() = default;
  933. };
  934. class Images2Neibs : public OpDefImplBase<Images2Neibs> {
  935. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  936. public:
  937. uint32_t pad_h = 0;
  938. uint32_t pad_w = 0;
  939. uint32_t stride_h = 1;
  940. uint32_t stride_w = 1;
  941. uint32_t dilate_h = 1;
  942. uint32_t dilate_w = 1;
  943. uint32_t window_h = 3;
  944. uint32_t window_w = 3;
  945. Images2Neibs() = default;
  946. Images2Neibs(uint32_t pad_h_, uint32_t pad_w_, uint32_t stride_h_, uint32_t stride_w_, uint32_t dilate_h_, uint32_t dilate_w_, uint32_t window_h_, uint32_t window_w_, std::string scope_ = {}): pad_h(pad_h_), pad_w(pad_w_), stride_h(stride_h_), stride_w(stride_w_), dilate_h(dilate_h_), dilate_w(dilate_w_), window_h(window_h_), window_w(window_w_) { set_scope(scope_); }
  947. Images2Neibs(::megdnn::param::Images2Neibs packed_param_0): pad_h(packed_param_0.pad_h), pad_w(packed_param_0.pad_w), stride_h(packed_param_0.stride_h), stride_w(packed_param_0.stride_w), dilate_h(packed_param_0.dilate_h), dilate_w(packed_param_0.dilate_w), window_h(packed_param_0.window_h), window_w(packed_param_0.window_w) {}
  948. ::megdnn::param::Images2Neibs param() const {
  949. return {pad_h, pad_w, stride_h, stride_w, dilate_h, dilate_w, window_h, window_w};
  950. }
  951. };
  952. class IncrMeshIndexing : public OpDefImplBase<IncrMeshIndexing> {
  953. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  954. public:
  955. std::vector<std::tuple<int8_t, bool, bool, bool, bool>> items;
  956. IncrMeshIndexing() = default;
  957. IncrMeshIndexing(std::vector<std::tuple<int8_t, bool, bool, bool, bool>> items_, std::string scope_ = {}): items(items_) { set_scope(scope_); }
  958. };
  959. class IncrSubtensor : public OpDefImplBase<IncrSubtensor> {
  960. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  961. public:
  962. std::vector<std::tuple<int8_t, bool, bool, bool, bool>> items;
  963. IncrSubtensor() = default;
  964. IncrSubtensor(std::vector<std::tuple<int8_t, bool, bool, bool, bool>> items_, std::string scope_ = {}): items(items_) { set_scope(scope_); }
  965. };
  966. class IndexingIncrMultiAxisVec : public OpDefImplBase<IndexingIncrMultiAxisVec> {
  967. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  968. public:
  969. std::vector<std::tuple<int8_t, bool, bool, bool, bool>> items;
  970. IndexingIncrMultiAxisVec() = default;
  971. IndexingIncrMultiAxisVec(std::vector<std::tuple<int8_t, bool, bool, bool, bool>> items_, std::string scope_ = {}): items(items_) { set_scope(scope_); }
  972. };
  973. class IndexingMultiAxisVec : public OpDefImplBase<IndexingMultiAxisVec> {
  974. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  975. public:
  976. std::vector<std::tuple<int8_t, bool, bool, bool, bool>> items;
  977. IndexingMultiAxisVec() = default;
  978. IndexingMultiAxisVec(std::vector<std::tuple<int8_t, bool, bool, bool, bool>> items_, std::string scope_ = {}): items(items_) { set_scope(scope_); }
  979. };
  980. class IndexingOneHot : public OpDefImplBase<IndexingOneHot> {
  981. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  982. public:
  983. int32_t axis = 0;
  984. int32_t ndim;
  985. IndexingOneHot() = default;
  986. IndexingOneHot(int32_t axis_, int32_t ndim_, std::string scope_ = {}): axis(axis_), ndim(ndim_) { set_scope(scope_); }
  987. IndexingOneHot(::megdnn::param::Axis packed_param_0, int32_t ndim_): axis(packed_param_0.axis), ndim(ndim_) {}
  988. ::megdnn::param::Axis param() const {
  989. return {axis};
  990. }
  991. };
  992. class IndexingSetMultiAxisVec : public OpDefImplBase<IndexingSetMultiAxisVec> {
  993. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  994. public:
  995. std::vector<std::tuple<int8_t, bool, bool, bool, bool>> items;
  996. IndexingSetMultiAxisVec() = default;
  997. IndexingSetMultiAxisVec(std::vector<std::tuple<int8_t, bool, bool, bool, bool>> items_, std::string scope_ = {}): items(items_) { set_scope(scope_); }
  998. };
  999. class IndexingSetOneHot : public OpDefImplBase<IndexingSetOneHot> {
  1000. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  1001. public:
  1002. int32_t axis = 0;
  1003. int32_t ndim;
  1004. IndexingSetOneHot() = default;
  1005. IndexingSetOneHot(int32_t axis_, int32_t ndim_, std::string scope_ = {}): axis(axis_), ndim(ndim_) { set_scope(scope_); }
  1006. IndexingSetOneHot(::megdnn::param::Axis packed_param_0, int32_t ndim_): axis(packed_param_0.axis), ndim(ndim_) {}
  1007. ::megdnn::param::Axis param() const {
  1008. return {axis};
  1009. }
  1010. };
  1011. class InplaceAdd : public OpDefImplBase<InplaceAdd> {
  1012. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  1013. public:
  1014. InplaceAdd() = default;
  1015. InplaceAdd(::megdnn::param::Empty) {}
  1016. ::megdnn::param::Empty param() const {
  1017. return {};
  1018. }
  1019. };
  1020. class LAMBUpdate : public OpDefImplBase<LAMBUpdate> {
  1021. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  1022. public:
  1023. float beta_1 = 1.f;
  1024. float beta_2 = 1.f;
  1025. float step = 1.f;
  1026. float lr = 1.f;
  1027. float weight_decay = 1.f;
  1028. float eps = 1.f;
  1029. bool bias_correction = true;
  1030. bool always_adapt = false;
  1031. LAMBUpdate() = default;
  1032. LAMBUpdate(float beta_1_, float beta_2_, float step_, float lr_, float weight_decay_, float eps_, bool bias_correction_, bool always_adapt_, std::string scope_ = {}): beta_1(beta_1_), beta_2(beta_2_), step(step_), lr(lr_), weight_decay(weight_decay_), eps(eps_), bias_correction(bias_correction_), always_adapt(always_adapt_) { set_scope(scope_); }
  1033. LAMBUpdate(::megdnn::param::LAMBUpdate packed_param_0): beta_1(packed_param_0.beta_1), beta_2(packed_param_0.beta_2), step(packed_param_0.step), lr(packed_param_0.lr), weight_decay(packed_param_0.weight_decay), eps(packed_param_0.eps), bias_correction(packed_param_0.bias_correction), always_adapt(packed_param_0.always_adapt) {}
  1034. ::megdnn::param::LAMBUpdate param() const {
  1035. return {beta_1, beta_2, step, lr, weight_decay, eps, bias_correction, always_adapt};
  1036. }
  1037. };
  1038. class LRN : public OpDefImplBase<LRN> {
  1039. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  1040. public:
  1041. uint32_t n = 5;
  1042. float k = 2.f;
  1043. float alpha = 1e-4f;
  1044. float beta = 0.75f;
  1045. LRN() = default;
  1046. LRN(uint32_t n_, float k_, float alpha_, float beta_, std::string scope_ = {}): n(n_), k(k_), alpha(alpha_), beta(beta_) { set_scope(scope_); }
  1047. LRN(::megdnn::param::LRN packed_param_0): n(packed_param_0.n), k(packed_param_0.k), alpha(packed_param_0.alpha), beta(packed_param_0.beta) {}
  1048. ::megdnn::param::LRN param() const {
  1049. return {n, k, alpha, beta};
  1050. }
  1051. };
  1052. class LSQ : public OpDefImplBase<LSQ> {
  1053. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  1054. public:
  1055. int32_t qmin = -2147483648;
  1056. int32_t qmax = 2147483647;
  1057. LSQ() = default;
  1058. LSQ(int32_t qmin_, int32_t qmax_, std::string scope_ = {}): qmin(qmin_), qmax(qmax_) { set_scope(scope_); }
  1059. LSQ(::megdnn::param::LSQ packed_param_0): qmin(packed_param_0.qmin), qmax(packed_param_0.qmax) {}
  1060. ::megdnn::param::LSQ param() const {
  1061. return {qmin, qmax};
  1062. }
  1063. };
  1064. class LSTM : public OpDefImplBase<LSTM> {
  1065. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  1066. public:
  1067. using FwdMode = ::megdnn::param::LSTM::FwdMode;
  1068. uint32_t num_layers = 1;
  1069. bool bidirectional = false;
  1070. bool bias = true;
  1071. uint32_t hidden_size = 128;
  1072. uint32_t proj_size = 0;
  1073. float dropout = 0.f;
  1074. FwdMode fwd_mode = ::megdnn::param::LSTM::FwdMode::TRAINING;
  1075. LSTM() = default;
  1076. LSTM(uint32_t num_layers_, bool bidirectional_, bool bias_, uint32_t hidden_size_, uint32_t proj_size_, float dropout_, FwdMode fwd_mode_, std::string scope_ = {}): num_layers(num_layers_), bidirectional(bidirectional_), bias(bias_), hidden_size(hidden_size_), proj_size(proj_size_), dropout(dropout_), fwd_mode(fwd_mode_) { set_scope(scope_); }
  1077. LSTM(::megdnn::param::LSTM packed_param_0): num_layers(packed_param_0.num_layers), bidirectional(packed_param_0.bidirectional), bias(packed_param_0.bias), hidden_size(packed_param_0.hidden_size), proj_size(packed_param_0.proj_size), dropout(packed_param_0.dropout), fwd_mode(packed_param_0.fwd_mode) {}
  1078. ::megdnn::param::LSTM param() const {
  1079. return {num_layers, bidirectional, bias, hidden_size, proj_size, dropout, fwd_mode};
  1080. }
  1081. };
  1082. class LSTMCell : public OpDefImplBase<LSTMCell> {
  1083. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  1084. public:
  1085. LSTMCell() = default;
  1086. LSTMCell(::megdnn::param::Empty) {}
  1087. ::megdnn::param::Empty param() const {
  1088. return {};
  1089. }
  1090. };
  1091. class LayerNorm : public OpDefImplBase<LayerNorm> {
  1092. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  1093. public:
  1094. bool affine = true;
  1095. float eps = 1e-5f;
  1096. uint64_t normalized_dim = 1;
  1097. uint64_t normalized_size = 1;
  1098. LayerNorm() = default;
  1099. LayerNorm(bool affine_, float eps_, uint64_t normalized_dim_, uint64_t normalized_size_, std::string scope_ = {}): affine(affine_), eps(eps_), normalized_dim(normalized_dim_), normalized_size(normalized_size_) { set_scope(scope_); }
  1100. LayerNorm(::megdnn::param::LayerNorm packed_param_0): affine(packed_param_0.affine), eps(packed_param_0.eps), normalized_dim(packed_param_0.normalized_dim), normalized_size(packed_param_0.normalized_size) {}
  1101. ::megdnn::param::LayerNorm param() const {
  1102. return {affine, eps, normalized_dim, normalized_size};
  1103. }
  1104. };
  1105. class Linspace : public OpDefImplBase<Linspace> {
  1106. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  1107. public:
  1108. bool endpoint = true;
  1109. ::mgb::CompNode comp_node;
  1110. Linspace() = default;
  1111. Linspace(bool endpoint_, ::mgb::CompNode comp_node_, std::string scope_ = {}): endpoint(endpoint_), comp_node(comp_node_) { set_scope(scope_); }
  1112. Linspace(::megdnn::param::Linspace packed_param_0, ::mgb::CompNode comp_node_): endpoint(packed_param_0.endpoint), comp_node(comp_node_) {}
  1113. ::megdnn::param::Linspace param() const {
  1114. return {endpoint};
  1115. }
  1116. };
  1117. class MagicMindRuntime : public OpDefImplBase<MagicMindRuntime> {
  1118. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  1119. public:
  1120. std::string buf;
  1121. size_t buf_size;
  1122. MagicMindRuntime() = default;
  1123. MagicMindRuntime(std::string buf_, size_t buf_size_, std::string scope_ = {}): buf(buf_), buf_size(buf_size_) { set_scope(scope_); }
  1124. };
  1125. class MatrixInverse : public OpDefImplBase<MatrixInverse> {
  1126. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  1127. public:
  1128. MatrixInverse() = default;
  1129. MatrixInverse(::megdnn::param::Empty) {}
  1130. ::megdnn::param::Empty param() const {
  1131. return {};
  1132. }
  1133. };
  1134. class MatrixMul : public OpDefImplBase<MatrixMul> {
  1135. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  1136. public:
  1137. using ComputeMode = ::megdnn::param::MatrixMul::ComputeMode;
  1138. using Format = ::megdnn::param::MatrixMul::Format;
  1139. using Strategy = ::megdnn::param::ExecutionPolicy::Strategy;
  1140. bool transposeA = false;
  1141. bool transposeB = false;
  1142. ComputeMode compute_mode = ::megdnn::param::MatrixMul::ComputeMode::DEFAULT;
  1143. Format format = ::megdnn::param::MatrixMul::Format::DEFAULT;
  1144. Strategy strategy = static_cast<::megdnn::param::ExecutionPolicy::Strategy>(1);
  1145. uint64_t workspace_limit = 18446744073709551615ull;
  1146. uint32_t dimA;
  1147. uint32_t dimB;
  1148. MatrixMul() = default;
  1149. MatrixMul(bool transposeA_, bool transposeB_, ComputeMode compute_mode_, Format format_, Strategy strategy_, uint64_t workspace_limit_, uint32_t dimA_, uint32_t dimB_, std::string scope_ = {}): transposeA(transposeA_), transposeB(transposeB_), compute_mode(compute_mode_), format(format_), strategy(strategy_), workspace_limit(workspace_limit_), dimA(dimA_), dimB(dimB_) {
  1150. set_scope(scope_);
  1151. mgb_assert(static_cast<uint32_t>(strategy) <= uint32_t(8));
  1152. }
  1153. MatrixMul(::megdnn::param::MatrixMul packed_param_0, ::megdnn::param::ExecutionPolicy packed_param_1, uint32_t dimA_, uint32_t dimB_): transposeA(packed_param_0.transposeA), transposeB(packed_param_0.transposeB), compute_mode(packed_param_0.compute_mode), format(packed_param_0.format), strategy(packed_param_1.strategy), workspace_limit(packed_param_1.workspace_limit), dimA(dimA_), dimB(dimB_) {
  1154. mgb_assert(static_cast<uint32_t>(strategy) <= uint32_t(8));
  1155. }
  1156. ::megdnn::param::MatrixMul param() const {
  1157. return {transposeA, transposeB, compute_mode, format};
  1158. }
  1159. ::megdnn::param::ExecutionPolicy policy() const {
  1160. return {strategy, workspace_limit};
  1161. }
  1162. };
  1163. class MeshIndexing : public OpDefImplBase<MeshIndexing> {
  1164. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  1165. public:
  1166. std::vector<std::tuple<int8_t, bool, bool, bool, bool>> items;
  1167. MeshIndexing() = default;
  1168. MeshIndexing(std::vector<std::tuple<int8_t, bool, bool, bool, bool>> items_, std::string scope_ = {}): items(items_) { set_scope(scope_); }
  1169. };
  1170. class NMSKeep : public OpDefImplBase<NMSKeep> {
  1171. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  1172. public:
  1173. float iou_thresh;
  1174. uint32_t max_output;
  1175. NMSKeep() = default;
  1176. NMSKeep(float iou_thresh_, uint32_t max_output_, std::string scope_ = {}): iou_thresh(iou_thresh_), max_output(max_output_) { set_scope(scope_); }
  1177. };
  1178. class NvOf : public OpDefImplBase<NvOf> {
  1179. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  1180. public:
  1181. uint32_t precision = 1;
  1182. NvOf() = default;
  1183. NvOf(uint32_t precision_, std::string scope_ = {}): precision(precision_) { set_scope(scope_); }
  1184. NvOf(::megdnn::param::NvOf packed_param_0): precision(packed_param_0.precision) {}
  1185. ::megdnn::param::NvOf param() const {
  1186. return {precision};
  1187. }
  1188. };
  1189. class Padding : public OpDefImplBase<Padding> {
  1190. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  1191. public:
  1192. using PaddingMode = ::megdnn::param::Padding::PaddingMode;
  1193. uint32_t front_offset_dim0 = 0;
  1194. uint32_t front_offset_dim1 = 0;
  1195. uint32_t front_offset_dim2 = 0;
  1196. uint32_t front_offset_dim3 = 0;
  1197. uint32_t front_offset_dim4 = 0;
  1198. uint32_t front_offset_dim5 = 0;
  1199. uint32_t front_offset_dim6 = 0;
  1200. uint32_t back_offset_dim0 = 0;
  1201. uint32_t back_offset_dim1 = 0;
  1202. uint32_t back_offset_dim2 = 0;
  1203. uint32_t back_offset_dim3 = 0;
  1204. uint32_t back_offset_dim4 = 0;
  1205. uint32_t back_offset_dim5 = 0;
  1206. uint32_t back_offset_dim6 = 0;
  1207. float padding_val = 0;
  1208. PaddingMode padding_mode = ::megdnn::param::Padding::PaddingMode::CONSTANT;
  1209. Padding() = default;
  1210. Padding(uint32_t front_offset_dim0_, uint32_t front_offset_dim1_, uint32_t front_offset_dim2_, uint32_t front_offset_dim3_, uint32_t front_offset_dim4_, uint32_t front_offset_dim5_, uint32_t front_offset_dim6_, uint32_t back_offset_dim0_, uint32_t back_offset_dim1_, uint32_t back_offset_dim2_, uint32_t back_offset_dim3_, uint32_t back_offset_dim4_, uint32_t back_offset_dim5_, uint32_t back_offset_dim6_, float padding_val_, PaddingMode padding_mode_, std::string scope_ = {}): front_offset_dim0(front_offset_dim0_), front_offset_dim1(front_offset_dim1_), front_offset_dim2(front_offset_dim2_), front_offset_dim3(front_offset_dim3_), front_offset_dim4(front_offset_dim4_), front_offset_dim5(front_offset_dim5_), front_offset_dim6(front_offset_dim6_), back_offset_dim0(back_offset_dim0_), back_offset_dim1(back_offset_dim1_), back_offset_dim2(back_offset_dim2_), back_offset_dim3(back_offset_dim3_), back_offset_dim4(back_offset_dim4_), back_offset_dim5(back_offset_dim5_), back_offset_dim6(back_offset_dim6_), padding_val(padding_val_), padding_mode(padding_mode_) { set_scope(scope_); }
  1211. Padding(::megdnn::param::Padding packed_param_0): front_offset_dim0(packed_param_0.front_offset_dim0), front_offset_dim1(packed_param_0.front_offset_dim1), front_offset_dim2(packed_param_0.front_offset_dim2), front_offset_dim3(packed_param_0.front_offset_dim3), front_offset_dim4(packed_param_0.front_offset_dim4), front_offset_dim5(packed_param_0.front_offset_dim5), front_offset_dim6(packed_param_0.front_offset_dim6), back_offset_dim0(packed_param_0.back_offset_dim0), back_offset_dim1(packed_param_0.back_offset_dim1), back_offset_dim2(packed_param_0.back_offset_dim2), back_offset_dim3(packed_param_0.back_offset_dim3), back_offset_dim4(packed_param_0.back_offset_dim4), back_offset_dim5(packed_param_0.back_offset_dim5), back_offset_dim6(packed_param_0.back_offset_dim6), padding_val(packed_param_0.padding_val), padding_mode(packed_param_0.padding_mode) {}
  1212. ::megdnn::param::Padding param() const {
  1213. return {front_offset_dim0, front_offset_dim1, front_offset_dim2, front_offset_dim3, front_offset_dim4, front_offset_dim5, front_offset_dim6, back_offset_dim0, back_offset_dim1, back_offset_dim2, back_offset_dim3, back_offset_dim4, back_offset_dim5, back_offset_dim6, padding_val, padding_mode};
  1214. }
  1215. };
  1216. class ParamPackConcat : public OpDefImplBase<ParamPackConcat> {
  1217. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  1218. public:
  1219. std::vector<int32_t> offsets;
  1220. ParamPackConcat() = default;
  1221. ParamPackConcat(std::vector<int32_t> offsets_, std::string scope_ = {}): offsets(offsets_) { set_scope(scope_); }
  1222. };
  1223. class ParamPackSplit : public OpDefImplBase<ParamPackSplit> {
  1224. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  1225. public:
  1226. std::vector<int32_t> offsets;
  1227. std::vector<std::vector<size_t>> shapes;
  1228. ParamPackSplit() = default;
  1229. ParamPackSplit(std::vector<int32_t> offsets_, std::vector<std::vector<size_t>> shapes_, std::string scope_ = {}): offsets(offsets_), shapes(shapes_) { set_scope(scope_); }
  1230. };
  1231. class PermutationRNG : public OpDefImplBase<PermutationRNG> {
  1232. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  1233. public:
  1234. uint64_t seed = 0;
  1235. ::megdnn::DType dtype = megdnn::DType::from_enum(megdnn::DTypeEnum::Int32);
  1236. size_t handle;
  1237. PermutationRNG() = default;
  1238. PermutationRNG(uint64_t seed_, ::megdnn::DType dtype_, size_t handle_, std::string scope_ = {}): seed(seed_), dtype(dtype_), handle(handle_) { set_scope(scope_); }
  1239. };
  1240. class PixelShuffle : public OpDefImplBase<PixelShuffle> {
  1241. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  1242. public:
  1243. int32_t factor;
  1244. PixelShuffle() = default;
  1245. PixelShuffle(int32_t factor_, std::string scope_ = {}): factor(factor_) { set_scope(scope_); }
  1246. };
  1247. class PixelShuffleBackward : public OpDefImplBase<PixelShuffleBackward> {
  1248. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  1249. public:
  1250. int32_t factor;
  1251. PixelShuffleBackward() = default;
  1252. PixelShuffleBackward(int32_t factor_, std::string scope_ = {}): factor(factor_) { set_scope(scope_); }
  1253. };
  1254. class PoissonRNG : public OpDefImplBase<PoissonRNG> {
  1255. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  1256. public:
  1257. uint64_t seed = 0;
  1258. size_t handle;
  1259. PoissonRNG() = default;
  1260. PoissonRNG(uint64_t seed_, size_t handle_, std::string scope_ = {}): seed(seed_), handle(handle_) { set_scope(scope_); }
  1261. PoissonRNG(::megdnn::param::PoissonRNG packed_param_0, size_t handle_): seed(packed_param_0.seed), handle(handle_) {}
  1262. ::megdnn::param::PoissonRNG param() const {
  1263. return {seed};
  1264. }
  1265. };
  1266. class Pooling : public OpDefImplBase<Pooling> {
  1267. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  1268. public:
  1269. using Mode = ::megdnn::param::Pooling::Mode;
  1270. using Format = ::megdnn::param::Pooling::Format;
  1271. using Strategy = ::megdnn::param::ExecutionPolicy::Strategy;
  1272. Mode mode = ::megdnn::param::Pooling::Mode::MAX;
  1273. uint32_t pad_h = 0;
  1274. uint32_t pad_w = 0;
  1275. uint32_t stride_h = 2;
  1276. uint32_t stride_w = 2;
  1277. uint32_t window_h = 2;
  1278. uint32_t window_w = 2;
  1279. Format format = ::megdnn::param::Pooling::Format::NCHW;
  1280. Strategy strategy = static_cast<::megdnn::param::ExecutionPolicy::Strategy>(1);
  1281. uint64_t workspace_limit = 18446744073709551615ull;
  1282. Pooling() = default;
  1283. Pooling(Mode mode_, uint32_t pad_h_, uint32_t pad_w_, uint32_t stride_h_, uint32_t stride_w_, uint32_t window_h_, uint32_t window_w_, Format format_, Strategy strategy_, uint64_t workspace_limit_, std::string scope_ = {}): mode(mode_), pad_h(pad_h_), pad_w(pad_w_), stride_h(stride_h_), stride_w(stride_w_), window_h(window_h_), window_w(window_w_), format(format_), strategy(strategy_), workspace_limit(workspace_limit_) {
  1284. set_scope(scope_);
  1285. mgb_assert(static_cast<uint32_t>(strategy) <= uint32_t(8));
  1286. }
  1287. Pooling(::megdnn::param::Pooling packed_param_0, ::megdnn::param::ExecutionPolicy packed_param_1): mode(packed_param_0.mode), pad_h(packed_param_0.pad_h), pad_w(packed_param_0.pad_w), stride_h(packed_param_0.stride_h), stride_w(packed_param_0.stride_w), window_h(packed_param_0.window_h), window_w(packed_param_0.window_w), format(packed_param_0.format), strategy(packed_param_1.strategy), workspace_limit(packed_param_1.workspace_limit) {
  1288. mgb_assert(static_cast<uint32_t>(strategy) <= uint32_t(8));
  1289. }
  1290. ::megdnn::param::Pooling param() const {
  1291. return {mode, pad_h, pad_w, stride_h, stride_w, window_h, window_w, format};
  1292. }
  1293. ::megdnn::param::ExecutionPolicy policy() const {
  1294. return {strategy, workspace_limit};
  1295. }
  1296. };
  1297. class RNN : public OpDefImplBase<RNN> {
  1298. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  1299. public:
  1300. using NonlineMode = ::megdnn::param::RNN::NonlineMode;
  1301. using FwdMode = ::megdnn::param::RNN::FwdMode;
  1302. uint32_t num_layers = 1;
  1303. bool bidirectional = false;
  1304. bool bias = true;
  1305. uint32_t hidden_size = 128;
  1306. float dropout = 0.f;
  1307. NonlineMode nonlineMode = ::megdnn::param::RNN::NonlineMode::IDENTITY;
  1308. FwdMode fwd_mode = ::megdnn::param::RNN::FwdMode::TRAINING;
  1309. RNN() = default;
  1310. RNN(uint32_t num_layers_, bool bidirectional_, bool bias_, uint32_t hidden_size_, float dropout_, NonlineMode nonlineMode_, FwdMode fwd_mode_, std::string scope_ = {}): num_layers(num_layers_), bidirectional(bidirectional_), bias(bias_), hidden_size(hidden_size_), dropout(dropout_), nonlineMode(nonlineMode_), fwd_mode(fwd_mode_) { set_scope(scope_); }
  1311. RNN(::megdnn::param::RNN packed_param_0): num_layers(packed_param_0.num_layers), bidirectional(packed_param_0.bidirectional), bias(packed_param_0.bias), hidden_size(packed_param_0.hidden_size), dropout(packed_param_0.dropout), nonlineMode(packed_param_0.nonlineMode), fwd_mode(packed_param_0.fwd_mode) {}
  1312. ::megdnn::param::RNN param() const {
  1313. return {num_layers, bidirectional, bias, hidden_size, dropout, nonlineMode, fwd_mode};
  1314. }
  1315. };
  1316. class RNNCell : public OpDefImplBase<RNNCell> {
  1317. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  1318. public:
  1319. using NonlineMode = ::megdnn::param::RNNCell::NonlineMode;
  1320. NonlineMode nonlineMode = ::megdnn::param::RNNCell::NonlineMode::IDENTITY;
  1321. RNNCell() = default;
  1322. RNNCell(NonlineMode nonlineMode_, std::string scope_ = {}): nonlineMode(nonlineMode_) { set_scope(scope_); }
  1323. RNNCell(::megdnn::param::RNNCell packed_param_0): nonlineMode(packed_param_0.nonlineMode) {}
  1324. ::megdnn::param::RNNCell param() const {
  1325. return {nonlineMode};
  1326. }
  1327. };
  1328. class ROIAlign : public OpDefImplBase<ROIAlign> {
  1329. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  1330. public:
  1331. using Mode = ::megdnn::param::ROIAlign::Mode;
  1332. using Format = ::megdnn::param::ROIAlign::Format;
  1333. Mode mode = ::megdnn::param::ROIAlign::Mode::MAX;
  1334. Format format = ::megdnn::param::ROIAlign::Format::NCHW;
  1335. float spatial_scale = 1.0;
  1336. float offset = 0.0;
  1337. uint32_t pooled_height = 1;
  1338. uint32_t pooled_width = 1;
  1339. uint32_t sample_height = 2;
  1340. uint32_t sample_width = 2;
  1341. ROIAlign() = default;
  1342. ROIAlign(Mode mode_, Format format_, float spatial_scale_, float offset_, uint32_t pooled_height_, uint32_t pooled_width_, uint32_t sample_height_, uint32_t sample_width_, std::string scope_ = {}): mode(mode_), format(format_), spatial_scale(spatial_scale_), offset(offset_), pooled_height(pooled_height_), pooled_width(pooled_width_), sample_height(sample_height_), sample_width(sample_width_) { set_scope(scope_); }
  1343. ROIAlign(::megdnn::param::ROIAlign packed_param_0): mode(packed_param_0.mode), format(packed_param_0.format), spatial_scale(packed_param_0.spatial_scale), offset(packed_param_0.offset), pooled_height(packed_param_0.pooled_height), pooled_width(packed_param_0.pooled_width), sample_height(packed_param_0.sample_height), sample_width(packed_param_0.sample_width) {}
  1344. ::megdnn::param::ROIAlign param() const {
  1345. return {mode, format, spatial_scale, offset, pooled_height, pooled_width, sample_height, sample_width};
  1346. }
  1347. };
  1348. class ROIPooling : public OpDefImplBase<ROIPooling> {
  1349. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  1350. public:
  1351. using Mode = ::megdnn::param::ROIPooling::Mode;
  1352. Mode mode = ::megdnn::param::ROIPooling::Mode::MAX;
  1353. float scale = 1.f;
  1354. ROIPooling() = default;
  1355. ROIPooling(Mode mode_, float scale_, std::string scope_ = {}): mode(mode_), scale(scale_) { set_scope(scope_); }
  1356. ROIPooling(::megdnn::param::ROIPooling packed_param_0): mode(packed_param_0.mode), scale(packed_param_0.scale) {}
  1357. ::megdnn::param::ROIPooling param() const {
  1358. return {mode, scale};
  1359. }
  1360. };
  1361. class Reduce : public OpDefImplBase<Reduce> {
  1362. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  1363. public:
  1364. using Mode = ::megdnn::param::Reduce::Mode;
  1365. using DataType = ::megdnn::param::Reduce::DataType;
  1366. Mode mode = ::megdnn::param::Reduce::Mode::SUM;
  1367. int32_t axis = 2147483647;
  1368. DataType data_type = ::megdnn::param::Reduce::DataType::DEFAULT;
  1369. bool keepdim = true;
  1370. Reduce() = default;
  1371. Reduce(Mode mode_, int32_t axis_, DataType data_type_, bool keepdim_, std::string scope_ = {}): mode(mode_), axis(axis_), data_type(data_type_), keepdim(keepdim_) { set_scope(scope_); }
  1372. Reduce(::megdnn::param::Reduce packed_param_0, bool keepdim_): mode(packed_param_0.mode), axis(packed_param_0.axis), data_type(packed_param_0.data_type), keepdim(keepdim_) {}
  1373. ::megdnn::param::Reduce param() const {
  1374. return {mode, axis, data_type};
  1375. }
  1376. };
  1377. class Remap : public OpDefImplBase<Remap> {
  1378. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  1379. public:
  1380. using InterpolationMode = ::megdnn::param::Remap::InterpolationMode;
  1381. using BorderMode = ::megdnn::param::Remap::BorderMode;
  1382. using Format = ::megdnn::param::Remap::Format;
  1383. InterpolationMode imode = ::megdnn::param::Remap::InterpolationMode::LINEAR;
  1384. BorderMode border_type = ::megdnn::param::Remap::BorderMode::REPLICATE;
  1385. Format format = ::megdnn::param::Remap::Format::NHWC;
  1386. float scalar = 0.f;
  1387. Remap() = default;
  1388. Remap(InterpolationMode imode_, BorderMode border_type_, Format format_, float scalar_, std::string scope_ = {}): imode(imode_), border_type(border_type_), format(format_), scalar(scalar_) { set_scope(scope_); }
  1389. Remap(::megdnn::param::Remap packed_param_0): imode(packed_param_0.imode), border_type(packed_param_0.border_type), format(packed_param_0.format), scalar(packed_param_0.scalar) {}
  1390. ::megdnn::param::Remap param() const {
  1391. return {imode, border_type, format, scalar};
  1392. }
  1393. };
  1394. class RemoteRecv : public OpDefImplBase<RemoteRecv> {
  1395. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  1396. public:
  1397. std::string key;
  1398. std::string addr;
  1399. uint32_t port;
  1400. uint32_t rank_from;
  1401. ::mgb::CompNode cn;
  1402. std::vector<int32_t> shape;
  1403. ::megdnn::DType dtype;
  1404. std::string backend;
  1405. RemoteRecv() = default;
  1406. RemoteRecv(std::string key_, std::string addr_, uint32_t port_, uint32_t rank_from_, ::mgb::CompNode cn_, std::vector<int32_t> shape_, ::megdnn::DType dtype_, std::string backend_, std::string scope_ = {}): key(key_), addr(addr_), port(port_), rank_from(rank_from_), cn(cn_), shape(shape_), dtype(dtype_), backend(backend_) { set_scope(scope_); }
  1407. };
  1408. class RemoteSend : public OpDefImplBase<RemoteSend> {
  1409. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  1410. public:
  1411. std::string key;
  1412. std::string addr;
  1413. uint32_t port;
  1414. uint32_t rank_to;
  1415. std::string backend;
  1416. RemoteSend() = default;
  1417. RemoteSend(std::string key_, std::string addr_, uint32_t port_, uint32_t rank_to_, std::string backend_, std::string scope_ = {}): key(key_), addr(addr_), port(port_), rank_to(rank_to_), backend(backend_) { set_scope(scope_); }
  1418. };
  1419. class RemoveAxis : public OpDefImplBase<RemoveAxis> {
  1420. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  1421. public:
  1422. std::vector<int32_t> axis;
  1423. RemoveAxis() = default;
  1424. RemoveAxis(std::vector<int32_t> axis_, std::string scope_ = {}): axis(axis_) { set_scope(scope_); }
  1425. };
  1426. class Reshape : public OpDefImplBase<Reshape> {
  1427. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  1428. public:
  1429. int32_t axis = ::megdnn::param::OptionalAxisV1::INVALID_AXIS;
  1430. std::vector<int32_t> shape;
  1431. Reshape() = default;
  1432. Reshape(int32_t axis_, std::vector<int32_t> shape_, std::string scope_ = {}): axis(axis_), shape(shape_) { set_scope(scope_); }
  1433. Reshape(::megdnn::param::OptionalAxisV1 packed_param_0, std::vector<int32_t> shape_): axis(packed_param_0.axis), shape(shape_) {}
  1434. ::megdnn::param::OptionalAxisV1 param() const {
  1435. return {axis};
  1436. }
  1437. };
  1438. class Resize : public OpDefImplBase<Resize> {
  1439. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  1440. public:
  1441. using InterpolationMode = ::megdnn::param::Resize::InterpolationMode;
  1442. using Format = ::megdnn::param::Resize::Format;
  1443. InterpolationMode imode = ::megdnn::param::Resize::InterpolationMode::LINEAR;
  1444. Format format = ::megdnn::param::Resize::Format::NHWC;
  1445. Resize() = default;
  1446. Resize(InterpolationMode imode_, Format format_, std::string scope_ = {}): imode(imode_), format(format_) { set_scope(scope_); }
  1447. Resize(::megdnn::param::Resize packed_param_0): imode(packed_param_0.imode), format(packed_param_0.format) {}
  1448. ::megdnn::param::Resize param() const {
  1449. return {imode, format};
  1450. }
  1451. };
  1452. class SVD : public OpDefImplBase<SVD> {
  1453. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  1454. public:
  1455. bool full_matrices = false;
  1456. bool compute_uv = true;
  1457. SVD() = default;
  1458. SVD(bool full_matrices_, bool compute_uv_, std::string scope_ = {}): full_matrices(full_matrices_), compute_uv(compute_uv_) { set_scope(scope_); }
  1459. SVD(::megdnn::param::SVD packed_param_0): full_matrices(packed_param_0.full_matrices), compute_uv(packed_param_0.compute_uv) {}
  1460. ::megdnn::param::SVD param() const {
  1461. return {full_matrices, compute_uv};
  1462. }
  1463. };
  1464. class SetMeshIndexing : public OpDefImplBase<SetMeshIndexing> {
  1465. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  1466. public:
  1467. std::vector<std::tuple<int8_t, bool, bool, bool, bool>> items;
  1468. SetMeshIndexing() = default;
  1469. SetMeshIndexing(std::vector<std::tuple<int8_t, bool, bool, bool, bool>> items_, std::string scope_ = {}): items(items_) { set_scope(scope_); }
  1470. };
  1471. class SetSubtensor : public OpDefImplBase<SetSubtensor> {
  1472. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  1473. public:
  1474. std::vector<std::tuple<int8_t, bool, bool, bool, bool>> items;
  1475. SetSubtensor() = default;
  1476. SetSubtensor(std::vector<std::tuple<int8_t, bool, bool, bool, bool>> items_, std::string scope_ = {}): items(items_) { set_scope(scope_); }
  1477. };
  1478. class ShuffleRNG : public OpDefImplBase<ShuffleRNG> {
  1479. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  1480. public:
  1481. uint64_t seed = 0;
  1482. size_t handle;
  1483. ShuffleRNG() = default;
  1484. ShuffleRNG(uint64_t seed_, size_t handle_, std::string scope_ = {}): seed(seed_), handle(handle_) { set_scope(scope_); }
  1485. ShuffleRNG(::megdnn::param::ShuffleRNG packed_param_0, size_t handle_): seed(packed_param_0.seed), handle(handle_) {}
  1486. ::megdnn::param::ShuffleRNG param() const {
  1487. return {seed};
  1488. }
  1489. };
  1490. class SlidingWindowTranspose : public OpDefImplBase<SlidingWindowTranspose> {
  1491. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  1492. public:
  1493. uint32_t out_h = 0;
  1494. uint32_t out_w = 0;
  1495. uint32_t pad_h = 0;
  1496. uint32_t pad_w = 0;
  1497. uint32_t stride_h = 1;
  1498. uint32_t stride_w = 1;
  1499. uint32_t dilate_h = 1;
  1500. uint32_t dilate_w = 1;
  1501. uint32_t window_h = 3;
  1502. uint32_t window_w = 3;
  1503. SlidingWindowTranspose() = default;
  1504. SlidingWindowTranspose(uint32_t out_h_, uint32_t out_w_, uint32_t pad_h_, uint32_t pad_w_, uint32_t stride_h_, uint32_t stride_w_, uint32_t dilate_h_, uint32_t dilate_w_, uint32_t window_h_, uint32_t window_w_, std::string scope_ = {}): out_h(out_h_), out_w(out_w_), pad_h(pad_h_), pad_w(pad_w_), stride_h(stride_h_), stride_w(stride_w_), dilate_h(dilate_h_), dilate_w(dilate_w_), window_h(window_h_), window_w(window_w_) { set_scope(scope_); }
  1505. SlidingWindowTranspose(::megdnn::param::SlidingWindowTranspose packed_param_0): out_h(packed_param_0.out_h), out_w(packed_param_0.out_w), pad_h(packed_param_0.pad_h), pad_w(packed_param_0.pad_w), stride_h(packed_param_0.stride_h), stride_w(packed_param_0.stride_w), dilate_h(packed_param_0.dilate_h), dilate_w(packed_param_0.dilate_w), window_h(packed_param_0.window_h), window_w(packed_param_0.window_w) {}
  1506. ::megdnn::param::SlidingWindowTranspose param() const {
  1507. return {out_h, out_w, pad_h, pad_w, stride_h, stride_w, dilate_h, dilate_w, window_h, window_w};
  1508. }
  1509. };
  1510. class Softmax : public OpDefImplBase<Softmax> {
  1511. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  1512. public:
  1513. int32_t axis = -1;
  1514. Softmax() = default;
  1515. Softmax(int32_t axis_, std::string scope_ = {}): axis(axis_) { set_scope(scope_); }
  1516. Softmax(::megdnn::param::Softmax packed_param_0): axis(packed_param_0.axis) {}
  1517. ::megdnn::param::Softmax param() const {
  1518. return {axis};
  1519. }
  1520. };
  1521. class Split : public OpDefImplBase<Split> {
  1522. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  1523. public:
  1524. int32_t axis;
  1525. int32_t nsections;
  1526. Split() = default;
  1527. Split(int32_t axis_, int32_t nsections_, std::string scope_ = {}): axis(axis_), nsections(nsections_) { set_scope(scope_); }
  1528. Split(::megdnn::param::Empty, int32_t axis_, int32_t nsections_): axis(axis_), nsections(nsections_) {}
  1529. ::megdnn::param::Empty param() const {
  1530. return {};
  1531. }
  1532. };
  1533. class Subtensor : public OpDefImplBase<Subtensor> {
  1534. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  1535. public:
  1536. std::vector<std::tuple<int8_t, bool, bool, bool, bool>> items;
  1537. Subtensor() = default;
  1538. Subtensor(std::vector<std::tuple<int8_t, bool, bool, bool, bool>> items_, std::string scope_ = {}): items(items_) { set_scope(scope_); }
  1539. };
  1540. class TQT : public OpDefImplBase<TQT> {
  1541. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  1542. public:
  1543. int32_t qmin = -2147483648;
  1544. int32_t qmax = 2147483647;
  1545. TQT() = default;
  1546. TQT(int32_t qmin_, int32_t qmax_, std::string scope_ = {}): qmin(qmin_), qmax(qmax_) { set_scope(scope_); }
  1547. TQT(::megdnn::param::TQT packed_param_0): qmin(packed_param_0.qmin), qmax(packed_param_0.qmax) {}
  1548. ::megdnn::param::TQT param() const {
  1549. return {qmin, qmax};
  1550. }
  1551. };
  1552. class TensorRTRuntime : public OpDefImplBase<TensorRTRuntime> {
  1553. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  1554. public:
  1555. std::string buf;
  1556. size_t buf_size;
  1557. TensorRTRuntime() = default;
  1558. TensorRTRuntime(std::string buf_, size_t buf_size_, std::string scope_ = {}): buf(buf_), buf_size(buf_size_) { set_scope(scope_); }
  1559. };
  1560. class TopK : public OpDefImplBase<TopK> {
  1561. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  1562. public:
  1563. using Mode = ::megdnn::param::TopK::Mode;
  1564. Mode mode = ::megdnn::param::TopK::Mode::KTH_ONLY;
  1565. TopK() = default;
  1566. TopK(Mode mode_, std::string scope_ = {}): mode(mode_) { set_scope(scope_); }
  1567. TopK(::megdnn::param::TopK packed_param_0): mode(packed_param_0.mode) {}
  1568. ::megdnn::param::TopK param() const {
  1569. return {mode};
  1570. }
  1571. };
  1572. class TypeCvt : public OpDefImplBase<TypeCvt> {
  1573. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  1574. public:
  1575. ::megdnn::DType dtype;
  1576. TypeCvt() = default;
  1577. TypeCvt(::megdnn::DType dtype_, std::string scope_ = {}): dtype(dtype_) { set_scope(scope_); }
  1578. };
  1579. class UniformRNG : public OpDefImplBase<UniformRNG> {
  1580. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  1581. public:
  1582. uint64_t seed = 0;
  1583. ::megdnn::DType dtype = megdnn::DType::from_enum(megdnn::DTypeEnum::Float32);
  1584. size_t handle;
  1585. UniformRNG() = default;
  1586. UniformRNG(uint64_t seed_, ::megdnn::DType dtype_, size_t handle_, std::string scope_ = {}): seed(seed_), dtype(dtype_), handle(handle_) { set_scope(scope_); }
  1587. };
  1588. class WarpAffine : public OpDefImplBase<WarpAffine> {
  1589. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  1590. public:
  1591. using InterpolationMode = ::megdnn::param::WarpAffine::InterpolationMode;
  1592. using BorderMode = ::megdnn::param::WarpAffine::BorderMode;
  1593. using Format = ::megdnn::param::WarpAffine::Format;
  1594. InterpolationMode imode = ::megdnn::param::WarpAffine::InterpolationMode::LINEAR;
  1595. BorderMode border_mode = ::megdnn::param::WarpAffine::BorderMode::REPLICATE;
  1596. float border_val = .0f;
  1597. Format format = ::megdnn::param::WarpAffine::Format::NHWC;
  1598. WarpAffine() = default;
  1599. WarpAffine(InterpolationMode imode_, BorderMode border_mode_, float border_val_, Format format_, std::string scope_ = {}): imode(imode_), border_mode(border_mode_), border_val(border_val_), format(format_) { set_scope(scope_); }
  1600. WarpAffine(::megdnn::param::WarpAffine packed_param_0): imode(packed_param_0.imode), border_mode(packed_param_0.border_mode), border_val(packed_param_0.border_val), format(packed_param_0.format) {}
  1601. ::megdnn::param::WarpAffine param() const {
  1602. return {imode, border_mode, border_val, format};
  1603. }
  1604. };
  1605. class WarpPerspective : public OpDefImplBase<WarpPerspective> {
  1606. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  1607. public:
  1608. using InterpolationMode = ::megdnn::param::WarpPerspective::InterpolationMode;
  1609. using BorderMode = ::megdnn::param::WarpPerspective::BorderMode;
  1610. using Format = ::megdnn::param::WarpPerspective::Format;
  1611. InterpolationMode imode = ::megdnn::param::WarpPerspective::InterpolationMode::LINEAR;
  1612. BorderMode bmode = ::megdnn::param::WarpPerspective::BorderMode::REPLICATE;
  1613. Format format = ::megdnn::param::WarpPerspective::Format::NCHW;
  1614. float border_val = .0f;
  1615. WarpPerspective() = default;
  1616. WarpPerspective(InterpolationMode imode_, BorderMode bmode_, Format format_, float border_val_, std::string scope_ = {}): imode(imode_), bmode(bmode_), format(format_), border_val(border_val_) { set_scope(scope_); }
  1617. WarpPerspective(::megdnn::param::WarpPerspective packed_param_0): imode(packed_param_0.imode), bmode(packed_param_0.bmode), format(packed_param_0.format), border_val(packed_param_0.border_val) {}
  1618. ::megdnn::param::WarpPerspective param() const {
  1619. return {imode, bmode, format, border_val};
  1620. }
  1621. };
  1622. // clang-format on