You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

opdef.py.inl 112 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915
  1. // clang-format off
  2. py::class_<AdaptivePooling, std::shared_ptr<AdaptivePooling>, OpDef> AdaptivePoolingInst(m, "AdaptivePooling");
  3. py::enum_<AdaptivePooling::Mode>(AdaptivePoolingInst, "Mode")
  4. .value("MAX", AdaptivePooling::Mode::MAX)
  5. .value("AVERAGE", AdaptivePooling::Mode::AVERAGE)
  6. .value("AVERAGE_COUNT_EXCLUDE_PADDING", AdaptivePooling::Mode::AVERAGE_COUNT_EXCLUDE_PADDING)
  7. .def(py::init([](const std::string& in) {
  8. auto&& str = normalize_enum(in);
  9. if (str == "MAX") return AdaptivePooling::Mode::MAX;
  10. if (str == "AVERAGE") return AdaptivePooling::Mode::AVERAGE;
  11. if (str == "AVERAGE_COUNT_EXCLUDE_PADDING") return AdaptivePooling::Mode::AVERAGE_COUNT_EXCLUDE_PADDING;
  12. throw py::cast_error("invalid enum value " + in);
  13. }));
  14. py::implicitly_convertible<std::string, AdaptivePooling::Mode>();
  15. py::enum_<AdaptivePooling::Format>(AdaptivePoolingInst, "Format")
  16. .value("NCHW", AdaptivePooling::Format::NCHW)
  17. .value("NHWC", AdaptivePooling::Format::NHWC)
  18. .value("NHWCD4", AdaptivePooling::Format::NHWCD4)
  19. .value("NCHW4", AdaptivePooling::Format::NCHW4)
  20. .value("NCHW8", AdaptivePooling::Format::NCHW8)
  21. .value("NCHW32", AdaptivePooling::Format::NCHW32)
  22. .value("NCHW88", AdaptivePooling::Format::NCHW88)
  23. .value("NCHW44", AdaptivePooling::Format::NCHW44)
  24. .value("NCHW44_DOT", AdaptivePooling::Format::NCHW44_DOT)
  25. .value("NCHW4_NCHW32", AdaptivePooling::Format::NCHW4_NCHW32)
  26. .value("NCHW32_NCHW4", AdaptivePooling::Format::NCHW32_NCHW4)
  27. .value("NCHW4_NCHW", AdaptivePooling::Format::NCHW4_NCHW)
  28. .value("NHWC_NCHW", AdaptivePooling::Format::NHWC_NCHW)
  29. .value("NHWC_NCHW4_IC_SMALL", AdaptivePooling::Format::NHWC_NCHW4_IC_SMALL)
  30. .value("NCHW_NCHW4_IC_SMALL", AdaptivePooling::Format::NCHW_NCHW4_IC_SMALL)
  31. .value("CHWN4", AdaptivePooling::Format::CHWN4)
  32. .value("NCHW64", AdaptivePooling::Format::NCHW64)
  33. .value("NCHW4_NHWC", AdaptivePooling::Format::NCHW4_NHWC)
  34. .def(py::init([](const std::string& in) {
  35. auto&& str = normalize_enum(in);
  36. if (str == "NCHW") return AdaptivePooling::Format::NCHW;
  37. if (str == "NHWC") return AdaptivePooling::Format::NHWC;
  38. if (str == "NHWCD4") return AdaptivePooling::Format::NHWCD4;
  39. if (str == "NCHW4") return AdaptivePooling::Format::NCHW4;
  40. if (str == "NCHW8") return AdaptivePooling::Format::NCHW8;
  41. if (str == "NCHW32") return AdaptivePooling::Format::NCHW32;
  42. if (str == "NCHW88") return AdaptivePooling::Format::NCHW88;
  43. if (str == "NCHW44") return AdaptivePooling::Format::NCHW44;
  44. if (str == "NCHW44_DOT") return AdaptivePooling::Format::NCHW44_DOT;
  45. if (str == "NCHW4_NCHW32") return AdaptivePooling::Format::NCHW4_NCHW32;
  46. if (str == "NCHW32_NCHW4") return AdaptivePooling::Format::NCHW32_NCHW4;
  47. if (str == "NCHW4_NCHW") return AdaptivePooling::Format::NCHW4_NCHW;
  48. if (str == "NHWC_NCHW") return AdaptivePooling::Format::NHWC_NCHW;
  49. if (str == "NHWC_NCHW4_IC_SMALL") return AdaptivePooling::Format::NHWC_NCHW4_IC_SMALL;
  50. if (str == "NCHW_NCHW4_IC_SMALL") return AdaptivePooling::Format::NCHW_NCHW4_IC_SMALL;
  51. if (str == "CHWN4") return AdaptivePooling::Format::CHWN4;
  52. if (str == "NCHW64") return AdaptivePooling::Format::NCHW64;
  53. if (str == "NCHW4_NHWC") return AdaptivePooling::Format::NCHW4_NHWC;
  54. throw py::cast_error("invalid enum value " + in);
  55. }));
  56. py::implicitly_convertible<std::string, AdaptivePooling::Format>();
  57. AdaptivePoolingInst
  58. .def(py::init<::megdnn::param::AdaptivePooling::Mode, ::megdnn::param::AdaptivePooling::Format, std::vector<int32_t>, std::string>(), py::arg("mode") = ::megdnn::param::AdaptivePooling::Mode::MAX, py::arg("format") = ::megdnn::param::AdaptivePooling::Format::NCHW, py::arg("shape"), py::arg("scope") = {})
  59. .def(py::init<>())
  60. .def_readwrite("mode", &AdaptivePooling::mode)
  61. .def_readwrite("format", &AdaptivePooling::format)
  62. .def_readwrite("shape", &AdaptivePooling::shape);
  63. py::class_<AddAxis, std::shared_ptr<AddAxis>, OpDef> AddAxisInst(m, "AddAxis");
  64. AddAxisInst
  65. .def(py::init<std::vector<int32_t>, std::string>(), py::arg("axis"), py::arg("scope") = {})
  66. .def(py::init<>())
  67. .def_readwrite("axis", &AddAxis::axis);
  68. py::class_<Argmax, std::shared_ptr<Argmax>, OpDef> ArgmaxInst(m, "Argmax");
  69. ArgmaxInst
  70. .def(py::init<int32_t, std::string>(), py::arg("axis") = 0, py::arg("scope") = {})
  71. .def_readwrite("axis", &Argmax::axis);
  72. py::class_<Argmin, std::shared_ptr<Argmin>, OpDef> ArgminInst(m, "Argmin");
  73. ArgminInst
  74. .def(py::init<int32_t, std::string>(), py::arg("axis") = 0, py::arg("scope") = {})
  75. .def_readwrite("axis", &Argmin::axis);
  76. py::class_<Argsort, std::shared_ptr<Argsort>, OpDef> ArgsortInst(m, "Argsort");
  77. py::enum_<Argsort::Order>(ArgsortInst, "Order")
  78. .value("ASCENDING", Argsort::Order::ASCENDING)
  79. .value("DESCENDING", Argsort::Order::DESCENDING)
  80. .def(py::init([](const std::string& in) {
  81. auto&& str = normalize_enum(in);
  82. if (str == "ASCENDING") return Argsort::Order::ASCENDING;
  83. if (str == "DESCENDING") return Argsort::Order::DESCENDING;
  84. throw py::cast_error("invalid enum value " + in);
  85. }));
  86. py::implicitly_convertible<std::string, Argsort::Order>();
  87. ArgsortInst
  88. .def(py::init<::megdnn::param::Argsort::Order, std::string>(), py::arg("order") = ::megdnn::param::Argsort::Order::ASCENDING, py::arg("scope") = {})
  89. .def_readwrite("order", &Argsort::order);
  90. py::class_<AssertEqual, std::shared_ptr<AssertEqual>, OpDef> AssertEqualInst(m, "AssertEqual");
  91. AssertEqualInst
  92. .def(py::init<float, bool, std::string>(), py::arg("maxerr") = 0.0001, py::arg("verbose") = false, py::arg("scope") = {})
  93. .def_readwrite("maxerr", &AssertEqual::maxerr)
  94. .def_readwrite("verbose", &AssertEqual::verbose);
  95. py::class_<AtlasRuntime, std::shared_ptr<AtlasRuntime>, OpDef> AtlasRuntimeInst(m, "AtlasRuntime");
  96. AtlasRuntimeInst
  97. .def(py::init<std::string, size_t, std::string>(), py::arg("buf"), py::arg("buf_size"), py::arg("scope") = {})
  98. .def(py::init<>())
  99. .def_readwrite("buf", &AtlasRuntime::buf)
  100. .def_readwrite("buf_size", &AtlasRuntime::buf_size);
  101. py::class_<Barrier, std::shared_ptr<Barrier>, OpDef> BarrierInst(m, "Barrier");
  102. BarrierInst
  103. .def(py::init<::mgb::CompNode, uint32_t, std::string>(), py::arg("comp_node"), py::arg("nr_outputs"), py::arg("scope") = {})
  104. .def(py::init<>())
  105. .def_readwrite("comp_node", &Barrier::comp_node)
  106. .def_readwrite("nr_outputs", &Barrier::nr_outputs);
  107. py::class_<BatchConvBias, std::shared_ptr<BatchConvBias>, OpDef> BatchConvBiasInst(m, "BatchConvBias");
  108. py::enum_<BatchConvBias::NonlineMode>(BatchConvBiasInst, "NonlineMode")
  109. .value("IDENTITY", BatchConvBias::NonlineMode::IDENTITY)
  110. .value("RELU", BatchConvBias::NonlineMode::RELU)
  111. .value("SIGMOID", BatchConvBias::NonlineMode::SIGMOID)
  112. .value("H_SWISH", BatchConvBias::NonlineMode::H_SWISH)
  113. .def(py::init([](const std::string& in) {
  114. auto&& str = normalize_enum(in);
  115. if (str == "IDENTITY") return BatchConvBias::NonlineMode::IDENTITY;
  116. if (str == "RELU") return BatchConvBias::NonlineMode::RELU;
  117. if (str == "SIGMOID") return BatchConvBias::NonlineMode::SIGMOID;
  118. if (str == "H_SWISH") return BatchConvBias::NonlineMode::H_SWISH;
  119. throw py::cast_error("invalid enum value " + in);
  120. }));
  121. py::implicitly_convertible<std::string, BatchConvBias::NonlineMode>();
  122. py::enum_<BatchConvBias::Mode>(BatchConvBiasInst, "Mode")
  123. .value("CROSS_CORRELATION", BatchConvBias::Mode::CROSS_CORRELATION)
  124. .value("CONVOLUTION", BatchConvBias::Mode::CONVOLUTION)
  125. .def(py::init([](const std::string& in) {
  126. auto&& str = normalize_enum(in);
  127. if (str == "CROSS_CORRELATION") return BatchConvBias::Mode::CROSS_CORRELATION;
  128. if (str == "CONVOLUTION") return BatchConvBias::Mode::CONVOLUTION;
  129. throw py::cast_error("invalid enum value " + in);
  130. }));
  131. py::implicitly_convertible<std::string, BatchConvBias::Mode>();
  132. py::enum_<BatchConvBias::Sparse>(BatchConvBiasInst, "Sparse")
  133. .value("DENSE", BatchConvBias::Sparse::DENSE)
  134. .value("GROUP", BatchConvBias::Sparse::GROUP)
  135. .def(py::init([](const std::string& in) {
  136. auto&& str = normalize_enum(in);
  137. if (str == "DENSE") return BatchConvBias::Sparse::DENSE;
  138. if (str == "GROUP") return BatchConvBias::Sparse::GROUP;
  139. throw py::cast_error("invalid enum value " + in);
  140. }));
  141. py::implicitly_convertible<std::string, BatchConvBias::Sparse>();
  142. BatchConvBiasInst.attr("Format") = AdaptivePoolingInst.attr("Format");
  143. py::enum_<BatchConvBias::ComputeMode>(BatchConvBiasInst, "ComputeMode")
  144. .value("DEFAULT", BatchConvBias::ComputeMode::DEFAULT)
  145. .value("FLOAT32", BatchConvBias::ComputeMode::FLOAT32)
  146. .def(py::init([](const std::string& in) {
  147. auto&& str = normalize_enum(in);
  148. if (str == "DEFAULT") return BatchConvBias::ComputeMode::DEFAULT;
  149. if (str == "FLOAT32") return BatchConvBias::ComputeMode::FLOAT32;
  150. throw py::cast_error("invalid enum value " + in);
  151. }));
  152. py::implicitly_convertible<std::string, BatchConvBias::ComputeMode>();
  153. py::enum_<BatchConvBias::Strategy>(BatchConvBiasInst, "Strategy")
  154. .value("HEURISTIC", BatchConvBias::Strategy::HEURISTIC)
  155. .value("PROFILE", BatchConvBias::Strategy::PROFILE)
  156. .value("REPRODUCIBLE", BatchConvBias::Strategy::REPRODUCIBLE)
  157. .value("OPTIMIZED", BatchConvBias::Strategy::OPTIMIZED)
  158. .def("__or__", [](BatchConvBias::Strategy s0, BatchConvBias::Strategy s1) {
  159. return static_cast<BatchConvBias::Strategy>(uint32_t(s0) | uint32_t(s1));
  160. })
  161. .def("__and__", [](BatchConvBias::Strategy s0, BatchConvBias::Strategy s1) {
  162. return static_cast<BatchConvBias::Strategy>(uint32_t(s0) & uint32_t(s1));
  163. })
  164. .def(py::init([](const std::string& in) {
  165. auto&& str = normalize_enum(in);
  166. if (str == "HEURISTIC") return BatchConvBias::Strategy::HEURISTIC;
  167. if (str == "PROFILE") return BatchConvBias::Strategy::PROFILE;
  168. if (str == "REPRODUCIBLE") return BatchConvBias::Strategy::REPRODUCIBLE;
  169. if (str == "OPTIMIZED") return BatchConvBias::Strategy::OPTIMIZED;
  170. throw py::cast_error("invalid enum value " + in);
  171. }));
  172. py::implicitly_convertible<std::string, BatchConvBias::Strategy>();
  173. BatchConvBiasInst
  174. .def(py::init<::megdnn::param::BatchConvBias::NonlineMode, ::megdnn::param::BatchConvBias::Mode, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, ::megdnn::param::BatchConvBias::Sparse, ::megdnn::param::BatchConvBias::Format, ::megdnn::param::BatchConvBias::ComputeMode, ::megdnn::param::ExecutionPolicy::Strategy, uint64_t, ::megdnn::DType, std::string>(), py::arg("nonlineMode") = ::megdnn::param::BatchConvBias::NonlineMode::IDENTITY, py::arg("mode") = ::megdnn::param::BatchConvBias::Mode::CROSS_CORRELATION, py::arg("pad_h") = 0, py::arg("pad_w") = 0, py::arg("stride_h") = 1, py::arg("stride_w") = 1, py::arg("dilate_h") = 1, py::arg("dilate_w") = 1, py::arg("sparse") = ::megdnn::param::BatchConvBias::Sparse::DENSE, py::arg("format") = ::megdnn::param::BatchConvBias::Format::NCHW, py::arg("compute_mode") = ::megdnn::param::BatchConvBias::ComputeMode::DEFAULT, py::arg("strategy") = static_cast<::megdnn::param::ExecutionPolicy::Strategy>(1), py::arg("workspace_limit") = 18446744073709551615ull, py::arg("dtype"), py::arg("scope") = {})
  175. .def(py::init<>())
  176. .def_readwrite("nonlineMode", &BatchConvBias::nonlineMode)
  177. .def_readwrite("mode", &BatchConvBias::mode)
  178. .def_readwrite("pad_h", &BatchConvBias::pad_h)
  179. .def_readwrite("pad_w", &BatchConvBias::pad_w)
  180. .def_readwrite("stride_h", &BatchConvBias::stride_h)
  181. .def_readwrite("stride_w", &BatchConvBias::stride_w)
  182. .def_readwrite("dilate_h", &BatchConvBias::dilate_h)
  183. .def_readwrite("dilate_w", &BatchConvBias::dilate_w)
  184. .def_readwrite("sparse", &BatchConvBias::sparse)
  185. .def_readwrite("format", &BatchConvBias::format)
  186. .def_readwrite("compute_mode", &BatchConvBias::compute_mode)
  187. .def_readwrite("strategy", &BatchConvBias::strategy)
  188. .def_readwrite("workspace_limit", &BatchConvBias::workspace_limit)
  189. .def_readwrite("dtype", &BatchConvBias::dtype);
  190. py::class_<BatchNorm, std::shared_ptr<BatchNorm>, OpDef> BatchNormInst(m, "BatchNorm");
  191. py::enum_<BatchNorm::ParamDim>(BatchNormInst, "ParamDim")
  192. .value("DIM_11HW", BatchNorm::ParamDim::DIM_11HW)
  193. .value("DIM_1CHW", BatchNorm::ParamDim::DIM_1CHW)
  194. .value("DIM_1C11", BatchNorm::ParamDim::DIM_1C11)
  195. .value("DIM_111C", BatchNorm::ParamDim::DIM_111C)
  196. .def(py::init([](const std::string& in) {
  197. auto&& str = normalize_enum(in);
  198. if (str == "DIM_11HW") return BatchNorm::ParamDim::DIM_11HW;
  199. if (str == "DIM_1CHW") return BatchNorm::ParamDim::DIM_1CHW;
  200. if (str == "DIM_1C11") return BatchNorm::ParamDim::DIM_1C11;
  201. if (str == "DIM_111C") return BatchNorm::ParamDim::DIM_111C;
  202. throw py::cast_error("invalid enum value " + in);
  203. }));
  204. py::implicitly_convertible<std::string, BatchNorm::ParamDim>();
  205. py::enum_<BatchNorm::FwdMode>(BatchNormInst, "FwdMode")
  206. .value("TRAINING", BatchNorm::FwdMode::TRAINING)
  207. .value("INFERENCE", BatchNorm::FwdMode::INFERENCE)
  208. .def(py::init([](const std::string& in) {
  209. auto&& str = normalize_enum(in);
  210. if (str == "TRAINING") return BatchNorm::FwdMode::TRAINING;
  211. if (str == "INFERENCE") return BatchNorm::FwdMode::INFERENCE;
  212. throw py::cast_error("invalid enum value " + in);
  213. }));
  214. py::implicitly_convertible<std::string, BatchNorm::FwdMode>();
  215. BatchNormInst
  216. .def(py::init<::megdnn::param::BN::ParamDim, ::megdnn::param::BN::FwdMode, double, double, float, float, std::string>(), py::arg("param_dim") = ::megdnn::param::BN::ParamDim::DIM_11HW, py::arg("fwd_mode") = ::megdnn::param::BN::FwdMode::TRAINING, py::arg("epsilon") = 1e-4f, py::arg("avg_factor") = 1.f, py::arg("scale") = 1.f, py::arg("bias") = 0.f, py::arg("scope") = {})
  217. .def_readwrite("param_dim", &BatchNorm::param_dim)
  218. .def_readwrite("fwd_mode", &BatchNorm::fwd_mode)
  219. .def_readwrite("epsilon", &BatchNorm::epsilon)
  220. .def_readwrite("avg_factor", &BatchNorm::avg_factor)
  221. .def_readwrite("scale", &BatchNorm::scale)
  222. .def_readwrite("bias", &BatchNorm::bias);
  223. py::class_<BatchNormBackward, std::shared_ptr<BatchNormBackward>, OpDef> BatchNormBackwardInst(m, "BatchNormBackward");
  224. BatchNormBackwardInst.attr("ParamDim") = BatchNormInst.attr("ParamDim");
  225. BatchNormBackwardInst.attr("FwdMode") = BatchNormInst.attr("FwdMode");
  226. BatchNormBackwardInst
  227. .def(py::init<::megdnn::param::BN::ParamDim, ::megdnn::param::BN::FwdMode, double, double, float, float, std::string>(), py::arg("param_dim") = ::megdnn::param::BN::ParamDim::DIM_11HW, py::arg("fwd_mode") = ::megdnn::param::BN::FwdMode::TRAINING, py::arg("epsilon") = 1e-4f, py::arg("avg_factor") = 1.f, py::arg("scale") = 1.f, py::arg("bias") = 0.f, py::arg("scope") = {})
  228. .def_readwrite("param_dim", &BatchNormBackward::param_dim)
  229. .def_readwrite("fwd_mode", &BatchNormBackward::fwd_mode)
  230. .def_readwrite("epsilon", &BatchNormBackward::epsilon)
  231. .def_readwrite("avg_factor", &BatchNormBackward::avg_factor)
  232. .def_readwrite("scale", &BatchNormBackward::scale)
  233. .def_readwrite("bias", &BatchNormBackward::bias);
  234. py::class_<BatchedIncrMeshIndexing, std::shared_ptr<BatchedIncrMeshIndexing>, OpDef> BatchedIncrMeshIndexingInst(m, "BatchedIncrMeshIndexing");
  235. BatchedIncrMeshIndexingInst
  236. .def(py::init<std::vector<std::tuple<int8_t, bool, bool, bool, bool>>, std::string>(), py::arg("items"), py::arg("scope") = {})
  237. .def(py::init<>())
  238. .def_readwrite("items", &BatchedIncrMeshIndexing::items);
  239. py::class_<BatchedMatrixMul, std::shared_ptr<BatchedMatrixMul>, OpDef> BatchedMatrixMulInst(m, "BatchedMatrixMul");
  240. py::enum_<BatchedMatrixMul::ComputeMode>(BatchedMatrixMulInst, "ComputeMode")
  241. .value("DEFAULT", BatchedMatrixMul::ComputeMode::DEFAULT)
  242. .value("FLOAT32", BatchedMatrixMul::ComputeMode::FLOAT32)
  243. .def(py::init([](const std::string& in) {
  244. auto&& str = normalize_enum(in);
  245. if (str == "DEFAULT") return BatchedMatrixMul::ComputeMode::DEFAULT;
  246. if (str == "FLOAT32") return BatchedMatrixMul::ComputeMode::FLOAT32;
  247. throw py::cast_error("invalid enum value " + in);
  248. }));
  249. py::implicitly_convertible<std::string, BatchedMatrixMul::ComputeMode>();
  250. py::enum_<BatchedMatrixMul::Format>(BatchedMatrixMulInst, "Format")
  251. .value("DEFAULT", BatchedMatrixMul::Format::DEFAULT)
  252. .value("MK4", BatchedMatrixMul::Format::MK4)
  253. .value("MK8", BatchedMatrixMul::Format::MK8)
  254. .value("MK4_DOT", BatchedMatrixMul::Format::MK4_DOT)
  255. .value("N32K4_DOT", BatchedMatrixMul::Format::N32K4_DOT)
  256. .def(py::init([](const std::string& in) {
  257. auto&& str = normalize_enum(in);
  258. if (str == "DEFAULT") return BatchedMatrixMul::Format::DEFAULT;
  259. if (str == "MK4") return BatchedMatrixMul::Format::MK4;
  260. if (str == "MK8") return BatchedMatrixMul::Format::MK8;
  261. if (str == "MK4_DOT") return BatchedMatrixMul::Format::MK4_DOT;
  262. if (str == "N32K4_DOT") return BatchedMatrixMul::Format::N32K4_DOT;
  263. throw py::cast_error("invalid enum value " + in);
  264. }));
  265. py::implicitly_convertible<std::string, BatchedMatrixMul::Format>();
  266. BatchedMatrixMulInst.attr("Strategy") = BatchConvBiasInst.attr("Strategy");
  267. BatchedMatrixMulInst
  268. .def(py::init<bool, bool, ::megdnn::param::MatrixMul::ComputeMode, ::megdnn::param::MatrixMul::Format, ::megdnn::param::ExecutionPolicy::Strategy, uint64_t, uint32_t, uint32_t, std::string>(), py::arg("transposeA") = false, py::arg("transposeB") = false, py::arg("compute_mode") = ::megdnn::param::MatrixMul::ComputeMode::DEFAULT, py::arg("format") = ::megdnn::param::MatrixMul::Format::DEFAULT, py::arg("strategy") = static_cast<::megdnn::param::ExecutionPolicy::Strategy>(1), py::arg("workspace_limit") = 18446744073709551615ull, py::arg("dimA"), py::arg("dimB"), py::arg("scope") = {})
  269. .def(py::init<>())
  270. .def_readwrite("transposeA", &BatchedMatrixMul::transposeA)
  271. .def_readwrite("transposeB", &BatchedMatrixMul::transposeB)
  272. .def_readwrite("compute_mode", &BatchedMatrixMul::compute_mode)
  273. .def_readwrite("format", &BatchedMatrixMul::format)
  274. .def_readwrite("strategy", &BatchedMatrixMul::strategy)
  275. .def_readwrite("workspace_limit", &BatchedMatrixMul::workspace_limit)
  276. .def_readwrite("dimA", &BatchedMatrixMul::dimA)
  277. .def_readwrite("dimB", &BatchedMatrixMul::dimB);
  278. py::class_<BatchedMeshIndexing, std::shared_ptr<BatchedMeshIndexing>, OpDef> BatchedMeshIndexingInst(m, "BatchedMeshIndexing");
  279. BatchedMeshIndexingInst
  280. .def(py::init<std::vector<std::tuple<int8_t, bool, bool, bool, bool>>, std::string>(), py::arg("items"), py::arg("scope") = {})
  281. .def(py::init<>())
  282. .def_readwrite("items", &BatchedMeshIndexing::items);
  283. py::class_<BatchedSetMeshIndexing, std::shared_ptr<BatchedSetMeshIndexing>, OpDef> BatchedSetMeshIndexingInst(m, "BatchedSetMeshIndexing");
  284. BatchedSetMeshIndexingInst
  285. .def(py::init<std::vector<std::tuple<int8_t, bool, bool, bool, bool>>, std::string>(), py::arg("items"), py::arg("scope") = {})
  286. .def(py::init<>())
  287. .def_readwrite("items", &BatchedSetMeshIndexing::items);
  288. py::class_<BetaRNG, std::shared_ptr<BetaRNG>, OpDef> BetaRNGInst(m, "BetaRNG");
  289. BetaRNGInst
  290. .def(py::init<uint64_t, size_t, std::string>(), py::arg("seed") = 0, py::arg("handle"), py::arg("scope") = {})
  291. .def(py::init<>())
  292. .def_readwrite("seed", &BetaRNG::seed)
  293. .def_readwrite("handle", &BetaRNG::handle);
  294. py::class_<Borrow, std::shared_ptr<Borrow>, OpDef> BorrowInst(m, "Borrow");
  295. BorrowInst
  296. .def(py::init<::mgb::CompNode, std::string>(), py::arg("comp_node"), py::arg("scope") = {})
  297. .def(py::init<>())
  298. .def_readwrite("comp_node", &Borrow::comp_node);
  299. py::class_<Broadcast, std::shared_ptr<Broadcast>, OpDef> BroadcastInst(m, "Broadcast");
  300. BroadcastInst
  301. .def(py::init<std::vector<int32_t>, std::string>(), py::arg("shape"), py::arg("scope") = {})
  302. .def(py::init<>())
  303. .def_readwrite("shape", &Broadcast::shape);
  304. py::class_<CambriconRuntime, std::shared_ptr<CambriconRuntime>, OpDef> CambriconRuntimeInst(m, "CambriconRuntime");
  305. CambriconRuntimeInst
  306. .def(py::init<std::string, size_t, std::string, bool, std::string>(), py::arg("buf"), py::arg("buf_size"), py::arg("symbol"), py::arg("tensor_dim_mutable"), py::arg("scope") = {})
  307. .def(py::init<>())
  308. .def_readwrite("buf", &CambriconRuntime::buf)
  309. .def_readwrite("buf_size", &CambriconRuntime::buf_size)
  310. .def_readwrite("symbol", &CambriconRuntime::symbol)
  311. .def_readwrite("tensor_dim_mutable", &CambriconRuntime::tensor_dim_mutable);
  312. py::class_<CheckNonFinite, std::shared_ptr<CheckNonFinite>, OpDef> CheckNonFiniteInst(m, "CheckNonFinite");
  313. CheckNonFiniteInst
  314. .def(py::init<float, std::string>(), py::arg("scale") = 1.0, py::arg("scope") = {})
  315. .def_readwrite("scale", &CheckNonFinite::scale);
  316. py::class_<CollectiveComm, std::shared_ptr<CollectiveComm>, OpDef> CollectiveCommInst(m, "CollectiveComm");
  317. py::enum_<CollectiveComm::Mode>(CollectiveCommInst, "Mode")
  318. .value("REDUCE_SUM", CollectiveComm::Mode::REDUCE_SUM)
  319. .value("BROADCAST", CollectiveComm::Mode::BROADCAST)
  320. .value("ALL_GATHER", CollectiveComm::Mode::ALL_GATHER)
  321. .value("REDUCE_SCATTER_SUM", CollectiveComm::Mode::REDUCE_SCATTER_SUM)
  322. .value("ALL_REDUCE_SUM", CollectiveComm::Mode::ALL_REDUCE_SUM)
  323. .value("ALL_REDUCE_MAX", CollectiveComm::Mode::ALL_REDUCE_MAX)
  324. .value("ALL_REDUCE_MIN", CollectiveComm::Mode::ALL_REDUCE_MIN)
  325. .value("ALL_REDUCE_PROD", CollectiveComm::Mode::ALL_REDUCE_PROD)
  326. .value("GATHER", CollectiveComm::Mode::GATHER)
  327. .value("SCATTER", CollectiveComm::Mode::SCATTER)
  328. .value("ALL_TO_ALL", CollectiveComm::Mode::ALL_TO_ALL)
  329. .def(py::init([](const std::string& in) {
  330. auto&& str = normalize_enum(in);
  331. if (str == "REDUCE_SUM") return CollectiveComm::Mode::REDUCE_SUM;
  332. if (str == "BROADCAST") return CollectiveComm::Mode::BROADCAST;
  333. if (str == "ALL_GATHER") return CollectiveComm::Mode::ALL_GATHER;
  334. if (str == "REDUCE_SCATTER_SUM") return CollectiveComm::Mode::REDUCE_SCATTER_SUM;
  335. if (str == "ALL_REDUCE_SUM") return CollectiveComm::Mode::ALL_REDUCE_SUM;
  336. if (str == "ALL_REDUCE_MAX") return CollectiveComm::Mode::ALL_REDUCE_MAX;
  337. if (str == "ALL_REDUCE_MIN") return CollectiveComm::Mode::ALL_REDUCE_MIN;
  338. if (str == "ALL_REDUCE_PROD") return CollectiveComm::Mode::ALL_REDUCE_PROD;
  339. if (str == "GATHER") return CollectiveComm::Mode::GATHER;
  340. if (str == "SCATTER") return CollectiveComm::Mode::SCATTER;
  341. if (str == "ALL_TO_ALL") return CollectiveComm::Mode::ALL_TO_ALL;
  342. throw py::cast_error("invalid enum value " + in);
  343. }));
  344. py::implicitly_convertible<std::string, CollectiveComm::Mode>();
  345. CollectiveCommInst
  346. .def(py::init<::megdnn::param::CollectiveComm::Mode, std::string, uint32_t, uint32_t, bool, bool, std::string, uint32_t, ::megdnn::DType, std::string, std::string, std::string>(), py::arg("mode") = ::megdnn::param::CollectiveComm::Mode::REDUCE_SUM, py::arg("key"), py::arg("nr_devices"), py::arg("rank"), py::arg("is_root"), py::arg("local_grad"), py::arg("addr"), py::arg("port"), py::arg("dtype"), py::arg("backend"), py::arg("comp_node"), py::arg("scope") = {})
  347. .def(py::init<>())
  348. .def_readwrite("mode", &CollectiveComm::mode)
  349. .def_readwrite("key", &CollectiveComm::key)
  350. .def_readwrite("nr_devices", &CollectiveComm::nr_devices)
  351. .def_readwrite("rank", &CollectiveComm::rank)
  352. .def_readwrite("is_root", &CollectiveComm::is_root)
  353. .def_readwrite("local_grad", &CollectiveComm::local_grad)
  354. .def_readwrite("addr", &CollectiveComm::addr)
  355. .def_readwrite("port", &CollectiveComm::port)
  356. .def_readwrite("dtype", &CollectiveComm::dtype)
  357. .def_readwrite("backend", &CollectiveComm::backend)
  358. .def_readwrite("comp_node", &CollectiveComm::comp_node);
  359. py::class_<Concat, std::shared_ptr<Concat>, OpDef> ConcatInst(m, "Concat");
  360. ConcatInst
  361. .def(py::init<int32_t, ::mgb::CompNode, std::string>(), py::arg("axis") = 0, py::arg("comp_node"), py::arg("scope") = {})
  362. .def(py::init<>())
  363. .def_readwrite("axis", &Concat::axis)
  364. .def_readwrite("comp_node", &Concat::comp_node);
  365. py::class_<CondTake, std::shared_ptr<CondTake>, OpDef> CondTakeInst(m, "CondTake");
  366. CondTakeInst
  367. .def(py::init<>());
  368. py::class_<ConvBias, std::shared_ptr<ConvBias>, OpDef> ConvBiasInst(m, "ConvBias");
  369. ConvBiasInst.attr("NonlineMode") = BatchConvBiasInst.attr("NonlineMode");
  370. ConvBiasInst.attr("Mode") = BatchConvBiasInst.attr("Mode");
  371. ConvBiasInst.attr("Sparse") = BatchConvBiasInst.attr("Sparse");
  372. ConvBiasInst.attr("Format") = AdaptivePoolingInst.attr("Format");
  373. ConvBiasInst.attr("ComputeMode") = BatchConvBiasInst.attr("ComputeMode");
  374. ConvBiasInst.attr("Strategy") = BatchConvBiasInst.attr("Strategy");
  375. ConvBiasInst
  376. .def(py::init<::megdnn::param::ConvBias::NonlineMode, ::megdnn::param::ConvBias::Mode, ::megdnn::param::ConvBias::Sparse, ::megdnn::param::ConvBias::Format, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, ::megdnn::param::ConvBias::ComputeMode, ::megdnn::param::ExecutionPolicy::Strategy, uint64_t, ::megdnn::DType, std::string>(), py::arg("nonlineMode") = ::megdnn::param::ConvBias::NonlineMode::IDENTITY, py::arg("mode") = ::megdnn::param::ConvBias::Mode::CROSS_CORRELATION, py::arg("sparse") = ::megdnn::param::ConvBias::Sparse::DENSE, py::arg("format") = ::megdnn::param::ConvBias::Format::NCHW, py::arg("pad_h") = 0, py::arg("pad_w") = 0, py::arg("stride_h") = 1, py::arg("stride_w") = 1, py::arg("dilate_h") = 1, py::arg("dilate_w") = 1, py::arg("compute_mode") = ::megdnn::param::ConvBias::ComputeMode::DEFAULT, py::arg("strategy") = static_cast<::megdnn::param::ExecutionPolicy::Strategy>(1), py::arg("workspace_limit") = 18446744073709551615ull, py::arg("dtype"), py::arg("scope") = {})
  377. .def(py::init<>())
  378. .def_readwrite("nonlineMode", &ConvBias::nonlineMode)
  379. .def_readwrite("mode", &ConvBias::mode)
  380. .def_readwrite("sparse", &ConvBias::sparse)
  381. .def_readwrite("format", &ConvBias::format)
  382. .def_readwrite("pad_h", &ConvBias::pad_h)
  383. .def_readwrite("pad_w", &ConvBias::pad_w)
  384. .def_readwrite("stride_h", &ConvBias::stride_h)
  385. .def_readwrite("stride_w", &ConvBias::stride_w)
  386. .def_readwrite("dilate_h", &ConvBias::dilate_h)
  387. .def_readwrite("dilate_w", &ConvBias::dilate_w)
  388. .def_readwrite("compute_mode", &ConvBias::compute_mode)
  389. .def_readwrite("strategy", &ConvBias::strategy)
  390. .def_readwrite("workspace_limit", &ConvBias::workspace_limit)
  391. .def_readwrite("dtype", &ConvBias::dtype);
  392. py::class_<Convolution, std::shared_ptr<Convolution>, OpDef> ConvolutionInst(m, "Convolution");
  393. ConvolutionInst.attr("Mode") = BatchConvBiasInst.attr("Mode");
  394. ConvolutionInst.attr("Sparse") = BatchConvBiasInst.attr("Sparse");
  395. ConvolutionInst.attr("Format") = AdaptivePoolingInst.attr("Format");
  396. ConvolutionInst.attr("ComputeMode") = BatchConvBiasInst.attr("ComputeMode");
  397. ConvolutionInst.attr("Strategy") = BatchConvBiasInst.attr("Strategy");
  398. ConvolutionInst
  399. .def(py::init<::megdnn::param::Convolution::Mode, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, ::megdnn::param::Convolution::Sparse, ::megdnn::param::Convolution::Format, ::megdnn::param::Convolution::ComputeMode, ::megdnn::param::ExecutionPolicy::Strategy, uint64_t, std::string>(), py::arg("mode") = ::megdnn::param::Convolution::Mode::CROSS_CORRELATION, py::arg("pad_h") = 0, py::arg("pad_w") = 0, py::arg("stride_h") = 1, py::arg("stride_w") = 1, py::arg("dilate_h") = 1, py::arg("dilate_w") = 1, py::arg("sparse") = ::megdnn::param::Convolution::Sparse::DENSE, py::arg("format") = ::megdnn::param::Convolution::Format::NCHW, py::arg("compute_mode") = ::megdnn::param::Convolution::ComputeMode::DEFAULT, py::arg("strategy") = static_cast<::megdnn::param::ExecutionPolicy::Strategy>(1), py::arg("workspace_limit") = 18446744073709551615ull, py::arg("scope") = {})
  400. .def_readwrite("mode", &Convolution::mode)
  401. .def_readwrite("pad_h", &Convolution::pad_h)
  402. .def_readwrite("pad_w", &Convolution::pad_w)
  403. .def_readwrite("stride_h", &Convolution::stride_h)
  404. .def_readwrite("stride_w", &Convolution::stride_w)
  405. .def_readwrite("dilate_h", &Convolution::dilate_h)
  406. .def_readwrite("dilate_w", &Convolution::dilate_w)
  407. .def_readwrite("sparse", &Convolution::sparse)
  408. .def_readwrite("format", &Convolution::format)
  409. .def_readwrite("compute_mode", &Convolution::compute_mode)
  410. .def_readwrite("strategy", &Convolution::strategy)
  411. .def_readwrite("workspace_limit", &Convolution::workspace_limit);
  412. py::class_<Convolution3D, std::shared_ptr<Convolution3D>, OpDef> Convolution3DInst(m, "Convolution3D");
  413. py::enum_<Convolution3D::Mode>(Convolution3DInst, "Mode")
  414. .value("CROSS_CORRELATION", Convolution3D::Mode::CROSS_CORRELATION)
  415. .value("CONVOLUTION", Convolution3D::Mode::CONVOLUTION)
  416. .def(py::init([](const std::string& in) {
  417. auto&& str = normalize_enum(in);
  418. if (str == "CROSS_CORRELATION") return Convolution3D::Mode::CROSS_CORRELATION;
  419. if (str == "CONVOLUTION") return Convolution3D::Mode::CONVOLUTION;
  420. throw py::cast_error("invalid enum value " + in);
  421. }));
  422. py::implicitly_convertible<std::string, Convolution3D::Mode>();
  423. py::enum_<Convolution3D::Sparse>(Convolution3DInst, "Sparse")
  424. .value("DENSE", Convolution3D::Sparse::DENSE)
  425. .value("GROUP", Convolution3D::Sparse::GROUP)
  426. .def(py::init([](const std::string& in) {
  427. auto&& str = normalize_enum(in);
  428. if (str == "DENSE") return Convolution3D::Sparse::DENSE;
  429. if (str == "GROUP") return Convolution3D::Sparse::GROUP;
  430. throw py::cast_error("invalid enum value " + in);
  431. }));
  432. py::implicitly_convertible<std::string, Convolution3D::Sparse>();
  433. py::enum_<Convolution3D::DataType>(Convolution3DInst, "DataType")
  434. .value("FLOAT", Convolution3D::DataType::FLOAT)
  435. .value("FLOAT_IO16xC32", Convolution3D::DataType::FLOAT_IO16xC32)
  436. .def(py::init([](const std::string& in) {
  437. auto&& str = normalize_enum(in);
  438. if (str == "FLOAT") return Convolution3D::DataType::FLOAT;
  439. if (str == "FLOAT_IO16xC32") return Convolution3D::DataType::FLOAT_IO16xC32;
  440. throw py::cast_error("invalid enum value " + in);
  441. }));
  442. py::implicitly_convertible<std::string, Convolution3D::DataType>();
  443. py::enum_<Convolution3D::Format>(Convolution3DInst, "Format")
  444. .value("NCDHW", Convolution3D::Format::NCDHW)
  445. .value("NDHWC", Convolution3D::Format::NDHWC)
  446. .def(py::init([](const std::string& in) {
  447. auto&& str = normalize_enum(in);
  448. if (str == "NCDHW") return Convolution3D::Format::NCDHW;
  449. if (str == "NDHWC") return Convolution3D::Format::NDHWC;
  450. throw py::cast_error("invalid enum value " + in);
  451. }));
  452. py::implicitly_convertible<std::string, Convolution3D::Format>();
  453. Convolution3DInst.attr("Strategy") = BatchConvBiasInst.attr("Strategy");
  454. Convolution3DInst
  455. .def(py::init<::megdnn::param::Convolution3D::Mode, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, ::megdnn::param::Convolution3D::Sparse, ::megdnn::param::Convolution3D::DataType, ::megdnn::param::Convolution3D::Format, ::megdnn::param::ExecutionPolicy::Strategy, uint64_t, std::string>(), py::arg("mode") = ::megdnn::param::Convolution3D::Mode::CROSS_CORRELATION, py::arg("pad_d") = 0, py::arg("pad_h") = 0, py::arg("pad_w") = 0, py::arg("stride_d") = 1, py::arg("stride_h") = 1, py::arg("stride_w") = 1, py::arg("dilate_d") = 1, py::arg("dilate_h") = 1, py::arg("dilate_w") = 1, py::arg("sparse") = ::megdnn::param::Convolution3D::Sparse::DENSE, py::arg("data_type") = ::megdnn::param::Convolution3D::DataType::FLOAT, py::arg("format") = ::megdnn::param::Convolution3D::Format::NCDHW, py::arg("strategy") = static_cast<::megdnn::param::ExecutionPolicy::Strategy>(1), py::arg("workspace_limit") = 18446744073709551615ull, py::arg("scope") = {})
  456. .def_readwrite("mode", &Convolution3D::mode)
  457. .def_readwrite("pad_d", &Convolution3D::pad_d)
  458. .def_readwrite("pad_h", &Convolution3D::pad_h)
  459. .def_readwrite("pad_w", &Convolution3D::pad_w)
  460. .def_readwrite("stride_d", &Convolution3D::stride_d)
  461. .def_readwrite("stride_h", &Convolution3D::stride_h)
  462. .def_readwrite("stride_w", &Convolution3D::stride_w)
  463. .def_readwrite("dilate_d", &Convolution3D::dilate_d)
  464. .def_readwrite("dilate_h", &Convolution3D::dilate_h)
  465. .def_readwrite("dilate_w", &Convolution3D::dilate_w)
  466. .def_readwrite("sparse", &Convolution3D::sparse)
  467. .def_readwrite("data_type", &Convolution3D::data_type)
  468. .def_readwrite("format", &Convolution3D::format)
  469. .def_readwrite("strategy", &Convolution3D::strategy)
  470. .def_readwrite("workspace_limit", &Convolution3D::workspace_limit);
  471. py::class_<Convolution3DBackwardData, std::shared_ptr<Convolution3DBackwardData>, OpDef> Convolution3DBackwardDataInst(m, "Convolution3DBackwardData");
  472. Convolution3DBackwardDataInst.attr("Mode") = Convolution3DInst.attr("Mode");
  473. Convolution3DBackwardDataInst.attr("Sparse") = Convolution3DInst.attr("Sparse");
  474. Convolution3DBackwardDataInst.attr("DataType") = Convolution3DInst.attr("DataType");
  475. Convolution3DBackwardDataInst.attr("Format") = Convolution3DInst.attr("Format");
  476. Convolution3DBackwardDataInst.attr("Strategy") = BatchConvBiasInst.attr("Strategy");
  477. Convolution3DBackwardDataInst
  478. .def(py::init<::megdnn::param::Convolution3D::Mode, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, ::megdnn::param::Convolution3D::Sparse, ::megdnn::param::Convolution3D::DataType, ::megdnn::param::Convolution3D::Format, ::megdnn::param::ExecutionPolicy::Strategy, uint64_t, std::string>(), py::arg("mode") = ::megdnn::param::Convolution3D::Mode::CROSS_CORRELATION, py::arg("pad_d") = 0, py::arg("pad_h") = 0, py::arg("pad_w") = 0, py::arg("stride_d") = 1, py::arg("stride_h") = 1, py::arg("stride_w") = 1, py::arg("dilate_d") = 1, py::arg("dilate_h") = 1, py::arg("dilate_w") = 1, py::arg("sparse") = ::megdnn::param::Convolution3D::Sparse::DENSE, py::arg("data_type") = ::megdnn::param::Convolution3D::DataType::FLOAT, py::arg("format") = ::megdnn::param::Convolution3D::Format::NCDHW, py::arg("strategy") = static_cast<::megdnn::param::ExecutionPolicy::Strategy>(1), py::arg("workspace_limit") = 18446744073709551615ull, py::arg("scope") = {})
  479. .def_readwrite("mode", &Convolution3DBackwardData::mode)
  480. .def_readwrite("pad_d", &Convolution3DBackwardData::pad_d)
  481. .def_readwrite("pad_h", &Convolution3DBackwardData::pad_h)
  482. .def_readwrite("pad_w", &Convolution3DBackwardData::pad_w)
  483. .def_readwrite("stride_d", &Convolution3DBackwardData::stride_d)
  484. .def_readwrite("stride_h", &Convolution3DBackwardData::stride_h)
  485. .def_readwrite("stride_w", &Convolution3DBackwardData::stride_w)
  486. .def_readwrite("dilate_d", &Convolution3DBackwardData::dilate_d)
  487. .def_readwrite("dilate_h", &Convolution3DBackwardData::dilate_h)
  488. .def_readwrite("dilate_w", &Convolution3DBackwardData::dilate_w)
  489. .def_readwrite("sparse", &Convolution3DBackwardData::sparse)
  490. .def_readwrite("data_type", &Convolution3DBackwardData::data_type)
  491. .def_readwrite("format", &Convolution3DBackwardData::format)
  492. .def_readwrite("strategy", &Convolution3DBackwardData::strategy)
  493. .def_readwrite("workspace_limit", &Convolution3DBackwardData::workspace_limit);
  494. py::class_<ConvolutionBackwardData, std::shared_ptr<ConvolutionBackwardData>, OpDef> ConvolutionBackwardDataInst(m, "ConvolutionBackwardData");
  495. ConvolutionBackwardDataInst.attr("Mode") = BatchConvBiasInst.attr("Mode");
  496. ConvolutionBackwardDataInst.attr("Sparse") = BatchConvBiasInst.attr("Sparse");
  497. ConvolutionBackwardDataInst.attr("Format") = AdaptivePoolingInst.attr("Format");
  498. ConvolutionBackwardDataInst.attr("ComputeMode") = BatchConvBiasInst.attr("ComputeMode");
  499. ConvolutionBackwardDataInst.attr("Strategy") = BatchConvBiasInst.attr("Strategy");
  500. ConvolutionBackwardDataInst
  501. .def(py::init<::megdnn::param::Convolution::Mode, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, ::megdnn::param::Convolution::Sparse, ::megdnn::param::Convolution::Format, ::megdnn::param::Convolution::ComputeMode, ::megdnn::param::ExecutionPolicy::Strategy, uint64_t, ::megdnn::DType, std::string>(), py::arg("mode") = ::megdnn::param::Convolution::Mode::CROSS_CORRELATION, py::arg("pad_h") = 0, py::arg("pad_w") = 0, py::arg("stride_h") = 1, py::arg("stride_w") = 1, py::arg("dilate_h") = 1, py::arg("dilate_w") = 1, py::arg("sparse") = ::megdnn::param::Convolution::Sparse::DENSE, py::arg("format") = ::megdnn::param::Convolution::Format::NCHW, py::arg("compute_mode") = ::megdnn::param::Convolution::ComputeMode::DEFAULT, py::arg("strategy") = static_cast<::megdnn::param::ExecutionPolicy::Strategy>(1), py::arg("workspace_limit") = 18446744073709551615ull, py::arg("dtype"), py::arg("scope") = {})
  502. .def(py::init<>())
  503. .def_readwrite("mode", &ConvolutionBackwardData::mode)
  504. .def_readwrite("pad_h", &ConvolutionBackwardData::pad_h)
  505. .def_readwrite("pad_w", &ConvolutionBackwardData::pad_w)
  506. .def_readwrite("stride_h", &ConvolutionBackwardData::stride_h)
  507. .def_readwrite("stride_w", &ConvolutionBackwardData::stride_w)
  508. .def_readwrite("dilate_h", &ConvolutionBackwardData::dilate_h)
  509. .def_readwrite("dilate_w", &ConvolutionBackwardData::dilate_w)
  510. .def_readwrite("sparse", &ConvolutionBackwardData::sparse)
  511. .def_readwrite("format", &ConvolutionBackwardData::format)
  512. .def_readwrite("compute_mode", &ConvolutionBackwardData::compute_mode)
  513. .def_readwrite("strategy", &ConvolutionBackwardData::strategy)
  514. .def_readwrite("workspace_limit", &ConvolutionBackwardData::workspace_limit)
  515. .def_readwrite("dtype", &ConvolutionBackwardData::dtype);
  516. py::class_<Copy, std::shared_ptr<Copy>, OpDef> CopyInst(m, "Copy");
  517. CopyInst
  518. .def(py::init<::mgb::CompNode, std::string>(), py::arg("comp_node"), py::arg("scope") = {})
  519. .def(py::init<>())
  520. .def_readwrite("comp_node", &Copy::comp_node);
  521. py::class_<Correlation, std::shared_ptr<Correlation>, OpDef> CorrelationInst(m, "Correlation");
  522. py::enum_<Correlation::Format>(CorrelationInst, "Format")
  523. .value("NCHW", Correlation::Format::NCHW)
  524. .value("NHWC", Correlation::Format::NHWC)
  525. .value("NHWCD4", Correlation::Format::NHWCD4)
  526. .value("NCHW4", Correlation::Format::NCHW4)
  527. .value("NCHW8", Correlation::Format::NCHW8)
  528. .value("NCHW32", Correlation::Format::NCHW32)
  529. .value("NCHW88", Correlation::Format::NCHW88)
  530. .value("NCHW44", Correlation::Format::NCHW44)
  531. .value("NCHW44_DOT", Correlation::Format::NCHW44_DOT)
  532. .value("NCHW_WINOGRAD", Correlation::Format::NCHW_WINOGRAD)
  533. .value("NCHW88_WINOGRAD", Correlation::Format::NCHW88_WINOGRAD)
  534. .value("NCHW44_WINOGRAD", Correlation::Format::NCHW44_WINOGRAD)
  535. .value("NCHW4_NCHW32", Correlation::Format::NCHW4_NCHW32)
  536. .value("NCHW32_NCHW4", Correlation::Format::NCHW32_NCHW4)
  537. .value("NCHW4_NCHW", Correlation::Format::NCHW4_NCHW)
  538. .value("NHWC_NCHW", Correlation::Format::NHWC_NCHW)
  539. .value("NHWC_NCHW4_IC_SMALL", Correlation::Format::NHWC_NCHW4_IC_SMALL)
  540. .value("NCHW_NCHW4_IC_SMALL", Correlation::Format::NCHW_NCHW4_IC_SMALL)
  541. .value("CHWN4", Correlation::Format::CHWN4)
  542. .value("NCHW4_NHWC", Correlation::Format::NCHW4_NHWC)
  543. .def(py::init([](const std::string& in) {
  544. auto&& str = normalize_enum(in);
  545. if (str == "NCHW") return Correlation::Format::NCHW;
  546. if (str == "NHWC") return Correlation::Format::NHWC;
  547. if (str == "NHWCD4") return Correlation::Format::NHWCD4;
  548. if (str == "NCHW4") return Correlation::Format::NCHW4;
  549. if (str == "NCHW8") return Correlation::Format::NCHW8;
  550. if (str == "NCHW32") return Correlation::Format::NCHW32;
  551. if (str == "NCHW88") return Correlation::Format::NCHW88;
  552. if (str == "NCHW44") return Correlation::Format::NCHW44;
  553. if (str == "NCHW44_DOT") return Correlation::Format::NCHW44_DOT;
  554. if (str == "NCHW_WINOGRAD") return Correlation::Format::NCHW_WINOGRAD;
  555. if (str == "NCHW88_WINOGRAD") return Correlation::Format::NCHW88_WINOGRAD;
  556. if (str == "NCHW44_WINOGRAD") return Correlation::Format::NCHW44_WINOGRAD;
  557. if (str == "NCHW4_NCHW32") return Correlation::Format::NCHW4_NCHW32;
  558. if (str == "NCHW32_NCHW4") return Correlation::Format::NCHW32_NCHW4;
  559. if (str == "NCHW4_NCHW") return Correlation::Format::NCHW4_NCHW;
  560. if (str == "NHWC_NCHW") return Correlation::Format::NHWC_NCHW;
  561. if (str == "NHWC_NCHW4_IC_SMALL") return Correlation::Format::NHWC_NCHW4_IC_SMALL;
  562. if (str == "NCHW_NCHW4_IC_SMALL") return Correlation::Format::NCHW_NCHW4_IC_SMALL;
  563. if (str == "CHWN4") return Correlation::Format::CHWN4;
  564. if (str == "NCHW4_NHWC") return Correlation::Format::NCHW4_NHWC;
  565. throw py::cast_error("invalid enum value " + in);
  566. }));
  567. py::implicitly_convertible<std::string, Correlation::Format>();
  568. CorrelationInst
  569. .def(py::init<::megdnn::param::Correlation::Format, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, bool, std::string>(), py::arg("format") = ::megdnn::param::Correlation::Format::NCHW, py::arg("kernel_size") = 1, py::arg("max_displacement") = 1, py::arg("stride1") = 1, py::arg("stride2") = 1, py::arg("pad_size") = 0, py::arg("is_multiply") = true, py::arg("scope") = {})
  570. .def_readwrite("format", &Correlation::format)
  571. .def_readwrite("kernel_size", &Correlation::kernel_size)
  572. .def_readwrite("max_displacement", &Correlation::max_displacement)
  573. .def_readwrite("stride1", &Correlation::stride1)
  574. .def_readwrite("stride2", &Correlation::stride2)
  575. .def_readwrite("pad_size", &Correlation::pad_size)
  576. .def_readwrite("is_multiply", &Correlation::is_multiply);
  577. py::class_<Cumprod, std::shared_ptr<Cumprod>, OpDef> CumprodInst(m, "Cumprod");
  578. CumprodInst
  579. .def(py::init<int32_t, bool, bool, std::string>(), py::arg("axis") = 2147483647, py::arg("exclusive") = true, py::arg("reverse") = false, py::arg("scope") = {})
  580. .def_readwrite("axis", &Cumprod::axis)
  581. .def_readwrite("exclusive", &Cumprod::exclusive)
  582. .def_readwrite("reverse", &Cumprod::reverse);
  583. py::class_<Cumsum, std::shared_ptr<Cumsum>, OpDef> CumsumInst(m, "Cumsum");
  584. CumsumInst
  585. .def(py::init<int32_t, bool, bool, std::string>(), py::arg("axis") = 2147483647, py::arg("exclusive") = true, py::arg("reverse") = false, py::arg("scope") = {})
  586. .def_readwrite("axis", &Cumsum::axis)
  587. .def_readwrite("exclusive", &Cumsum::exclusive)
  588. .def_readwrite("reverse", &Cumsum::reverse);
  589. py::class_<CvtColor, std::shared_ptr<CvtColor>, OpDef> CvtColorInst(m, "CvtColor");
  590. py::enum_<CvtColor::Mode>(CvtColorInst, "Mode")
  591. .value("RGB2GRAY", CvtColor::Mode::RGB2GRAY)
  592. .value("RGB2YUV", CvtColor::Mode::RGB2YUV)
  593. .value("YUV2RGB", CvtColor::Mode::YUV2RGB)
  594. .value("GRAY2RGB", CvtColor::Mode::GRAY2RGB)
  595. .value("RGBA2RGB", CvtColor::Mode::RGBA2RGB)
  596. .value("RGBA2BGR", CvtColor::Mode::RGBA2BGR)
  597. .value("RGBA2GRAY", CvtColor::Mode::RGBA2GRAY)
  598. .value("RGB2BGR", CvtColor::Mode::RGB2BGR)
  599. .value("BGR2GRAY", CvtColor::Mode::BGR2GRAY)
  600. .value("BGR2RGB", CvtColor::Mode::BGR2RGB)
  601. .value("YUV2GRAY_NV21", CvtColor::Mode::YUV2GRAY_NV21)
  602. .value("YUV2RGB_NV21", CvtColor::Mode::YUV2RGB_NV21)
  603. .value("YUV2BGR_NV21", CvtColor::Mode::YUV2BGR_NV21)
  604. .value("YUV2GRAY_NV12", CvtColor::Mode::YUV2GRAY_NV12)
  605. .value("YUV2RGB_NV12", CvtColor::Mode::YUV2RGB_NV12)
  606. .value("YUV2BGR_NV12", CvtColor::Mode::YUV2BGR_NV12)
  607. .value("YUV2GRAY_YV12", CvtColor::Mode::YUV2GRAY_YV12)
  608. .value("YUV2RGB_YV12", CvtColor::Mode::YUV2RGB_YV12)
  609. .value("YUV2BGR_YV12", CvtColor::Mode::YUV2BGR_YV12)
  610. .value("YUV2GRAY_YU12", CvtColor::Mode::YUV2GRAY_YU12)
  611. .value("YUV2RGB_YU12", CvtColor::Mode::YUV2RGB_YU12)
  612. .value("YUV2BGR_YU12", CvtColor::Mode::YUV2BGR_YU12)
  613. .value("YCrCb2RGB", CvtColor::Mode::YCrCb2RGB)
  614. .value("YCrCb2BGR", CvtColor::Mode::YCrCb2BGR)
  615. .value("BT601_YUV2RGB_NV21", CvtColor::Mode::BT601_YUV2RGB_NV21)
  616. .value("BT601_YUV2BGR_NV21", CvtColor::Mode::BT601_YUV2BGR_NV21)
  617. .value("BT601_YUV2RGB_NV12", CvtColor::Mode::BT601_YUV2RGB_NV12)
  618. .value("BT601_YUV2BGR_NV12", CvtColor::Mode::BT601_YUV2BGR_NV12)
  619. .value("BT601_YUV2RGB_YV12", CvtColor::Mode::BT601_YUV2RGB_YV12)
  620. .value("BT601_YUV2BGR_YV12", CvtColor::Mode::BT601_YUV2BGR_YV12)
  621. .value("BT601_YUV2RGB_YU12", CvtColor::Mode::BT601_YUV2RGB_YU12)
  622. .value("BT601_YUV2BGR_YU12", CvtColor::Mode::BT601_YUV2BGR_YU12)
  623. .def(py::init([](const std::string& in) {
  624. auto&& str = normalize_enum(in);
  625. if (str == "RGB2GRAY") return CvtColor::Mode::RGB2GRAY;
  626. if (str == "RGB2YUV") return CvtColor::Mode::RGB2YUV;
  627. if (str == "YUV2RGB") return CvtColor::Mode::YUV2RGB;
  628. if (str == "GRAY2RGB") return CvtColor::Mode::GRAY2RGB;
  629. if (str == "RGBA2RGB") return CvtColor::Mode::RGBA2RGB;
  630. if (str == "RGBA2BGR") return CvtColor::Mode::RGBA2BGR;
  631. if (str == "RGBA2GRAY") return CvtColor::Mode::RGBA2GRAY;
  632. if (str == "RGB2BGR") return CvtColor::Mode::RGB2BGR;
  633. if (str == "BGR2GRAY") return CvtColor::Mode::BGR2GRAY;
  634. if (str == "BGR2RGB") return CvtColor::Mode::BGR2RGB;
  635. if (str == "YUV2GRAY_NV21") return CvtColor::Mode::YUV2GRAY_NV21;
  636. if (str == "YUV2RGB_NV21") return CvtColor::Mode::YUV2RGB_NV21;
  637. if (str == "YUV2BGR_NV21") return CvtColor::Mode::YUV2BGR_NV21;
  638. if (str == "YUV2GRAY_NV12") return CvtColor::Mode::YUV2GRAY_NV12;
  639. if (str == "YUV2RGB_NV12") return CvtColor::Mode::YUV2RGB_NV12;
  640. if (str == "YUV2BGR_NV12") return CvtColor::Mode::YUV2BGR_NV12;
  641. if (str == "YUV2GRAY_YV12") return CvtColor::Mode::YUV2GRAY_YV12;
  642. if (str == "YUV2RGB_YV12") return CvtColor::Mode::YUV2RGB_YV12;
  643. if (str == "YUV2BGR_YV12") return CvtColor::Mode::YUV2BGR_YV12;
  644. if (str == "YUV2GRAY_YU12") return CvtColor::Mode::YUV2GRAY_YU12;
  645. if (str == "YUV2RGB_YU12") return CvtColor::Mode::YUV2RGB_YU12;
  646. if (str == "YUV2BGR_YU12") return CvtColor::Mode::YUV2BGR_YU12;
  647. if (str == "YCrCb2RGB") return CvtColor::Mode::YCrCb2RGB;
  648. if (str == "YCrCb2BGR") return CvtColor::Mode::YCrCb2BGR;
  649. if (str == "BT601_YUV2RGB_NV21") return CvtColor::Mode::BT601_YUV2RGB_NV21;
  650. if (str == "BT601_YUV2BGR_NV21") return CvtColor::Mode::BT601_YUV2BGR_NV21;
  651. if (str == "BT601_YUV2RGB_NV12") return CvtColor::Mode::BT601_YUV2RGB_NV12;
  652. if (str == "BT601_YUV2BGR_NV12") return CvtColor::Mode::BT601_YUV2BGR_NV12;
  653. if (str == "BT601_YUV2RGB_YV12") return CvtColor::Mode::BT601_YUV2RGB_YV12;
  654. if (str == "BT601_YUV2BGR_YV12") return CvtColor::Mode::BT601_YUV2BGR_YV12;
  655. if (str == "BT601_YUV2RGB_YU12") return CvtColor::Mode::BT601_YUV2RGB_YU12;
  656. if (str == "BT601_YUV2BGR_YU12") return CvtColor::Mode::BT601_YUV2BGR_YU12;
  657. throw py::cast_error("invalid enum value " + in);
  658. }));
  659. py::implicitly_convertible<std::string, CvtColor::Mode>();
  660. CvtColorInst
  661. .def(py::init<::megdnn::param::CvtColor::Mode, std::string>(), py::arg("mode") = ::megdnn::param::CvtColor::Mode::RGB2GRAY, py::arg("scope") = {})
  662. .def_readwrite("mode", &CvtColor::mode);
  663. py::class_<DeformableConv, std::shared_ptr<DeformableConv>, OpDef> DeformableConvInst(m, "DeformableConv");
  664. DeformableConvInst.attr("Mode") = BatchConvBiasInst.attr("Mode");
  665. DeformableConvInst.attr("Sparse") = BatchConvBiasInst.attr("Sparse");
  666. DeformableConvInst.attr("Format") = AdaptivePoolingInst.attr("Format");
  667. DeformableConvInst.attr("ComputeMode") = BatchConvBiasInst.attr("ComputeMode");
  668. DeformableConvInst.attr("Strategy") = BatchConvBiasInst.attr("Strategy");
  669. DeformableConvInst
  670. .def(py::init<::megdnn::param::Convolution::Mode, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, ::megdnn::param::Convolution::Sparse, ::megdnn::param::Convolution::Format, ::megdnn::param::Convolution::ComputeMode, ::megdnn::param::ExecutionPolicy::Strategy, uint64_t, std::string>(), py::arg("mode") = ::megdnn::param::Convolution::Mode::CROSS_CORRELATION, py::arg("pad_h") = 0, py::arg("pad_w") = 0, py::arg("stride_h") = 1, py::arg("stride_w") = 1, py::arg("dilate_h") = 1, py::arg("dilate_w") = 1, py::arg("sparse") = ::megdnn::param::Convolution::Sparse::DENSE, py::arg("format") = ::megdnn::param::Convolution::Format::NCHW, py::arg("compute_mode") = ::megdnn::param::Convolution::ComputeMode::DEFAULT, py::arg("strategy") = static_cast<::megdnn::param::ExecutionPolicy::Strategy>(1), py::arg("workspace_limit") = 18446744073709551615ull, py::arg("scope") = {})
  671. .def_readwrite("mode", &DeformableConv::mode)
  672. .def_readwrite("pad_h", &DeformableConv::pad_h)
  673. .def_readwrite("pad_w", &DeformableConv::pad_w)
  674. .def_readwrite("stride_h", &DeformableConv::stride_h)
  675. .def_readwrite("stride_w", &DeformableConv::stride_w)
  676. .def_readwrite("dilate_h", &DeformableConv::dilate_h)
  677. .def_readwrite("dilate_w", &DeformableConv::dilate_w)
  678. .def_readwrite("sparse", &DeformableConv::sparse)
  679. .def_readwrite("format", &DeformableConv::format)
  680. .def_readwrite("compute_mode", &DeformableConv::compute_mode)
  681. .def_readwrite("strategy", &DeformableConv::strategy)
  682. .def_readwrite("workspace_limit", &DeformableConv::workspace_limit);
  683. py::class_<DeformablePSROIPooling, std::shared_ptr<DeformablePSROIPooling>, OpDef> DeformablePSROIPoolingInst(m, "DeformablePSROIPooling");
  684. DeformablePSROIPoolingInst
  685. .def(py::init<bool, float, float, uint32_t, uint32_t, uint32_t, uint32_t, std::string>(), py::arg("no_trans") = true, py::arg("spatial_scale") = 1, py::arg("trans_std") = 1, py::arg("pooled_h") = 1, py::arg("pooled_w") = 1, py::arg("part_size") = 1, py::arg("sample_per_part") = 1, py::arg("scope") = {})
  686. .def_readwrite("no_trans", &DeformablePSROIPooling::no_trans)
  687. .def_readwrite("spatial_scale", &DeformablePSROIPooling::spatial_scale)
  688. .def_readwrite("trans_std", &DeformablePSROIPooling::trans_std)
  689. .def_readwrite("pooled_h", &DeformablePSROIPooling::pooled_h)
  690. .def_readwrite("pooled_w", &DeformablePSROIPooling::pooled_w)
  691. .def_readwrite("part_size", &DeformablePSROIPooling::part_size)
  692. .def_readwrite("sample_per_part", &DeformablePSROIPooling::sample_per_part);
  693. py::class_<Diag, std::shared_ptr<Diag>, OpDef> DiagInst(m, "Diag");
  694. DiagInst
  695. .def(py::init<int32_t, std::string>(), py::arg("k") = 0, py::arg("scope") = {})
  696. .def_readwrite("k", &Diag::k);
  697. py::class_<Dimshuffle, std::shared_ptr<Dimshuffle>, OpDef> DimshuffleInst(m, "Dimshuffle");
  698. DimshuffleInst
  699. .def(py::init<std::vector<int32_t>, std::string>(), py::arg("pattern"), py::arg("scope") = {})
  700. .def(py::init<>())
  701. .def_readwrite("pattern", &Dimshuffle::pattern);
  702. py::class_<Dot, std::shared_ptr<Dot>, OpDef> DotInst(m, "Dot");
  703. DotInst
  704. .def(py::init<>());
  705. py::class_<Dropout, std::shared_ptr<Dropout>, OpDef> DropoutInst(m, "Dropout");
  706. DropoutInst
  707. .def(py::init<float, uint64_t, size_t, std::string>(), py::arg("drop_prob") = 0, py::arg("seed") = 0, py::arg("handle"), py::arg("scope") = {})
  708. .def(py::init<>())
  709. .def_readwrite("drop_prob", &Dropout::drop_prob)
  710. .def_readwrite("seed", &Dropout::seed)
  711. .def_readwrite("handle", &Dropout::handle);
  712. py::class_<Elemwise, std::shared_ptr<Elemwise>, OpDef> ElemwiseInst(m, "Elemwise");
  713. py::enum_<Elemwise::Mode>(ElemwiseInst, "Mode")
  714. .value("RELU", Elemwise::Mode::RELU)
  715. .value("ABS", Elemwise::Mode::ABS)
  716. .value("ACOS", Elemwise::Mode::ACOS)
  717. .value("ASIN", Elemwise::Mode::ASIN)
  718. .value("CEIL", Elemwise::Mode::CEIL)
  719. .value("COS", Elemwise::Mode::COS)
  720. .value("EXP", Elemwise::Mode::EXP)
  721. .value("EXPM1", Elemwise::Mode::EXPM1)
  722. .value("FLOOR", Elemwise::Mode::FLOOR)
  723. .value("LOG", Elemwise::Mode::LOG)
  724. .value("LOG1P", Elemwise::Mode::LOG1P)
  725. .value("NEGATE", Elemwise::Mode::NEGATE)
  726. .value("SIGMOID", Elemwise::Mode::SIGMOID)
  727. .value("SIN", Elemwise::Mode::SIN)
  728. .value("TANH", Elemwise::Mode::TANH)
  729. .value("ABS_GRAD", Elemwise::Mode::ABS_GRAD)
  730. .value("ADD", Elemwise::Mode::ADD)
  731. .value("FLOOR_DIV", Elemwise::Mode::FLOOR_DIV)
  732. .value("MAX", Elemwise::Mode::MAX)
  733. .value("MIN", Elemwise::Mode::MIN)
  734. .value("MOD", Elemwise::Mode::MOD)
  735. .value("MUL", Elemwise::Mode::MUL)
  736. .value("POW", Elemwise::Mode::POW)
  737. .value("SIGMOID_GRAD", Elemwise::Mode::SIGMOID_GRAD)
  738. .value("SUB", Elemwise::Mode::SUB)
  739. .value("SWITCH_GT0", Elemwise::Mode::SWITCH_GT0)
  740. .value("TANH_GRAD", Elemwise::Mode::TANH_GRAD)
  741. .value("TRUE_DIV", Elemwise::Mode::TRUE_DIV)
  742. .value("LOG_SUM_EXP", Elemwise::Mode::LOG_SUM_EXP)
  743. .value("LT", Elemwise::Mode::LT)
  744. .value("LEQ", Elemwise::Mode::LEQ)
  745. .value("EQ", Elemwise::Mode::EQ)
  746. .value("SHL", Elemwise::Mode::SHL)
  747. .value("SHR", Elemwise::Mode::SHR)
  748. .value("COND_LEQ_MOV", Elemwise::Mode::COND_LEQ_MOV)
  749. .value("FUSE_MUL_ADD3", Elemwise::Mode::FUSE_MUL_ADD3)
  750. .value("FUSE_MUL_ADD4", Elemwise::Mode::FUSE_MUL_ADD4)
  751. .value("FUSE_ADD_RELU", Elemwise::Mode::FUSE_ADD_RELU)
  752. .value("FUSE_ADD_SIGMOID", Elemwise::Mode::FUSE_ADD_SIGMOID)
  753. .value("FUSE_ADD_TANH", Elemwise::Mode::FUSE_ADD_TANH)
  754. .value("FAST_TANH", Elemwise::Mode::FAST_TANH)
  755. .value("FAST_TANH_GRAD", Elemwise::Mode::FAST_TANH_GRAD)
  756. .value("ROUND", Elemwise::Mode::ROUND)
  757. .value("RMULH", Elemwise::Mode::RMULH)
  758. .value("ATAN2", Elemwise::Mode::ATAN2)
  759. .value("ERF", Elemwise::Mode::ERF)
  760. .value("ERFINV", Elemwise::Mode::ERFINV)
  761. .value("ERFC", Elemwise::Mode::ERFC)
  762. .value("ERFCINV", Elemwise::Mode::ERFCINV)
  763. .value("H_SWISH", Elemwise::Mode::H_SWISH)
  764. .value("H_SWISH_GRAD", Elemwise::Mode::H_SWISH_GRAD)
  765. .value("FUSE_ADD_H_SWISH", Elemwise::Mode::FUSE_ADD_H_SWISH)
  766. .value("NOT", Elemwise::Mode::NOT)
  767. .value("AND", Elemwise::Mode::AND)
  768. .value("OR", Elemwise::Mode::OR)
  769. .value("XOR", Elemwise::Mode::XOR)
  770. .value("SILU", Elemwise::Mode::SILU)
  771. .value("SILU_GRAD", Elemwise::Mode::SILU_GRAD)
  772. .value("GELU", Elemwise::Mode::GELU)
  773. .value("GELU_GRAD", Elemwise::Mode::GELU_GRAD)
  774. .value("COND_LT_MOV", Elemwise::Mode::COND_LT_MOV)
  775. .value("SINH", Elemwise::Mode::SINH)
  776. .value("COSH", Elemwise::Mode::COSH)
  777. .value("ASINH", Elemwise::Mode::ASINH)
  778. .value("ACOSH", Elemwise::Mode::ACOSH)
  779. .value("ATANH", Elemwise::Mode::ATANH)
  780. .value("TAN", Elemwise::Mode::TAN)
  781. .value("ASINH_GRAD", Elemwise::Mode::ASINH_GRAD)
  782. .value("ACOSH_GRAD", Elemwise::Mode::ACOSH_GRAD)
  783. .value("ATANH_GRAD", Elemwise::Mode::ATANH_GRAD)
  784. .value("PRELU", Elemwise::Mode::PRELU)
  785. .value("CLIP", Elemwise::Mode::CLIP)
  786. .value("PRELU_GRAD", Elemwise::Mode::PRELU_GRAD)
  787. .value("SOFTPLUS", Elemwise::Mode::SOFTPLUS)
  788. .value("SOFTPLUS_GRAD", Elemwise::Mode::SOFTPLUS_GRAD)
  789. .value("RELU6", Elemwise::Mode::RELU6)
  790. .value("RELU6_GRAD", Elemwise::Mode::RELU6_GRAD)
  791. .value("HSIGMOID", Elemwise::Mode::HSIGMOID)
  792. .value("HSIGMOID_GRAD", Elemwise::Mode::HSIGMOID_GRAD)
  793. .value("LOGSIGMOID", Elemwise::Mode::LOGSIGMOID)
  794. .value("SQRT", Elemwise::Mode::SQRT)
  795. .value("SQUARE", Elemwise::Mode::SQUARE)
  796. .value("SIGN", Elemwise::Mode::SIGN)
  797. .value("SAFE_DIV", Elemwise::Mode::SAFE_DIV)
  798. .value("NEQ", Elemwise::Mode::NEQ)
  799. .value("ISNAN", Elemwise::Mode::ISNAN)
  800. .value("ISINF", Elemwise::Mode::ISINF)
  801. .def(py::init([](const std::string& in) {
  802. auto&& str = normalize_enum(in);
  803. if (str == "RELU") return Elemwise::Mode::RELU;
  804. if (str == "ABS") return Elemwise::Mode::ABS;
  805. if (str == "ACOS") return Elemwise::Mode::ACOS;
  806. if (str == "ASIN") return Elemwise::Mode::ASIN;
  807. if (str == "CEIL") return Elemwise::Mode::CEIL;
  808. if (str == "COS") return Elemwise::Mode::COS;
  809. if (str == "EXP") return Elemwise::Mode::EXP;
  810. if (str == "EXPM1") return Elemwise::Mode::EXPM1;
  811. if (str == "FLOOR") return Elemwise::Mode::FLOOR;
  812. if (str == "LOG") return Elemwise::Mode::LOG;
  813. if (str == "LOG1P") return Elemwise::Mode::LOG1P;
  814. if (str == "NEGATE") return Elemwise::Mode::NEGATE;
  815. if (str == "SIGMOID") return Elemwise::Mode::SIGMOID;
  816. if (str == "SIN") return Elemwise::Mode::SIN;
  817. if (str == "TANH") return Elemwise::Mode::TANH;
  818. if (str == "ABS_GRAD") return Elemwise::Mode::ABS_GRAD;
  819. if (str == "ADD") return Elemwise::Mode::ADD;
  820. if (str == "FLOOR_DIV") return Elemwise::Mode::FLOOR_DIV;
  821. if (str == "MAX") return Elemwise::Mode::MAX;
  822. if (str == "MIN") return Elemwise::Mode::MIN;
  823. if (str == "MOD") return Elemwise::Mode::MOD;
  824. if (str == "MUL") return Elemwise::Mode::MUL;
  825. if (str == "POW") return Elemwise::Mode::POW;
  826. if (str == "SIGMOID_GRAD") return Elemwise::Mode::SIGMOID_GRAD;
  827. if (str == "SUB") return Elemwise::Mode::SUB;
  828. if (str == "SWITCH_GT0") return Elemwise::Mode::SWITCH_GT0;
  829. if (str == "TANH_GRAD") return Elemwise::Mode::TANH_GRAD;
  830. if (str == "TRUE_DIV") return Elemwise::Mode::TRUE_DIV;
  831. if (str == "LOG_SUM_EXP") return Elemwise::Mode::LOG_SUM_EXP;
  832. if (str == "LT") return Elemwise::Mode::LT;
  833. if (str == "LEQ") return Elemwise::Mode::LEQ;
  834. if (str == "EQ") return Elemwise::Mode::EQ;
  835. if (str == "SHL") return Elemwise::Mode::SHL;
  836. if (str == "SHR") return Elemwise::Mode::SHR;
  837. if (str == "COND_LEQ_MOV") return Elemwise::Mode::COND_LEQ_MOV;
  838. if (str == "FUSE_MUL_ADD3") return Elemwise::Mode::FUSE_MUL_ADD3;
  839. if (str == "FUSE_MUL_ADD4") return Elemwise::Mode::FUSE_MUL_ADD4;
  840. if (str == "FUSE_ADD_RELU") return Elemwise::Mode::FUSE_ADD_RELU;
  841. if (str == "FUSE_ADD_SIGMOID") return Elemwise::Mode::FUSE_ADD_SIGMOID;
  842. if (str == "FUSE_ADD_TANH") return Elemwise::Mode::FUSE_ADD_TANH;
  843. if (str == "FAST_TANH") return Elemwise::Mode::FAST_TANH;
  844. if (str == "FAST_TANH_GRAD") return Elemwise::Mode::FAST_TANH_GRAD;
  845. if (str == "ROUND") return Elemwise::Mode::ROUND;
  846. if (str == "RMULH") return Elemwise::Mode::RMULH;
  847. if (str == "ATAN2") return Elemwise::Mode::ATAN2;
  848. if (str == "ERF") return Elemwise::Mode::ERF;
  849. if (str == "ERFINV") return Elemwise::Mode::ERFINV;
  850. if (str == "ERFC") return Elemwise::Mode::ERFC;
  851. if (str == "ERFCINV") return Elemwise::Mode::ERFCINV;
  852. if (str == "H_SWISH") return Elemwise::Mode::H_SWISH;
  853. if (str == "H_SWISH_GRAD") return Elemwise::Mode::H_SWISH_GRAD;
  854. if (str == "FUSE_ADD_H_SWISH") return Elemwise::Mode::FUSE_ADD_H_SWISH;
  855. if (str == "NOT") return Elemwise::Mode::NOT;
  856. if (str == "AND") return Elemwise::Mode::AND;
  857. if (str == "OR") return Elemwise::Mode::OR;
  858. if (str == "XOR") return Elemwise::Mode::XOR;
  859. if (str == "SILU") return Elemwise::Mode::SILU;
  860. if (str == "SILU_GRAD") return Elemwise::Mode::SILU_GRAD;
  861. if (str == "GELU") return Elemwise::Mode::GELU;
  862. if (str == "GELU_GRAD") return Elemwise::Mode::GELU_GRAD;
  863. if (str == "COND_LT_MOV") return Elemwise::Mode::COND_LT_MOV;
  864. if (str == "SINH") return Elemwise::Mode::SINH;
  865. if (str == "COSH") return Elemwise::Mode::COSH;
  866. if (str == "ASINH") return Elemwise::Mode::ASINH;
  867. if (str == "ACOSH") return Elemwise::Mode::ACOSH;
  868. if (str == "ATANH") return Elemwise::Mode::ATANH;
  869. if (str == "TAN") return Elemwise::Mode::TAN;
  870. if (str == "ASINH_GRAD") return Elemwise::Mode::ASINH_GRAD;
  871. if (str == "ACOSH_GRAD") return Elemwise::Mode::ACOSH_GRAD;
  872. if (str == "ATANH_GRAD") return Elemwise::Mode::ATANH_GRAD;
  873. if (str == "PRELU") return Elemwise::Mode::PRELU;
  874. if (str == "CLIP") return Elemwise::Mode::CLIP;
  875. if (str == "PRELU_GRAD") return Elemwise::Mode::PRELU_GRAD;
  876. if (str == "SOFTPLUS") return Elemwise::Mode::SOFTPLUS;
  877. if (str == "SOFTPLUS_GRAD") return Elemwise::Mode::SOFTPLUS_GRAD;
  878. if (str == "RELU6") return Elemwise::Mode::RELU6;
  879. if (str == "RELU6_GRAD") return Elemwise::Mode::RELU6_GRAD;
  880. if (str == "HSIGMOID") return Elemwise::Mode::HSIGMOID;
  881. if (str == "HSIGMOID_GRAD") return Elemwise::Mode::HSIGMOID_GRAD;
  882. if (str == "LOGSIGMOID") return Elemwise::Mode::LOGSIGMOID;
  883. if (str == "SQRT") return Elemwise::Mode::SQRT;
  884. if (str == "SQUARE") return Elemwise::Mode::SQUARE;
  885. if (str == "SIGN") return Elemwise::Mode::SIGN;
  886. if (str == "SAFE_DIV") return Elemwise::Mode::SAFE_DIV;
  887. if (str == "NEQ") return Elemwise::Mode::NEQ;
  888. if (str == "ISNAN") return Elemwise::Mode::ISNAN;
  889. if (str == "ISINF") return Elemwise::Mode::ISINF;
  890. throw py::cast_error("invalid enum value " + in);
  891. }));
  892. py::implicitly_convertible<std::string, Elemwise::Mode>();
  893. ElemwiseInst
  894. .def(py::init<::megdnn::param::Elemwise::Mode, std::string>(), py::arg("mode") = ::megdnn::param::Elemwise::Mode::RELU, py::arg("scope") = {})
  895. .def_readwrite("mode", &Elemwise::mode);
  896. py::class_<ElemwiseMultiType, std::shared_ptr<ElemwiseMultiType>, OpDef> ElemwiseMultiTypeInst(m, "ElemwiseMultiType");
  897. py::enum_<ElemwiseMultiType::Mode>(ElemwiseMultiTypeInst, "Mode")
  898. .value("FUSE_MUL_ADD3_INT16x32x32x32", ElemwiseMultiType::Mode::FUSE_MUL_ADD3_INT16x32x32x32)
  899. .value("FUSE_MUL_ADD3_IXxF32xF32xI8", ElemwiseMultiType::Mode::FUSE_MUL_ADD3_IXxF32xF32xI8)
  900. .value("ROUND_SHR_SATURATE_IXxI8xI8", ElemwiseMultiType::Mode::ROUND_SHR_SATURATE_IXxI8xI8)
  901. .value("FUSE_ADD_RMULH_ROUND_SHR_SATURATE_INT16x16x16x8", ElemwiseMultiType::Mode::FUSE_ADD_RMULH_ROUND_SHR_SATURATE_INT16x16x16x8)
  902. .value("FUSE_ADD_RMULH_ROUND_SHR_SATURATE_INT32x32x32x8", ElemwiseMultiType::Mode::FUSE_ADD_RMULH_ROUND_SHR_SATURATE_INT32x32x32x8)
  903. .value("ROUND_SHR_SATURATE_IXxI8xI16", ElemwiseMultiType::Mode::ROUND_SHR_SATURATE_IXxI8xI16)
  904. .value("QADD", ElemwiseMultiType::Mode::QADD)
  905. .value("QFUSE_ADD_RELU", ElemwiseMultiType::Mode::QFUSE_ADD_RELU)
  906. .value("QMUL", ElemwiseMultiType::Mode::QMUL)
  907. .value("QMIN", ElemwiseMultiType::Mode::QMIN)
  908. .value("QMAX", ElemwiseMultiType::Mode::QMAX)
  909. .value("QSUB", ElemwiseMultiType::Mode::QSUB)
  910. .value("QTRUE_DIV", ElemwiseMultiType::Mode::QTRUE_DIV)
  911. .value("QFUSE_ADD_SIGMOID", ElemwiseMultiType::Mode::QFUSE_ADD_SIGMOID)
  912. .value("QFUSE_ADD_TANH", ElemwiseMultiType::Mode::QFUSE_ADD_TANH)
  913. .value("QRELU", ElemwiseMultiType::Mode::QRELU)
  914. .value("QABS", ElemwiseMultiType::Mode::QABS)
  915. .value("QSIGMOID", ElemwiseMultiType::Mode::QSIGMOID)
  916. .value("QEXP", ElemwiseMultiType::Mode::QEXP)
  917. .value("QTANH", ElemwiseMultiType::Mode::QTANH)
  918. .value("QFUSE_MUL_ADD3", ElemwiseMultiType::Mode::QFUSE_MUL_ADD3)
  919. .value("QFAST_TANH", ElemwiseMultiType::Mode::QFAST_TANH)
  920. .value("QNEGATE", ElemwiseMultiType::Mode::QNEGATE)
  921. .value("QACOS", ElemwiseMultiType::Mode::QACOS)
  922. .value("QASIN", ElemwiseMultiType::Mode::QASIN)
  923. .value("QCEIL", ElemwiseMultiType::Mode::QCEIL)
  924. .value("QCOS", ElemwiseMultiType::Mode::QCOS)
  925. .value("QEXPM1", ElemwiseMultiType::Mode::QEXPM1)
  926. .value("QFLOOR", ElemwiseMultiType::Mode::QFLOOR)
  927. .value("QLOG", ElemwiseMultiType::Mode::QLOG)
  928. .value("QLOG1P", ElemwiseMultiType::Mode::QLOG1P)
  929. .value("QSIN", ElemwiseMultiType::Mode::QSIN)
  930. .value("QROUND", ElemwiseMultiType::Mode::QROUND)
  931. .value("QERF", ElemwiseMultiType::Mode::QERF)
  932. .value("QERFINV", ElemwiseMultiType::Mode::QERFINV)
  933. .value("QERFC", ElemwiseMultiType::Mode::QERFC)
  934. .value("QERFCINV", ElemwiseMultiType::Mode::QERFCINV)
  935. .value("QABS_GRAD", ElemwiseMultiType::Mode::QABS_GRAD)
  936. .value("QFLOOR_DIV", ElemwiseMultiType::Mode::QFLOOR_DIV)
  937. .value("QMOD", ElemwiseMultiType::Mode::QMOD)
  938. .value("QSIGMOID_GRAD", ElemwiseMultiType::Mode::QSIGMOID_GRAD)
  939. .value("QSWITCH_GT0", ElemwiseMultiType::Mode::QSWITCH_GT0)
  940. .value("QTANH_GRAD", ElemwiseMultiType::Mode::QTANH_GRAD)
  941. .value("QLT", ElemwiseMultiType::Mode::QLT)
  942. .value("QLEQ", ElemwiseMultiType::Mode::QLEQ)
  943. .value("QEQ", ElemwiseMultiType::Mode::QEQ)
  944. .value("QPOW", ElemwiseMultiType::Mode::QPOW)
  945. .value("QLOG_SUM_EXP", ElemwiseMultiType::Mode::QLOG_SUM_EXP)
  946. .value("QFAST_TANH_GRAD", ElemwiseMultiType::Mode::QFAST_TANH_GRAD)
  947. .value("QATAN2", ElemwiseMultiType::Mode::QATAN2)
  948. .value("QCOND_LEQ_MOV", ElemwiseMultiType::Mode::QCOND_LEQ_MOV)
  949. .value("QH_SWISH", ElemwiseMultiType::Mode::QH_SWISH)
  950. .value("QFUSE_ADD_H_SWISH", ElemwiseMultiType::Mode::QFUSE_ADD_H_SWISH)
  951. .value("QH_SWISH_GRAD", ElemwiseMultiType::Mode::QH_SWISH_GRAD)
  952. .value("FUSE_MUL_ADD3_INT16xF32xF32xF32", ElemwiseMultiType::Mode::FUSE_MUL_ADD3_INT16xF32xF32xF32)
  953. .value("MUL_INT16xF32xF32", ElemwiseMultiType::Mode::MUL_INT16xF32xF32)
  954. .value("FUSE_MUL_ADD3_UINT8xF32xF32xF32", ElemwiseMultiType::Mode::FUSE_MUL_ADD3_UINT8xF32xF32xF32)
  955. .value("QCOND_LT_MOV", ElemwiseMultiType::Mode::QCOND_LT_MOV)
  956. .value("EQ", ElemwiseMultiType::Mode::EQ)
  957. .value("NEQ", ElemwiseMultiType::Mode::NEQ)
  958. .value("LT", ElemwiseMultiType::Mode::LT)
  959. .value("LEQ", ElemwiseMultiType::Mode::LEQ)
  960. .value("ISNAN", ElemwiseMultiType::Mode::ISNAN)
  961. .value("ISINF", ElemwiseMultiType::Mode::ISINF)
  962. .def(py::init([](const std::string& in) {
  963. auto&& str = normalize_enum(in);
  964. if (str == "FUSE_MUL_ADD3_INT16x32x32x32") return ElemwiseMultiType::Mode::FUSE_MUL_ADD3_INT16x32x32x32;
  965. if (str == "FUSE_MUL_ADD3_IXxF32xF32xI8") return ElemwiseMultiType::Mode::FUSE_MUL_ADD3_IXxF32xF32xI8;
  966. if (str == "ROUND_SHR_SATURATE_IXxI8xI8") return ElemwiseMultiType::Mode::ROUND_SHR_SATURATE_IXxI8xI8;
  967. if (str == "FUSE_ADD_RMULH_ROUND_SHR_SATURATE_INT16x16x16x8") return ElemwiseMultiType::Mode::FUSE_ADD_RMULH_ROUND_SHR_SATURATE_INT16x16x16x8;
  968. if (str == "FUSE_ADD_RMULH_ROUND_SHR_SATURATE_INT32x32x32x8") return ElemwiseMultiType::Mode::FUSE_ADD_RMULH_ROUND_SHR_SATURATE_INT32x32x32x8;
  969. if (str == "ROUND_SHR_SATURATE_IXxI8xI16") return ElemwiseMultiType::Mode::ROUND_SHR_SATURATE_IXxI8xI16;
  970. if (str == "QADD") return ElemwiseMultiType::Mode::QADD;
  971. if (str == "QFUSE_ADD_RELU") return ElemwiseMultiType::Mode::QFUSE_ADD_RELU;
  972. if (str == "QMUL") return ElemwiseMultiType::Mode::QMUL;
  973. if (str == "QMIN") return ElemwiseMultiType::Mode::QMIN;
  974. if (str == "QMAX") return ElemwiseMultiType::Mode::QMAX;
  975. if (str == "QSUB") return ElemwiseMultiType::Mode::QSUB;
  976. if (str == "QTRUE_DIV") return ElemwiseMultiType::Mode::QTRUE_DIV;
  977. if (str == "QFUSE_ADD_SIGMOID") return ElemwiseMultiType::Mode::QFUSE_ADD_SIGMOID;
  978. if (str == "QFUSE_ADD_TANH") return ElemwiseMultiType::Mode::QFUSE_ADD_TANH;
  979. if (str == "QRELU") return ElemwiseMultiType::Mode::QRELU;
  980. if (str == "QABS") return ElemwiseMultiType::Mode::QABS;
  981. if (str == "QSIGMOID") return ElemwiseMultiType::Mode::QSIGMOID;
  982. if (str == "QEXP") return ElemwiseMultiType::Mode::QEXP;
  983. if (str == "QTANH") return ElemwiseMultiType::Mode::QTANH;
  984. if (str == "QFUSE_MUL_ADD3") return ElemwiseMultiType::Mode::QFUSE_MUL_ADD3;
  985. if (str == "QFAST_TANH") return ElemwiseMultiType::Mode::QFAST_TANH;
  986. if (str == "QNEGATE") return ElemwiseMultiType::Mode::QNEGATE;
  987. if (str == "QACOS") return ElemwiseMultiType::Mode::QACOS;
  988. if (str == "QASIN") return ElemwiseMultiType::Mode::QASIN;
  989. if (str == "QCEIL") return ElemwiseMultiType::Mode::QCEIL;
  990. if (str == "QCOS") return ElemwiseMultiType::Mode::QCOS;
  991. if (str == "QEXPM1") return ElemwiseMultiType::Mode::QEXPM1;
  992. if (str == "QFLOOR") return ElemwiseMultiType::Mode::QFLOOR;
  993. if (str == "QLOG") return ElemwiseMultiType::Mode::QLOG;
  994. if (str == "QLOG1P") return ElemwiseMultiType::Mode::QLOG1P;
  995. if (str == "QSIN") return ElemwiseMultiType::Mode::QSIN;
  996. if (str == "QROUND") return ElemwiseMultiType::Mode::QROUND;
  997. if (str == "QERF") return ElemwiseMultiType::Mode::QERF;
  998. if (str == "QERFINV") return ElemwiseMultiType::Mode::QERFINV;
  999. if (str == "QERFC") return ElemwiseMultiType::Mode::QERFC;
  1000. if (str == "QERFCINV") return ElemwiseMultiType::Mode::QERFCINV;
  1001. if (str == "QABS_GRAD") return ElemwiseMultiType::Mode::QABS_GRAD;
  1002. if (str == "QFLOOR_DIV") return ElemwiseMultiType::Mode::QFLOOR_DIV;
  1003. if (str == "QMOD") return ElemwiseMultiType::Mode::QMOD;
  1004. if (str == "QSIGMOID_GRAD") return ElemwiseMultiType::Mode::QSIGMOID_GRAD;
  1005. if (str == "QSWITCH_GT0") return ElemwiseMultiType::Mode::QSWITCH_GT0;
  1006. if (str == "QTANH_GRAD") return ElemwiseMultiType::Mode::QTANH_GRAD;
  1007. if (str == "QLT") return ElemwiseMultiType::Mode::QLT;
  1008. if (str == "QLEQ") return ElemwiseMultiType::Mode::QLEQ;
  1009. if (str == "QEQ") return ElemwiseMultiType::Mode::QEQ;
  1010. if (str == "QPOW") return ElemwiseMultiType::Mode::QPOW;
  1011. if (str == "QLOG_SUM_EXP") return ElemwiseMultiType::Mode::QLOG_SUM_EXP;
  1012. if (str == "QFAST_TANH_GRAD") return ElemwiseMultiType::Mode::QFAST_TANH_GRAD;
  1013. if (str == "QATAN2") return ElemwiseMultiType::Mode::QATAN2;
  1014. if (str == "QCOND_LEQ_MOV") return ElemwiseMultiType::Mode::QCOND_LEQ_MOV;
  1015. if (str == "QH_SWISH") return ElemwiseMultiType::Mode::QH_SWISH;
  1016. if (str == "QFUSE_ADD_H_SWISH") return ElemwiseMultiType::Mode::QFUSE_ADD_H_SWISH;
  1017. if (str == "QH_SWISH_GRAD") return ElemwiseMultiType::Mode::QH_SWISH_GRAD;
  1018. if (str == "FUSE_MUL_ADD3_INT16xF32xF32xF32") return ElemwiseMultiType::Mode::FUSE_MUL_ADD3_INT16xF32xF32xF32;
  1019. if (str == "MUL_INT16xF32xF32") return ElemwiseMultiType::Mode::MUL_INT16xF32xF32;
  1020. if (str == "FUSE_MUL_ADD3_UINT8xF32xF32xF32") return ElemwiseMultiType::Mode::FUSE_MUL_ADD3_UINT8xF32xF32xF32;
  1021. if (str == "QCOND_LT_MOV") return ElemwiseMultiType::Mode::QCOND_LT_MOV;
  1022. if (str == "EQ") return ElemwiseMultiType::Mode::EQ;
  1023. if (str == "NEQ") return ElemwiseMultiType::Mode::NEQ;
  1024. if (str == "LT") return ElemwiseMultiType::Mode::LT;
  1025. if (str == "LEQ") return ElemwiseMultiType::Mode::LEQ;
  1026. if (str == "ISNAN") return ElemwiseMultiType::Mode::ISNAN;
  1027. if (str == "ISINF") return ElemwiseMultiType::Mode::ISINF;
  1028. throw py::cast_error("invalid enum value " + in);
  1029. }));
  1030. py::implicitly_convertible<std::string, ElemwiseMultiType::Mode>();
  1031. ElemwiseMultiTypeInst
  1032. .def(py::init<::megdnn::param::ElemwiseMultiType::Mode, ::megdnn::DType, std::string>(), py::arg("mode") = ::megdnn::param::ElemwiseMultiType::Mode::FUSE_MUL_ADD3_INT16x32x32x32, py::arg("dtype"), py::arg("scope") = {})
  1033. .def(py::init<>())
  1034. .def_readwrite("mode", &ElemwiseMultiType::mode)
  1035. .def_readwrite("dtype", &ElemwiseMultiType::dtype);
  1036. py::class_<ExternOpr, std::shared_ptr<ExternOpr>, OpDef> ExternOprInst(m, "ExternOpr");
  1037. ExternOprInst
  1038. .def(py::init<std::vector<std::vector<size_t>>, std::string, std::string, size_t, std::vector<::megdnn::DType>, std::string>(), py::arg("output_shapes"), py::arg("name"), py::arg("data"), py::arg("data_len"), py::arg("output_dtypes"), py::arg("scope") = {})
  1039. .def(py::init<>())
  1040. .def_readwrite("output_shapes", &ExternOpr::output_shapes)
  1041. .def_readwrite("name", &ExternOpr::name)
  1042. .def_readwrite("data", &ExternOpr::data)
  1043. .def_readwrite("data_len", &ExternOpr::data_len)
  1044. .def_readwrite("output_dtypes", &ExternOpr::output_dtypes);
  1045. py::class_<Eye, std::shared_ptr<Eye>, OpDef> EyeInst(m, "Eye");
  1046. EyeInst
  1047. .def(py::init<int32_t, ::megdnn::DType, ::mgb::CompNode, std::string>(), py::arg("k") = 0, py::arg("dtype") = megdnn::DType::from_enum(megdnn::DTypeEnum::Float32), py::arg("comp_node"), py::arg("scope") = {})
  1048. .def(py::init<>())
  1049. .def_readwrite("k", &Eye::k)
  1050. .def_readwrite("dtype", &Eye::dtype)
  1051. .def_readwrite("comp_node", &Eye::comp_node);
  1052. py::class_<FakeQuant, std::shared_ptr<FakeQuant>, OpDef> FakeQuantInst(m, "FakeQuant");
  1053. FakeQuantInst
  1054. .def(py::init<int32_t, int32_t, std::string>(), py::arg("qmin") = -2147483648, py::arg("qmax") = 2147483647, py::arg("scope") = {})
  1055. .def_readwrite("qmin", &FakeQuant::qmin)
  1056. .def_readwrite("qmax", &FakeQuant::qmax);
  1057. py::class_<FastpathCopy, std::shared_ptr<FastpathCopy>, OpDef> FastpathCopyInst(m, "FastpathCopy");
  1058. FastpathCopyInst
  1059. .def(py::init<>());
  1060. py::class_<GammaRNG, std::shared_ptr<GammaRNG>, OpDef> GammaRNGInst(m, "GammaRNG");
  1061. GammaRNGInst
  1062. .def(py::init<uint64_t, size_t, std::string>(), py::arg("seed") = 0, py::arg("handle"), py::arg("scope") = {})
  1063. .def(py::init<>())
  1064. .def_readwrite("seed", &GammaRNG::seed)
  1065. .def_readwrite("handle", &GammaRNG::handle);
  1066. py::class_<GaussianRNG, std::shared_ptr<GaussianRNG>, OpDef> GaussianRNGInst(m, "GaussianRNG");
  1067. GaussianRNGInst
  1068. .def(py::init<uint64_t, float, float, ::megdnn::DType, size_t, std::string>(), py::arg("seed") = 0, py::arg("mean") = 0, py::arg("std") = 1, py::arg("dtype") = megdnn::DType::from_enum(megdnn::DTypeEnum::Float32), py::arg("handle"), py::arg("scope") = {})
  1069. .def(py::init<>())
  1070. .def_readwrite("seed", &GaussianRNG::seed)
  1071. .def_readwrite("mean", &GaussianRNG::mean)
  1072. .def_readwrite("std", &GaussianRNG::std)
  1073. .def_readwrite("dtype", &GaussianRNG::dtype)
  1074. .def_readwrite("handle", &GaussianRNG::handle);
  1075. py::class_<GetVarShape, std::shared_ptr<GetVarShape>, OpDef> GetVarShapeInst(m, "GetVarShape");
  1076. GetVarShapeInst
  1077. .def(py::init<int32_t, std::string>(), py::arg("axis") = ::megdnn::param::OptionalAxisV1::INVALID_AXIS, py::arg("scope") = {})
  1078. .def_readwrite("axis", &GetVarShape::axis);
  1079. py::class_<GroupLocal, std::shared_ptr<GroupLocal>, OpDef> GroupLocalInst(m, "GroupLocal");
  1080. GroupLocalInst.attr("Mode") = BatchConvBiasInst.attr("Mode");
  1081. GroupLocalInst.attr("Sparse") = BatchConvBiasInst.attr("Sparse");
  1082. GroupLocalInst.attr("Format") = AdaptivePoolingInst.attr("Format");
  1083. GroupLocalInst.attr("ComputeMode") = BatchConvBiasInst.attr("ComputeMode");
  1084. GroupLocalInst
  1085. .def(py::init<::megdnn::param::Convolution::Mode, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, ::megdnn::param::Convolution::Sparse, ::megdnn::param::Convolution::Format, ::megdnn::param::Convolution::ComputeMode, std::string>(), py::arg("mode") = ::megdnn::param::Convolution::Mode::CROSS_CORRELATION, py::arg("pad_h") = 0, py::arg("pad_w") = 0, py::arg("stride_h") = 1, py::arg("stride_w") = 1, py::arg("dilate_h") = 1, py::arg("dilate_w") = 1, py::arg("sparse") = ::megdnn::param::Convolution::Sparse::DENSE, py::arg("format") = ::megdnn::param::Convolution::Format::NCHW, py::arg("compute_mode") = ::megdnn::param::Convolution::ComputeMode::DEFAULT, py::arg("scope") = {})
  1086. .def_readwrite("mode", &GroupLocal::mode)
  1087. .def_readwrite("pad_h", &GroupLocal::pad_h)
  1088. .def_readwrite("pad_w", &GroupLocal::pad_w)
  1089. .def_readwrite("stride_h", &GroupLocal::stride_h)
  1090. .def_readwrite("stride_w", &GroupLocal::stride_w)
  1091. .def_readwrite("dilate_h", &GroupLocal::dilate_h)
  1092. .def_readwrite("dilate_w", &GroupLocal::dilate_w)
  1093. .def_readwrite("sparse", &GroupLocal::sparse)
  1094. .def_readwrite("format", &GroupLocal::format)
  1095. .def_readwrite("compute_mode", &GroupLocal::compute_mode);
  1096. py::class_<Identity, std::shared_ptr<Identity>, OpDef> IdentityInst(m, "Identity");
  1097. IdentityInst
  1098. .def(py::init<>());
  1099. py::class_<Images2Neibs, std::shared_ptr<Images2Neibs>, OpDef> Images2NeibsInst(m, "Images2Neibs");
  1100. Images2NeibsInst
  1101. .def(py::init<uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, std::string>(), py::arg("pad_h") = 0, py::arg("pad_w") = 0, py::arg("stride_h") = 1, py::arg("stride_w") = 1, py::arg("dilate_h") = 1, py::arg("dilate_w") = 1, py::arg("window_h") = 3, py::arg("window_w") = 3, py::arg("scope") = {})
  1102. .def_readwrite("pad_h", &Images2Neibs::pad_h)
  1103. .def_readwrite("pad_w", &Images2Neibs::pad_w)
  1104. .def_readwrite("stride_h", &Images2Neibs::stride_h)
  1105. .def_readwrite("stride_w", &Images2Neibs::stride_w)
  1106. .def_readwrite("dilate_h", &Images2Neibs::dilate_h)
  1107. .def_readwrite("dilate_w", &Images2Neibs::dilate_w)
  1108. .def_readwrite("window_h", &Images2Neibs::window_h)
  1109. .def_readwrite("window_w", &Images2Neibs::window_w);
  1110. py::class_<IncrMeshIndexing, std::shared_ptr<IncrMeshIndexing>, OpDef> IncrMeshIndexingInst(m, "IncrMeshIndexing");
  1111. IncrMeshIndexingInst
  1112. .def(py::init<std::vector<std::tuple<int8_t, bool, bool, bool, bool>>, std::string>(), py::arg("items"), py::arg("scope") = {})
  1113. .def(py::init<>())
  1114. .def_readwrite("items", &IncrMeshIndexing::items);
  1115. py::class_<IncrSubtensor, std::shared_ptr<IncrSubtensor>, OpDef> IncrSubtensorInst(m, "IncrSubtensor");
  1116. IncrSubtensorInst
  1117. .def(py::init<std::vector<std::tuple<int8_t, bool, bool, bool, bool>>, std::string>(), py::arg("items"), py::arg("scope") = {})
  1118. .def(py::init<>())
  1119. .def_readwrite("items", &IncrSubtensor::items);
  1120. py::class_<IndexingIncrMultiAxisVec, std::shared_ptr<IndexingIncrMultiAxisVec>, OpDef> IndexingIncrMultiAxisVecInst(m, "IndexingIncrMultiAxisVec");
  1121. IndexingIncrMultiAxisVecInst
  1122. .def(py::init<std::vector<std::tuple<int8_t, bool, bool, bool, bool>>, std::string>(), py::arg("items"), py::arg("scope") = {})
  1123. .def(py::init<>())
  1124. .def_readwrite("items", &IndexingIncrMultiAxisVec::items);
  1125. py::class_<IndexingMultiAxisVec, std::shared_ptr<IndexingMultiAxisVec>, OpDef> IndexingMultiAxisVecInst(m, "IndexingMultiAxisVec");
  1126. IndexingMultiAxisVecInst
  1127. .def(py::init<std::vector<std::tuple<int8_t, bool, bool, bool, bool>>, std::string>(), py::arg("items"), py::arg("scope") = {})
  1128. .def(py::init<>())
  1129. .def_readwrite("items", &IndexingMultiAxisVec::items);
  1130. py::class_<IndexingOneHot, std::shared_ptr<IndexingOneHot>, OpDef> IndexingOneHotInst(m, "IndexingOneHot");
  1131. IndexingOneHotInst
  1132. .def(py::init<int32_t, int32_t, std::string>(), py::arg("axis") = 0, py::arg("ndim"), py::arg("scope") = {})
  1133. .def(py::init<>())
  1134. .def_readwrite("axis", &IndexingOneHot::axis)
  1135. .def_readwrite("ndim", &IndexingOneHot::ndim);
  1136. py::class_<IndexingSetMultiAxisVec, std::shared_ptr<IndexingSetMultiAxisVec>, OpDef> IndexingSetMultiAxisVecInst(m, "IndexingSetMultiAxisVec");
  1137. IndexingSetMultiAxisVecInst
  1138. .def(py::init<std::vector<std::tuple<int8_t, bool, bool, bool, bool>>, std::string>(), py::arg("items"), py::arg("scope") = {})
  1139. .def(py::init<>())
  1140. .def_readwrite("items", &IndexingSetMultiAxisVec::items);
  1141. py::class_<IndexingSetOneHot, std::shared_ptr<IndexingSetOneHot>, OpDef> IndexingSetOneHotInst(m, "IndexingSetOneHot");
  1142. IndexingSetOneHotInst
  1143. .def(py::init<int32_t, int32_t, std::string>(), py::arg("axis") = 0, py::arg("ndim"), py::arg("scope") = {})
  1144. .def(py::init<>())
  1145. .def_readwrite("axis", &IndexingSetOneHot::axis)
  1146. .def_readwrite("ndim", &IndexingSetOneHot::ndim);
  1147. py::class_<InplaceAdd, std::shared_ptr<InplaceAdd>, OpDef> InplaceAddInst(m, "InplaceAdd");
  1148. InplaceAddInst
  1149. .def(py::init<>());
  1150. py::class_<LAMBUpdate, std::shared_ptr<LAMBUpdate>, OpDef> LAMBUpdateInst(m, "LAMBUpdate");
  1151. LAMBUpdateInst
  1152. .def(py::init<float, float, float, float, float, float, bool, bool, std::string>(), py::arg("beta_1") = 1.f, py::arg("beta_2") = 1.f, py::arg("step") = 1.f, py::arg("lr") = 1.f, py::arg("weight_decay") = 1.f, py::arg("eps") = 1.f, py::arg("bias_correction") = true, py::arg("always_adapt") = false, py::arg("scope") = {})
  1153. .def_readwrite("beta_1", &LAMBUpdate::beta_1)
  1154. .def_readwrite("beta_2", &LAMBUpdate::beta_2)
  1155. .def_readwrite("step", &LAMBUpdate::step)
  1156. .def_readwrite("lr", &LAMBUpdate::lr)
  1157. .def_readwrite("weight_decay", &LAMBUpdate::weight_decay)
  1158. .def_readwrite("eps", &LAMBUpdate::eps)
  1159. .def_readwrite("bias_correction", &LAMBUpdate::bias_correction)
  1160. .def_readwrite("always_adapt", &LAMBUpdate::always_adapt);
  1161. py::class_<LRN, std::shared_ptr<LRN>, OpDef> LRNInst(m, "LRN");
  1162. LRNInst
  1163. .def(py::init<uint32_t, float, float, float, std::string>(), py::arg("n") = 5, py::arg("k") = 2.f, py::arg("alpha") = 1e-4f, py::arg("beta") = 0.75f, py::arg("scope") = {})
  1164. .def_readwrite("n", &LRN::n)
  1165. .def_readwrite("k", &LRN::k)
  1166. .def_readwrite("alpha", &LRN::alpha)
  1167. .def_readwrite("beta", &LRN::beta);
  1168. py::class_<LSQ, std::shared_ptr<LSQ>, OpDef> LSQInst(m, "LSQ");
  1169. LSQInst
  1170. .def(py::init<int32_t, int32_t, std::string>(), py::arg("qmin") = -2147483648, py::arg("qmax") = 2147483647, py::arg("scope") = {})
  1171. .def_readwrite("qmin", &LSQ::qmin)
  1172. .def_readwrite("qmax", &LSQ::qmax);
  1173. py::class_<LSTM, std::shared_ptr<LSTM>, OpDef> LSTMInst(m, "LSTM");
  1174. LSTMInst.attr("FwdMode") = BatchNormInst.attr("FwdMode");
  1175. LSTMInst
  1176. .def(py::init<uint32_t, bool, bool, uint32_t, uint32_t, float, ::megdnn::param::LSTM::FwdMode, std::string>(), py::arg("num_layers") = 1, py::arg("bidirectional") = false, py::arg("bias") = true, py::arg("hidden_size") = 128, py::arg("proj_size") = 0, py::arg("dropout") = 0.f, py::arg("fwd_mode") = ::megdnn::param::LSTM::FwdMode::TRAINING, py::arg("scope") = {})
  1177. .def_readwrite("num_layers", &LSTM::num_layers)
  1178. .def_readwrite("bidirectional", &LSTM::bidirectional)
  1179. .def_readwrite("bias", &LSTM::bias)
  1180. .def_readwrite("hidden_size", &LSTM::hidden_size)
  1181. .def_readwrite("proj_size", &LSTM::proj_size)
  1182. .def_readwrite("dropout", &LSTM::dropout)
  1183. .def_readwrite("fwd_mode", &LSTM::fwd_mode);
  1184. py::class_<LSTMCell, std::shared_ptr<LSTMCell>, OpDef> LSTMCellInst(m, "LSTMCell");
  1185. LSTMCellInst
  1186. .def(py::init<>());
  1187. py::class_<LayerNorm, std::shared_ptr<LayerNorm>, OpDef> LayerNormInst(m, "LayerNorm");
  1188. LayerNormInst
  1189. .def(py::init<bool, float, uint64_t, uint64_t, std::string>(), py::arg("affine") = true, py::arg("eps") = 1e-5f, py::arg("normalized_dim") = 1, py::arg("normalized_size") = 1, py::arg("scope") = {})
  1190. .def_readwrite("affine", &LayerNorm::affine)
  1191. .def_readwrite("eps", &LayerNorm::eps)
  1192. .def_readwrite("normalized_dim", &LayerNorm::normalized_dim)
  1193. .def_readwrite("normalized_size", &LayerNorm::normalized_size);
  1194. py::class_<Linspace, std::shared_ptr<Linspace>, OpDef> LinspaceInst(m, "Linspace");
  1195. LinspaceInst
  1196. .def(py::init<bool, ::mgb::CompNode, std::string>(), py::arg("endpoint") = true, py::arg("comp_node"), py::arg("scope") = {})
  1197. .def(py::init<>())
  1198. .def_readwrite("endpoint", &Linspace::endpoint)
  1199. .def_readwrite("comp_node", &Linspace::comp_node);
  1200. py::class_<MagicMindRuntime, std::shared_ptr<MagicMindRuntime>, OpDef> MagicMindRuntimeInst(m, "MagicMindRuntime");
  1201. MagicMindRuntimeInst
  1202. .def(py::init<std::string, size_t, std::string>(), py::arg("buf"), py::arg("buf_size"), py::arg("scope") = {})
  1203. .def(py::init<>())
  1204. .def_readwrite("buf", &MagicMindRuntime::buf)
  1205. .def_readwrite("buf_size", &MagicMindRuntime::buf_size);
  1206. py::class_<MatrixInverse, std::shared_ptr<MatrixInverse>, OpDef> MatrixInverseInst(m, "MatrixInverse");
  1207. MatrixInverseInst
  1208. .def(py::init<>());
  1209. py::class_<MatrixMul, std::shared_ptr<MatrixMul>, OpDef> MatrixMulInst(m, "MatrixMul");
  1210. MatrixMulInst.attr("ComputeMode") = BatchedMatrixMulInst.attr("ComputeMode");
  1211. MatrixMulInst.attr("Format") = BatchedMatrixMulInst.attr("Format");
  1212. MatrixMulInst.attr("Strategy") = BatchConvBiasInst.attr("Strategy");
  1213. MatrixMulInst
  1214. .def(py::init<bool, bool, ::megdnn::param::MatrixMul::ComputeMode, ::megdnn::param::MatrixMul::Format, ::megdnn::param::ExecutionPolicy::Strategy, uint64_t, uint32_t, uint32_t, std::string>(), py::arg("transposeA") = false, py::arg("transposeB") = false, py::arg("compute_mode") = ::megdnn::param::MatrixMul::ComputeMode::DEFAULT, py::arg("format") = ::megdnn::param::MatrixMul::Format::DEFAULT, py::arg("strategy") = static_cast<::megdnn::param::ExecutionPolicy::Strategy>(1), py::arg("workspace_limit") = 18446744073709551615ull, py::arg("dimA"), py::arg("dimB"), py::arg("scope") = {})
  1215. .def(py::init<>())
  1216. .def_readwrite("transposeA", &MatrixMul::transposeA)
  1217. .def_readwrite("transposeB", &MatrixMul::transposeB)
  1218. .def_readwrite("compute_mode", &MatrixMul::compute_mode)
  1219. .def_readwrite("format", &MatrixMul::format)
  1220. .def_readwrite("strategy", &MatrixMul::strategy)
  1221. .def_readwrite("workspace_limit", &MatrixMul::workspace_limit)
  1222. .def_readwrite("dimA", &MatrixMul::dimA)
  1223. .def_readwrite("dimB", &MatrixMul::dimB);
  1224. py::class_<MeshIndexing, std::shared_ptr<MeshIndexing>, OpDef> MeshIndexingInst(m, "MeshIndexing");
  1225. MeshIndexingInst
  1226. .def(py::init<std::vector<std::tuple<int8_t, bool, bool, bool, bool>>, std::string>(), py::arg("items"), py::arg("scope") = {})
  1227. .def(py::init<>())
  1228. .def_readwrite("items", &MeshIndexing::items);
  1229. py::class_<NMSKeep, std::shared_ptr<NMSKeep>, OpDef> NMSKeepInst(m, "NMSKeep");
  1230. NMSKeepInst
  1231. .def(py::init<float, uint32_t, std::string>(), py::arg("iou_thresh"), py::arg("max_output"), py::arg("scope") = {})
  1232. .def(py::init<>())
  1233. .def_readwrite("iou_thresh", &NMSKeep::iou_thresh)
  1234. .def_readwrite("max_output", &NMSKeep::max_output);
  1235. py::class_<NvOf, std::shared_ptr<NvOf>, OpDef> NvOfInst(m, "NvOf");
  1236. NvOfInst
  1237. .def(py::init<uint32_t, std::string>(), py::arg("precision") = 1, py::arg("scope") = {})
  1238. .def_readwrite("precision", &NvOf::precision);
  1239. py::class_<Padding, std::shared_ptr<Padding>, OpDef> PaddingInst(m, "Padding");
  1240. py::enum_<Padding::PaddingMode>(PaddingInst, "PaddingMode")
  1241. .value("REPLICATE", Padding::PaddingMode::REPLICATE)
  1242. .value("REFLECT", Padding::PaddingMode::REFLECT)
  1243. .value("CONSTANT", Padding::PaddingMode::CONSTANT)
  1244. .def(py::init([](const std::string& in) {
  1245. auto&& str = normalize_enum(in);
  1246. if (str == "REPLICATE") return Padding::PaddingMode::REPLICATE;
  1247. if (str == "REFLECT") return Padding::PaddingMode::REFLECT;
  1248. if (str == "CONSTANT") return Padding::PaddingMode::CONSTANT;
  1249. throw py::cast_error("invalid enum value " + in);
  1250. }));
  1251. py::implicitly_convertible<std::string, Padding::PaddingMode>();
  1252. PaddingInst
  1253. .def(py::init<uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, float, ::megdnn::param::Padding::PaddingMode, std::string>(), py::arg("front_offset_dim0") = 0, py::arg("front_offset_dim1") = 0, py::arg("front_offset_dim2") = 0, py::arg("front_offset_dim3") = 0, py::arg("front_offset_dim4") = 0, py::arg("front_offset_dim5") = 0, py::arg("front_offset_dim6") = 0, py::arg("back_offset_dim0") = 0, py::arg("back_offset_dim1") = 0, py::arg("back_offset_dim2") = 0, py::arg("back_offset_dim3") = 0, py::arg("back_offset_dim4") = 0, py::arg("back_offset_dim5") = 0, py::arg("back_offset_dim6") = 0, py::arg("padding_val") = 0, py::arg("padding_mode") = ::megdnn::param::Padding::PaddingMode::CONSTANT, py::arg("scope") = {})
  1254. .def_readwrite("front_offset_dim0", &Padding::front_offset_dim0)
  1255. .def_readwrite("front_offset_dim1", &Padding::front_offset_dim1)
  1256. .def_readwrite("front_offset_dim2", &Padding::front_offset_dim2)
  1257. .def_readwrite("front_offset_dim3", &Padding::front_offset_dim3)
  1258. .def_readwrite("front_offset_dim4", &Padding::front_offset_dim4)
  1259. .def_readwrite("front_offset_dim5", &Padding::front_offset_dim5)
  1260. .def_readwrite("front_offset_dim6", &Padding::front_offset_dim6)
  1261. .def_readwrite("back_offset_dim0", &Padding::back_offset_dim0)
  1262. .def_readwrite("back_offset_dim1", &Padding::back_offset_dim1)
  1263. .def_readwrite("back_offset_dim2", &Padding::back_offset_dim2)
  1264. .def_readwrite("back_offset_dim3", &Padding::back_offset_dim3)
  1265. .def_readwrite("back_offset_dim4", &Padding::back_offset_dim4)
  1266. .def_readwrite("back_offset_dim5", &Padding::back_offset_dim5)
  1267. .def_readwrite("back_offset_dim6", &Padding::back_offset_dim6)
  1268. .def_readwrite("padding_val", &Padding::padding_val)
  1269. .def_readwrite("padding_mode", &Padding::padding_mode);
  1270. py::class_<ParamPackConcat, std::shared_ptr<ParamPackConcat>, OpDef> ParamPackConcatInst(m, "ParamPackConcat");
  1271. ParamPackConcatInst
  1272. .def(py::init<std::vector<int32_t>, std::string>(), py::arg("offsets"), py::arg("scope") = {})
  1273. .def(py::init<>())
  1274. .def_readwrite("offsets", &ParamPackConcat::offsets);
  1275. py::class_<ParamPackSplit, std::shared_ptr<ParamPackSplit>, OpDef> ParamPackSplitInst(m, "ParamPackSplit");
  1276. ParamPackSplitInst
  1277. .def(py::init<std::vector<int32_t>, std::vector<std::vector<size_t>>, std::string>(), py::arg("offsets"), py::arg("shapes"), py::arg("scope") = {})
  1278. .def(py::init<>())
  1279. .def_readwrite("offsets", &ParamPackSplit::offsets)
  1280. .def_readwrite("shapes", &ParamPackSplit::shapes);
  1281. py::class_<PermutationRNG, std::shared_ptr<PermutationRNG>, OpDef> PermutationRNGInst(m, "PermutationRNG");
  1282. PermutationRNGInst
  1283. .def(py::init<uint64_t, ::megdnn::DType, size_t, std::string>(), py::arg("seed") = 0, py::arg("dtype") = megdnn::DType::from_enum(megdnn::DTypeEnum::Int32), py::arg("handle"), py::arg("scope") = {})
  1284. .def(py::init<>())
  1285. .def_readwrite("seed", &PermutationRNG::seed)
  1286. .def_readwrite("dtype", &PermutationRNG::dtype)
  1287. .def_readwrite("handle", &PermutationRNG::handle);
  1288. py::class_<PixelShuffle, std::shared_ptr<PixelShuffle>, OpDef> PixelShuffleInst(m, "PixelShuffle");
  1289. PixelShuffleInst
  1290. .def(py::init<int32_t, std::string>(), py::arg("factor"), py::arg("scope") = {})
  1291. .def(py::init<>())
  1292. .def_readwrite("factor", &PixelShuffle::factor);
  1293. py::class_<PixelShuffleBackward, std::shared_ptr<PixelShuffleBackward>, OpDef> PixelShuffleBackwardInst(m, "PixelShuffleBackward");
  1294. PixelShuffleBackwardInst
  1295. .def(py::init<int32_t, std::string>(), py::arg("factor"), py::arg("scope") = {})
  1296. .def(py::init<>())
  1297. .def_readwrite("factor", &PixelShuffleBackward::factor);
  1298. py::class_<PoissonRNG, std::shared_ptr<PoissonRNG>, OpDef> PoissonRNGInst(m, "PoissonRNG");
  1299. PoissonRNGInst
  1300. .def(py::init<uint64_t, size_t, std::string>(), py::arg("seed") = 0, py::arg("handle"), py::arg("scope") = {})
  1301. .def(py::init<>())
  1302. .def_readwrite("seed", &PoissonRNG::seed)
  1303. .def_readwrite("handle", &PoissonRNG::handle);
  1304. py::class_<Pooling, std::shared_ptr<Pooling>, OpDef> PoolingInst(m, "Pooling");
  1305. PoolingInst.attr("Mode") = AdaptivePoolingInst.attr("Mode");
  1306. PoolingInst.attr("Format") = AdaptivePoolingInst.attr("Format");
  1307. PoolingInst.attr("Strategy") = BatchConvBiasInst.attr("Strategy");
  1308. PoolingInst
  1309. .def(py::init<::megdnn::param::Pooling::Mode, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, ::megdnn::param::Pooling::Format, ::megdnn::param::ExecutionPolicy::Strategy, uint64_t, std::string>(), py::arg("mode") = ::megdnn::param::Pooling::Mode::MAX, py::arg("pad_h") = 0, py::arg("pad_w") = 0, py::arg("stride_h") = 2, py::arg("stride_w") = 2, py::arg("window_h") = 2, py::arg("window_w") = 2, py::arg("format") = ::megdnn::param::Pooling::Format::NCHW, py::arg("strategy") = static_cast<::megdnn::param::ExecutionPolicy::Strategy>(1), py::arg("workspace_limit") = 18446744073709551615ull, py::arg("scope") = {})
  1310. .def_readwrite("mode", &Pooling::mode)
  1311. .def_readwrite("pad_h", &Pooling::pad_h)
  1312. .def_readwrite("pad_w", &Pooling::pad_w)
  1313. .def_readwrite("stride_h", &Pooling::stride_h)
  1314. .def_readwrite("stride_w", &Pooling::stride_w)
  1315. .def_readwrite("window_h", &Pooling::window_h)
  1316. .def_readwrite("window_w", &Pooling::window_w)
  1317. .def_readwrite("format", &Pooling::format)
  1318. .def_readwrite("strategy", &Pooling::strategy)
  1319. .def_readwrite("workspace_limit", &Pooling::workspace_limit);
  1320. py::class_<RNN, std::shared_ptr<RNN>, OpDef> RNNInst(m, "RNN");
  1321. py::enum_<RNN::NonlineMode>(RNNInst, "NonlineMode")
  1322. .value("IDENTITY", RNN::NonlineMode::IDENTITY)
  1323. .value("RELU", RNN::NonlineMode::RELU)
  1324. .value("TANH", RNN::NonlineMode::TANH)
  1325. .def(py::init([](const std::string& in) {
  1326. auto&& str = normalize_enum(in);
  1327. if (str == "IDENTITY") return RNN::NonlineMode::IDENTITY;
  1328. if (str == "RELU") return RNN::NonlineMode::RELU;
  1329. if (str == "TANH") return RNN::NonlineMode::TANH;
  1330. throw py::cast_error("invalid enum value " + in);
  1331. }));
  1332. py::implicitly_convertible<std::string, RNN::NonlineMode>();
  1333. RNNInst.attr("FwdMode") = BatchNormInst.attr("FwdMode");
  1334. RNNInst
  1335. .def(py::init<uint32_t, bool, bool, uint32_t, float, ::megdnn::param::RNN::NonlineMode, ::megdnn::param::RNN::FwdMode, std::string>(), py::arg("num_layers") = 1, py::arg("bidirectional") = false, py::arg("bias") = true, py::arg("hidden_size") = 128, py::arg("dropout") = 0.f, py::arg("nonlineMode") = ::megdnn::param::RNN::NonlineMode::IDENTITY, py::arg("fwd_mode") = ::megdnn::param::RNN::FwdMode::TRAINING, py::arg("scope") = {})
  1336. .def_readwrite("num_layers", &RNN::num_layers)
  1337. .def_readwrite("bidirectional", &RNN::bidirectional)
  1338. .def_readwrite("bias", &RNN::bias)
  1339. .def_readwrite("hidden_size", &RNN::hidden_size)
  1340. .def_readwrite("dropout", &RNN::dropout)
  1341. .def_readwrite("nonlineMode", &RNN::nonlineMode)
  1342. .def_readwrite("fwd_mode", &RNN::fwd_mode);
  1343. py::class_<RNNCell, std::shared_ptr<RNNCell>, OpDef> RNNCellInst(m, "RNNCell");
  1344. RNNCellInst.attr("NonlineMode") = RNNInst.attr("NonlineMode");
  1345. RNNCellInst
  1346. .def(py::init<::megdnn::param::RNNCell::NonlineMode, std::string>(), py::arg("nonlineMode") = ::megdnn::param::RNNCell::NonlineMode::IDENTITY, py::arg("scope") = {})
  1347. .def_readwrite("nonlineMode", &RNNCell::nonlineMode);
  1348. py::class_<ROIAlign, std::shared_ptr<ROIAlign>, OpDef> ROIAlignInst(m, "ROIAlign");
  1349. py::enum_<ROIAlign::Mode>(ROIAlignInst, "Mode")
  1350. .value("MAX", ROIAlign::Mode::MAX)
  1351. .value("AVERAGE", ROIAlign::Mode::AVERAGE)
  1352. .def(py::init([](const std::string& in) {
  1353. auto&& str = normalize_enum(in);
  1354. if (str == "MAX") return ROIAlign::Mode::MAX;
  1355. if (str == "AVERAGE") return ROIAlign::Mode::AVERAGE;
  1356. throw py::cast_error("invalid enum value " + in);
  1357. }));
  1358. py::implicitly_convertible<std::string, ROIAlign::Mode>();
  1359. ROIAlignInst.attr("Format") = AdaptivePoolingInst.attr("Format");
  1360. ROIAlignInst
  1361. .def(py::init<::megdnn::param::ROIAlign::Mode, ::megdnn::param::ROIAlign::Format, float, float, uint32_t, uint32_t, uint32_t, uint32_t, std::string>(), py::arg("mode") = ::megdnn::param::ROIAlign::Mode::MAX, py::arg("format") = ::megdnn::param::ROIAlign::Format::NCHW, py::arg("spatial_scale") = 1.0, py::arg("offset") = 0.0, py::arg("pooled_height") = 1, py::arg("pooled_width") = 1, py::arg("sample_height") = 2, py::arg("sample_width") = 2, py::arg("scope") = {})
  1362. .def_readwrite("mode", &ROIAlign::mode)
  1363. .def_readwrite("format", &ROIAlign::format)
  1364. .def_readwrite("spatial_scale", &ROIAlign::spatial_scale)
  1365. .def_readwrite("offset", &ROIAlign::offset)
  1366. .def_readwrite("pooled_height", &ROIAlign::pooled_height)
  1367. .def_readwrite("pooled_width", &ROIAlign::pooled_width)
  1368. .def_readwrite("sample_height", &ROIAlign::sample_height)
  1369. .def_readwrite("sample_width", &ROIAlign::sample_width);
  1370. py::class_<ROIPooling, std::shared_ptr<ROIPooling>, OpDef> ROIPoolingInst(m, "ROIPooling");
  1371. py::enum_<ROIPooling::Mode>(ROIPoolingInst, "Mode")
  1372. .value("MAX", ROIPooling::Mode::MAX)
  1373. .value("AVERAGE", ROIPooling::Mode::AVERAGE)
  1374. .def(py::init([](const std::string& in) {
  1375. auto&& str = normalize_enum(in);
  1376. if (str == "MAX") return ROIPooling::Mode::MAX;
  1377. if (str == "AVERAGE") return ROIPooling::Mode::AVERAGE;
  1378. throw py::cast_error("invalid enum value " + in);
  1379. }));
  1380. py::implicitly_convertible<std::string, ROIPooling::Mode>();
  1381. ROIPoolingInst
  1382. .def(py::init<::megdnn::param::ROIPooling::Mode, float, std::string>(), py::arg("mode") = ::megdnn::param::ROIPooling::Mode::MAX, py::arg("scale") = 1.f, py::arg("scope") = {})
  1383. .def_readwrite("mode", &ROIPooling::mode)
  1384. .def_readwrite("scale", &ROIPooling::scale);
  1385. py::class_<Reduce, std::shared_ptr<Reduce>, OpDef> ReduceInst(m, "Reduce");
  1386. py::enum_<Reduce::Mode>(ReduceInst, "Mode")
  1387. .value("SUM", Reduce::Mode::SUM)
  1388. .value("SUM_SQR", Reduce::Mode::SUM_SQR)
  1389. .value("PRODUCT", Reduce::Mode::PRODUCT)
  1390. .value("MIN", Reduce::Mode::MIN)
  1391. .value("MAX", Reduce::Mode::MAX)
  1392. .value("MEAN", Reduce::Mode::MEAN)
  1393. .def(py::init([](const std::string& in) {
  1394. auto&& str = normalize_enum(in);
  1395. if (str == "SUM") return Reduce::Mode::SUM;
  1396. if (str == "SUM_SQR") return Reduce::Mode::SUM_SQR;
  1397. if (str == "PRODUCT") return Reduce::Mode::PRODUCT;
  1398. if (str == "MIN") return Reduce::Mode::MIN;
  1399. if (str == "MAX") return Reduce::Mode::MAX;
  1400. if (str == "MEAN") return Reduce::Mode::MEAN;
  1401. throw py::cast_error("invalid enum value " + in);
  1402. }));
  1403. py::implicitly_convertible<std::string, Reduce::Mode>();
  1404. py::enum_<Reduce::DataType>(ReduceInst, "DataType")
  1405. .value("DEFAULT", Reduce::DataType::DEFAULT)
  1406. .value("FLOAT_IO16xC32", Reduce::DataType::FLOAT_IO16xC32)
  1407. .value("FLOAT_O32xC32", Reduce::DataType::FLOAT_O32xC32)
  1408. .value("FLOAT_O16xC32", Reduce::DataType::FLOAT_O16xC32)
  1409. .value("QUINT_I8xO32", Reduce::DataType::QUINT_I8xO32)
  1410. .value("QINT_I8xO32", Reduce::DataType::QINT_I8xO32)
  1411. .def(py::init([](const std::string& in) {
  1412. auto&& str = normalize_enum(in);
  1413. if (str == "DEFAULT") return Reduce::DataType::DEFAULT;
  1414. if (str == "FLOAT_IO16xC32") return Reduce::DataType::FLOAT_IO16xC32;
  1415. if (str == "FLOAT_O32xC32") return Reduce::DataType::FLOAT_O32xC32;
  1416. if (str == "FLOAT_O16xC32") return Reduce::DataType::FLOAT_O16xC32;
  1417. if (str == "QUINT_I8xO32") return Reduce::DataType::QUINT_I8xO32;
  1418. if (str == "QINT_I8xO32") return Reduce::DataType::QINT_I8xO32;
  1419. throw py::cast_error("invalid enum value " + in);
  1420. }));
  1421. py::implicitly_convertible<std::string, Reduce::DataType>();
  1422. ReduceInst
  1423. .def(py::init<::megdnn::param::Reduce::Mode, int32_t, ::megdnn::param::Reduce::DataType, bool, std::string>(), py::arg("mode") = ::megdnn::param::Reduce::Mode::SUM, py::arg("axis") = 2147483647, py::arg("data_type") = ::megdnn::param::Reduce::DataType::DEFAULT, py::arg("keepdim") = true, py::arg("scope") = {})
  1424. .def_readwrite("mode", &Reduce::mode)
  1425. .def_readwrite("axis", &Reduce::axis)
  1426. .def_readwrite("data_type", &Reduce::data_type)
  1427. .def_readwrite("keepdim", &Reduce::keepdim);
  1428. py::class_<Remap, std::shared_ptr<Remap>, OpDef> RemapInst(m, "Remap");
  1429. py::enum_<Remap::InterpolationMode>(RemapInst, "InterpolationMode")
  1430. .value("NEAREST", Remap::InterpolationMode::NEAREST)
  1431. .value("LINEAR", Remap::InterpolationMode::LINEAR)
  1432. .value("AREA", Remap::InterpolationMode::AREA)
  1433. .value("CUBIC", Remap::InterpolationMode::CUBIC)
  1434. .value("LANCZOS4", Remap::InterpolationMode::LANCZOS4)
  1435. .def(py::init([](const std::string& in) {
  1436. auto&& str = normalize_enum(in);
  1437. if (str == "NEAREST") return Remap::InterpolationMode::NEAREST;
  1438. if (str == "LINEAR") return Remap::InterpolationMode::LINEAR;
  1439. if (str == "AREA") return Remap::InterpolationMode::AREA;
  1440. if (str == "CUBIC") return Remap::InterpolationMode::CUBIC;
  1441. if (str == "LANCZOS4") return Remap::InterpolationMode::LANCZOS4;
  1442. throw py::cast_error("invalid enum value " + in);
  1443. }));
  1444. py::implicitly_convertible<std::string, Remap::InterpolationMode>();
  1445. py::enum_<Remap::BorderMode>(RemapInst, "BorderMode")
  1446. .value("REPLICATE", Remap::BorderMode::REPLICATE)
  1447. .value("REFLECT", Remap::BorderMode::REFLECT)
  1448. .value("REFLECT_101", Remap::BorderMode::REFLECT_101)
  1449. .value("WRAP", Remap::BorderMode::WRAP)
  1450. .value("CONSTANT", Remap::BorderMode::CONSTANT)
  1451. .value("TRANSPARENT", Remap::BorderMode::TRANSPARENT)
  1452. .value("ISOLATED", Remap::BorderMode::ISOLATED)
  1453. .def(py::init([](const std::string& in) {
  1454. auto&& str = normalize_enum(in);
  1455. if (str == "REPLICATE") return Remap::BorderMode::REPLICATE;
  1456. if (str == "REFLECT") return Remap::BorderMode::REFLECT;
  1457. if (str == "REFLECT_101") return Remap::BorderMode::REFLECT_101;
  1458. if (str == "WRAP") return Remap::BorderMode::WRAP;
  1459. if (str == "CONSTANT") return Remap::BorderMode::CONSTANT;
  1460. if (str == "TRANSPARENT") return Remap::BorderMode::TRANSPARENT;
  1461. if (str == "ISOLATED") return Remap::BorderMode::ISOLATED;
  1462. throw py::cast_error("invalid enum value " + in);
  1463. }));
  1464. py::implicitly_convertible<std::string, Remap::BorderMode>();
  1465. RemapInst.attr("Format") = AdaptivePoolingInst.attr("Format");
  1466. RemapInst
  1467. .def(py::init<::megdnn::param::Remap::InterpolationMode, ::megdnn::param::Remap::BorderMode, ::megdnn::param::Remap::Format, float, std::string>(), py::arg("imode") = ::megdnn::param::Remap::InterpolationMode::LINEAR, py::arg("border_type") = ::megdnn::param::Remap::BorderMode::REPLICATE, py::arg("format") = ::megdnn::param::Remap::Format::NHWC, py::arg("scalar") = 0.f, py::arg("scope") = {})
  1468. .def_readwrite("imode", &Remap::imode)
  1469. .def_readwrite("border_type", &Remap::border_type)
  1470. .def_readwrite("format", &Remap::format)
  1471. .def_readwrite("scalar", &Remap::scalar);
  1472. py::class_<RemoteRecv, std::shared_ptr<RemoteRecv>, OpDef> RemoteRecvInst(m, "RemoteRecv");
  1473. RemoteRecvInst
  1474. .def(py::init<std::string, std::string, uint32_t, uint32_t, ::mgb::CompNode, std::vector<int32_t>, ::megdnn::DType, std::string, std::string>(), py::arg("key"), py::arg("addr"), py::arg("port"), py::arg("rank_from"), py::arg("cn"), py::arg("shape"), py::arg("dtype"), py::arg("backend"), py::arg("scope") = {})
  1475. .def(py::init<>())
  1476. .def_readwrite("key", &RemoteRecv::key)
  1477. .def_readwrite("addr", &RemoteRecv::addr)
  1478. .def_readwrite("port", &RemoteRecv::port)
  1479. .def_readwrite("rank_from", &RemoteRecv::rank_from)
  1480. .def_readwrite("cn", &RemoteRecv::cn)
  1481. .def_readwrite("shape", &RemoteRecv::shape)
  1482. .def_readwrite("dtype", &RemoteRecv::dtype)
  1483. .def_readwrite("backend", &RemoteRecv::backend);
  1484. py::class_<RemoteSend, std::shared_ptr<RemoteSend>, OpDef> RemoteSendInst(m, "RemoteSend");
  1485. RemoteSendInst
  1486. .def(py::init<std::string, std::string, uint32_t, uint32_t, std::string, std::string>(), py::arg("key"), py::arg("addr"), py::arg("port"), py::arg("rank_to"), py::arg("backend"), py::arg("scope") = {})
  1487. .def(py::init<>())
  1488. .def_readwrite("key", &RemoteSend::key)
  1489. .def_readwrite("addr", &RemoteSend::addr)
  1490. .def_readwrite("port", &RemoteSend::port)
  1491. .def_readwrite("rank_to", &RemoteSend::rank_to)
  1492. .def_readwrite("backend", &RemoteSend::backend);
  1493. py::class_<RemoveAxis, std::shared_ptr<RemoveAxis>, OpDef> RemoveAxisInst(m, "RemoveAxis");
  1494. RemoveAxisInst
  1495. .def(py::init<std::vector<int32_t>, std::string>(), py::arg("axis"), py::arg("scope") = {})
  1496. .def(py::init<>())
  1497. .def_readwrite("axis", &RemoveAxis::axis);
  1498. py::class_<Reshape, std::shared_ptr<Reshape>, OpDef> ReshapeInst(m, "Reshape");
  1499. ReshapeInst
  1500. .def(py::init<int32_t, std::vector<int32_t>, std::string>(), py::arg("axis") = ::megdnn::param::OptionalAxisV1::INVALID_AXIS, py::arg("shape"), py::arg("scope") = {})
  1501. .def(py::init<>())
  1502. .def_readwrite("axis", &Reshape::axis)
  1503. .def_readwrite("shape", &Reshape::shape);
  1504. py::class_<Resize, std::shared_ptr<Resize>, OpDef> ResizeInst(m, "Resize");
  1505. ResizeInst.attr("InterpolationMode") = RemapInst.attr("InterpolationMode");
  1506. ResizeInst.attr("Format") = AdaptivePoolingInst.attr("Format");
  1507. ResizeInst
  1508. .def(py::init<::megdnn::param::Resize::InterpolationMode, ::megdnn::param::Resize::Format, std::string>(), py::arg("imode") = ::megdnn::param::Resize::InterpolationMode::LINEAR, py::arg("format") = ::megdnn::param::Resize::Format::NHWC, py::arg("scope") = {})
  1509. .def_readwrite("imode", &Resize::imode)
  1510. .def_readwrite("format", &Resize::format);
  1511. py::class_<SVD, std::shared_ptr<SVD>, OpDef> SVDInst(m, "SVD");
  1512. SVDInst
  1513. .def(py::init<bool, bool, std::string>(), py::arg("full_matrices") = false, py::arg("compute_uv") = true, py::arg("scope") = {})
  1514. .def_readwrite("full_matrices", &SVD::full_matrices)
  1515. .def_readwrite("compute_uv", &SVD::compute_uv);
  1516. py::class_<SetMeshIndexing, std::shared_ptr<SetMeshIndexing>, OpDef> SetMeshIndexingInst(m, "SetMeshIndexing");
  1517. SetMeshIndexingInst
  1518. .def(py::init<std::vector<std::tuple<int8_t, bool, bool, bool, bool>>, std::string>(), py::arg("items"), py::arg("scope") = {})
  1519. .def(py::init<>())
  1520. .def_readwrite("items", &SetMeshIndexing::items);
  1521. py::class_<SetSubtensor, std::shared_ptr<SetSubtensor>, OpDef> SetSubtensorInst(m, "SetSubtensor");
  1522. SetSubtensorInst
  1523. .def(py::init<std::vector<std::tuple<int8_t, bool, bool, bool, bool>>, std::string>(), py::arg("items"), py::arg("scope") = {})
  1524. .def(py::init<>())
  1525. .def_readwrite("items", &SetSubtensor::items);
  1526. py::class_<ShuffleRNG, std::shared_ptr<ShuffleRNG>, OpDef> ShuffleRNGInst(m, "ShuffleRNG");
  1527. ShuffleRNGInst
  1528. .def(py::init<uint64_t, size_t, std::string>(), py::arg("seed") = 0, py::arg("handle"), py::arg("scope") = {})
  1529. .def(py::init<>())
  1530. .def_readwrite("seed", &ShuffleRNG::seed)
  1531. .def_readwrite("handle", &ShuffleRNG::handle);
  1532. py::class_<SlidingWindowTranspose, std::shared_ptr<SlidingWindowTranspose>, OpDef> SlidingWindowTransposeInst(m, "SlidingWindowTranspose");
  1533. SlidingWindowTransposeInst
  1534. .def(py::init<uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, std::string>(), py::arg("out_h") = 0, py::arg("out_w") = 0, py::arg("pad_h") = 0, py::arg("pad_w") = 0, py::arg("stride_h") = 1, py::arg("stride_w") = 1, py::arg("dilate_h") = 1, py::arg("dilate_w") = 1, py::arg("window_h") = 3, py::arg("window_w") = 3, py::arg("scope") = {})
  1535. .def_readwrite("out_h", &SlidingWindowTranspose::out_h)
  1536. .def_readwrite("out_w", &SlidingWindowTranspose::out_w)
  1537. .def_readwrite("pad_h", &SlidingWindowTranspose::pad_h)
  1538. .def_readwrite("pad_w", &SlidingWindowTranspose::pad_w)
  1539. .def_readwrite("stride_h", &SlidingWindowTranspose::stride_h)
  1540. .def_readwrite("stride_w", &SlidingWindowTranspose::stride_w)
  1541. .def_readwrite("dilate_h", &SlidingWindowTranspose::dilate_h)
  1542. .def_readwrite("dilate_w", &SlidingWindowTranspose::dilate_w)
  1543. .def_readwrite("window_h", &SlidingWindowTranspose::window_h)
  1544. .def_readwrite("window_w", &SlidingWindowTranspose::window_w);
  1545. py::class_<Softmax, std::shared_ptr<Softmax>, OpDef> SoftmaxInst(m, "Softmax");
  1546. SoftmaxInst
  1547. .def(py::init<int32_t, std::string>(), py::arg("axis") = -1, py::arg("scope") = {})
  1548. .def_readwrite("axis", &Softmax::axis);
  1549. py::class_<Split, std::shared_ptr<Split>, OpDef> SplitInst(m, "Split");
  1550. SplitInst
  1551. .def(py::init<int32_t, int32_t, std::string>(), py::arg("axis"), py::arg("nsections"), py::arg("scope") = {})
  1552. .def(py::init<>())
  1553. .def_readwrite("axis", &Split::axis)
  1554. .def_readwrite("nsections", &Split::nsections);
  1555. py::class_<Subtensor, std::shared_ptr<Subtensor>, OpDef> SubtensorInst(m, "Subtensor");
  1556. SubtensorInst
  1557. .def(py::init<std::vector<std::tuple<int8_t, bool, bool, bool, bool>>, std::string>(), py::arg("items"), py::arg("scope") = {})
  1558. .def(py::init<>())
  1559. .def_readwrite("items", &Subtensor::items);
  1560. py::class_<TQT, std::shared_ptr<TQT>, OpDef> TQTInst(m, "TQT");
  1561. TQTInst
  1562. .def(py::init<int32_t, int32_t, std::string>(), py::arg("qmin") = -2147483648, py::arg("qmax") = 2147483647, py::arg("scope") = {})
  1563. .def_readwrite("qmin", &TQT::qmin)
  1564. .def_readwrite("qmax", &TQT::qmax);
  1565. py::class_<TensorRTRuntime, std::shared_ptr<TensorRTRuntime>, OpDef> TensorRTRuntimeInst(m, "TensorRTRuntime");
  1566. TensorRTRuntimeInst
  1567. .def(py::init<std::string, size_t, std::string>(), py::arg("buf"), py::arg("buf_size"), py::arg("scope") = {})
  1568. .def(py::init<>())
  1569. .def_readwrite("buf", &TensorRTRuntime::buf)
  1570. .def_readwrite("buf_size", &TensorRTRuntime::buf_size);
  1571. py::class_<TopK, std::shared_ptr<TopK>, OpDef> TopKInst(m, "TopK");
  1572. py::enum_<TopK::Mode>(TopKInst, "Mode")
  1573. .value("KTH_ONLY", TopK::Mode::KTH_ONLY)
  1574. .value("VALUE_IDX_NOSORT", TopK::Mode::VALUE_IDX_NOSORT)
  1575. .value("VALUE_IDX_SORTED", TopK::Mode::VALUE_IDX_SORTED)
  1576. .def(py::init([](const std::string& in) {
  1577. auto&& str = normalize_enum(in);
  1578. if (str == "KTH_ONLY") return TopK::Mode::KTH_ONLY;
  1579. if (str == "VALUE_IDX_NOSORT") return TopK::Mode::VALUE_IDX_NOSORT;
  1580. if (str == "VALUE_IDX_SORTED") return TopK::Mode::VALUE_IDX_SORTED;
  1581. throw py::cast_error("invalid enum value " + in);
  1582. }));
  1583. py::implicitly_convertible<std::string, TopK::Mode>();
  1584. TopKInst
  1585. .def(py::init<::megdnn::param::TopK::Mode, std::string>(), py::arg("mode") = ::megdnn::param::TopK::Mode::KTH_ONLY, py::arg("scope") = {})
  1586. .def_readwrite("mode", &TopK::mode);
  1587. py::class_<TypeCvt, std::shared_ptr<TypeCvt>, OpDef> TypeCvtInst(m, "TypeCvt");
  1588. TypeCvtInst
  1589. .def(py::init<::megdnn::DType, std::string>(), py::arg("dtype"), py::arg("scope") = {})
  1590. .def(py::init<>())
  1591. .def_readwrite("dtype", &TypeCvt::dtype);
  1592. py::class_<UniformRNG, std::shared_ptr<UniformRNG>, OpDef> UniformRNGInst(m, "UniformRNG");
  1593. UniformRNGInst
  1594. .def(py::init<uint64_t, ::megdnn::DType, size_t, std::string>(), py::arg("seed") = 0, py::arg("dtype") = megdnn::DType::from_enum(megdnn::DTypeEnum::Float32), py::arg("handle"), py::arg("scope") = {})
  1595. .def(py::init<>())
  1596. .def_readwrite("seed", &UniformRNG::seed)
  1597. .def_readwrite("dtype", &UniformRNG::dtype)
  1598. .def_readwrite("handle", &UniformRNG::handle);
  1599. py::class_<WarpAffine, std::shared_ptr<WarpAffine>, OpDef> WarpAffineInst(m, "WarpAffine");
  1600. WarpAffineInst.attr("InterpolationMode") = RemapInst.attr("InterpolationMode");
  1601. WarpAffineInst.attr("BorderMode") = RemapInst.attr("BorderMode");
  1602. WarpAffineInst.attr("Format") = AdaptivePoolingInst.attr("Format");
  1603. WarpAffineInst
  1604. .def(py::init<::megdnn::param::WarpAffine::InterpolationMode, ::megdnn::param::WarpAffine::BorderMode, float, ::megdnn::param::WarpAffine::Format, std::string>(), py::arg("imode") = ::megdnn::param::WarpAffine::InterpolationMode::LINEAR, py::arg("border_mode") = ::megdnn::param::WarpAffine::BorderMode::REPLICATE, py::arg("border_val") = .0f, py::arg("format") = ::megdnn::param::WarpAffine::Format::NHWC, py::arg("scope") = {})
  1605. .def_readwrite("imode", &WarpAffine::imode)
  1606. .def_readwrite("border_mode", &WarpAffine::border_mode)
  1607. .def_readwrite("border_val", &WarpAffine::border_val)
  1608. .def_readwrite("format", &WarpAffine::format);
  1609. py::class_<WarpPerspective, std::shared_ptr<WarpPerspective>, OpDef> WarpPerspectiveInst(m, "WarpPerspective");
  1610. WarpPerspectiveInst.attr("InterpolationMode") = RemapInst.attr("InterpolationMode");
  1611. WarpPerspectiveInst.attr("BorderMode") = RemapInst.attr("BorderMode");
  1612. WarpPerspectiveInst.attr("Format") = AdaptivePoolingInst.attr("Format");
  1613. WarpPerspectiveInst
  1614. .def(py::init<::megdnn::param::WarpPerspective::InterpolationMode, ::megdnn::param::WarpPerspective::BorderMode, ::megdnn::param::WarpPerspective::Format, float, std::string>(), py::arg("imode") = ::megdnn::param::WarpPerspective::InterpolationMode::LINEAR, py::arg("bmode") = ::megdnn::param::WarpPerspective::BorderMode::REPLICATE, py::arg("format") = ::megdnn::param::WarpPerspective::Format::NCHW, py::arg("border_val") = .0f, py::arg("scope") = {})
  1615. .def_readwrite("imode", &WarpPerspective::imode)
  1616. .def_readwrite("bmode", &WarpPerspective::bmode)
  1617. .def_readwrite("format", &WarpPerspective::format)
  1618. .def_readwrite("border_val", &WarpPerspective::border_val);
  1619. // clang-format on