|
|
@@ -97,7 +97,7 @@ void EnumAttrEmitter::emit_tpl_spl() { |
|
|
|
if (!firstOccur) return; |
|
|
|
|
|
|
|
os << tgfmt( |
|
|
|
"template<> PyTypeObject $enumTpl<$opClass::$enumClass>::type = {};\n", |
|
|
|
"template<> PyTypeObject* $enumTpl<$opClass::$enumClass>::type = nullptr;\n", |
|
|
|
&ctx); |
|
|
|
|
|
|
|
auto quote = [&](auto&& i) -> std::string { |
|
|
@@ -120,13 +120,6 @@ $enumTpl<$opClass::$enumClass>::mem2value = {$0}; |
|
|
|
"template<> PyObject* " |
|
|
|
"$enumTpl<$opClass::$enumClass>::pyobj_insts[$0] = {nullptr};\n", |
|
|
|
&ctx, attr->getEnumMembers().size()); |
|
|
|
|
|
|
|
if (attr->getEnumCombinedFlag()) { |
|
|
|
os << tgfmt( |
|
|
|
"template<> PyNumberMethods " |
|
|
|
"$enumTpl<$opClass::$enumClass>::number_methods = {};\n", |
|
|
|
&ctx); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
Initproc EnumAttrEmitter::emit_initproc() { |
|
|
@@ -140,45 +133,70 @@ void $0(PyTypeObject& py_type) { |
|
|
|
|
|
|
|
if (firstOccur) { |
|
|
|
os << tgfmt(R"( |
|
|
|
e_type = {PyVarObject_HEAD_INIT(NULL, 0)}; |
|
|
|
e_type.tp_name = "megengine.core._imperative_rt.ops.$opClass.$enumClass"; |
|
|
|
e_type.tp_basicsize = sizeof($enumTpl<$opClass::$enumClass>); |
|
|
|
e_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE; |
|
|
|
e_type.tp_doc = "$opClass.$enumClass"; |
|
|
|
e_type.tp_base = &PyBaseObject_Type; |
|
|
|
e_type.tp_repr = $enumTpl<$opClass::$enumClass>::py_repr; |
|
|
|
e_type.tp_richcompare = $enumTpl<$opClass::$enumClass>::tp_richcompare; |
|
|
|
static PyType_Slot slots[] = { |
|
|
|
{Py_tp_repr, (void*)$enumTpl<$opClass::$enumClass>::py_repr}, |
|
|
|
{Py_tp_richcompare, (void*)$enumTpl<$opClass::$enumClass>::tp_richcompare}, |
|
|
|
)", &ctx); |
|
|
|
if (attr->getEnumCombinedFlag()) { |
|
|
|
// only bit combined enum could new instance because bitwise operation, |
|
|
|
// others should always use singleton |
|
|
|
os << tgfmt(R"( |
|
|
|
e_type.tp_new = $enumTpl<$opClass::$enumClass>::py_new_combined_enum; |
|
|
|
auto& number_method = $enumTpl<$opClass::$enumClass>::number_methods; |
|
|
|
number_method.nb_or = $enumTpl<$opClass::$enumClass>::py_or; |
|
|
|
number_method.nb_and = $enumTpl<$opClass::$enumClass>::py_and; |
|
|
|
e_type.tp_as_number = &number_method; |
|
|
|
{Py_tp_new, (void*)$enumTpl<$opClass::$enumClass>::py_new_combined_enum}, |
|
|
|
{Py_nb_or, (void*)$enumTpl<$opClass::$enumClass>::py_or}, |
|
|
|
{Py_nb_and, (void*)$enumTpl<$opClass::$enumClass>::py_and}, |
|
|
|
)", &ctx); |
|
|
|
} |
|
|
|
os << R"( |
|
|
|
{0, NULL} |
|
|
|
};)"; |
|
|
|
|
|
|
|
os << tgfmt(R"( |
|
|
|
static PyType_Spec spec = { |
|
|
|
// name |
|
|
|
"megengine.core._imperative_rt.ops.$opClass.$enumClass", |
|
|
|
// basicsize |
|
|
|
sizeof($enumTpl<$opClass::$enumClass>), |
|
|
|
// itemsize |
|
|
|
0, |
|
|
|
// flags |
|
|
|
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HEAPTYPE, |
|
|
|
// slots |
|
|
|
slots |
|
|
|
};)", &ctx); |
|
|
|
|
|
|
|
os << tgfmt(R"( |
|
|
|
e_type = reinterpret_cast<PyTypeObject*>(PyType_FromSpec(&spec)); |
|
|
|
)", &ctx); |
|
|
|
|
|
|
|
os << " mgb_assert(PyType_Ready(&e_type) >= 0);\n"; |
|
|
|
for (auto&& i : { |
|
|
|
std::pair<std::string, std::string>{"__name__", tgfmt("$enumClass", &ctx)}, |
|
|
|
{"__module__", "megengine.core._imperative_rt.ops"}, |
|
|
|
{"__qualname__", tgfmt("$opClass.$enumClass", &ctx)}}) { |
|
|
|
os << formatv(R"( |
|
|
|
mgb_assert( |
|
|
|
e_type->tp_setattro( |
|
|
|
reinterpret_cast<PyObject*>(e_type), |
|
|
|
py::cast("{0}").release().ptr(), |
|
|
|
py::cast("{1}").release().ptr()) >= 0); |
|
|
|
)", i.first, i.second); |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
auto&& members = attr->getEnumMembers(); |
|
|
|
for (size_t idx = 0; idx < members.size(); ++ idx) { |
|
|
|
os << tgfmt(R"({ |
|
|
|
PyObject* inst = e_type.tp_alloc(&e_type, 0); |
|
|
|
PyObject* inst = e_type->tp_alloc(e_type, 0); |
|
|
|
reinterpret_cast<$enumTpl<$opClass::$enumClass>*>(inst)->value = $opClass::$enumClass::$0; |
|
|
|
mgb_assert(PyDict_SetItemString(e_type.tp_dict, "$0", inst) >= 0); |
|
|
|
mgb_assert(PyDict_SetItemString(e_type->tp_dict, "$0", inst) >= 0); |
|
|
|
$enumTpl<$opClass::$enumClass>::pyobj_insts[$1] = inst; |
|
|
|
})", &ctx, members[idx], idx); |
|
|
|
} |
|
|
|
os << " PyType_Modified(&e_type);\n"; |
|
|
|
} |
|
|
|
|
|
|
|
os << tgfmt(R"( |
|
|
|
Py_INCREF(e_type); |
|
|
|
mgb_assert(PyDict_SetItemString( |
|
|
|
py_type.tp_dict, "$enumClass", reinterpret_cast<PyObject*>(&e_type)) >= 0); |
|
|
|
py_type.tp_dict, "$enumClass", reinterpret_cast<PyObject*>(e_type)) >= 0); |
|
|
|
)", &ctx); |
|
|
|
os << "}\n"; |
|
|
|
return initproc; |
|
|
|