You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

opdef.py.inl 116 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944
  1. // clang-format off
  2. py::class_<AdaptivePooling, std::shared_ptr<AdaptivePooling>, OpDef> AdaptivePoolingInst(m, "AdaptivePooling");
  3. py::enum_<AdaptivePooling::Mode>(AdaptivePoolingInst, "Mode")
  4. .value("MAX", AdaptivePooling::Mode::MAX)
  5. .value("AVERAGE", AdaptivePooling::Mode::AVERAGE)
  6. .value("AVERAGE_COUNT_EXCLUDE_PADDING", AdaptivePooling::Mode::AVERAGE_COUNT_EXCLUDE_PADDING)
  7. .def(py::init([](const std::string& in) {
  8. auto&& str = normalize_enum(in);
  9. if (str == "MAX") return AdaptivePooling::Mode::MAX;
  10. if (str == "AVERAGE") return AdaptivePooling::Mode::AVERAGE;
  11. if (str == "AVERAGE_COUNT_EXCLUDE_PADDING") return AdaptivePooling::Mode::AVERAGE_COUNT_EXCLUDE_PADDING;
  12. throw py::cast_error("invalid enum value " + in);
  13. }));
  14. py::implicitly_convertible<std::string, AdaptivePooling::Mode>();
  15. py::enum_<AdaptivePooling::Format>(AdaptivePoolingInst, "Format")
  16. .value("NCHW", AdaptivePooling::Format::NCHW)
  17. .value("NHWC", AdaptivePooling::Format::NHWC)
  18. .value("NHWCD4", AdaptivePooling::Format::NHWCD4)
  19. .value("NCHW4", AdaptivePooling::Format::NCHW4)
  20. .value("NCHW8", AdaptivePooling::Format::NCHW8)
  21. .value("NCHW32", AdaptivePooling::Format::NCHW32)
  22. .value("NCHW88", AdaptivePooling::Format::NCHW88)
  23. .value("NCHW44", AdaptivePooling::Format::NCHW44)
  24. .value("NCHW44_DOT", AdaptivePooling::Format::NCHW44_DOT)
  25. .value("NCHW4_NCHW32", AdaptivePooling::Format::NCHW4_NCHW32)
  26. .value("NCHW32_NCHW4", AdaptivePooling::Format::NCHW32_NCHW4)
  27. .value("NCHW4_NCHW", AdaptivePooling::Format::NCHW4_NCHW)
  28. .value("NHWC_NCHW", AdaptivePooling::Format::NHWC_NCHW)
  29. .value("NHWC_NCHW4_IC_SMALL", AdaptivePooling::Format::NHWC_NCHW4_IC_SMALL)
  30. .value("NCHW_NCHW4_IC_SMALL", AdaptivePooling::Format::NCHW_NCHW4_IC_SMALL)
  31. .value("CHWN4", AdaptivePooling::Format::CHWN4)
  32. .value("NCHW64", AdaptivePooling::Format::NCHW64)
  33. .value("NCHW4_NHWC", AdaptivePooling::Format::NCHW4_NHWC)
  34. .def(py::init([](const std::string& in) {
  35. auto&& str = normalize_enum(in);
  36. if (str == "NCHW") return AdaptivePooling::Format::NCHW;
  37. if (str == "NHWC") return AdaptivePooling::Format::NHWC;
  38. if (str == "NHWCD4") return AdaptivePooling::Format::NHWCD4;
  39. if (str == "NCHW4") return AdaptivePooling::Format::NCHW4;
  40. if (str == "NCHW8") return AdaptivePooling::Format::NCHW8;
  41. if (str == "NCHW32") return AdaptivePooling::Format::NCHW32;
  42. if (str == "NCHW88") return AdaptivePooling::Format::NCHW88;
  43. if (str == "NCHW44") return AdaptivePooling::Format::NCHW44;
  44. if (str == "NCHW44_DOT") return AdaptivePooling::Format::NCHW44_DOT;
  45. if (str == "NCHW4_NCHW32") return AdaptivePooling::Format::NCHW4_NCHW32;
  46. if (str == "NCHW32_NCHW4") return AdaptivePooling::Format::NCHW32_NCHW4;
  47. if (str == "NCHW4_NCHW") return AdaptivePooling::Format::NCHW4_NCHW;
  48. if (str == "NHWC_NCHW") return AdaptivePooling::Format::NHWC_NCHW;
  49. if (str == "NHWC_NCHW4_IC_SMALL") return AdaptivePooling::Format::NHWC_NCHW4_IC_SMALL;
  50. if (str == "NCHW_NCHW4_IC_SMALL") return AdaptivePooling::Format::NCHW_NCHW4_IC_SMALL;
  51. if (str == "CHWN4") return AdaptivePooling::Format::CHWN4;
  52. if (str == "NCHW64") return AdaptivePooling::Format::NCHW64;
  53. if (str == "NCHW4_NHWC") return AdaptivePooling::Format::NCHW4_NHWC;
  54. throw py::cast_error("invalid enum value " + in);
  55. }));
  56. py::implicitly_convertible<std::string, AdaptivePooling::Format>();
  57. AdaptivePoolingInst
  58. .def(py::init<::megdnn::param::AdaptivePooling::Mode, ::megdnn::param::AdaptivePooling::Format, std::vector<int32_t>, std::string>(), py::arg("mode") = ::megdnn::param::AdaptivePooling::Mode::MAX, py::arg("format") = ::megdnn::param::AdaptivePooling::Format::NCHW, py::arg("shape"), py::arg("scope") = {})
  59. .def(py::init<>())
  60. .def_readwrite("mode", &AdaptivePooling::mode)
  61. .def_readwrite("format", &AdaptivePooling::format)
  62. .def_readwrite("shape", &AdaptivePooling::shape);
  63. py::class_<AddAxis, std::shared_ptr<AddAxis>, OpDef> AddAxisInst(m, "AddAxis");
  64. AddAxisInst
  65. .def(py::init<std::vector<int32_t>, std::string>(), py::arg("axis"), py::arg("scope") = {})
  66. .def(py::init<>())
  67. .def_readwrite("axis", &AddAxis::axis);
  68. py::class_<Argmax, std::shared_ptr<Argmax>, OpDef> ArgmaxInst(m, "Argmax");
  69. ArgmaxInst
  70. .def(py::init<int32_t, std::string>(), py::arg("axis") = 0, py::arg("scope") = {})
  71. .def_readwrite("axis", &Argmax::axis);
  72. py::class_<Argmin, std::shared_ptr<Argmin>, OpDef> ArgminInst(m, "Argmin");
  73. ArgminInst
  74. .def(py::init<int32_t, std::string>(), py::arg("axis") = 0, py::arg("scope") = {})
  75. .def_readwrite("axis", &Argmin::axis);
  76. py::class_<Argsort, std::shared_ptr<Argsort>, OpDef> ArgsortInst(m, "Argsort");
  77. py::enum_<Argsort::Order>(ArgsortInst, "Order")
  78. .value("ASCENDING", Argsort::Order::ASCENDING)
  79. .value("DESCENDING", Argsort::Order::DESCENDING)
  80. .def(py::init([](const std::string& in) {
  81. auto&& str = normalize_enum(in);
  82. if (str == "ASCENDING") return Argsort::Order::ASCENDING;
  83. if (str == "DESCENDING") return Argsort::Order::DESCENDING;
  84. throw py::cast_error("invalid enum value " + in);
  85. }));
  86. py::implicitly_convertible<std::string, Argsort::Order>();
  87. ArgsortInst
  88. .def(py::init<::megdnn::param::Argsort::Order, std::string>(), py::arg("order") = ::megdnn::param::Argsort::Order::ASCENDING, py::arg("scope") = {})
  89. .def_readwrite("order", &Argsort::order);
  90. py::class_<AssertEqual, std::shared_ptr<AssertEqual>, OpDef> AssertEqualInst(m, "AssertEqual");
  91. AssertEqualInst
  92. .def(py::init<float, bool, std::string>(), py::arg("maxerr") = 0.0001, py::arg("verbose") = false, py::arg("scope") = {})
  93. .def_readwrite("maxerr", &AssertEqual::maxerr)
  94. .def_readwrite("verbose", &AssertEqual::verbose);
  95. py::class_<AtlasRuntime, std::shared_ptr<AtlasRuntime>, OpDef> AtlasRuntimeInst(m, "AtlasRuntime");
  96. AtlasRuntimeInst
  97. .def(py::init<std::string, size_t, std::string>(), py::arg("buf"), py::arg("buf_size"), py::arg("scope") = {})
  98. .def(py::init<>())
  99. .def_readwrite("buf", &AtlasRuntime::buf)
  100. .def_readwrite("buf_size", &AtlasRuntime::buf_size);
  101. py::class_<Barrier, std::shared_ptr<Barrier>, OpDef> BarrierInst(m, "Barrier");
  102. BarrierInst
  103. .def(py::init<::mgb::CompNode, uint32_t, std::string>(), py::arg("comp_node"), py::arg("nr_outputs"), py::arg("scope") = {})
  104. .def(py::init<>())
  105. .def_readwrite("comp_node", &Barrier::comp_node)
  106. .def_readwrite("nr_outputs", &Barrier::nr_outputs);
  107. py::class_<BatchConvBias, std::shared_ptr<BatchConvBias>, OpDef> BatchConvBiasInst(m, "BatchConvBias");
  108. py::enum_<BatchConvBias::NonlineMode>(BatchConvBiasInst, "NonlineMode")
  109. .value("IDENTITY", BatchConvBias::NonlineMode::IDENTITY)
  110. .value("RELU", BatchConvBias::NonlineMode::RELU)
  111. .value("SIGMOID", BatchConvBias::NonlineMode::SIGMOID)
  112. .value("H_SWISH", BatchConvBias::NonlineMode::H_SWISH)
  113. .def(py::init([](const std::string& in) {
  114. auto&& str = normalize_enum(in);
  115. if (str == "IDENTITY") return BatchConvBias::NonlineMode::IDENTITY;
  116. if (str == "RELU") return BatchConvBias::NonlineMode::RELU;
  117. if (str == "SIGMOID") return BatchConvBias::NonlineMode::SIGMOID;
  118. if (str == "H_SWISH") return BatchConvBias::NonlineMode::H_SWISH;
  119. throw py::cast_error("invalid enum value " + in);
  120. }));
  121. py::implicitly_convertible<std::string, BatchConvBias::NonlineMode>();
  122. py::enum_<BatchConvBias::Mode>(BatchConvBiasInst, "Mode")
  123. .value("CROSS_CORRELATION", BatchConvBias::Mode::CROSS_CORRELATION)
  124. .value("CONVOLUTION", BatchConvBias::Mode::CONVOLUTION)
  125. .def(py::init([](const std::string& in) {
  126. auto&& str = normalize_enum(in);
  127. if (str == "CROSS_CORRELATION") return BatchConvBias::Mode::CROSS_CORRELATION;
  128. if (str == "CONVOLUTION") return BatchConvBias::Mode::CONVOLUTION;
  129. throw py::cast_error("invalid enum value " + in);
  130. }));
  131. py::implicitly_convertible<std::string, BatchConvBias::Mode>();
  132. py::enum_<BatchConvBias::Sparse>(BatchConvBiasInst, "Sparse")
  133. .value("DENSE", BatchConvBias::Sparse::DENSE)
  134. .value("GROUP", BatchConvBias::Sparse::GROUP)
  135. .def(py::init([](const std::string& in) {
  136. auto&& str = normalize_enum(in);
  137. if (str == "DENSE") return BatchConvBias::Sparse::DENSE;
  138. if (str == "GROUP") return BatchConvBias::Sparse::GROUP;
  139. throw py::cast_error("invalid enum value " + in);
  140. }));
  141. py::implicitly_convertible<std::string, BatchConvBias::Sparse>();
  142. BatchConvBiasInst.attr("Format") = AdaptivePoolingInst.attr("Format");
  143. py::enum_<BatchConvBias::ComputeMode>(BatchConvBiasInst, "ComputeMode")
  144. .value("DEFAULT", BatchConvBias::ComputeMode::DEFAULT)
  145. .value("FLOAT32", BatchConvBias::ComputeMode::FLOAT32)
  146. .def(py::init([](const std::string& in) {
  147. auto&& str = normalize_enum(in);
  148. if (str == "DEFAULT") return BatchConvBias::ComputeMode::DEFAULT;
  149. if (str == "FLOAT32") return BatchConvBias::ComputeMode::FLOAT32;
  150. throw py::cast_error("invalid enum value " + in);
  151. }));
  152. py::implicitly_convertible<std::string, BatchConvBias::ComputeMode>();
  153. py::enum_<BatchConvBias::Strategy>(BatchConvBiasInst, "Strategy")
  154. .value("HEURISTIC", BatchConvBias::Strategy::HEURISTIC)
  155. .value("PROFILE", BatchConvBias::Strategy::PROFILE)
  156. .value("REPRODUCIBLE", BatchConvBias::Strategy::REPRODUCIBLE)
  157. .value("OPTIMIZED", BatchConvBias::Strategy::OPTIMIZED)
  158. .def("__or__", [](BatchConvBias::Strategy s0, BatchConvBias::Strategy s1) {
  159. return static_cast<BatchConvBias::Strategy>(uint32_t(s0) | uint32_t(s1));
  160. })
  161. .def("__and__", [](BatchConvBias::Strategy s0, BatchConvBias::Strategy s1) {
  162. return static_cast<BatchConvBias::Strategy>(uint32_t(s0) & uint32_t(s1));
  163. })
  164. .def(py::init([](const std::string& in) {
  165. auto&& str = normalize_enum(in);
  166. if (str == "HEURISTIC") return BatchConvBias::Strategy::HEURISTIC;
  167. if (str == "PROFILE") return BatchConvBias::Strategy::PROFILE;
  168. if (str == "REPRODUCIBLE") return BatchConvBias::Strategy::REPRODUCIBLE;
  169. if (str == "OPTIMIZED") return BatchConvBias::Strategy::OPTIMIZED;
  170. throw py::cast_error("invalid enum value " + in);
  171. }));
  172. py::implicitly_convertible<std::string, BatchConvBias::Strategy>();
  173. BatchConvBiasInst
  174. .def(py::init<::megdnn::param::BatchConvBias::NonlineMode, ::megdnn::param::BatchConvBias::Mode, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, ::megdnn::param::BatchConvBias::Sparse, ::megdnn::param::BatchConvBias::Format, ::megdnn::param::BatchConvBias::ComputeMode, ::megdnn::param::ExecutionPolicy::Strategy, uint64_t, ::megdnn::DType, std::string>(), py::arg("nonlineMode") = ::megdnn::param::BatchConvBias::NonlineMode::IDENTITY, py::arg("mode") = ::megdnn::param::BatchConvBias::Mode::CROSS_CORRELATION, py::arg("pad_h") = 0, py::arg("pad_w") = 0, py::arg("stride_h") = 1, py::arg("stride_w") = 1, py::arg("dilate_h") = 1, py::arg("dilate_w") = 1, py::arg("sparse") = ::megdnn::param::BatchConvBias::Sparse::DENSE, py::arg("format") = ::megdnn::param::BatchConvBias::Format::NCHW, py::arg("compute_mode") = ::megdnn::param::BatchConvBias::ComputeMode::DEFAULT, py::arg("strategy") = static_cast<::megdnn::param::ExecutionPolicy::Strategy>(1), py::arg("workspace_limit") = 18446744073709551615ull, py::arg("dtype"), py::arg("scope") = {})
  175. .def(py::init<>())
  176. .def_readwrite("nonlineMode", &BatchConvBias::nonlineMode)
  177. .def_readwrite("mode", &BatchConvBias::mode)
  178. .def_readwrite("pad_h", &BatchConvBias::pad_h)
  179. .def_readwrite("pad_w", &BatchConvBias::pad_w)
  180. .def_readwrite("stride_h", &BatchConvBias::stride_h)
  181. .def_readwrite("stride_w", &BatchConvBias::stride_w)
  182. .def_readwrite("dilate_h", &BatchConvBias::dilate_h)
  183. .def_readwrite("dilate_w", &BatchConvBias::dilate_w)
  184. .def_readwrite("sparse", &BatchConvBias::sparse)
  185. .def_readwrite("format", &BatchConvBias::format)
  186. .def_readwrite("compute_mode", &BatchConvBias::compute_mode)
  187. .def_readwrite("strategy", &BatchConvBias::strategy)
  188. .def_readwrite("workspace_limit", &BatchConvBias::workspace_limit)
  189. .def_readwrite("dtype", &BatchConvBias::dtype);
  190. py::class_<BatchNorm, std::shared_ptr<BatchNorm>, OpDef> BatchNormInst(m, "BatchNorm");
  191. py::enum_<BatchNorm::ParamDim>(BatchNormInst, "ParamDim")
  192. .value("DIM_11HW", BatchNorm::ParamDim::DIM_11HW)
  193. .value("DIM_1CHW", BatchNorm::ParamDim::DIM_1CHW)
  194. .value("DIM_1C11", BatchNorm::ParamDim::DIM_1C11)
  195. .value("DIM_111C", BatchNorm::ParamDim::DIM_111C)
  196. .def(py::init([](const std::string& in) {
  197. auto&& str = normalize_enum(in);
  198. if (str == "DIM_11HW") return BatchNorm::ParamDim::DIM_11HW;
  199. if (str == "DIM_1CHW") return BatchNorm::ParamDim::DIM_1CHW;
  200. if (str == "DIM_1C11") return BatchNorm::ParamDim::DIM_1C11;
  201. if (str == "DIM_111C") return BatchNorm::ParamDim::DIM_111C;
  202. throw py::cast_error("invalid enum value " + in);
  203. }));
  204. py::implicitly_convertible<std::string, BatchNorm::ParamDim>();
  205. py::enum_<BatchNorm::FwdMode>(BatchNormInst, "FwdMode")
  206. .value("TRAINING", BatchNorm::FwdMode::TRAINING)
  207. .value("INFERENCE", BatchNorm::FwdMode::INFERENCE)
  208. .def(py::init([](const std::string& in) {
  209. auto&& str = normalize_enum(in);
  210. if (str == "TRAINING") return BatchNorm::FwdMode::TRAINING;
  211. if (str == "INFERENCE") return BatchNorm::FwdMode::INFERENCE;
  212. throw py::cast_error("invalid enum value " + in);
  213. }));
  214. py::implicitly_convertible<std::string, BatchNorm::FwdMode>();
  215. BatchNormInst
  216. .def(py::init<::megdnn::param::BN::ParamDim, ::megdnn::param::BN::FwdMode, double, double, float, float, std::string>(), py::arg("param_dim") = ::megdnn::param::BN::ParamDim::DIM_11HW, py::arg("fwd_mode") = ::megdnn::param::BN::FwdMode::TRAINING, py::arg("epsilon") = 1e-4f, py::arg("avg_factor") = 1.f, py::arg("scale") = 1.f, py::arg("bias") = 0.f, py::arg("scope") = {})
  217. .def_readwrite("param_dim", &BatchNorm::param_dim)
  218. .def_readwrite("fwd_mode", &BatchNorm::fwd_mode)
  219. .def_readwrite("epsilon", &BatchNorm::epsilon)
  220. .def_readwrite("avg_factor", &BatchNorm::avg_factor)
  221. .def_readwrite("scale", &BatchNorm::scale)
  222. .def_readwrite("bias", &BatchNorm::bias);
  223. py::class_<BatchNormBackward, std::shared_ptr<BatchNormBackward>, OpDef> BatchNormBackwardInst(m, "BatchNormBackward");
  224. BatchNormBackwardInst.attr("ParamDim") = BatchNormInst.attr("ParamDim");
  225. BatchNormBackwardInst.attr("FwdMode") = BatchNormInst.attr("FwdMode");
  226. BatchNormBackwardInst
  227. .def(py::init<::megdnn::param::BN::ParamDim, ::megdnn::param::BN::FwdMode, double, double, float, float, std::string>(), py::arg("param_dim") = ::megdnn::param::BN::ParamDim::DIM_11HW, py::arg("fwd_mode") = ::megdnn::param::BN::FwdMode::TRAINING, py::arg("epsilon") = 1e-4f, py::arg("avg_factor") = 1.f, py::arg("scale") = 1.f, py::arg("bias") = 0.f, py::arg("scope") = {})
  228. .def_readwrite("param_dim", &BatchNormBackward::param_dim)
  229. .def_readwrite("fwd_mode", &BatchNormBackward::fwd_mode)
  230. .def_readwrite("epsilon", &BatchNormBackward::epsilon)
  231. .def_readwrite("avg_factor", &BatchNormBackward::avg_factor)
  232. .def_readwrite("scale", &BatchNormBackward::scale)
  233. .def_readwrite("bias", &BatchNormBackward::bias);
  234. py::class_<BatchedIncrMeshIndexing, std::shared_ptr<BatchedIncrMeshIndexing>, OpDef> BatchedIncrMeshIndexingInst(m, "BatchedIncrMeshIndexing");
  235. BatchedIncrMeshIndexingInst
  236. .def(py::init<std::vector<std::tuple<int8_t, bool, bool, bool, bool>>, std::string>(), py::arg("items"), py::arg("scope") = {})
  237. .def(py::init<>())
  238. .def_readwrite("items", &BatchedIncrMeshIndexing::items);
  239. py::class_<BatchedMatrixMul, std::shared_ptr<BatchedMatrixMul>, OpDef> BatchedMatrixMulInst(m, "BatchedMatrixMul");
  240. py::enum_<BatchedMatrixMul::ComputeMode>(BatchedMatrixMulInst, "ComputeMode")
  241. .value("DEFAULT", BatchedMatrixMul::ComputeMode::DEFAULT)
  242. .value("FLOAT32", BatchedMatrixMul::ComputeMode::FLOAT32)
  243. .def(py::init([](const std::string& in) {
  244. auto&& str = normalize_enum(in);
  245. if (str == "DEFAULT") return BatchedMatrixMul::ComputeMode::DEFAULT;
  246. if (str == "FLOAT32") return BatchedMatrixMul::ComputeMode::FLOAT32;
  247. throw py::cast_error("invalid enum value " + in);
  248. }));
  249. py::implicitly_convertible<std::string, BatchedMatrixMul::ComputeMode>();
  250. py::enum_<BatchedMatrixMul::Format>(BatchedMatrixMulInst, "Format")
  251. .value("DEFAULT", BatchedMatrixMul::Format::DEFAULT)
  252. .value("MK4", BatchedMatrixMul::Format::MK4)
  253. .value("MK8", BatchedMatrixMul::Format::MK8)
  254. .value("MK4_DOT", BatchedMatrixMul::Format::MK4_DOT)
  255. .value("N32K4_DOT", BatchedMatrixMul::Format::N32K4_DOT)
  256. .def(py::init([](const std::string& in) {
  257. auto&& str = normalize_enum(in);
  258. if (str == "DEFAULT") return BatchedMatrixMul::Format::DEFAULT;
  259. if (str == "MK4") return BatchedMatrixMul::Format::MK4;
  260. if (str == "MK8") return BatchedMatrixMul::Format::MK8;
  261. if (str == "MK4_DOT") return BatchedMatrixMul::Format::MK4_DOT;
  262. if (str == "N32K4_DOT") return BatchedMatrixMul::Format::N32K4_DOT;
  263. throw py::cast_error("invalid enum value " + in);
  264. }));
  265. py::implicitly_convertible<std::string, BatchedMatrixMul::Format>();
  266. BatchedMatrixMulInst.attr("Strategy") = BatchConvBiasInst.attr("Strategy");
  267. BatchedMatrixMulInst
  268. .def(py::init<bool, bool, ::megdnn::param::MatrixMul::ComputeMode, ::megdnn::param::MatrixMul::Format, ::megdnn::param::ExecutionPolicy::Strategy, uint64_t, uint32_t, uint32_t, std::string>(), py::arg("transposeA") = false, py::arg("transposeB") = false, py::arg("compute_mode") = ::megdnn::param::MatrixMul::ComputeMode::DEFAULT, py::arg("format") = ::megdnn::param::MatrixMul::Format::DEFAULT, py::arg("strategy") = static_cast<::megdnn::param::ExecutionPolicy::Strategy>(1), py::arg("workspace_limit") = 18446744073709551615ull, py::arg("dimA"), py::arg("dimB"), py::arg("scope") = {})
  269. .def(py::init<>())
  270. .def_readwrite("transposeA", &BatchedMatrixMul::transposeA)
  271. .def_readwrite("transposeB", &BatchedMatrixMul::transposeB)
  272. .def_readwrite("compute_mode", &BatchedMatrixMul::compute_mode)
  273. .def_readwrite("format", &BatchedMatrixMul::format)
  274. .def_readwrite("strategy", &BatchedMatrixMul::strategy)
  275. .def_readwrite("workspace_limit", &BatchedMatrixMul::workspace_limit)
  276. .def_readwrite("dimA", &BatchedMatrixMul::dimA)
  277. .def_readwrite("dimB", &BatchedMatrixMul::dimB);
  278. py::class_<BatchedMeshIndexing, std::shared_ptr<BatchedMeshIndexing>, OpDef> BatchedMeshIndexingInst(m, "BatchedMeshIndexing");
  279. BatchedMeshIndexingInst
  280. .def(py::init<std::vector<std::tuple<int8_t, bool, bool, bool, bool>>, std::string>(), py::arg("items"), py::arg("scope") = {})
  281. .def(py::init<>())
  282. .def_readwrite("items", &BatchedMeshIndexing::items);
  283. py::class_<BatchedSetMeshIndexing, std::shared_ptr<BatchedSetMeshIndexing>, OpDef> BatchedSetMeshIndexingInst(m, "BatchedSetMeshIndexing");
  284. BatchedSetMeshIndexingInst
  285. .def(py::init<std::vector<std::tuple<int8_t, bool, bool, bool, bool>>, std::string>(), py::arg("items"), py::arg("scope") = {})
  286. .def(py::init<>())
  287. .def_readwrite("items", &BatchedSetMeshIndexing::items);
  288. py::class_<BetaRNG, std::shared_ptr<BetaRNG>, OpDef> BetaRNGInst(m, "BetaRNG");
  289. BetaRNGInst
  290. .def(py::init<uint64_t, size_t, std::string>(), py::arg("seed") = 0, py::arg("handle"), py::arg("scope") = {})
  291. .def(py::init<>())
  292. .def_readwrite("seed", &BetaRNG::seed)
  293. .def_readwrite("handle", &BetaRNG::handle);
  294. py::class_<Borrow, std::shared_ptr<Borrow>, OpDef> BorrowInst(m, "Borrow");
  295. BorrowInst
  296. .def(py::init<::mgb::CompNode, std::string>(), py::arg("comp_node"), py::arg("scope") = {})
  297. .def(py::init<>())
  298. .def_readwrite("comp_node", &Borrow::comp_node);
  299. py::class_<Broadcast, std::shared_ptr<Broadcast>, OpDef> BroadcastInst(m, "Broadcast");
  300. BroadcastInst
  301. .def(py::init<std::vector<int32_t>, std::string>(), py::arg("shape"), py::arg("scope") = {})
  302. .def(py::init<>())
  303. .def_readwrite("shape", &Broadcast::shape);
  304. py::class_<CambriconRuntime, std::shared_ptr<CambriconRuntime>, OpDef> CambriconRuntimeInst(m, "CambriconRuntime");
  305. CambriconRuntimeInst
  306. .def(py::init<std::string, size_t, std::string, bool, std::string>(), py::arg("buf"), py::arg("buf_size"), py::arg("symbol"), py::arg("tensor_dim_mutable"), py::arg("scope") = {})
  307. .def(py::init<>())
  308. .def_readwrite("buf", &CambriconRuntime::buf)
  309. .def_readwrite("buf_size", &CambriconRuntime::buf_size)
  310. .def_readwrite("symbol", &CambriconRuntime::symbol)
  311. .def_readwrite("tensor_dim_mutable", &CambriconRuntime::tensor_dim_mutable);
  312. py::class_<CheckNonFinite, std::shared_ptr<CheckNonFinite>, OpDef> CheckNonFiniteInst(m, "CheckNonFinite");
  313. CheckNonFiniteInst
  314. .def(py::init<float, std::string>(), py::arg("scale") = 1.0, py::arg("scope") = {})
  315. .def_readwrite("scale", &CheckNonFinite::scale);
  316. py::class_<CollectiveComm, std::shared_ptr<CollectiveComm>, OpDef> CollectiveCommInst(m, "CollectiveComm");
  317. py::enum_<CollectiveComm::Mode>(CollectiveCommInst, "Mode")
  318. .value("REDUCE_SUM", CollectiveComm::Mode::REDUCE_SUM)
  319. .value("BROADCAST", CollectiveComm::Mode::BROADCAST)
  320. .value("ALL_GATHER", CollectiveComm::Mode::ALL_GATHER)
  321. .value("REDUCE_SCATTER_SUM", CollectiveComm::Mode::REDUCE_SCATTER_SUM)
  322. .value("ALL_REDUCE_SUM", CollectiveComm::Mode::ALL_REDUCE_SUM)
  323. .value("ALL_REDUCE_MAX", CollectiveComm::Mode::ALL_REDUCE_MAX)
  324. .value("ALL_REDUCE_MIN", CollectiveComm::Mode::ALL_REDUCE_MIN)
  325. .value("ALL_REDUCE_PROD", CollectiveComm::Mode::ALL_REDUCE_PROD)
  326. .value("GATHER", CollectiveComm::Mode::GATHER)
  327. .value("SCATTER", CollectiveComm::Mode::SCATTER)
  328. .value("ALL_TO_ALL", CollectiveComm::Mode::ALL_TO_ALL)
  329. .def(py::init([](const std::string& in) {
  330. auto&& str = normalize_enum(in);
  331. if (str == "REDUCE_SUM") return CollectiveComm::Mode::REDUCE_SUM;
  332. if (str == "BROADCAST") return CollectiveComm::Mode::BROADCAST;
  333. if (str == "ALL_GATHER") return CollectiveComm::Mode::ALL_GATHER;
  334. if (str == "REDUCE_SCATTER_SUM") return CollectiveComm::Mode::REDUCE_SCATTER_SUM;
  335. if (str == "ALL_REDUCE_SUM") return CollectiveComm::Mode::ALL_REDUCE_SUM;
  336. if (str == "ALL_REDUCE_MAX") return CollectiveComm::Mode::ALL_REDUCE_MAX;
  337. if (str == "ALL_REDUCE_MIN") return CollectiveComm::Mode::ALL_REDUCE_MIN;
  338. if (str == "ALL_REDUCE_PROD") return CollectiveComm::Mode::ALL_REDUCE_PROD;
  339. if (str == "GATHER") return CollectiveComm::Mode::GATHER;
  340. if (str == "SCATTER") return CollectiveComm::Mode::SCATTER;
  341. if (str == "ALL_TO_ALL") return CollectiveComm::Mode::ALL_TO_ALL;
  342. throw py::cast_error("invalid enum value " + in);
  343. }));
  344. py::implicitly_convertible<std::string, CollectiveComm::Mode>();
  345. CollectiveCommInst
  346. .def(py::init<::megdnn::param::CollectiveComm::Mode, std::string, uint32_t, uint32_t, bool, bool, std::string, uint32_t, ::megdnn::DType, std::string, std::string, std::string>(), py::arg("mode") = ::megdnn::param::CollectiveComm::Mode::REDUCE_SUM, py::arg("key"), py::arg("nr_devices"), py::arg("rank"), py::arg("is_root"), py::arg("local_grad"), py::arg("addr"), py::arg("port"), py::arg("dtype"), py::arg("backend"), py::arg("comp_node"), py::arg("scope") = {})
  347. .def(py::init<>())
  348. .def_readwrite("mode", &CollectiveComm::mode)
  349. .def_readwrite("key", &CollectiveComm::key)
  350. .def_readwrite("nr_devices", &CollectiveComm::nr_devices)
  351. .def_readwrite("rank", &CollectiveComm::rank)
  352. .def_readwrite("is_root", &CollectiveComm::is_root)
  353. .def_readwrite("local_grad", &CollectiveComm::local_grad)
  354. .def_readwrite("addr", &CollectiveComm::addr)
  355. .def_readwrite("port", &CollectiveComm::port)
  356. .def_readwrite("dtype", &CollectiveComm::dtype)
  357. .def_readwrite("backend", &CollectiveComm::backend)
  358. .def_readwrite("comp_node", &CollectiveComm::comp_node);
  359. py::class_<Concat, std::shared_ptr<Concat>, OpDef> ConcatInst(m, "Concat");
  360. ConcatInst
  361. .def(py::init<int32_t, ::mgb::CompNode, std::string>(), py::arg("axis") = 0, py::arg("comp_node"), py::arg("scope") = {})
  362. .def(py::init<>())
  363. .def_readwrite("axis", &Concat::axis)
  364. .def_readwrite("comp_node", &Concat::comp_node);
  365. py::class_<CondTake, std::shared_ptr<CondTake>, OpDef> CondTakeInst(m, "CondTake");
  366. CondTakeInst
  367. .def(py::init<>());
  368. py::class_<ConvBias, std::shared_ptr<ConvBias>, OpDef> ConvBiasInst(m, "ConvBias");
  369. ConvBiasInst.attr("NonlineMode") = BatchConvBiasInst.attr("NonlineMode");
  370. ConvBiasInst.attr("Mode") = BatchConvBiasInst.attr("Mode");
  371. ConvBiasInst.attr("Sparse") = BatchConvBiasInst.attr("Sparse");
  372. ConvBiasInst.attr("Format") = AdaptivePoolingInst.attr("Format");
  373. ConvBiasInst.attr("ComputeMode") = BatchConvBiasInst.attr("ComputeMode");
  374. ConvBiasInst.attr("Strategy") = BatchConvBiasInst.attr("Strategy");
  375. ConvBiasInst
  376. .def(py::init<::megdnn::param::ConvBias::NonlineMode, ::megdnn::param::ConvBias::Mode, ::megdnn::param::ConvBias::Sparse, ::megdnn::param::ConvBias::Format, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, ::megdnn::param::ConvBias::ComputeMode, ::megdnn::param::ExecutionPolicy::Strategy, uint64_t, ::megdnn::DType, std::string>(), py::arg("nonlineMode") = ::megdnn::param::ConvBias::NonlineMode::IDENTITY, py::arg("mode") = ::megdnn::param::ConvBias::Mode::CROSS_CORRELATION, py::arg("sparse") = ::megdnn::param::ConvBias::Sparse::DENSE, py::arg("format") = ::megdnn::param::ConvBias::Format::NCHW, py::arg("pad_h") = 0, py::arg("pad_w") = 0, py::arg("stride_h") = 1, py::arg("stride_w") = 1, py::arg("dilate_h") = 1, py::arg("dilate_w") = 1, py::arg("compute_mode") = ::megdnn::param::ConvBias::ComputeMode::DEFAULT, py::arg("strategy") = static_cast<::megdnn::param::ExecutionPolicy::Strategy>(1), py::arg("workspace_limit") = 18446744073709551615ull, py::arg("dtype"), py::arg("scope") = {})
  377. .def(py::init<>())
  378. .def_readwrite("nonlineMode", &ConvBias::nonlineMode)
  379. .def_readwrite("mode", &ConvBias::mode)
  380. .def_readwrite("sparse", &ConvBias::sparse)
  381. .def_readwrite("format", &ConvBias::format)
  382. .def_readwrite("pad_h", &ConvBias::pad_h)
  383. .def_readwrite("pad_w", &ConvBias::pad_w)
  384. .def_readwrite("stride_h", &ConvBias::stride_h)
  385. .def_readwrite("stride_w", &ConvBias::stride_w)
  386. .def_readwrite("dilate_h", &ConvBias::dilate_h)
  387. .def_readwrite("dilate_w", &ConvBias::dilate_w)
  388. .def_readwrite("compute_mode", &ConvBias::compute_mode)
  389. .def_readwrite("strategy", &ConvBias::strategy)
  390. .def_readwrite("workspace_limit", &ConvBias::workspace_limit)
  391. .def_readwrite("dtype", &ConvBias::dtype);
  392. py::class_<Convolution, std::shared_ptr<Convolution>, OpDef> ConvolutionInst(m, "Convolution");
  393. ConvolutionInst.attr("Mode") = BatchConvBiasInst.attr("Mode");
  394. ConvolutionInst.attr("Sparse") = BatchConvBiasInst.attr("Sparse");
  395. ConvolutionInst.attr("Format") = AdaptivePoolingInst.attr("Format");
  396. ConvolutionInst.attr("ComputeMode") = BatchConvBiasInst.attr("ComputeMode");
  397. ConvolutionInst.attr("Strategy") = BatchConvBiasInst.attr("Strategy");
  398. ConvolutionInst
  399. .def(py::init<::megdnn::param::Convolution::Mode, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, ::megdnn::param::Convolution::Sparse, ::megdnn::param::Convolution::Format, ::megdnn::param::Convolution::ComputeMode, ::megdnn::param::ExecutionPolicy::Strategy, uint64_t, std::string>(), py::arg("mode") = ::megdnn::param::Convolution::Mode::CROSS_CORRELATION, py::arg("pad_h") = 0, py::arg("pad_w") = 0, py::arg("stride_h") = 1, py::arg("stride_w") = 1, py::arg("dilate_h") = 1, py::arg("dilate_w") = 1, py::arg("sparse") = ::megdnn::param::Convolution::Sparse::DENSE, py::arg("format") = ::megdnn::param::Convolution::Format::NCHW, py::arg("compute_mode") = ::megdnn::param::Convolution::ComputeMode::DEFAULT, py::arg("strategy") = static_cast<::megdnn::param::ExecutionPolicy::Strategy>(1), py::arg("workspace_limit") = 18446744073709551615ull, py::arg("scope") = {})
  400. .def_readwrite("mode", &Convolution::mode)
  401. .def_readwrite("pad_h", &Convolution::pad_h)
  402. .def_readwrite("pad_w", &Convolution::pad_w)
  403. .def_readwrite("stride_h", &Convolution::stride_h)
  404. .def_readwrite("stride_w", &Convolution::stride_w)
  405. .def_readwrite("dilate_h", &Convolution::dilate_h)
  406. .def_readwrite("dilate_w", &Convolution::dilate_w)
  407. .def_readwrite("sparse", &Convolution::sparse)
  408. .def_readwrite("format", &Convolution::format)
  409. .def_readwrite("compute_mode", &Convolution::compute_mode)
  410. .def_readwrite("strategy", &Convolution::strategy)
  411. .def_readwrite("workspace_limit", &Convolution::workspace_limit);
  412. py::class_<Convolution3D, std::shared_ptr<Convolution3D>, OpDef> Convolution3DInst(m, "Convolution3D");
  413. py::enum_<Convolution3D::Mode>(Convolution3DInst, "Mode")
  414. .value("CROSS_CORRELATION", Convolution3D::Mode::CROSS_CORRELATION)
  415. .value("CONVOLUTION", Convolution3D::Mode::CONVOLUTION)
  416. .def(py::init([](const std::string& in) {
  417. auto&& str = normalize_enum(in);
  418. if (str == "CROSS_CORRELATION") return Convolution3D::Mode::CROSS_CORRELATION;
  419. if (str == "CONVOLUTION") return Convolution3D::Mode::CONVOLUTION;
  420. throw py::cast_error("invalid enum value " + in);
  421. }));
  422. py::implicitly_convertible<std::string, Convolution3D::Mode>();
  423. py::enum_<Convolution3D::Sparse>(Convolution3DInst, "Sparse")
  424. .value("DENSE", Convolution3D::Sparse::DENSE)
  425. .value("GROUP", Convolution3D::Sparse::GROUP)
  426. .def(py::init([](const std::string& in) {
  427. auto&& str = normalize_enum(in);
  428. if (str == "DENSE") return Convolution3D::Sparse::DENSE;
  429. if (str == "GROUP") return Convolution3D::Sparse::GROUP;
  430. throw py::cast_error("invalid enum value " + in);
  431. }));
  432. py::implicitly_convertible<std::string, Convolution3D::Sparse>();
  433. py::enum_<Convolution3D::DataType>(Convolution3DInst, "DataType")
  434. .value("FLOAT", Convolution3D::DataType::FLOAT)
  435. .value("FLOAT_IO16xC32", Convolution3D::DataType::FLOAT_IO16xC32)
  436. .def(py::init([](const std::string& in) {
  437. auto&& str = normalize_enum(in);
  438. if (str == "FLOAT") return Convolution3D::DataType::FLOAT;
  439. if (str == "FLOAT_IO16xC32") return Convolution3D::DataType::FLOAT_IO16xC32;
  440. throw py::cast_error("invalid enum value " + in);
  441. }));
  442. py::implicitly_convertible<std::string, Convolution3D::DataType>();
  443. py::enum_<Convolution3D::Format>(Convolution3DInst, "Format")
  444. .value("NCDHW", Convolution3D::Format::NCDHW)
  445. .value("NDHWC", Convolution3D::Format::NDHWC)
  446. .def(py::init([](const std::string& in) {
  447. auto&& str = normalize_enum(in);
  448. if (str == "NCDHW") return Convolution3D::Format::NCDHW;
  449. if (str == "NDHWC") return Convolution3D::Format::NDHWC;
  450. throw py::cast_error("invalid enum value " + in);
  451. }));
  452. py::implicitly_convertible<std::string, Convolution3D::Format>();
  453. Convolution3DInst.attr("Strategy") = BatchConvBiasInst.attr("Strategy");
  454. Convolution3DInst
  455. .def(py::init<::megdnn::param::Convolution3D::Mode, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, ::megdnn::param::Convolution3D::Sparse, ::megdnn::param::Convolution3D::DataType, ::megdnn::param::Convolution3D::Format, ::megdnn::param::ExecutionPolicy::Strategy, uint64_t, std::string>(), py::arg("mode") = ::megdnn::param::Convolution3D::Mode::CROSS_CORRELATION, py::arg("pad_d") = 0, py::arg("pad_h") = 0, py::arg("pad_w") = 0, py::arg("stride_d") = 1, py::arg("stride_h") = 1, py::arg("stride_w") = 1, py::arg("dilate_d") = 1, py::arg("dilate_h") = 1, py::arg("dilate_w") = 1, py::arg("sparse") = ::megdnn::param::Convolution3D::Sparse::DENSE, py::arg("data_type") = ::megdnn::param::Convolution3D::DataType::FLOAT, py::arg("format") = ::megdnn::param::Convolution3D::Format::NCDHW, py::arg("strategy") = static_cast<::megdnn::param::ExecutionPolicy::Strategy>(1), py::arg("workspace_limit") = 18446744073709551615ull, py::arg("scope") = {})
  456. .def_readwrite("mode", &Convolution3D::mode)
  457. .def_readwrite("pad_d", &Convolution3D::pad_d)
  458. .def_readwrite("pad_h", &Convolution3D::pad_h)
  459. .def_readwrite("pad_w", &Convolution3D::pad_w)
  460. .def_readwrite("stride_d", &Convolution3D::stride_d)
  461. .def_readwrite("stride_h", &Convolution3D::stride_h)
  462. .def_readwrite("stride_w", &Convolution3D::stride_w)
  463. .def_readwrite("dilate_d", &Convolution3D::dilate_d)
  464. .def_readwrite("dilate_h", &Convolution3D::dilate_h)
  465. .def_readwrite("dilate_w", &Convolution3D::dilate_w)
  466. .def_readwrite("sparse", &Convolution3D::sparse)
  467. .def_readwrite("data_type", &Convolution3D::data_type)
  468. .def_readwrite("format", &Convolution3D::format)
  469. .def_readwrite("strategy", &Convolution3D::strategy)
  470. .def_readwrite("workspace_limit", &Convolution3D::workspace_limit);
  471. py::class_<Convolution3DBackwardData, std::shared_ptr<Convolution3DBackwardData>, OpDef> Convolution3DBackwardDataInst(m, "Convolution3DBackwardData");
  472. Convolution3DBackwardDataInst.attr("Mode") = Convolution3DInst.attr("Mode");
  473. Convolution3DBackwardDataInst.attr("Sparse") = Convolution3DInst.attr("Sparse");
  474. Convolution3DBackwardDataInst.attr("DataType") = Convolution3DInst.attr("DataType");
  475. Convolution3DBackwardDataInst.attr("Format") = Convolution3DInst.attr("Format");
  476. Convolution3DBackwardDataInst.attr("Strategy") = BatchConvBiasInst.attr("Strategy");
  477. Convolution3DBackwardDataInst
  478. .def(py::init<::megdnn::param::Convolution3D::Mode, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, ::megdnn::param::Convolution3D::Sparse, ::megdnn::param::Convolution3D::DataType, ::megdnn::param::Convolution3D::Format, ::megdnn::param::ExecutionPolicy::Strategy, uint64_t, std::string>(), py::arg("mode") = ::megdnn::param::Convolution3D::Mode::CROSS_CORRELATION, py::arg("pad_d") = 0, py::arg("pad_h") = 0, py::arg("pad_w") = 0, py::arg("stride_d") = 1, py::arg("stride_h") = 1, py::arg("stride_w") = 1, py::arg("dilate_d") = 1, py::arg("dilate_h") = 1, py::arg("dilate_w") = 1, py::arg("sparse") = ::megdnn::param::Convolution3D::Sparse::DENSE, py::arg("data_type") = ::megdnn::param::Convolution3D::DataType::FLOAT, py::arg("format") = ::megdnn::param::Convolution3D::Format::NCDHW, py::arg("strategy") = static_cast<::megdnn::param::ExecutionPolicy::Strategy>(1), py::arg("workspace_limit") = 18446744073709551615ull, py::arg("scope") = {})
  479. .def_readwrite("mode", &Convolution3DBackwardData::mode)
  480. .def_readwrite("pad_d", &Convolution3DBackwardData::pad_d)
  481. .def_readwrite("pad_h", &Convolution3DBackwardData::pad_h)
  482. .def_readwrite("pad_w", &Convolution3DBackwardData::pad_w)
  483. .def_readwrite("stride_d", &Convolution3DBackwardData::stride_d)
  484. .def_readwrite("stride_h", &Convolution3DBackwardData::stride_h)
  485. .def_readwrite("stride_w", &Convolution3DBackwardData::stride_w)
  486. .def_readwrite("dilate_d", &Convolution3DBackwardData::dilate_d)
  487. .def_readwrite("dilate_h", &Convolution3DBackwardData::dilate_h)
  488. .def_readwrite("dilate_w", &Convolution3DBackwardData::dilate_w)
  489. .def_readwrite("sparse", &Convolution3DBackwardData::sparse)
  490. .def_readwrite("data_type", &Convolution3DBackwardData::data_type)
  491. .def_readwrite("format", &Convolution3DBackwardData::format)
  492. .def_readwrite("strategy", &Convolution3DBackwardData::strategy)
  493. .def_readwrite("workspace_limit", &Convolution3DBackwardData::workspace_limit);
  494. py::class_<ConvolutionBackwardData, std::shared_ptr<ConvolutionBackwardData>, OpDef> ConvolutionBackwardDataInst(m, "ConvolutionBackwardData");
  495. ConvolutionBackwardDataInst.attr("Mode") = BatchConvBiasInst.attr("Mode");
  496. ConvolutionBackwardDataInst.attr("Sparse") = BatchConvBiasInst.attr("Sparse");
  497. ConvolutionBackwardDataInst.attr("Format") = AdaptivePoolingInst.attr("Format");
  498. ConvolutionBackwardDataInst.attr("ComputeMode") = BatchConvBiasInst.attr("ComputeMode");
  499. ConvolutionBackwardDataInst.attr("Strategy") = BatchConvBiasInst.attr("Strategy");
  500. ConvolutionBackwardDataInst
  501. .def(py::init<::megdnn::param::Convolution::Mode, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, ::megdnn::param::Convolution::Sparse, ::megdnn::param::Convolution::Format, ::megdnn::param::Convolution::ComputeMode, ::megdnn::param::ExecutionPolicy::Strategy, uint64_t, ::megdnn::DType, std::string>(), py::arg("mode") = ::megdnn::param::Convolution::Mode::CROSS_CORRELATION, py::arg("pad_h") = 0, py::arg("pad_w") = 0, py::arg("stride_h") = 1, py::arg("stride_w") = 1, py::arg("dilate_h") = 1, py::arg("dilate_w") = 1, py::arg("sparse") = ::megdnn::param::Convolution::Sparse::DENSE, py::arg("format") = ::megdnn::param::Convolution::Format::NCHW, py::arg("compute_mode") = ::megdnn::param::Convolution::ComputeMode::DEFAULT, py::arg("strategy") = static_cast<::megdnn::param::ExecutionPolicy::Strategy>(1), py::arg("workspace_limit") = 18446744073709551615ull, py::arg("dtype"), py::arg("scope") = {})
  502. .def(py::init<>())
  503. .def_readwrite("mode", &ConvolutionBackwardData::mode)
  504. .def_readwrite("pad_h", &ConvolutionBackwardData::pad_h)
  505. .def_readwrite("pad_w", &ConvolutionBackwardData::pad_w)
  506. .def_readwrite("stride_h", &ConvolutionBackwardData::stride_h)
  507. .def_readwrite("stride_w", &ConvolutionBackwardData::stride_w)
  508. .def_readwrite("dilate_h", &ConvolutionBackwardData::dilate_h)
  509. .def_readwrite("dilate_w", &ConvolutionBackwardData::dilate_w)
  510. .def_readwrite("sparse", &ConvolutionBackwardData::sparse)
  511. .def_readwrite("format", &ConvolutionBackwardData::format)
  512. .def_readwrite("compute_mode", &ConvolutionBackwardData::compute_mode)
  513. .def_readwrite("strategy", &ConvolutionBackwardData::strategy)
  514. .def_readwrite("workspace_limit", &ConvolutionBackwardData::workspace_limit)
  515. .def_readwrite("dtype", &ConvolutionBackwardData::dtype);
  516. py::class_<Copy, std::shared_ptr<Copy>, OpDef> CopyInst(m, "Copy");
  517. CopyInst
  518. .def(py::init<::mgb::CompNode, std::string>(), py::arg("comp_node"), py::arg("scope") = {})
  519. .def(py::init<>())
  520. .def_readwrite("comp_node", &Copy::comp_node);
  521. py::class_<Correlation, std::shared_ptr<Correlation>, OpDef> CorrelationInst(m, "Correlation");
  522. py::enum_<Correlation::Format>(CorrelationInst, "Format")
  523. .value("NCHW", Correlation::Format::NCHW)
  524. .value("NHWC", Correlation::Format::NHWC)
  525. .value("NHWCD4", Correlation::Format::NHWCD4)
  526. .value("NCHW4", Correlation::Format::NCHW4)
  527. .value("NCHW8", Correlation::Format::NCHW8)
  528. .value("NCHW32", Correlation::Format::NCHW32)
  529. .value("NCHW88", Correlation::Format::NCHW88)
  530. .value("NCHW44", Correlation::Format::NCHW44)
  531. .value("NCHW44_DOT", Correlation::Format::NCHW44_DOT)
  532. .value("NCHW_WINOGRAD", Correlation::Format::NCHW_WINOGRAD)
  533. .value("NCHW88_WINOGRAD", Correlation::Format::NCHW88_WINOGRAD)
  534. .value("NCHW44_WINOGRAD", Correlation::Format::NCHW44_WINOGRAD)
  535. .value("NCHW4_NCHW32", Correlation::Format::NCHW4_NCHW32)
  536. .value("NCHW32_NCHW4", Correlation::Format::NCHW32_NCHW4)
  537. .value("NCHW4_NCHW", Correlation::Format::NCHW4_NCHW)
  538. .value("NHWC_NCHW", Correlation::Format::NHWC_NCHW)
  539. .value("NHWC_NCHW4_IC_SMALL", Correlation::Format::NHWC_NCHW4_IC_SMALL)
  540. .value("NCHW_NCHW4_IC_SMALL", Correlation::Format::NCHW_NCHW4_IC_SMALL)
  541. .value("CHWN4", Correlation::Format::CHWN4)
  542. .value("NCHW4_NHWC", Correlation::Format::NCHW4_NHWC)
  543. .def(py::init([](const std::string& in) {
  544. auto&& str = normalize_enum(in);
  545. if (str == "NCHW") return Correlation::Format::NCHW;
  546. if (str == "NHWC") return Correlation::Format::NHWC;
  547. if (str == "NHWCD4") return Correlation::Format::NHWCD4;
  548. if (str == "NCHW4") return Correlation::Format::NCHW4;
  549. if (str == "NCHW8") return Correlation::Format::NCHW8;
  550. if (str == "NCHW32") return Correlation::Format::NCHW32;
  551. if (str == "NCHW88") return Correlation::Format::NCHW88;
  552. if (str == "NCHW44") return Correlation::Format::NCHW44;
  553. if (str == "NCHW44_DOT") return Correlation::Format::NCHW44_DOT;
  554. if (str == "NCHW_WINOGRAD") return Correlation::Format::NCHW_WINOGRAD;
  555. if (str == "NCHW88_WINOGRAD") return Correlation::Format::NCHW88_WINOGRAD;
  556. if (str == "NCHW44_WINOGRAD") return Correlation::Format::NCHW44_WINOGRAD;
  557. if (str == "NCHW4_NCHW32") return Correlation::Format::NCHW4_NCHW32;
  558. if (str == "NCHW32_NCHW4") return Correlation::Format::NCHW32_NCHW4;
  559. if (str == "NCHW4_NCHW") return Correlation::Format::NCHW4_NCHW;
  560. if (str == "NHWC_NCHW") return Correlation::Format::NHWC_NCHW;
  561. if (str == "NHWC_NCHW4_IC_SMALL") return Correlation::Format::NHWC_NCHW4_IC_SMALL;
  562. if (str == "NCHW_NCHW4_IC_SMALL") return Correlation::Format::NCHW_NCHW4_IC_SMALL;
  563. if (str == "CHWN4") return Correlation::Format::CHWN4;
  564. if (str == "NCHW4_NHWC") return Correlation::Format::NCHW4_NHWC;
  565. throw py::cast_error("invalid enum value " + in);
  566. }));
  567. py::implicitly_convertible<std::string, Correlation::Format>();
  568. CorrelationInst
  569. .def(py::init<::megdnn::param::Correlation::Format, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, bool, std::string>(), py::arg("format") = ::megdnn::param::Correlation::Format::NCHW, py::arg("kernel_size") = 1, py::arg("max_displacement") = 1, py::arg("stride1") = 1, py::arg("stride2") = 1, py::arg("pad_size") = 0, py::arg("is_multiply") = true, py::arg("scope") = {})
  570. .def_readwrite("format", &Correlation::format)
  571. .def_readwrite("kernel_size", &Correlation::kernel_size)
  572. .def_readwrite("max_displacement", &Correlation::max_displacement)
  573. .def_readwrite("stride1", &Correlation::stride1)
  574. .def_readwrite("stride2", &Correlation::stride2)
  575. .def_readwrite("pad_size", &Correlation::pad_size)
  576. .def_readwrite("is_multiply", &Correlation::is_multiply);
  577. py::class_<Cumsum, std::shared_ptr<Cumsum>, OpDef> CumsumInst(m, "Cumsum");
  578. CumsumInst
  579. .def(py::init<int32_t, bool, bool, std::string>(), py::arg("axis") = 2147483647, py::arg("exclusive") = true, py::arg("reverse") = false, py::arg("scope") = {})
  580. .def_readwrite("axis", &Cumsum::axis)
  581. .def_readwrite("exclusive", &Cumsum::exclusive)
  582. .def_readwrite("reverse", &Cumsum::reverse);
  583. py::class_<CvtColor, std::shared_ptr<CvtColor>, OpDef> CvtColorInst(m, "CvtColor");
  584. py::enum_<CvtColor::Mode>(CvtColorInst, "Mode")
  585. .value("RGB2GRAY", CvtColor::Mode::RGB2GRAY)
  586. .value("RGB2YUV", CvtColor::Mode::RGB2YUV)
  587. .value("YUV2RGB", CvtColor::Mode::YUV2RGB)
  588. .value("GRAY2RGB", CvtColor::Mode::GRAY2RGB)
  589. .value("RGBA2RGB", CvtColor::Mode::RGBA2RGB)
  590. .value("RGBA2BGR", CvtColor::Mode::RGBA2BGR)
  591. .value("RGBA2GRAY", CvtColor::Mode::RGBA2GRAY)
  592. .value("RGB2BGR", CvtColor::Mode::RGB2BGR)
  593. .value("BGR2GRAY", CvtColor::Mode::BGR2GRAY)
  594. .value("BGR2RGB", CvtColor::Mode::BGR2RGB)
  595. .value("YUV2GRAY_NV21", CvtColor::Mode::YUV2GRAY_NV21)
  596. .value("YUV2RGB_NV21", CvtColor::Mode::YUV2RGB_NV21)
  597. .value("YUV2BGR_NV21", CvtColor::Mode::YUV2BGR_NV21)
  598. .value("YUV2GRAY_NV12", CvtColor::Mode::YUV2GRAY_NV12)
  599. .value("YUV2RGB_NV12", CvtColor::Mode::YUV2RGB_NV12)
  600. .value("YUV2BGR_NV12", CvtColor::Mode::YUV2BGR_NV12)
  601. .value("YUV2GRAY_YV12", CvtColor::Mode::YUV2GRAY_YV12)
  602. .value("YUV2RGB_YV12", CvtColor::Mode::YUV2RGB_YV12)
  603. .value("YUV2BGR_YV12", CvtColor::Mode::YUV2BGR_YV12)
  604. .value("YUV2GRAY_YU12", CvtColor::Mode::YUV2GRAY_YU12)
  605. .value("YUV2RGB_YU12", CvtColor::Mode::YUV2RGB_YU12)
  606. .value("YUV2BGR_YU12", CvtColor::Mode::YUV2BGR_YU12)
  607. .value("YCrCb2RGB", CvtColor::Mode::YCrCb2RGB)
  608. .value("YCrCb2BGR", CvtColor::Mode::YCrCb2BGR)
  609. .value("BT601_YUV2RGB_NV21", CvtColor::Mode::BT601_YUV2RGB_NV21)
  610. .value("BT601_YUV2BGR_NV21", CvtColor::Mode::BT601_YUV2BGR_NV21)
  611. .value("BT601_YUV2RGB_NV12", CvtColor::Mode::BT601_YUV2RGB_NV12)
  612. .value("BT601_YUV2BGR_NV12", CvtColor::Mode::BT601_YUV2BGR_NV12)
  613. .value("BT601_YUV2RGB_YV12", CvtColor::Mode::BT601_YUV2RGB_YV12)
  614. .value("BT601_YUV2BGR_YV12", CvtColor::Mode::BT601_YUV2BGR_YV12)
  615. .value("BT601_YUV2RGB_YU12", CvtColor::Mode::BT601_YUV2RGB_YU12)
  616. .value("BT601_YUV2BGR_YU12", CvtColor::Mode::BT601_YUV2BGR_YU12)
  617. .def(py::init([](const std::string& in) {
  618. auto&& str = normalize_enum(in);
  619. if (str == "RGB2GRAY") return CvtColor::Mode::RGB2GRAY;
  620. if (str == "RGB2YUV") return CvtColor::Mode::RGB2YUV;
  621. if (str == "YUV2RGB") return CvtColor::Mode::YUV2RGB;
  622. if (str == "GRAY2RGB") return CvtColor::Mode::GRAY2RGB;
  623. if (str == "RGBA2RGB") return CvtColor::Mode::RGBA2RGB;
  624. if (str == "RGBA2BGR") return CvtColor::Mode::RGBA2BGR;
  625. if (str == "RGBA2GRAY") return CvtColor::Mode::RGBA2GRAY;
  626. if (str == "RGB2BGR") return CvtColor::Mode::RGB2BGR;
  627. if (str == "BGR2GRAY") return CvtColor::Mode::BGR2GRAY;
  628. if (str == "BGR2RGB") return CvtColor::Mode::BGR2RGB;
  629. if (str == "YUV2GRAY_NV21") return CvtColor::Mode::YUV2GRAY_NV21;
  630. if (str == "YUV2RGB_NV21") return CvtColor::Mode::YUV2RGB_NV21;
  631. if (str == "YUV2BGR_NV21") return CvtColor::Mode::YUV2BGR_NV21;
  632. if (str == "YUV2GRAY_NV12") return CvtColor::Mode::YUV2GRAY_NV12;
  633. if (str == "YUV2RGB_NV12") return CvtColor::Mode::YUV2RGB_NV12;
  634. if (str == "YUV2BGR_NV12") return CvtColor::Mode::YUV2BGR_NV12;
  635. if (str == "YUV2GRAY_YV12") return CvtColor::Mode::YUV2GRAY_YV12;
  636. if (str == "YUV2RGB_YV12") return CvtColor::Mode::YUV2RGB_YV12;
  637. if (str == "YUV2BGR_YV12") return CvtColor::Mode::YUV2BGR_YV12;
  638. if (str == "YUV2GRAY_YU12") return CvtColor::Mode::YUV2GRAY_YU12;
  639. if (str == "YUV2RGB_YU12") return CvtColor::Mode::YUV2RGB_YU12;
  640. if (str == "YUV2BGR_YU12") return CvtColor::Mode::YUV2BGR_YU12;
  641. if (str == "YCrCb2RGB") return CvtColor::Mode::YCrCb2RGB;
  642. if (str == "YCrCb2BGR") return CvtColor::Mode::YCrCb2BGR;
  643. if (str == "BT601_YUV2RGB_NV21") return CvtColor::Mode::BT601_YUV2RGB_NV21;
  644. if (str == "BT601_YUV2BGR_NV21") return CvtColor::Mode::BT601_YUV2BGR_NV21;
  645. if (str == "BT601_YUV2RGB_NV12") return CvtColor::Mode::BT601_YUV2RGB_NV12;
  646. if (str == "BT601_YUV2BGR_NV12") return CvtColor::Mode::BT601_YUV2BGR_NV12;
  647. if (str == "BT601_YUV2RGB_YV12") return CvtColor::Mode::BT601_YUV2RGB_YV12;
  648. if (str == "BT601_YUV2BGR_YV12") return CvtColor::Mode::BT601_YUV2BGR_YV12;
  649. if (str == "BT601_YUV2RGB_YU12") return CvtColor::Mode::BT601_YUV2RGB_YU12;
  650. if (str == "BT601_YUV2BGR_YU12") return CvtColor::Mode::BT601_YUV2BGR_YU12;
  651. throw py::cast_error("invalid enum value " + in);
  652. }));
  653. py::implicitly_convertible<std::string, CvtColor::Mode>();
  654. CvtColorInst
  655. .def(py::init<::megdnn::param::CvtColor::Mode, std::string>(), py::arg("mode") = ::megdnn::param::CvtColor::Mode::RGB2GRAY, py::arg("scope") = {})
  656. .def_readwrite("mode", &CvtColor::mode);
  657. py::class_<DeformableConv, std::shared_ptr<DeformableConv>, OpDef> DeformableConvInst(m, "DeformableConv");
  658. DeformableConvInst.attr("Mode") = BatchConvBiasInst.attr("Mode");
  659. DeformableConvInst.attr("Sparse") = BatchConvBiasInst.attr("Sparse");
  660. DeformableConvInst.attr("Format") = AdaptivePoolingInst.attr("Format");
  661. DeformableConvInst.attr("ComputeMode") = BatchConvBiasInst.attr("ComputeMode");
  662. DeformableConvInst.attr("Strategy") = BatchConvBiasInst.attr("Strategy");
  663. DeformableConvInst
  664. .def(py::init<::megdnn::param::Convolution::Mode, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, ::megdnn::param::Convolution::Sparse, ::megdnn::param::Convolution::Format, ::megdnn::param::Convolution::ComputeMode, ::megdnn::param::ExecutionPolicy::Strategy, uint64_t, std::string>(), py::arg("mode") = ::megdnn::param::Convolution::Mode::CROSS_CORRELATION, py::arg("pad_h") = 0, py::arg("pad_w") = 0, py::arg("stride_h") = 1, py::arg("stride_w") = 1, py::arg("dilate_h") = 1, py::arg("dilate_w") = 1, py::arg("sparse") = ::megdnn::param::Convolution::Sparse::DENSE, py::arg("format") = ::megdnn::param::Convolution::Format::NCHW, py::arg("compute_mode") = ::megdnn::param::Convolution::ComputeMode::DEFAULT, py::arg("strategy") = static_cast<::megdnn::param::ExecutionPolicy::Strategy>(1), py::arg("workspace_limit") = 18446744073709551615ull, py::arg("scope") = {})
  665. .def_readwrite("mode", &DeformableConv::mode)
  666. .def_readwrite("pad_h", &DeformableConv::pad_h)
  667. .def_readwrite("pad_w", &DeformableConv::pad_w)
  668. .def_readwrite("stride_h", &DeformableConv::stride_h)
  669. .def_readwrite("stride_w", &DeformableConv::stride_w)
  670. .def_readwrite("dilate_h", &DeformableConv::dilate_h)
  671. .def_readwrite("dilate_w", &DeformableConv::dilate_w)
  672. .def_readwrite("sparse", &DeformableConv::sparse)
  673. .def_readwrite("format", &DeformableConv::format)
  674. .def_readwrite("compute_mode", &DeformableConv::compute_mode)
  675. .def_readwrite("strategy", &DeformableConv::strategy)
  676. .def_readwrite("workspace_limit", &DeformableConv::workspace_limit);
  677. py::class_<DeformablePSROIPooling, std::shared_ptr<DeformablePSROIPooling>, OpDef> DeformablePSROIPoolingInst(m, "DeformablePSROIPooling");
  678. DeformablePSROIPoolingInst
  679. .def(py::init<bool, float, float, uint32_t, uint32_t, uint32_t, uint32_t, std::string>(), py::arg("no_trans") = true, py::arg("spatial_scale") = 1, py::arg("trans_std") = 1, py::arg("pooled_h") = 1, py::arg("pooled_w") = 1, py::arg("part_size") = 1, py::arg("sample_per_part") = 1, py::arg("scope") = {})
  680. .def_readwrite("no_trans", &DeformablePSROIPooling::no_trans)
  681. .def_readwrite("spatial_scale", &DeformablePSROIPooling::spatial_scale)
  682. .def_readwrite("trans_std", &DeformablePSROIPooling::trans_std)
  683. .def_readwrite("pooled_h", &DeformablePSROIPooling::pooled_h)
  684. .def_readwrite("pooled_w", &DeformablePSROIPooling::pooled_w)
  685. .def_readwrite("part_size", &DeformablePSROIPooling::part_size)
  686. .def_readwrite("sample_per_part", &DeformablePSROIPooling::sample_per_part);
  687. py::class_<Diag, std::shared_ptr<Diag>, OpDef> DiagInst(m, "Diag");
  688. DiagInst
  689. .def(py::init<int32_t, std::string>(), py::arg("k") = 0, py::arg("scope") = {})
  690. .def_readwrite("k", &Diag::k);
  691. py::class_<Dimshuffle, std::shared_ptr<Dimshuffle>, OpDef> DimshuffleInst(m, "Dimshuffle");
  692. DimshuffleInst
  693. .def(py::init<std::vector<int32_t>, std::string>(), py::arg("pattern"), py::arg("scope") = {})
  694. .def(py::init<>())
  695. .def_readwrite("pattern", &Dimshuffle::pattern);
  696. py::class_<Dot, std::shared_ptr<Dot>, OpDef> DotInst(m, "Dot");
  697. DotInst
  698. .def(py::init<>());
  699. py::class_<Dropout, std::shared_ptr<Dropout>, OpDef> DropoutInst(m, "Dropout");
  700. DropoutInst
  701. .def(py::init<float, uint64_t, size_t, std::string>(), py::arg("drop_prob") = 0, py::arg("seed") = 0, py::arg("handle"), py::arg("scope") = {})
  702. .def(py::init<>())
  703. .def_readwrite("drop_prob", &Dropout::drop_prob)
  704. .def_readwrite("seed", &Dropout::seed)
  705. .def_readwrite("handle", &Dropout::handle);
  706. py::class_<Elemwise, std::shared_ptr<Elemwise>, OpDef> ElemwiseInst(m, "Elemwise");
  707. py::enum_<Elemwise::Mode>(ElemwiseInst, "Mode")
  708. .value("RELU", Elemwise::Mode::RELU)
  709. .value("ABS", Elemwise::Mode::ABS)
  710. .value("ACOS", Elemwise::Mode::ACOS)
  711. .value("ASIN", Elemwise::Mode::ASIN)
  712. .value("CEIL", Elemwise::Mode::CEIL)
  713. .value("COS", Elemwise::Mode::COS)
  714. .value("EXP", Elemwise::Mode::EXP)
  715. .value("EXPM1", Elemwise::Mode::EXPM1)
  716. .value("FLOOR", Elemwise::Mode::FLOOR)
  717. .value("LOG", Elemwise::Mode::LOG)
  718. .value("LOG1P", Elemwise::Mode::LOG1P)
  719. .value("NEGATE", Elemwise::Mode::NEGATE)
  720. .value("SIGMOID", Elemwise::Mode::SIGMOID)
  721. .value("SIN", Elemwise::Mode::SIN)
  722. .value("TANH", Elemwise::Mode::TANH)
  723. .value("ABS_GRAD", Elemwise::Mode::ABS_GRAD)
  724. .value("ADD", Elemwise::Mode::ADD)
  725. .value("FLOOR_DIV", Elemwise::Mode::FLOOR_DIV)
  726. .value("MAX", Elemwise::Mode::MAX)
  727. .value("MIN", Elemwise::Mode::MIN)
  728. .value("MOD", Elemwise::Mode::MOD)
  729. .value("MUL", Elemwise::Mode::MUL)
  730. .value("POW", Elemwise::Mode::POW)
  731. .value("SIGMOID_GRAD", Elemwise::Mode::SIGMOID_GRAD)
  732. .value("SUB", Elemwise::Mode::SUB)
  733. .value("SWITCH_GT0", Elemwise::Mode::SWITCH_GT0)
  734. .value("TANH_GRAD", Elemwise::Mode::TANH_GRAD)
  735. .value("TRUE_DIV", Elemwise::Mode::TRUE_DIV)
  736. .value("LOG_SUM_EXP", Elemwise::Mode::LOG_SUM_EXP)
  737. .value("LT", Elemwise::Mode::LT)
  738. .value("LEQ", Elemwise::Mode::LEQ)
  739. .value("EQ", Elemwise::Mode::EQ)
  740. .value("SHL", Elemwise::Mode::SHL)
  741. .value("SHR", Elemwise::Mode::SHR)
  742. .value("COND_LEQ_MOV", Elemwise::Mode::COND_LEQ_MOV)
  743. .value("FUSE_MUL_ADD3", Elemwise::Mode::FUSE_MUL_ADD3)
  744. .value("FUSE_MUL_ADD4", Elemwise::Mode::FUSE_MUL_ADD4)
  745. .value("FUSE_ADD_RELU", Elemwise::Mode::FUSE_ADD_RELU)
  746. .value("FUSE_ADD_SIGMOID", Elemwise::Mode::FUSE_ADD_SIGMOID)
  747. .value("FUSE_ADD_TANH", Elemwise::Mode::FUSE_ADD_TANH)
  748. .value("FAST_TANH", Elemwise::Mode::FAST_TANH)
  749. .value("FAST_TANH_GRAD", Elemwise::Mode::FAST_TANH_GRAD)
  750. .value("ROUND", Elemwise::Mode::ROUND)
  751. .value("RMULH", Elemwise::Mode::RMULH)
  752. .value("ATAN2", Elemwise::Mode::ATAN2)
  753. .value("ERF", Elemwise::Mode::ERF)
  754. .value("ERFINV", Elemwise::Mode::ERFINV)
  755. .value("ERFC", Elemwise::Mode::ERFC)
  756. .value("ERFCINV", Elemwise::Mode::ERFCINV)
  757. .value("H_SWISH", Elemwise::Mode::H_SWISH)
  758. .value("H_SWISH_GRAD", Elemwise::Mode::H_SWISH_GRAD)
  759. .value("FUSE_ADD_H_SWISH", Elemwise::Mode::FUSE_ADD_H_SWISH)
  760. .value("NOT", Elemwise::Mode::NOT)
  761. .value("AND", Elemwise::Mode::AND)
  762. .value("OR", Elemwise::Mode::OR)
  763. .value("XOR", Elemwise::Mode::XOR)
  764. .value("SILU", Elemwise::Mode::SILU)
  765. .value("SILU_GRAD", Elemwise::Mode::SILU_GRAD)
  766. .value("GELU", Elemwise::Mode::GELU)
  767. .value("GELU_GRAD", Elemwise::Mode::GELU_GRAD)
  768. .value("COND_LT_MOV", Elemwise::Mode::COND_LT_MOV)
  769. .value("NEQ", Elemwise::Mode::NEQ)
  770. .value("ISNAN", Elemwise::Mode::ISNAN)
  771. .value("ISINF", Elemwise::Mode::ISINF)
  772. .def(py::init([](const std::string& in) {
  773. auto&& str = normalize_enum(in);
  774. if (str == "RELU") return Elemwise::Mode::RELU;
  775. if (str == "ABS") return Elemwise::Mode::ABS;
  776. if (str == "ACOS") return Elemwise::Mode::ACOS;
  777. if (str == "ASIN") return Elemwise::Mode::ASIN;
  778. if (str == "CEIL") return Elemwise::Mode::CEIL;
  779. if (str == "COS") return Elemwise::Mode::COS;
  780. if (str == "EXP") return Elemwise::Mode::EXP;
  781. if (str == "EXPM1") return Elemwise::Mode::EXPM1;
  782. if (str == "FLOOR") return Elemwise::Mode::FLOOR;
  783. if (str == "LOG") return Elemwise::Mode::LOG;
  784. if (str == "LOG1P") return Elemwise::Mode::LOG1P;
  785. if (str == "NEGATE") return Elemwise::Mode::NEGATE;
  786. if (str == "SIGMOID") return Elemwise::Mode::SIGMOID;
  787. if (str == "SIN") return Elemwise::Mode::SIN;
  788. if (str == "TANH") return Elemwise::Mode::TANH;
  789. if (str == "ABS_GRAD") return Elemwise::Mode::ABS_GRAD;
  790. if (str == "ADD") return Elemwise::Mode::ADD;
  791. if (str == "FLOOR_DIV") return Elemwise::Mode::FLOOR_DIV;
  792. if (str == "MAX") return Elemwise::Mode::MAX;
  793. if (str == "MIN") return Elemwise::Mode::MIN;
  794. if (str == "MOD") return Elemwise::Mode::MOD;
  795. if (str == "MUL") return Elemwise::Mode::MUL;
  796. if (str == "POW") return Elemwise::Mode::POW;
  797. if (str == "SIGMOID_GRAD") return Elemwise::Mode::SIGMOID_GRAD;
  798. if (str == "SUB") return Elemwise::Mode::SUB;
  799. if (str == "SWITCH_GT0") return Elemwise::Mode::SWITCH_GT0;
  800. if (str == "TANH_GRAD") return Elemwise::Mode::TANH_GRAD;
  801. if (str == "TRUE_DIV") return Elemwise::Mode::TRUE_DIV;
  802. if (str == "LOG_SUM_EXP") return Elemwise::Mode::LOG_SUM_EXP;
  803. if (str == "LT") return Elemwise::Mode::LT;
  804. if (str == "LEQ") return Elemwise::Mode::LEQ;
  805. if (str == "EQ") return Elemwise::Mode::EQ;
  806. if (str == "SHL") return Elemwise::Mode::SHL;
  807. if (str == "SHR") return Elemwise::Mode::SHR;
  808. if (str == "COND_LEQ_MOV") return Elemwise::Mode::COND_LEQ_MOV;
  809. if (str == "FUSE_MUL_ADD3") return Elemwise::Mode::FUSE_MUL_ADD3;
  810. if (str == "FUSE_MUL_ADD4") return Elemwise::Mode::FUSE_MUL_ADD4;
  811. if (str == "FUSE_ADD_RELU") return Elemwise::Mode::FUSE_ADD_RELU;
  812. if (str == "FUSE_ADD_SIGMOID") return Elemwise::Mode::FUSE_ADD_SIGMOID;
  813. if (str == "FUSE_ADD_TANH") return Elemwise::Mode::FUSE_ADD_TANH;
  814. if (str == "FAST_TANH") return Elemwise::Mode::FAST_TANH;
  815. if (str == "FAST_TANH_GRAD") return Elemwise::Mode::FAST_TANH_GRAD;
  816. if (str == "ROUND") return Elemwise::Mode::ROUND;
  817. if (str == "RMULH") return Elemwise::Mode::RMULH;
  818. if (str == "ATAN2") return Elemwise::Mode::ATAN2;
  819. if (str == "ERF") return Elemwise::Mode::ERF;
  820. if (str == "ERFINV") return Elemwise::Mode::ERFINV;
  821. if (str == "ERFC") return Elemwise::Mode::ERFC;
  822. if (str == "ERFCINV") return Elemwise::Mode::ERFCINV;
  823. if (str == "H_SWISH") return Elemwise::Mode::H_SWISH;
  824. if (str == "H_SWISH_GRAD") return Elemwise::Mode::H_SWISH_GRAD;
  825. if (str == "FUSE_ADD_H_SWISH") return Elemwise::Mode::FUSE_ADD_H_SWISH;
  826. if (str == "NOT") return Elemwise::Mode::NOT;
  827. if (str == "AND") return Elemwise::Mode::AND;
  828. if (str == "OR") return Elemwise::Mode::OR;
  829. if (str == "XOR") return Elemwise::Mode::XOR;
  830. if (str == "SILU") return Elemwise::Mode::SILU;
  831. if (str == "SILU_GRAD") return Elemwise::Mode::SILU_GRAD;
  832. if (str == "GELU") return Elemwise::Mode::GELU;
  833. if (str == "GELU_GRAD") return Elemwise::Mode::GELU_GRAD;
  834. if (str == "COND_LT_MOV") return Elemwise::Mode::COND_LT_MOV;
  835. if (str == "NEQ") return Elemwise::Mode::NEQ;
  836. if (str == "ISNAN") return Elemwise::Mode::ISNAN;
  837. if (str == "ISINF") return Elemwise::Mode::ISINF;
  838. throw py::cast_error("invalid enum value " + in);
  839. }));
  840. py::implicitly_convertible<std::string, Elemwise::Mode>();
  841. ElemwiseInst
  842. .def(py::init<::megdnn::param::Elemwise::Mode, std::string>(), py::arg("mode") = ::megdnn::param::Elemwise::Mode::RELU, py::arg("scope") = {})
  843. .def_readwrite("mode", &Elemwise::mode);
  844. py::class_<ElemwiseMultiType, std::shared_ptr<ElemwiseMultiType>, OpDef> ElemwiseMultiTypeInst(m, "ElemwiseMultiType");
  845. py::enum_<ElemwiseMultiType::Mode>(ElemwiseMultiTypeInst, "Mode")
  846. .value("FUSE_MUL_ADD3_INT16x32x32x32", ElemwiseMultiType::Mode::FUSE_MUL_ADD3_INT16x32x32x32)
  847. .value("FUSE_MUL_ADD3_IXxF32xF32xI8", ElemwiseMultiType::Mode::FUSE_MUL_ADD3_IXxF32xF32xI8)
  848. .value("ROUND_SHR_SATURATE_IXxI8xI8", ElemwiseMultiType::Mode::ROUND_SHR_SATURATE_IXxI8xI8)
  849. .value("FUSE_ADD_RMULH_ROUND_SHR_SATURATE_INT16x16x16x8", ElemwiseMultiType::Mode::FUSE_ADD_RMULH_ROUND_SHR_SATURATE_INT16x16x16x8)
  850. .value("FUSE_ADD_RMULH_ROUND_SHR_SATURATE_INT32x32x32x8", ElemwiseMultiType::Mode::FUSE_ADD_RMULH_ROUND_SHR_SATURATE_INT32x32x32x8)
  851. .value("ROUND_SHR_SATURATE_IXxI8xI16", ElemwiseMultiType::Mode::ROUND_SHR_SATURATE_IXxI8xI16)
  852. .value("QADD", ElemwiseMultiType::Mode::QADD)
  853. .value("QFUSE_ADD_RELU", ElemwiseMultiType::Mode::QFUSE_ADD_RELU)
  854. .value("QMUL", ElemwiseMultiType::Mode::QMUL)
  855. .value("QMIN", ElemwiseMultiType::Mode::QMIN)
  856. .value("QMAX", ElemwiseMultiType::Mode::QMAX)
  857. .value("QSUB", ElemwiseMultiType::Mode::QSUB)
  858. .value("QTRUE_DIV", ElemwiseMultiType::Mode::QTRUE_DIV)
  859. .value("QFUSE_ADD_SIGMOID", ElemwiseMultiType::Mode::QFUSE_ADD_SIGMOID)
  860. .value("QFUSE_ADD_TANH", ElemwiseMultiType::Mode::QFUSE_ADD_TANH)
  861. .value("QRELU", ElemwiseMultiType::Mode::QRELU)
  862. .value("QABS", ElemwiseMultiType::Mode::QABS)
  863. .value("QSIGMOID", ElemwiseMultiType::Mode::QSIGMOID)
  864. .value("QEXP", ElemwiseMultiType::Mode::QEXP)
  865. .value("QTANH", ElemwiseMultiType::Mode::QTANH)
  866. .value("QFUSE_MUL_ADD3", ElemwiseMultiType::Mode::QFUSE_MUL_ADD3)
  867. .value("QFAST_TANH", ElemwiseMultiType::Mode::QFAST_TANH)
  868. .value("QNEGATE", ElemwiseMultiType::Mode::QNEGATE)
  869. .value("QACOS", ElemwiseMultiType::Mode::QACOS)
  870. .value("QASIN", ElemwiseMultiType::Mode::QASIN)
  871. .value("QCEIL", ElemwiseMultiType::Mode::QCEIL)
  872. .value("QCOS", ElemwiseMultiType::Mode::QCOS)
  873. .value("QEXPM1", ElemwiseMultiType::Mode::QEXPM1)
  874. .value("QFLOOR", ElemwiseMultiType::Mode::QFLOOR)
  875. .value("QLOG", ElemwiseMultiType::Mode::QLOG)
  876. .value("QLOG1P", ElemwiseMultiType::Mode::QLOG1P)
  877. .value("QSIN", ElemwiseMultiType::Mode::QSIN)
  878. .value("QROUND", ElemwiseMultiType::Mode::QROUND)
  879. .value("QERF", ElemwiseMultiType::Mode::QERF)
  880. .value("QERFINV", ElemwiseMultiType::Mode::QERFINV)
  881. .value("QERFC", ElemwiseMultiType::Mode::QERFC)
  882. .value("QERFCINV", ElemwiseMultiType::Mode::QERFCINV)
  883. .value("QABS_GRAD", ElemwiseMultiType::Mode::QABS_GRAD)
  884. .value("QFLOOR_DIV", ElemwiseMultiType::Mode::QFLOOR_DIV)
  885. .value("QMOD", ElemwiseMultiType::Mode::QMOD)
  886. .value("QSIGMOID_GRAD", ElemwiseMultiType::Mode::QSIGMOID_GRAD)
  887. .value("QSWITCH_GT0", ElemwiseMultiType::Mode::QSWITCH_GT0)
  888. .value("QTANH_GRAD", ElemwiseMultiType::Mode::QTANH_GRAD)
  889. .value("QLT", ElemwiseMultiType::Mode::QLT)
  890. .value("QLEQ", ElemwiseMultiType::Mode::QLEQ)
  891. .value("QEQ", ElemwiseMultiType::Mode::QEQ)
  892. .value("QPOW", ElemwiseMultiType::Mode::QPOW)
  893. .value("QLOG_SUM_EXP", ElemwiseMultiType::Mode::QLOG_SUM_EXP)
  894. .value("QFAST_TANH_GRAD", ElemwiseMultiType::Mode::QFAST_TANH_GRAD)
  895. .value("QATAN2", ElemwiseMultiType::Mode::QATAN2)
  896. .value("QCOND_LEQ_MOV", ElemwiseMultiType::Mode::QCOND_LEQ_MOV)
  897. .value("QH_SWISH", ElemwiseMultiType::Mode::QH_SWISH)
  898. .value("QFUSE_ADD_H_SWISH", ElemwiseMultiType::Mode::QFUSE_ADD_H_SWISH)
  899. .value("QH_SWISH_GRAD", ElemwiseMultiType::Mode::QH_SWISH_GRAD)
  900. .value("FUSE_MUL_ADD3_INT16xF32xF32xF32", ElemwiseMultiType::Mode::FUSE_MUL_ADD3_INT16xF32xF32xF32)
  901. .value("MUL_INT16xF32xF32", ElemwiseMultiType::Mode::MUL_INT16xF32xF32)
  902. .value("FUSE_MUL_ADD3_UINT8xF32xF32xF32", ElemwiseMultiType::Mode::FUSE_MUL_ADD3_UINT8xF32xF32xF32)
  903. .value("QCOND_LT_MOV", ElemwiseMultiType::Mode::QCOND_LT_MOV)
  904. .value("EQ", ElemwiseMultiType::Mode::EQ)
  905. .value("NEQ", ElemwiseMultiType::Mode::NEQ)
  906. .value("LT", ElemwiseMultiType::Mode::LT)
  907. .value("LEQ", ElemwiseMultiType::Mode::LEQ)
  908. .value("ISNAN", ElemwiseMultiType::Mode::ISNAN)
  909. .value("ISINF", ElemwiseMultiType::Mode::ISINF)
  910. .def(py::init([](const std::string& in) {
  911. auto&& str = normalize_enum(in);
  912. if (str == "FUSE_MUL_ADD3_INT16x32x32x32") return ElemwiseMultiType::Mode::FUSE_MUL_ADD3_INT16x32x32x32;
  913. if (str == "FUSE_MUL_ADD3_IXxF32xF32xI8") return ElemwiseMultiType::Mode::FUSE_MUL_ADD3_IXxF32xF32xI8;
  914. if (str == "ROUND_SHR_SATURATE_IXxI8xI8") return ElemwiseMultiType::Mode::ROUND_SHR_SATURATE_IXxI8xI8;
  915. if (str == "FUSE_ADD_RMULH_ROUND_SHR_SATURATE_INT16x16x16x8") return ElemwiseMultiType::Mode::FUSE_ADD_RMULH_ROUND_SHR_SATURATE_INT16x16x16x8;
  916. if (str == "FUSE_ADD_RMULH_ROUND_SHR_SATURATE_INT32x32x32x8") return ElemwiseMultiType::Mode::FUSE_ADD_RMULH_ROUND_SHR_SATURATE_INT32x32x32x8;
  917. if (str == "ROUND_SHR_SATURATE_IXxI8xI16") return ElemwiseMultiType::Mode::ROUND_SHR_SATURATE_IXxI8xI16;
  918. if (str == "QADD") return ElemwiseMultiType::Mode::QADD;
  919. if (str == "QFUSE_ADD_RELU") return ElemwiseMultiType::Mode::QFUSE_ADD_RELU;
  920. if (str == "QMUL") return ElemwiseMultiType::Mode::QMUL;
  921. if (str == "QMIN") return ElemwiseMultiType::Mode::QMIN;
  922. if (str == "QMAX") return ElemwiseMultiType::Mode::QMAX;
  923. if (str == "QSUB") return ElemwiseMultiType::Mode::QSUB;
  924. if (str == "QTRUE_DIV") return ElemwiseMultiType::Mode::QTRUE_DIV;
  925. if (str == "QFUSE_ADD_SIGMOID") return ElemwiseMultiType::Mode::QFUSE_ADD_SIGMOID;
  926. if (str == "QFUSE_ADD_TANH") return ElemwiseMultiType::Mode::QFUSE_ADD_TANH;
  927. if (str == "QRELU") return ElemwiseMultiType::Mode::QRELU;
  928. if (str == "QABS") return ElemwiseMultiType::Mode::QABS;
  929. if (str == "QSIGMOID") return ElemwiseMultiType::Mode::QSIGMOID;
  930. if (str == "QEXP") return ElemwiseMultiType::Mode::QEXP;
  931. if (str == "QTANH") return ElemwiseMultiType::Mode::QTANH;
  932. if (str == "QFUSE_MUL_ADD3") return ElemwiseMultiType::Mode::QFUSE_MUL_ADD3;
  933. if (str == "QFAST_TANH") return ElemwiseMultiType::Mode::QFAST_TANH;
  934. if (str == "QNEGATE") return ElemwiseMultiType::Mode::QNEGATE;
  935. if (str == "QACOS") return ElemwiseMultiType::Mode::QACOS;
  936. if (str == "QASIN") return ElemwiseMultiType::Mode::QASIN;
  937. if (str == "QCEIL") return ElemwiseMultiType::Mode::QCEIL;
  938. if (str == "QCOS") return ElemwiseMultiType::Mode::QCOS;
  939. if (str == "QEXPM1") return ElemwiseMultiType::Mode::QEXPM1;
  940. if (str == "QFLOOR") return ElemwiseMultiType::Mode::QFLOOR;
  941. if (str == "QLOG") return ElemwiseMultiType::Mode::QLOG;
  942. if (str == "QLOG1P") return ElemwiseMultiType::Mode::QLOG1P;
  943. if (str == "QSIN") return ElemwiseMultiType::Mode::QSIN;
  944. if (str == "QROUND") return ElemwiseMultiType::Mode::QROUND;
  945. if (str == "QERF") return ElemwiseMultiType::Mode::QERF;
  946. if (str == "QERFINV") return ElemwiseMultiType::Mode::QERFINV;
  947. if (str == "QERFC") return ElemwiseMultiType::Mode::QERFC;
  948. if (str == "QERFCINV") return ElemwiseMultiType::Mode::QERFCINV;
  949. if (str == "QABS_GRAD") return ElemwiseMultiType::Mode::QABS_GRAD;
  950. if (str == "QFLOOR_DIV") return ElemwiseMultiType::Mode::QFLOOR_DIV;
  951. if (str == "QMOD") return ElemwiseMultiType::Mode::QMOD;
  952. if (str == "QSIGMOID_GRAD") return ElemwiseMultiType::Mode::QSIGMOID_GRAD;
  953. if (str == "QSWITCH_GT0") return ElemwiseMultiType::Mode::QSWITCH_GT0;
  954. if (str == "QTANH_GRAD") return ElemwiseMultiType::Mode::QTANH_GRAD;
  955. if (str == "QLT") return ElemwiseMultiType::Mode::QLT;
  956. if (str == "QLEQ") return ElemwiseMultiType::Mode::QLEQ;
  957. if (str == "QEQ") return ElemwiseMultiType::Mode::QEQ;
  958. if (str == "QPOW") return ElemwiseMultiType::Mode::QPOW;
  959. if (str == "QLOG_SUM_EXP") return ElemwiseMultiType::Mode::QLOG_SUM_EXP;
  960. if (str == "QFAST_TANH_GRAD") return ElemwiseMultiType::Mode::QFAST_TANH_GRAD;
  961. if (str == "QATAN2") return ElemwiseMultiType::Mode::QATAN2;
  962. if (str == "QCOND_LEQ_MOV") return ElemwiseMultiType::Mode::QCOND_LEQ_MOV;
  963. if (str == "QH_SWISH") return ElemwiseMultiType::Mode::QH_SWISH;
  964. if (str == "QFUSE_ADD_H_SWISH") return ElemwiseMultiType::Mode::QFUSE_ADD_H_SWISH;
  965. if (str == "QH_SWISH_GRAD") return ElemwiseMultiType::Mode::QH_SWISH_GRAD;
  966. if (str == "FUSE_MUL_ADD3_INT16xF32xF32xF32") return ElemwiseMultiType::Mode::FUSE_MUL_ADD3_INT16xF32xF32xF32;
  967. if (str == "MUL_INT16xF32xF32") return ElemwiseMultiType::Mode::MUL_INT16xF32xF32;
  968. if (str == "FUSE_MUL_ADD3_UINT8xF32xF32xF32") return ElemwiseMultiType::Mode::FUSE_MUL_ADD3_UINT8xF32xF32xF32;
  969. if (str == "QCOND_LT_MOV") return ElemwiseMultiType::Mode::QCOND_LT_MOV;
  970. if (str == "EQ") return ElemwiseMultiType::Mode::EQ;
  971. if (str == "NEQ") return ElemwiseMultiType::Mode::NEQ;
  972. if (str == "LT") return ElemwiseMultiType::Mode::LT;
  973. if (str == "LEQ") return ElemwiseMultiType::Mode::LEQ;
  974. if (str == "ISNAN") return ElemwiseMultiType::Mode::ISNAN;
  975. if (str == "ISINF") return ElemwiseMultiType::Mode::ISINF;
  976. throw py::cast_error("invalid enum value " + in);
  977. }));
  978. py::implicitly_convertible<std::string, ElemwiseMultiType::Mode>();
  979. ElemwiseMultiTypeInst
  980. .def(py::init<::megdnn::param::ElemwiseMultiType::Mode, ::megdnn::DType, std::string>(), py::arg("mode") = ::megdnn::param::ElemwiseMultiType::Mode::FUSE_MUL_ADD3_INT16x32x32x32, py::arg("dtype"), py::arg("scope") = {})
  981. .def(py::init<>())
  982. .def_readwrite("mode", &ElemwiseMultiType::mode)
  983. .def_readwrite("dtype", &ElemwiseMultiType::dtype);
  984. py::class_<ExternOpr, std::shared_ptr<ExternOpr>, OpDef> ExternOprInst(m, "ExternOpr");
  985. ExternOprInst
  986. .def(py::init<std::vector<std::vector<size_t>>, std::string, std::string, size_t, std::vector<::megdnn::DType>, std::string>(), py::arg("output_shapes"), py::arg("name"), py::arg("data"), py::arg("data_len"), py::arg("output_dtypes"), py::arg("scope") = {})
  987. .def(py::init<>())
  988. .def_readwrite("output_shapes", &ExternOpr::output_shapes)
  989. .def_readwrite("name", &ExternOpr::name)
  990. .def_readwrite("data", &ExternOpr::data)
  991. .def_readwrite("data_len", &ExternOpr::data_len)
  992. .def_readwrite("output_dtypes", &ExternOpr::output_dtypes);
  993. py::class_<Eye, std::shared_ptr<Eye>, OpDef> EyeInst(m, "Eye");
  994. EyeInst
  995. .def(py::init<int32_t, ::megdnn::DType, ::mgb::CompNode, std::string>(), py::arg("k") = 0, py::arg("dtype") = megdnn::DType::from_enum(megdnn::DTypeEnum::Float32), py::arg("comp_node"), py::arg("scope") = {})
  996. .def(py::init<>())
  997. .def_readwrite("k", &Eye::k)
  998. .def_readwrite("dtype", &Eye::dtype)
  999. .def_readwrite("comp_node", &Eye::comp_node);
  1000. py::class_<FakeQuant, std::shared_ptr<FakeQuant>, OpDef> FakeQuantInst(m, "FakeQuant");
  1001. FakeQuantInst
  1002. .def(py::init<int32_t, int32_t, std::string>(), py::arg("qmin") = -2147483648, py::arg("qmax") = 2147483647, py::arg("scope") = {})
  1003. .def_readwrite("qmin", &FakeQuant::qmin)
  1004. .def_readwrite("qmax", &FakeQuant::qmax);
  1005. py::class_<FastpathCopy, std::shared_ptr<FastpathCopy>, OpDef> FastpathCopyInst(m, "FastpathCopy");
  1006. FastpathCopyInst
  1007. .def(py::init<>());
  1008. py::class_<GammaRNG, std::shared_ptr<GammaRNG>, OpDef> GammaRNGInst(m, "GammaRNG");
  1009. GammaRNGInst
  1010. .def(py::init<uint64_t, size_t, std::string>(), py::arg("seed") = 0, py::arg("handle"), py::arg("scope") = {})
  1011. .def(py::init<>())
  1012. .def_readwrite("seed", &GammaRNG::seed)
  1013. .def_readwrite("handle", &GammaRNG::handle);
  1014. py::class_<GaussianRNG, std::shared_ptr<GaussianRNG>, OpDef> GaussianRNGInst(m, "GaussianRNG");
  1015. GaussianRNGInst
  1016. .def(py::init<uint64_t, float, float, ::megdnn::DType, size_t, std::string>(), py::arg("seed") = 0, py::arg("mean") = 0, py::arg("std") = 1, py::arg("dtype") = megdnn::DType::from_enum(megdnn::DTypeEnum::Float32), py::arg("handle"), py::arg("scope") = {})
  1017. .def(py::init<>())
  1018. .def_readwrite("seed", &GaussianRNG::seed)
  1019. .def_readwrite("mean", &GaussianRNG::mean)
  1020. .def_readwrite("std", &GaussianRNG::std)
  1021. .def_readwrite("dtype", &GaussianRNG::dtype)
  1022. .def_readwrite("handle", &GaussianRNG::handle);
  1023. py::class_<GetVarShape, std::shared_ptr<GetVarShape>, OpDef> GetVarShapeInst(m, "GetVarShape");
  1024. GetVarShapeInst
  1025. .def(py::init<int32_t, std::string>(), py::arg("axis") = ::megdnn::param::OptionalAxisV1::INVALID_AXIS, py::arg("scope") = {})
  1026. .def_readwrite("axis", &GetVarShape::axis);
  1027. py::class_<GroupLocal, std::shared_ptr<GroupLocal>, OpDef> GroupLocalInst(m, "GroupLocal");
  1028. GroupLocalInst.attr("Mode") = BatchConvBiasInst.attr("Mode");
  1029. GroupLocalInst.attr("Sparse") = BatchConvBiasInst.attr("Sparse");
  1030. GroupLocalInst.attr("Format") = AdaptivePoolingInst.attr("Format");
  1031. GroupLocalInst.attr("ComputeMode") = BatchConvBiasInst.attr("ComputeMode");
  1032. GroupLocalInst
  1033. .def(py::init<::megdnn::param::Convolution::Mode, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, ::megdnn::param::Convolution::Sparse, ::megdnn::param::Convolution::Format, ::megdnn::param::Convolution::ComputeMode, std::string>(), py::arg("mode") = ::megdnn::param::Convolution::Mode::CROSS_CORRELATION, py::arg("pad_h") = 0, py::arg("pad_w") = 0, py::arg("stride_h") = 1, py::arg("stride_w") = 1, py::arg("dilate_h") = 1, py::arg("dilate_w") = 1, py::arg("sparse") = ::megdnn::param::Convolution::Sparse::DENSE, py::arg("format") = ::megdnn::param::Convolution::Format::NCHW, py::arg("compute_mode") = ::megdnn::param::Convolution::ComputeMode::DEFAULT, py::arg("scope") = {})
  1034. .def_readwrite("mode", &GroupLocal::mode)
  1035. .def_readwrite("pad_h", &GroupLocal::pad_h)
  1036. .def_readwrite("pad_w", &GroupLocal::pad_w)
  1037. .def_readwrite("stride_h", &GroupLocal::stride_h)
  1038. .def_readwrite("stride_w", &GroupLocal::stride_w)
  1039. .def_readwrite("dilate_h", &GroupLocal::dilate_h)
  1040. .def_readwrite("dilate_w", &GroupLocal::dilate_w)
  1041. .def_readwrite("sparse", &GroupLocal::sparse)
  1042. .def_readwrite("format", &GroupLocal::format)
  1043. .def_readwrite("compute_mode", &GroupLocal::compute_mode);
  1044. py::class_<Identity, std::shared_ptr<Identity>, OpDef> IdentityInst(m, "Identity");
  1045. IdentityInst
  1046. .def(py::init<>());
  1047. py::class_<Images2Neibs, std::shared_ptr<Images2Neibs>, OpDef> Images2NeibsInst(m, "Images2Neibs");
  1048. Images2NeibsInst
  1049. .def(py::init<uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, std::string>(), py::arg("pad_h") = 0, py::arg("pad_w") = 0, py::arg("stride_h") = 1, py::arg("stride_w") = 1, py::arg("dilate_h") = 1, py::arg("dilate_w") = 1, py::arg("window_h") = 3, py::arg("window_w") = 3, py::arg("scope") = {})
  1050. .def_readwrite("pad_h", &Images2Neibs::pad_h)
  1051. .def_readwrite("pad_w", &Images2Neibs::pad_w)
  1052. .def_readwrite("stride_h", &Images2Neibs::stride_h)
  1053. .def_readwrite("stride_w", &Images2Neibs::stride_w)
  1054. .def_readwrite("dilate_h", &Images2Neibs::dilate_h)
  1055. .def_readwrite("dilate_w", &Images2Neibs::dilate_w)
  1056. .def_readwrite("window_h", &Images2Neibs::window_h)
  1057. .def_readwrite("window_w", &Images2Neibs::window_w);
  1058. py::class_<IncrMeshIndexing, std::shared_ptr<IncrMeshIndexing>, OpDef> IncrMeshIndexingInst(m, "IncrMeshIndexing");
  1059. IncrMeshIndexingInst
  1060. .def(py::init<std::vector<std::tuple<int8_t, bool, bool, bool, bool>>, std::string>(), py::arg("items"), py::arg("scope") = {})
  1061. .def(py::init<>())
  1062. .def_readwrite("items", &IncrMeshIndexing::items);
  1063. py::class_<IncrSubtensor, std::shared_ptr<IncrSubtensor>, OpDef> IncrSubtensorInst(m, "IncrSubtensor");
  1064. IncrSubtensorInst
  1065. .def(py::init<std::vector<std::tuple<int8_t, bool, bool, bool, bool>>, std::string>(), py::arg("items"), py::arg("scope") = {})
  1066. .def(py::init<>())
  1067. .def_readwrite("items", &IncrSubtensor::items);
  1068. py::class_<IndexingIncrMultiAxisVec, std::shared_ptr<IndexingIncrMultiAxisVec>, OpDef> IndexingIncrMultiAxisVecInst(m, "IndexingIncrMultiAxisVec");
  1069. IndexingIncrMultiAxisVecInst
  1070. .def(py::init<std::vector<std::tuple<int8_t, bool, bool, bool, bool>>, std::string>(), py::arg("items"), py::arg("scope") = {})
  1071. .def(py::init<>())
  1072. .def_readwrite("items", &IndexingIncrMultiAxisVec::items);
  1073. py::class_<IndexingMultiAxisVec, std::shared_ptr<IndexingMultiAxisVec>, OpDef> IndexingMultiAxisVecInst(m, "IndexingMultiAxisVec");
  1074. IndexingMultiAxisVecInst
  1075. .def(py::init<std::vector<std::tuple<int8_t, bool, bool, bool, bool>>, std::string>(), py::arg("items"), py::arg("scope") = {})
  1076. .def(py::init<>())
  1077. .def_readwrite("items", &IndexingMultiAxisVec::items);
  1078. py::class_<IndexingOneHot, std::shared_ptr<IndexingOneHot>, OpDef> IndexingOneHotInst(m, "IndexingOneHot");
  1079. IndexingOneHotInst
  1080. .def(py::init<int32_t, int32_t, std::string>(), py::arg("axis") = 0, py::arg("ndim"), py::arg("scope") = {})
  1081. .def(py::init<>())
  1082. .def_readwrite("axis", &IndexingOneHot::axis)
  1083. .def_readwrite("ndim", &IndexingOneHot::ndim);
  1084. py::class_<IndexingSetMultiAxisVec, std::shared_ptr<IndexingSetMultiAxisVec>, OpDef> IndexingSetMultiAxisVecInst(m, "IndexingSetMultiAxisVec");
  1085. IndexingSetMultiAxisVecInst
  1086. .def(py::init<std::vector<std::tuple<int8_t, bool, bool, bool, bool>>, std::string>(), py::arg("items"), py::arg("scope") = {})
  1087. .def(py::init<>())
  1088. .def_readwrite("items", &IndexingSetMultiAxisVec::items);
  1089. py::class_<IndexingSetOneHot, std::shared_ptr<IndexingSetOneHot>, OpDef> IndexingSetOneHotInst(m, "IndexingSetOneHot");
  1090. IndexingSetOneHotInst
  1091. .def(py::init<int32_t, int32_t, std::string>(), py::arg("axis") = 0, py::arg("ndim"), py::arg("scope") = {})
  1092. .def(py::init<>())
  1093. .def_readwrite("axis", &IndexingSetOneHot::axis)
  1094. .def_readwrite("ndim", &IndexingSetOneHot::ndim);
  1095. py::class_<InplaceAdd, std::shared_ptr<InplaceAdd>, OpDef> InplaceAddInst(m, "InplaceAdd");
  1096. InplaceAddInst
  1097. .def(py::init<>());
  1098. py::class_<LAMBUpdate, std::shared_ptr<LAMBUpdate>, OpDef> LAMBUpdateInst(m, "LAMBUpdate");
  1099. LAMBUpdateInst
  1100. .def(py::init<float, float, float, float, float, float, bool, bool, std::string>(), py::arg("beta_1") = 1.f, py::arg("beta_2") = 1.f, py::arg("step") = 1.f, py::arg("lr") = 1.f, py::arg("weight_decay") = 1.f, py::arg("eps") = 1.f, py::arg("bias_correction") = true, py::arg("always_adapt") = false, py::arg("scope") = {})
  1101. .def_readwrite("beta_1", &LAMBUpdate::beta_1)
  1102. .def_readwrite("beta_2", &LAMBUpdate::beta_2)
  1103. .def_readwrite("step", &LAMBUpdate::step)
  1104. .def_readwrite("lr", &LAMBUpdate::lr)
  1105. .def_readwrite("weight_decay", &LAMBUpdate::weight_decay)
  1106. .def_readwrite("eps", &LAMBUpdate::eps)
  1107. .def_readwrite("bias_correction", &LAMBUpdate::bias_correction)
  1108. .def_readwrite("always_adapt", &LAMBUpdate::always_adapt);
  1109. py::class_<LRN, std::shared_ptr<LRN>, OpDef> LRNInst(m, "LRN");
  1110. LRNInst
  1111. .def(py::init<uint32_t, float, float, float, std::string>(), py::arg("n") = 5, py::arg("k") = 2.f, py::arg("alpha") = 1e-4f, py::arg("beta") = 0.75f, py::arg("scope") = {})
  1112. .def_readwrite("n", &LRN::n)
  1113. .def_readwrite("k", &LRN::k)
  1114. .def_readwrite("alpha", &LRN::alpha)
  1115. .def_readwrite("beta", &LRN::beta);
  1116. py::class_<LSQ, std::shared_ptr<LSQ>, OpDef> LSQInst(m, "LSQ");
  1117. LSQInst
  1118. .def(py::init<int32_t, int32_t, std::string>(), py::arg("qmin") = -2147483648, py::arg("qmax") = 2147483647, py::arg("scope") = {})
  1119. .def_readwrite("qmin", &LSQ::qmin)
  1120. .def_readwrite("qmax", &LSQ::qmax);
  1121. py::class_<LSTM, std::shared_ptr<LSTM>, OpDef> LSTMInst(m, "LSTM");
  1122. LSTMInst.attr("FwdMode") = BatchNormInst.attr("FwdMode");
  1123. LSTMInst
  1124. .def(py::init<uint32_t, bool, bool, uint32_t, uint32_t, float, ::megdnn::param::LSTM::FwdMode, std::string>(), py::arg("num_layers") = 1, py::arg("bidirectional") = false, py::arg("bias") = true, py::arg("hidden_size") = 128, py::arg("proj_size") = 0, py::arg("dropout") = 0.f, py::arg("fwd_mode") = ::megdnn::param::LSTM::FwdMode::TRAINING, py::arg("scope") = {})
  1125. .def_readwrite("num_layers", &LSTM::num_layers)
  1126. .def_readwrite("bidirectional", &LSTM::bidirectional)
  1127. .def_readwrite("bias", &LSTM::bias)
  1128. .def_readwrite("hidden_size", &LSTM::hidden_size)
  1129. .def_readwrite("proj_size", &LSTM::proj_size)
  1130. .def_readwrite("dropout", &LSTM::dropout)
  1131. .def_readwrite("fwd_mode", &LSTM::fwd_mode);
  1132. py::class_<LSTMCell, std::shared_ptr<LSTMCell>, OpDef> LSTMCellInst(m, "LSTMCell");
  1133. LSTMCellInst
  1134. .def(py::init<>());
  1135. py::class_<LayerNorm, std::shared_ptr<LayerNorm>, OpDef> LayerNormInst(m, "LayerNorm");
  1136. LayerNormInst
  1137. .def(py::init<bool, float, uint64_t, uint64_t, std::string>(), py::arg("affine") = true, py::arg("eps") = 1e-5f, py::arg("normalized_dim") = 1, py::arg("normalized_size") = 1, py::arg("scope") = {})
  1138. .def_readwrite("affine", &LayerNorm::affine)
  1139. .def_readwrite("eps", &LayerNorm::eps)
  1140. .def_readwrite("normalized_dim", &LayerNorm::normalized_dim)
  1141. .def_readwrite("normalized_size", &LayerNorm::normalized_size);
  1142. py::class_<Linspace, std::shared_ptr<Linspace>, OpDef> LinspaceInst(m, "Linspace");
  1143. LinspaceInst
  1144. .def(py::init<bool, ::mgb::CompNode, std::string>(), py::arg("endpoint") = true, py::arg("comp_node"), py::arg("scope") = {})
  1145. .def(py::init<>())
  1146. .def_readwrite("endpoint", &Linspace::endpoint)
  1147. .def_readwrite("comp_node", &Linspace::comp_node);
  1148. py::class_<MagicMindRuntime, std::shared_ptr<MagicMindRuntime>, OpDef> MagicMindRuntimeInst(m, "MagicMindRuntime");
  1149. MagicMindRuntimeInst
  1150. .def(py::init<std::string, size_t, std::string>(), py::arg("buf"), py::arg("buf_size"), py::arg("scope") = {})
  1151. .def(py::init<>())
  1152. .def_readwrite("buf", &MagicMindRuntime::buf)
  1153. .def_readwrite("buf_size", &MagicMindRuntime::buf_size);
  1154. py::class_<MatrixInverse, std::shared_ptr<MatrixInverse>, OpDef> MatrixInverseInst(m, "MatrixInverse");
  1155. MatrixInverseInst
  1156. .def(py::init<>());
  1157. py::class_<MatrixMul, std::shared_ptr<MatrixMul>, OpDef> MatrixMulInst(m, "MatrixMul");
  1158. MatrixMulInst.attr("ComputeMode") = BatchedMatrixMulInst.attr("ComputeMode");
  1159. MatrixMulInst.attr("Format") = BatchedMatrixMulInst.attr("Format");
  1160. MatrixMulInst.attr("Strategy") = BatchConvBiasInst.attr("Strategy");
  1161. MatrixMulInst
  1162. .def(py::init<bool, bool, ::megdnn::param::MatrixMul::ComputeMode, ::megdnn::param::MatrixMul::Format, ::megdnn::param::ExecutionPolicy::Strategy, uint64_t, uint32_t, uint32_t, std::string>(), py::arg("transposeA") = false, py::arg("transposeB") = false, py::arg("compute_mode") = ::megdnn::param::MatrixMul::ComputeMode::DEFAULT, py::arg("format") = ::megdnn::param::MatrixMul::Format::DEFAULT, py::arg("strategy") = static_cast<::megdnn::param::ExecutionPolicy::Strategy>(1), py::arg("workspace_limit") = 18446744073709551615ull, py::arg("dimA"), py::arg("dimB"), py::arg("scope") = {})
  1163. .def(py::init<>())
  1164. .def_readwrite("transposeA", &MatrixMul::transposeA)
  1165. .def_readwrite("transposeB", &MatrixMul::transposeB)
  1166. .def_readwrite("compute_mode", &MatrixMul::compute_mode)
  1167. .def_readwrite("format", &MatrixMul::format)
  1168. .def_readwrite("strategy", &MatrixMul::strategy)
  1169. .def_readwrite("workspace_limit", &MatrixMul::workspace_limit)
  1170. .def_readwrite("dimA", &MatrixMul::dimA)
  1171. .def_readwrite("dimB", &MatrixMul::dimB);
  1172. py::class_<MeshGrid, std::shared_ptr<MeshGrid>, OpDef> MeshGridInst(m, "MeshGrid");
  1173. MeshGridInst
  1174. .def(py::init<std::string, std::string>(), py::arg("indexing"), py::arg("scope") = {})
  1175. .def(py::init<>())
  1176. .def_readwrite("indexing", &MeshGrid::indexing);
  1177. py::class_<MeshIndexing, std::shared_ptr<MeshIndexing>, OpDef> MeshIndexingInst(m, "MeshIndexing");
  1178. MeshIndexingInst
  1179. .def(py::init<std::vector<std::tuple<int8_t, bool, bool, bool, bool>>, std::string>(), py::arg("items"), py::arg("scope") = {})
  1180. .def(py::init<>())
  1181. .def_readwrite("items", &MeshIndexing::items);
  1182. py::class_<NMSKeep, std::shared_ptr<NMSKeep>, OpDef> NMSKeepInst(m, "NMSKeep");
  1183. NMSKeepInst
  1184. .def(py::init<float, uint32_t, std::string>(), py::arg("iou_thresh"), py::arg("max_output"), py::arg("scope") = {})
  1185. .def(py::init<>())
  1186. .def_readwrite("iou_thresh", &NMSKeep::iou_thresh)
  1187. .def_readwrite("max_output", &NMSKeep::max_output);
  1188. py::class_<NvOf, std::shared_ptr<NvOf>, OpDef> NvOfInst(m, "NvOf");
  1189. NvOfInst
  1190. .def(py::init<uint32_t, std::string>(), py::arg("precision") = 1, py::arg("scope") = {})
  1191. .def_readwrite("precision", &NvOf::precision);
  1192. py::class_<Padding, std::shared_ptr<Padding>, OpDef> PaddingInst(m, "Padding");
  1193. py::enum_<Padding::PaddingMode>(PaddingInst, "PaddingMode")
  1194. .value("REPLICATE", Padding::PaddingMode::REPLICATE)
  1195. .value("REFLECT", Padding::PaddingMode::REFLECT)
  1196. .value("CONSTANT", Padding::PaddingMode::CONSTANT)
  1197. .def(py::init([](const std::string& in) {
  1198. auto&& str = normalize_enum(in);
  1199. if (str == "REPLICATE") return Padding::PaddingMode::REPLICATE;
  1200. if (str == "REFLECT") return Padding::PaddingMode::REFLECT;
  1201. if (str == "CONSTANT") return Padding::PaddingMode::CONSTANT;
  1202. throw py::cast_error("invalid enum value " + in);
  1203. }));
  1204. py::implicitly_convertible<std::string, Padding::PaddingMode>();
  1205. PaddingInst
  1206. .def(py::init<uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, float, ::megdnn::param::Padding::PaddingMode, std::string>(), py::arg("front_offset_dim0") = 0, py::arg("front_offset_dim1") = 0, py::arg("front_offset_dim2") = 0, py::arg("front_offset_dim3") = 0, py::arg("front_offset_dim4") = 0, py::arg("front_offset_dim5") = 0, py::arg("front_offset_dim6") = 0, py::arg("back_offset_dim0") = 0, py::arg("back_offset_dim1") = 0, py::arg("back_offset_dim2") = 0, py::arg("back_offset_dim3") = 0, py::arg("back_offset_dim4") = 0, py::arg("back_offset_dim5") = 0, py::arg("back_offset_dim6") = 0, py::arg("padding_val") = 0, py::arg("padding_mode") = ::megdnn::param::Padding::PaddingMode::CONSTANT, py::arg("scope") = {})
  1207. .def_readwrite("front_offset_dim0", &Padding::front_offset_dim0)
  1208. .def_readwrite("front_offset_dim1", &Padding::front_offset_dim1)
  1209. .def_readwrite("front_offset_dim2", &Padding::front_offset_dim2)
  1210. .def_readwrite("front_offset_dim3", &Padding::front_offset_dim3)
  1211. .def_readwrite("front_offset_dim4", &Padding::front_offset_dim4)
  1212. .def_readwrite("front_offset_dim5", &Padding::front_offset_dim5)
  1213. .def_readwrite("front_offset_dim6", &Padding::front_offset_dim6)
  1214. .def_readwrite("back_offset_dim0", &Padding::back_offset_dim0)
  1215. .def_readwrite("back_offset_dim1", &Padding::back_offset_dim1)
  1216. .def_readwrite("back_offset_dim2", &Padding::back_offset_dim2)
  1217. .def_readwrite("back_offset_dim3", &Padding::back_offset_dim3)
  1218. .def_readwrite("back_offset_dim4", &Padding::back_offset_dim4)
  1219. .def_readwrite("back_offset_dim5", &Padding::back_offset_dim5)
  1220. .def_readwrite("back_offset_dim6", &Padding::back_offset_dim6)
  1221. .def_readwrite("padding_val", &Padding::padding_val)
  1222. .def_readwrite("padding_mode", &Padding::padding_mode);
  1223. py::class_<ParamPackConcat, std::shared_ptr<ParamPackConcat>, OpDef> ParamPackConcatInst(m, "ParamPackConcat");
  1224. ParamPackConcatInst
  1225. .def(py::init<std::vector<int32_t>, std::string>(), py::arg("offsets"), py::arg("scope") = {})
  1226. .def(py::init<>())
  1227. .def_readwrite("offsets", &ParamPackConcat::offsets);
  1228. py::class_<ParamPackSplit, std::shared_ptr<ParamPackSplit>, OpDef> ParamPackSplitInst(m, "ParamPackSplit");
  1229. ParamPackSplitInst
  1230. .def(py::init<std::vector<int32_t>, std::vector<std::vector<size_t>>, std::string>(), py::arg("offsets"), py::arg("shapes"), py::arg("scope") = {})
  1231. .def(py::init<>())
  1232. .def_readwrite("offsets", &ParamPackSplit::offsets)
  1233. .def_readwrite("shapes", &ParamPackSplit::shapes);
  1234. py::class_<PermutationRNG, std::shared_ptr<PermutationRNG>, OpDef> PermutationRNGInst(m, "PermutationRNG");
  1235. PermutationRNGInst
  1236. .def(py::init<uint64_t, ::megdnn::DType, size_t, std::string>(), py::arg("seed") = 0, py::arg("dtype") = megdnn::DType::from_enum(megdnn::DTypeEnum::Int32), py::arg("handle"), py::arg("scope") = {})
  1237. .def(py::init<>())
  1238. .def_readwrite("seed", &PermutationRNG::seed)
  1239. .def_readwrite("dtype", &PermutationRNG::dtype)
  1240. .def_readwrite("handle", &PermutationRNG::handle);
  1241. py::class_<PixelShuffle, std::shared_ptr<PixelShuffle>, OpDef> PixelShuffleInst(m, "PixelShuffle");
  1242. PixelShuffleInst
  1243. .def(py::init<int32_t, std::string>(), py::arg("factor"), py::arg("scope") = {})
  1244. .def(py::init<>())
  1245. .def_readwrite("factor", &PixelShuffle::factor);
  1246. py::class_<PixelShuffleBackward, std::shared_ptr<PixelShuffleBackward>, OpDef> PixelShuffleBackwardInst(m, "PixelShuffleBackward");
  1247. PixelShuffleBackwardInst
  1248. .def(py::init<int32_t, std::string>(), py::arg("factor"), py::arg("scope") = {})
  1249. .def(py::init<>())
  1250. .def_readwrite("factor", &PixelShuffleBackward::factor);
  1251. py::class_<PoissonRNG, std::shared_ptr<PoissonRNG>, OpDef> PoissonRNGInst(m, "PoissonRNG");
  1252. PoissonRNGInst
  1253. .def(py::init<uint64_t, size_t, std::string>(), py::arg("seed") = 0, py::arg("handle"), py::arg("scope") = {})
  1254. .def(py::init<>())
  1255. .def_readwrite("seed", &PoissonRNG::seed)
  1256. .def_readwrite("handle", &PoissonRNG::handle);
  1257. py::class_<Pooling, std::shared_ptr<Pooling>, OpDef> PoolingInst(m, "Pooling");
  1258. PoolingInst.attr("Mode") = AdaptivePoolingInst.attr("Mode");
  1259. PoolingInst.attr("Format") = AdaptivePoolingInst.attr("Format");
  1260. PoolingInst.attr("Strategy") = BatchConvBiasInst.attr("Strategy");
  1261. PoolingInst
  1262. .def(py::init<::megdnn::param::Pooling::Mode, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, ::megdnn::param::Pooling::Format, ::megdnn::param::ExecutionPolicy::Strategy, uint64_t, std::string>(), py::arg("mode") = ::megdnn::param::Pooling::Mode::MAX, py::arg("pad_h") = 0, py::arg("pad_w") = 0, py::arg("stride_h") = 2, py::arg("stride_w") = 2, py::arg("window_h") = 2, py::arg("window_w") = 2, py::arg("format") = ::megdnn::param::Pooling::Format::NCHW, py::arg("strategy") = static_cast<::megdnn::param::ExecutionPolicy::Strategy>(1), py::arg("workspace_limit") = 18446744073709551615ull, py::arg("scope") = {})
  1263. .def_readwrite("mode", &Pooling::mode)
  1264. .def_readwrite("pad_h", &Pooling::pad_h)
  1265. .def_readwrite("pad_w", &Pooling::pad_w)
  1266. .def_readwrite("stride_h", &Pooling::stride_h)
  1267. .def_readwrite("stride_w", &Pooling::stride_w)
  1268. .def_readwrite("window_h", &Pooling::window_h)
  1269. .def_readwrite("window_w", &Pooling::window_w)
  1270. .def_readwrite("format", &Pooling::format)
  1271. .def_readwrite("strategy", &Pooling::strategy)
  1272. .def_readwrite("workspace_limit", &Pooling::workspace_limit);
  1273. py::class_<RNN, std::shared_ptr<RNN>, OpDef> RNNInst(m, "RNN");
  1274. py::enum_<RNN::NonlineMode>(RNNInst, "NonlineMode")
  1275. .value("IDENTITY", RNN::NonlineMode::IDENTITY)
  1276. .value("RELU", RNN::NonlineMode::RELU)
  1277. .value("TANH", RNN::NonlineMode::TANH)
  1278. .def(py::init([](const std::string& in) {
  1279. auto&& str = normalize_enum(in);
  1280. if (str == "IDENTITY") return RNN::NonlineMode::IDENTITY;
  1281. if (str == "RELU") return RNN::NonlineMode::RELU;
  1282. if (str == "TANH") return RNN::NonlineMode::TANH;
  1283. throw py::cast_error("invalid enum value " + in);
  1284. }));
  1285. py::implicitly_convertible<std::string, RNN::NonlineMode>();
  1286. RNNInst.attr("FwdMode") = BatchNormInst.attr("FwdMode");
  1287. RNNInst
  1288. .def(py::init<uint32_t, bool, bool, uint32_t, float, ::megdnn::param::RNN::NonlineMode, ::megdnn::param::RNN::FwdMode, std::string>(), py::arg("num_layers") = 1, py::arg("bidirectional") = false, py::arg("bias") = true, py::arg("hidden_size") = 128, py::arg("dropout") = 0.f, py::arg("nonlineMode") = ::megdnn::param::RNN::NonlineMode::IDENTITY, py::arg("fwd_mode") = ::megdnn::param::RNN::FwdMode::TRAINING, py::arg("scope") = {})
  1289. .def_readwrite("num_layers", &RNN::num_layers)
  1290. .def_readwrite("bidirectional", &RNN::bidirectional)
  1291. .def_readwrite("bias", &RNN::bias)
  1292. .def_readwrite("hidden_size", &RNN::hidden_size)
  1293. .def_readwrite("dropout", &RNN::dropout)
  1294. .def_readwrite("nonlineMode", &RNN::nonlineMode)
  1295. .def_readwrite("fwd_mode", &RNN::fwd_mode);
  1296. py::class_<RNNCell, std::shared_ptr<RNNCell>, OpDef> RNNCellInst(m, "RNNCell");
  1297. RNNCellInst.attr("NonlineMode") = RNNInst.attr("NonlineMode");
  1298. RNNCellInst
  1299. .def(py::init<::megdnn::param::RNNCell::NonlineMode, std::string>(), py::arg("nonlineMode") = ::megdnn::param::RNNCell::NonlineMode::IDENTITY, py::arg("scope") = {})
  1300. .def_readwrite("nonlineMode", &RNNCell::nonlineMode);
  1301. py::class_<ROIAlign, std::shared_ptr<ROIAlign>, OpDef> ROIAlignInst(m, "ROIAlign");
  1302. py::enum_<ROIAlign::Mode>(ROIAlignInst, "Mode")
  1303. .value("MAX", ROIAlign::Mode::MAX)
  1304. .value("AVERAGE", ROIAlign::Mode::AVERAGE)
  1305. .def(py::init([](const std::string& in) {
  1306. auto&& str = normalize_enum(in);
  1307. if (str == "MAX") return ROIAlign::Mode::MAX;
  1308. if (str == "AVERAGE") return ROIAlign::Mode::AVERAGE;
  1309. throw py::cast_error("invalid enum value " + in);
  1310. }));
  1311. py::implicitly_convertible<std::string, ROIAlign::Mode>();
  1312. ROIAlignInst.attr("Format") = AdaptivePoolingInst.attr("Format");
  1313. ROIAlignInst
  1314. .def(py::init<::megdnn::param::ROIAlign::Mode, ::megdnn::param::ROIAlign::Format, float, float, uint32_t, uint32_t, uint32_t, uint32_t, std::string>(), py::arg("mode") = ::megdnn::param::ROIAlign::Mode::MAX, py::arg("format") = ::megdnn::param::ROIAlign::Format::NCHW, py::arg("spatial_scale") = 1.0, py::arg("offset") = 0.0, py::arg("pooled_height") = 1, py::arg("pooled_width") = 1, py::arg("sample_height") = 2, py::arg("sample_width") = 2, py::arg("scope") = {})
  1315. .def_readwrite("mode", &ROIAlign::mode)
  1316. .def_readwrite("format", &ROIAlign::format)
  1317. .def_readwrite("spatial_scale", &ROIAlign::spatial_scale)
  1318. .def_readwrite("offset", &ROIAlign::offset)
  1319. .def_readwrite("pooled_height", &ROIAlign::pooled_height)
  1320. .def_readwrite("pooled_width", &ROIAlign::pooled_width)
  1321. .def_readwrite("sample_height", &ROIAlign::sample_height)
  1322. .def_readwrite("sample_width", &ROIAlign::sample_width);
  1323. py::class_<ROIPooling, std::shared_ptr<ROIPooling>, OpDef> ROIPoolingInst(m, "ROIPooling");
  1324. py::enum_<ROIPooling::Mode>(ROIPoolingInst, "Mode")
  1325. .value("MAX", ROIPooling::Mode::MAX)
  1326. .value("AVERAGE", ROIPooling::Mode::AVERAGE)
  1327. .def(py::init([](const std::string& in) {
  1328. auto&& str = normalize_enum(in);
  1329. if (str == "MAX") return ROIPooling::Mode::MAX;
  1330. if (str == "AVERAGE") return ROIPooling::Mode::AVERAGE;
  1331. throw py::cast_error("invalid enum value " + in);
  1332. }));
  1333. py::implicitly_convertible<std::string, ROIPooling::Mode>();
  1334. ROIPoolingInst
  1335. .def(py::init<::megdnn::param::ROIPooling::Mode, float, std::string>(), py::arg("mode") = ::megdnn::param::ROIPooling::Mode::MAX, py::arg("scale") = 1.f, py::arg("scope") = {})
  1336. .def_readwrite("mode", &ROIPooling::mode)
  1337. .def_readwrite("scale", &ROIPooling::scale);
  1338. py::class_<Reduce, std::shared_ptr<Reduce>, OpDef> ReduceInst(m, "Reduce");
  1339. py::enum_<Reduce::Mode>(ReduceInst, "Mode")
  1340. .value("SUM", Reduce::Mode::SUM)
  1341. .value("SUM_SQR", Reduce::Mode::SUM_SQR)
  1342. .value("PRODUCT", Reduce::Mode::PRODUCT)
  1343. .value("MIN", Reduce::Mode::MIN)
  1344. .value("MAX", Reduce::Mode::MAX)
  1345. .value("MEAN", Reduce::Mode::MEAN)
  1346. .def(py::init([](const std::string& in) {
  1347. auto&& str = normalize_enum(in);
  1348. if (str == "SUM") return Reduce::Mode::SUM;
  1349. if (str == "SUM_SQR") return Reduce::Mode::SUM_SQR;
  1350. if (str == "PRODUCT") return Reduce::Mode::PRODUCT;
  1351. if (str == "MIN") return Reduce::Mode::MIN;
  1352. if (str == "MAX") return Reduce::Mode::MAX;
  1353. if (str == "MEAN") return Reduce::Mode::MEAN;
  1354. throw py::cast_error("invalid enum value " + in);
  1355. }));
  1356. py::implicitly_convertible<std::string, Reduce::Mode>();
  1357. py::enum_<Reduce::DataType>(ReduceInst, "DataType")
  1358. .value("DEFAULT", Reduce::DataType::DEFAULT)
  1359. .value("FLOAT_IO16xC32", Reduce::DataType::FLOAT_IO16xC32)
  1360. .value("FLOAT_O32xC32", Reduce::DataType::FLOAT_O32xC32)
  1361. .value("FLOAT_O16xC32", Reduce::DataType::FLOAT_O16xC32)
  1362. .value("QUINT_I8xO32", Reduce::DataType::QUINT_I8xO32)
  1363. .value("QINT_I8xO32", Reduce::DataType::QINT_I8xO32)
  1364. .def(py::init([](const std::string& in) {
  1365. auto&& str = normalize_enum(in);
  1366. if (str == "DEFAULT") return Reduce::DataType::DEFAULT;
  1367. if (str == "FLOAT_IO16xC32") return Reduce::DataType::FLOAT_IO16xC32;
  1368. if (str == "FLOAT_O32xC32") return Reduce::DataType::FLOAT_O32xC32;
  1369. if (str == "FLOAT_O16xC32") return Reduce::DataType::FLOAT_O16xC32;
  1370. if (str == "QUINT_I8xO32") return Reduce::DataType::QUINT_I8xO32;
  1371. if (str == "QINT_I8xO32") return Reduce::DataType::QINT_I8xO32;
  1372. throw py::cast_error("invalid enum value " + in);
  1373. }));
  1374. py::implicitly_convertible<std::string, Reduce::DataType>();
  1375. ReduceInst
  1376. .def(py::init<::megdnn::param::Reduce::Mode, int32_t, ::megdnn::param::Reduce::DataType, bool, std::string>(), py::arg("mode") = ::megdnn::param::Reduce::Mode::SUM, py::arg("axis") = 2147483647, py::arg("data_type") = ::megdnn::param::Reduce::DataType::DEFAULT, py::arg("keepdim") = true, py::arg("scope") = {})
  1377. .def_readwrite("mode", &Reduce::mode)
  1378. .def_readwrite("axis", &Reduce::axis)
  1379. .def_readwrite("data_type", &Reduce::data_type)
  1380. .def_readwrite("keepdim", &Reduce::keepdim);
  1381. py::class_<RegionRestrictedConvolution, std::shared_ptr<RegionRestrictedConvolution>, OpDef> RegionRestrictedConvolutionInst(m, "RegionRestrictedConvolution");
  1382. RegionRestrictedConvolutionInst.attr("Mode") = BatchConvBiasInst.attr("Mode");
  1383. RegionRestrictedConvolutionInst.attr("Sparse") = BatchConvBiasInst.attr("Sparse");
  1384. RegionRestrictedConvolutionInst.attr("Format") = AdaptivePoolingInst.attr("Format");
  1385. RegionRestrictedConvolutionInst.attr("ComputeMode") = BatchConvBiasInst.attr("ComputeMode");
  1386. RegionRestrictedConvolutionInst
  1387. .def(py::init<::megdnn::param::Convolution::Mode, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, ::megdnn::param::Convolution::Sparse, ::megdnn::param::Convolution::Format, ::megdnn::param::Convolution::ComputeMode, std::string>(), py::arg("mode") = ::megdnn::param::Convolution::Mode::CROSS_CORRELATION, py::arg("pad_h") = 0, py::arg("pad_w") = 0, py::arg("stride_h") = 1, py::arg("stride_w") = 1, py::arg("dilate_h") = 1, py::arg("dilate_w") = 1, py::arg("sparse") = ::megdnn::param::Convolution::Sparse::DENSE, py::arg("format") = ::megdnn::param::Convolution::Format::NCHW, py::arg("compute_mode") = ::megdnn::param::Convolution::ComputeMode::DEFAULT, py::arg("scope") = {})
  1388. .def_readwrite("mode", &RegionRestrictedConvolution::mode)
  1389. .def_readwrite("pad_h", &RegionRestrictedConvolution::pad_h)
  1390. .def_readwrite("pad_w", &RegionRestrictedConvolution::pad_w)
  1391. .def_readwrite("stride_h", &RegionRestrictedConvolution::stride_h)
  1392. .def_readwrite("stride_w", &RegionRestrictedConvolution::stride_w)
  1393. .def_readwrite("dilate_h", &RegionRestrictedConvolution::dilate_h)
  1394. .def_readwrite("dilate_w", &RegionRestrictedConvolution::dilate_w)
  1395. .def_readwrite("sparse", &RegionRestrictedConvolution::sparse)
  1396. .def_readwrite("format", &RegionRestrictedConvolution::format)
  1397. .def_readwrite("compute_mode", &RegionRestrictedConvolution::compute_mode);
  1398. py::class_<RegionRestrictedConvolutionBackwardData, std::shared_ptr<RegionRestrictedConvolutionBackwardData>, OpDef> RegionRestrictedConvolutionBackwardDataInst(m, "RegionRestrictedConvolutionBackwardData");
  1399. RegionRestrictedConvolutionBackwardDataInst.attr("Mode") = BatchConvBiasInst.attr("Mode");
  1400. RegionRestrictedConvolutionBackwardDataInst.attr("Sparse") = BatchConvBiasInst.attr("Sparse");
  1401. RegionRestrictedConvolutionBackwardDataInst.attr("Format") = AdaptivePoolingInst.attr("Format");
  1402. RegionRestrictedConvolutionBackwardDataInst.attr("ComputeMode") = BatchConvBiasInst.attr("ComputeMode");
  1403. RegionRestrictedConvolutionBackwardDataInst
  1404. .def(py::init<::megdnn::param::Convolution::Mode, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, ::megdnn::param::Convolution::Sparse, ::megdnn::param::Convolution::Format, ::megdnn::param::Convolution::ComputeMode, std::string>(), py::arg("mode") = ::megdnn::param::Convolution::Mode::CROSS_CORRELATION, py::arg("pad_h") = 0, py::arg("pad_w") = 0, py::arg("stride_h") = 1, py::arg("stride_w") = 1, py::arg("dilate_h") = 1, py::arg("dilate_w") = 1, py::arg("sparse") = ::megdnn::param::Convolution::Sparse::DENSE, py::arg("format") = ::megdnn::param::Convolution::Format::NCHW, py::arg("compute_mode") = ::megdnn::param::Convolution::ComputeMode::DEFAULT, py::arg("scope") = {})
  1405. .def_readwrite("mode", &RegionRestrictedConvolutionBackwardData::mode)
  1406. .def_readwrite("pad_h", &RegionRestrictedConvolutionBackwardData::pad_h)
  1407. .def_readwrite("pad_w", &RegionRestrictedConvolutionBackwardData::pad_w)
  1408. .def_readwrite("stride_h", &RegionRestrictedConvolutionBackwardData::stride_h)
  1409. .def_readwrite("stride_w", &RegionRestrictedConvolutionBackwardData::stride_w)
  1410. .def_readwrite("dilate_h", &RegionRestrictedConvolutionBackwardData::dilate_h)
  1411. .def_readwrite("dilate_w", &RegionRestrictedConvolutionBackwardData::dilate_w)
  1412. .def_readwrite("sparse", &RegionRestrictedConvolutionBackwardData::sparse)
  1413. .def_readwrite("format", &RegionRestrictedConvolutionBackwardData::format)
  1414. .def_readwrite("compute_mode", &RegionRestrictedConvolutionBackwardData::compute_mode);
  1415. py::class_<Remap, std::shared_ptr<Remap>, OpDef> RemapInst(m, "Remap");
  1416. py::enum_<Remap::InterpolationMode>(RemapInst, "InterpolationMode")
  1417. .value("NEAREST", Remap::InterpolationMode::NEAREST)
  1418. .value("LINEAR", Remap::InterpolationMode::LINEAR)
  1419. .value("AREA", Remap::InterpolationMode::AREA)
  1420. .value("CUBIC", Remap::InterpolationMode::CUBIC)
  1421. .value("LANCZOS4", Remap::InterpolationMode::LANCZOS4)
  1422. .def(py::init([](const std::string& in) {
  1423. auto&& str = normalize_enum(in);
  1424. if (str == "NEAREST") return Remap::InterpolationMode::NEAREST;
  1425. if (str == "LINEAR") return Remap::InterpolationMode::LINEAR;
  1426. if (str == "AREA") return Remap::InterpolationMode::AREA;
  1427. if (str == "CUBIC") return Remap::InterpolationMode::CUBIC;
  1428. if (str == "LANCZOS4") return Remap::InterpolationMode::LANCZOS4;
  1429. throw py::cast_error("invalid enum value " + in);
  1430. }));
  1431. py::implicitly_convertible<std::string, Remap::InterpolationMode>();
  1432. py::enum_<Remap::BorderMode>(RemapInst, "BorderMode")
  1433. .value("REPLICATE", Remap::BorderMode::REPLICATE)
  1434. .value("REFLECT", Remap::BorderMode::REFLECT)
  1435. .value("REFLECT_101", Remap::BorderMode::REFLECT_101)
  1436. .value("WRAP", Remap::BorderMode::WRAP)
  1437. .value("CONSTANT", Remap::BorderMode::CONSTANT)
  1438. .value("TRANSPARENT", Remap::BorderMode::TRANSPARENT)
  1439. .value("ISOLATED", Remap::BorderMode::ISOLATED)
  1440. .def(py::init([](const std::string& in) {
  1441. auto&& str = normalize_enum(in);
  1442. if (str == "REPLICATE") return Remap::BorderMode::REPLICATE;
  1443. if (str == "REFLECT") return Remap::BorderMode::REFLECT;
  1444. if (str == "REFLECT_101") return Remap::BorderMode::REFLECT_101;
  1445. if (str == "WRAP") return Remap::BorderMode::WRAP;
  1446. if (str == "CONSTANT") return Remap::BorderMode::CONSTANT;
  1447. if (str == "TRANSPARENT") return Remap::BorderMode::TRANSPARENT;
  1448. if (str == "ISOLATED") return Remap::BorderMode::ISOLATED;
  1449. throw py::cast_error("invalid enum value " + in);
  1450. }));
  1451. py::implicitly_convertible<std::string, Remap::BorderMode>();
  1452. RemapInst.attr("Format") = AdaptivePoolingInst.attr("Format");
  1453. RemapInst
  1454. .def(py::init<::megdnn::param::Remap::InterpolationMode, ::megdnn::param::Remap::BorderMode, ::megdnn::param::Remap::Format, float, std::string>(), py::arg("imode") = ::megdnn::param::Remap::InterpolationMode::LINEAR, py::arg("border_type") = ::megdnn::param::Remap::BorderMode::REPLICATE, py::arg("format") = ::megdnn::param::Remap::Format::NHWC, py::arg("scalar") = 0.f, py::arg("scope") = {})
  1455. .def_readwrite("imode", &Remap::imode)
  1456. .def_readwrite("border_type", &Remap::border_type)
  1457. .def_readwrite("format", &Remap::format)
  1458. .def_readwrite("scalar", &Remap::scalar);
  1459. py::class_<RemoteRecv, std::shared_ptr<RemoteRecv>, OpDef> RemoteRecvInst(m, "RemoteRecv");
  1460. RemoteRecvInst
  1461. .def(py::init<std::string, std::string, uint32_t, uint32_t, ::mgb::CompNode, std::vector<int32_t>, ::megdnn::DType, std::string, std::string>(), py::arg("key"), py::arg("addr"), py::arg("port"), py::arg("rank_from"), py::arg("cn"), py::arg("shape"), py::arg("dtype"), py::arg("backend"), py::arg("scope") = {})
  1462. .def(py::init<>())
  1463. .def_readwrite("key", &RemoteRecv::key)
  1464. .def_readwrite("addr", &RemoteRecv::addr)
  1465. .def_readwrite("port", &RemoteRecv::port)
  1466. .def_readwrite("rank_from", &RemoteRecv::rank_from)
  1467. .def_readwrite("cn", &RemoteRecv::cn)
  1468. .def_readwrite("shape", &RemoteRecv::shape)
  1469. .def_readwrite("dtype", &RemoteRecv::dtype)
  1470. .def_readwrite("backend", &RemoteRecv::backend);
  1471. py::class_<RemoteSend, std::shared_ptr<RemoteSend>, OpDef> RemoteSendInst(m, "RemoteSend");
  1472. RemoteSendInst
  1473. .def(py::init<std::string, std::string, uint32_t, uint32_t, std::string, std::string>(), py::arg("key"), py::arg("addr"), py::arg("port"), py::arg("rank_to"), py::arg("backend"), py::arg("scope") = {})
  1474. .def(py::init<>())
  1475. .def_readwrite("key", &RemoteSend::key)
  1476. .def_readwrite("addr", &RemoteSend::addr)
  1477. .def_readwrite("port", &RemoteSend::port)
  1478. .def_readwrite("rank_to", &RemoteSend::rank_to)
  1479. .def_readwrite("backend", &RemoteSend::backend);
  1480. py::class_<RemoveAxis, std::shared_ptr<RemoveAxis>, OpDef> RemoveAxisInst(m, "RemoveAxis");
  1481. RemoveAxisInst
  1482. .def(py::init<std::vector<int32_t>, std::string>(), py::arg("axis"), py::arg("scope") = {})
  1483. .def(py::init<>())
  1484. .def_readwrite("axis", &RemoveAxis::axis);
  1485. py::class_<Reshape, std::shared_ptr<Reshape>, OpDef> ReshapeInst(m, "Reshape");
  1486. ReshapeInst
  1487. .def(py::init<int32_t, std::vector<int32_t>, std::string>(), py::arg("axis") = ::megdnn::param::OptionalAxisV1::INVALID_AXIS, py::arg("shape"), py::arg("scope") = {})
  1488. .def(py::init<>())
  1489. .def_readwrite("axis", &Reshape::axis)
  1490. .def_readwrite("shape", &Reshape::shape);
  1491. py::class_<Resize, std::shared_ptr<Resize>, OpDef> ResizeInst(m, "Resize");
  1492. ResizeInst.attr("InterpolationMode") = RemapInst.attr("InterpolationMode");
  1493. ResizeInst.attr("Format") = AdaptivePoolingInst.attr("Format");
  1494. ResizeInst
  1495. .def(py::init<::megdnn::param::Resize::InterpolationMode, ::megdnn::param::Resize::Format, std::string>(), py::arg("imode") = ::megdnn::param::Resize::InterpolationMode::LINEAR, py::arg("format") = ::megdnn::param::Resize::Format::NHWC, py::arg("scope") = {})
  1496. .def_readwrite("imode", &Resize::imode)
  1497. .def_readwrite("format", &Resize::format);
  1498. py::class_<SVD, std::shared_ptr<SVD>, OpDef> SVDInst(m, "SVD");
  1499. SVDInst
  1500. .def(py::init<bool, bool, std::string>(), py::arg("full_matrices") = false, py::arg("compute_uv") = true, py::arg("scope") = {})
  1501. .def_readwrite("full_matrices", &SVD::full_matrices)
  1502. .def_readwrite("compute_uv", &SVD::compute_uv);
  1503. py::class_<SetMeshIndexing, std::shared_ptr<SetMeshIndexing>, OpDef> SetMeshIndexingInst(m, "SetMeshIndexing");
  1504. SetMeshIndexingInst
  1505. .def(py::init<std::vector<std::tuple<int8_t, bool, bool, bool, bool>>, std::string>(), py::arg("items"), py::arg("scope") = {})
  1506. .def(py::init<>())
  1507. .def_readwrite("items", &SetMeshIndexing::items);
  1508. py::class_<SetSubtensor, std::shared_ptr<SetSubtensor>, OpDef> SetSubtensorInst(m, "SetSubtensor");
  1509. SetSubtensorInst
  1510. .def(py::init<std::vector<std::tuple<int8_t, bool, bool, bool, bool>>, std::string>(), py::arg("items"), py::arg("scope") = {})
  1511. .def(py::init<>())
  1512. .def_readwrite("items", &SetSubtensor::items);
  1513. py::class_<ShuffleRNG, std::shared_ptr<ShuffleRNG>, OpDef> ShuffleRNGInst(m, "ShuffleRNG");
  1514. ShuffleRNGInst
  1515. .def(py::init<uint64_t, size_t, std::string>(), py::arg("seed") = 0, py::arg("handle"), py::arg("scope") = {})
  1516. .def(py::init<>())
  1517. .def_readwrite("seed", &ShuffleRNG::seed)
  1518. .def_readwrite("handle", &ShuffleRNG::handle);
  1519. py::class_<SlidingWindowTranspose, std::shared_ptr<SlidingWindowTranspose>, OpDef> SlidingWindowTransposeInst(m, "SlidingWindowTranspose");
  1520. SlidingWindowTransposeInst
  1521. .def(py::init<uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, std::string>(), py::arg("out_h") = 0, py::arg("out_w") = 0, py::arg("pad_h") = 0, py::arg("pad_w") = 0, py::arg("stride_h") = 1, py::arg("stride_w") = 1, py::arg("dilate_h") = 1, py::arg("dilate_w") = 1, py::arg("window_h") = 3, py::arg("window_w") = 3, py::arg("scope") = {})
  1522. .def_readwrite("out_h", &SlidingWindowTranspose::out_h)
  1523. .def_readwrite("out_w", &SlidingWindowTranspose::out_w)
  1524. .def_readwrite("pad_h", &SlidingWindowTranspose::pad_h)
  1525. .def_readwrite("pad_w", &SlidingWindowTranspose::pad_w)
  1526. .def_readwrite("stride_h", &SlidingWindowTranspose::stride_h)
  1527. .def_readwrite("stride_w", &SlidingWindowTranspose::stride_w)
  1528. .def_readwrite("dilate_h", &SlidingWindowTranspose::dilate_h)
  1529. .def_readwrite("dilate_w", &SlidingWindowTranspose::dilate_w)
  1530. .def_readwrite("window_h", &SlidingWindowTranspose::window_h)
  1531. .def_readwrite("window_w", &SlidingWindowTranspose::window_w);
  1532. py::class_<Softmax, std::shared_ptr<Softmax>, OpDef> SoftmaxInst(m, "Softmax");
  1533. SoftmaxInst
  1534. .def(py::init<int32_t, std::string>(), py::arg("axis") = -1, py::arg("scope") = {})
  1535. .def_readwrite("axis", &Softmax::axis);
  1536. py::class_<Split, std::shared_ptr<Split>, OpDef> SplitInst(m, "Split");
  1537. SplitInst
  1538. .def(py::init<int32_t, int32_t, std::string>(), py::arg("axis"), py::arg("nsections"), py::arg("scope") = {})
  1539. .def(py::init<>())
  1540. .def_readwrite("axis", &Split::axis)
  1541. .def_readwrite("nsections", &Split::nsections);
  1542. py::class_<Subtensor, std::shared_ptr<Subtensor>, OpDef> SubtensorInst(m, "Subtensor");
  1543. SubtensorInst
  1544. .def(py::init<std::vector<std::tuple<int8_t, bool, bool, bool, bool>>, std::string>(), py::arg("items"), py::arg("scope") = {})
  1545. .def(py::init<>())
  1546. .def_readwrite("items", &Subtensor::items);
  1547. py::class_<TQT, std::shared_ptr<TQT>, OpDef> TQTInst(m, "TQT");
  1548. TQTInst
  1549. .def(py::init<int32_t, int32_t, std::string>(), py::arg("qmin") = -2147483648, py::arg("qmax") = 2147483647, py::arg("scope") = {})
  1550. .def_readwrite("qmin", &TQT::qmin)
  1551. .def_readwrite("qmax", &TQT::qmax);
  1552. py::class_<TensorRTRuntime, std::shared_ptr<TensorRTRuntime>, OpDef> TensorRTRuntimeInst(m, "TensorRTRuntime");
  1553. TensorRTRuntimeInst
  1554. .def(py::init<std::string, size_t, std::string>(), py::arg("buf"), py::arg("buf_size"), py::arg("scope") = {})
  1555. .def(py::init<>())
  1556. .def_readwrite("buf", &TensorRTRuntime::buf)
  1557. .def_readwrite("buf_size", &TensorRTRuntime::buf_size);
  1558. py::class_<TopK, std::shared_ptr<TopK>, OpDef> TopKInst(m, "TopK");
  1559. py::enum_<TopK::Mode>(TopKInst, "Mode")
  1560. .value("KTH_ONLY", TopK::Mode::KTH_ONLY)
  1561. .value("VALUE_IDX_NOSORT", TopK::Mode::VALUE_IDX_NOSORT)
  1562. .value("VALUE_IDX_SORTED", TopK::Mode::VALUE_IDX_SORTED)
  1563. .def(py::init([](const std::string& in) {
  1564. auto&& str = normalize_enum(in);
  1565. if (str == "KTH_ONLY") return TopK::Mode::KTH_ONLY;
  1566. if (str == "VALUE_IDX_NOSORT") return TopK::Mode::VALUE_IDX_NOSORT;
  1567. if (str == "VALUE_IDX_SORTED") return TopK::Mode::VALUE_IDX_SORTED;
  1568. throw py::cast_error("invalid enum value " + in);
  1569. }));
  1570. py::implicitly_convertible<std::string, TopK::Mode>();
  1571. TopKInst
  1572. .def(py::init<::megdnn::param::TopK::Mode, std::string>(), py::arg("mode") = ::megdnn::param::TopK::Mode::KTH_ONLY, py::arg("scope") = {})
  1573. .def_readwrite("mode", &TopK::mode);
  1574. py::class_<TypeCvt, std::shared_ptr<TypeCvt>, OpDef> TypeCvtInst(m, "TypeCvt");
  1575. TypeCvtInst
  1576. .def(py::init<::megdnn::DType, std::string>(), py::arg("dtype"), py::arg("scope") = {})
  1577. .def(py::init<>())
  1578. .def_readwrite("dtype", &TypeCvt::dtype);
  1579. py::class_<UniformRNG, std::shared_ptr<UniformRNG>, OpDef> UniformRNGInst(m, "UniformRNG");
  1580. UniformRNGInst
  1581. .def(py::init<uint64_t, ::megdnn::DType, size_t, std::string>(), py::arg("seed") = 0, py::arg("dtype") = megdnn::DType::from_enum(megdnn::DTypeEnum::Float32), py::arg("handle"), py::arg("scope") = {})
  1582. .def(py::init<>())
  1583. .def_readwrite("seed", &UniformRNG::seed)
  1584. .def_readwrite("dtype", &UniformRNG::dtype)
  1585. .def_readwrite("handle", &UniformRNG::handle);
  1586. py::class_<WarpAffine, std::shared_ptr<WarpAffine>, OpDef> WarpAffineInst(m, "WarpAffine");
  1587. WarpAffineInst.attr("InterpolationMode") = RemapInst.attr("InterpolationMode");
  1588. WarpAffineInst.attr("BorderMode") = RemapInst.attr("BorderMode");
  1589. WarpAffineInst.attr("Format") = AdaptivePoolingInst.attr("Format");
  1590. WarpAffineInst
  1591. .def(py::init<::megdnn::param::WarpAffine::InterpolationMode, ::megdnn::param::WarpAffine::BorderMode, float, ::megdnn::param::WarpAffine::Format, std::string>(), py::arg("imode") = ::megdnn::param::WarpAffine::InterpolationMode::LINEAR, py::arg("border_mode") = ::megdnn::param::WarpAffine::BorderMode::REPLICATE, py::arg("border_val") = .0f, py::arg("format") = ::megdnn::param::WarpAffine::Format::NHWC, py::arg("scope") = {})
  1592. .def_readwrite("imode", &WarpAffine::imode)
  1593. .def_readwrite("border_mode", &WarpAffine::border_mode)
  1594. .def_readwrite("border_val", &WarpAffine::border_val)
  1595. .def_readwrite("format", &WarpAffine::format);
  1596. py::class_<WarpPerspective, std::shared_ptr<WarpPerspective>, OpDef> WarpPerspectiveInst(m, "WarpPerspective");
  1597. WarpPerspectiveInst.attr("InterpolationMode") = RemapInst.attr("InterpolationMode");
  1598. WarpPerspectiveInst.attr("BorderMode") = RemapInst.attr("BorderMode");
  1599. WarpPerspectiveInst.attr("Format") = AdaptivePoolingInst.attr("Format");
  1600. WarpPerspectiveInst
  1601. .def(py::init<::megdnn::param::WarpPerspective::InterpolationMode, ::megdnn::param::WarpPerspective::BorderMode, ::megdnn::param::WarpPerspective::Format, float, std::string>(), py::arg("imode") = ::megdnn::param::WarpPerspective::InterpolationMode::LINEAR, py::arg("bmode") = ::megdnn::param::WarpPerspective::BorderMode::REPLICATE, py::arg("format") = ::megdnn::param::WarpPerspective::Format::NCHW, py::arg("border_val") = .0f, py::arg("scope") = {})
  1602. .def_readwrite("imode", &WarpPerspective::imode)
  1603. .def_readwrite("bmode", &WarpPerspective::bmode)
  1604. .def_readwrite("format", &WarpPerspective::format)
  1605. .def_readwrite("border_val", &WarpPerspective::border_val);
  1606. py::class_<WarpPerspectiveBackwardData, std::shared_ptr<WarpPerspectiveBackwardData>, OpDef> WarpPerspectiveBackwardDataInst(m, "WarpPerspectiveBackwardData");
  1607. WarpPerspectiveBackwardDataInst.attr("InterpolationMode") = RemapInst.attr("InterpolationMode");
  1608. WarpPerspectiveBackwardDataInst.attr("BorderMode") = RemapInst.attr("BorderMode");
  1609. WarpPerspectiveBackwardDataInst.attr("Format") = AdaptivePoolingInst.attr("Format");
  1610. WarpPerspectiveBackwardDataInst
  1611. .def(py::init<::megdnn::param::WarpPerspective::InterpolationMode, ::megdnn::param::WarpPerspective::BorderMode, ::megdnn::param::WarpPerspective::Format, float, std::string>(), py::arg("imode") = ::megdnn::param::WarpPerspective::InterpolationMode::LINEAR, py::arg("bmode") = ::megdnn::param::WarpPerspective::BorderMode::REPLICATE, py::arg("format") = ::megdnn::param::WarpPerspective::Format::NCHW, py::arg("border_val") = .0f, py::arg("scope") = {})
  1612. .def_readwrite("imode", &WarpPerspectiveBackwardData::imode)
  1613. .def_readwrite("bmode", &WarpPerspectiveBackwardData::bmode)
  1614. .def_readwrite("format", &WarpPerspectiveBackwardData::format)
  1615. .def_readwrite("border_val", &WarpPerspectiveBackwardData::border_val);
  1616. py::class_<WarpPerspectiveBackwardMat, std::shared_ptr<WarpPerspectiveBackwardMat>, OpDef> WarpPerspectiveBackwardMatInst(m, "WarpPerspectiveBackwardMat");
  1617. WarpPerspectiveBackwardMatInst.attr("InterpolationMode") = RemapInst.attr("InterpolationMode");
  1618. WarpPerspectiveBackwardMatInst.attr("BorderMode") = RemapInst.attr("BorderMode");
  1619. WarpPerspectiveBackwardMatInst.attr("Format") = AdaptivePoolingInst.attr("Format");
  1620. WarpPerspectiveBackwardMatInst
  1621. .def(py::init<::megdnn::param::WarpPerspective::InterpolationMode, ::megdnn::param::WarpPerspective::BorderMode, ::megdnn::param::WarpPerspective::Format, float, std::string>(), py::arg("imode") = ::megdnn::param::WarpPerspective::InterpolationMode::LINEAR, py::arg("bmode") = ::megdnn::param::WarpPerspective::BorderMode::REPLICATE, py::arg("format") = ::megdnn::param::WarpPerspective::Format::NCHW, py::arg("border_val") = .0f, py::arg("scope") = {})
  1622. .def_readwrite("imode", &WarpPerspectiveBackwardMat::imode)
  1623. .def_readwrite("bmode", &WarpPerspectiveBackwardMat::bmode)
  1624. .def_readwrite("format", &WarpPerspectiveBackwardMat::format)
  1625. .def_readwrite("border_val", &WarpPerspectiveBackwardMat::border_val);
  1626. // clang-format on