Browse Source

refactor(imperative): alloc enum type class on heap

GitOrigin-RevId: d2b2acea22
release-1.4
Megvii Engine Team 4 years ago
parent
commit
331995e7f6
4 changed files with 62 additions and 32 deletions
  1. +4
    -6
      imperative/python/src/ops.cpp
  2. +13
    -0
      imperative/python/test/unit/core/test_imperative_rt.py
  3. +44
    -26
      imperative/tablegen/targets/python_c_extension.cpp
  4. +1
    -0
      imperative/test/CMakeLists.txt

+ 4
- 6
imperative/python/src/ops.cpp View File

@@ -170,7 +170,7 @@ struct EnumTrait;
PyObject_HEAD \
T value; \
constexpr static const char *name = EnumTrait<T>::name; \
static PyTypeObject type; \
static PyTypeObject* type; \
static const char* members[]; \
static std::unordered_map<std::string, T> mem2value; \
static PyObject* pyobj_insts[];
@@ -196,7 +196,7 @@ struct EnumWrapper {
}
static bool load(py::handle src, T& value) {
PyObject* obj = src.ptr();
if (PyObject_TypeCheck(obj, &type)) {
if (PyObject_TypeCheck(obj, type)) {
value = reinterpret_cast<EnumWrapper*>(obj)->value;
return true;
}
@@ -224,7 +224,6 @@ struct EnumWrapper {
template<typename T>
struct BitCombinedEnumWrapper {
PyEnumHead
static PyNumberMethods number_methods;
std::string to_string() const {
uint32_t value_int = static_cast<uint32_t>(value);
if (value_int == 0) {
@@ -302,7 +301,7 @@ struct BitCombinedEnumWrapper {
}
static bool load(py::handle src, T& value) {
PyObject* obj = src.ptr();
if (PyObject_TypeCheck(obj, &type)) {
if (PyObject_TypeCheck(obj, type)) {
value = reinterpret_cast<BitCombinedEnumWrapper*>(obj)->value;
return true;
}
@@ -330,8 +329,7 @@ struct BitCombinedEnumWrapper {
auto v = static_cast<std::underlying_type_t<T>>(value);
mgb_assert(v <= EnumTrait<T>::max);
if ((!v) || (v & (v - 1))) {
PyTypeObject* pytype = &type;
PyObject* obj = pytype->tp_alloc(pytype, 0);
PyObject* obj = type->tp_alloc(type, 0);
reinterpret_cast<BitCombinedEnumWrapper*>(obj)->value = value;
return obj;
} else {


+ 13
- 0
imperative/python/test/unit/core/test_imperative_rt.py View File

@@ -69,3 +69,16 @@ def test_raw_tensor():
np.testing.assert_allclose(x * x, yy.numpy())
(yy,) = apply(Elemwise(Elemwise.Mode.MUL), xx, xx)
np.testing.assert_allclose(x * x, yy.numpy())


def test_opdef_path():
from megengine.core.ops.builtin import Elemwise

assert Elemwise.__module__ == "megengine.core._imperative_rt.ops"
assert Elemwise.__name__ == "Elemwise"
assert Elemwise.__qualname__ == "Elemwise"

Mode = Elemwise.Mode
assert Mode.__module__ == "megengine.core._imperative_rt.ops"
assert Mode.__name__ == "Mode"
assert Mode.__qualname__ == "Elemwise.Mode"

+ 44
- 26
imperative/tablegen/targets/python_c_extension.cpp View File

@@ -97,7 +97,7 @@ void EnumAttrEmitter::emit_tpl_spl() {
if (!firstOccur) return;

os << tgfmt(
"template<> PyTypeObject $enumTpl<$opClass::$enumClass>::type = {};\n",
"template<> PyTypeObject* $enumTpl<$opClass::$enumClass>::type = nullptr;\n",
&ctx);

auto quote = [&](auto&& i) -> std::string {
@@ -120,13 +120,6 @@ $enumTpl<$opClass::$enumClass>::mem2value = {$0};
"template<> PyObject* "
"$enumTpl<$opClass::$enumClass>::pyobj_insts[$0] = {nullptr};\n",
&ctx, attr->getEnumMembers().size());

if (attr->getEnumCombinedFlag()) {
os << tgfmt(
"template<> PyNumberMethods "
"$enumTpl<$opClass::$enumClass>::number_methods = {};\n",
&ctx);
}
}

Initproc EnumAttrEmitter::emit_initproc() {
@@ -140,45 +133,70 @@ void $0(PyTypeObject& py_type) {

if (firstOccur) {
os << tgfmt(R"(
e_type = {PyVarObject_HEAD_INIT(NULL, 0)};
e_type.tp_name = "megengine.core._imperative_rt.ops.$opClass.$enumClass";
e_type.tp_basicsize = sizeof($enumTpl<$opClass::$enumClass>);
e_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
e_type.tp_doc = "$opClass.$enumClass";
e_type.tp_base = &PyBaseObject_Type;
e_type.tp_repr = $enumTpl<$opClass::$enumClass>::py_repr;
e_type.tp_richcompare = $enumTpl<$opClass::$enumClass>::tp_richcompare;
static PyType_Slot slots[] = {
{Py_tp_repr, (void*)$enumTpl<$opClass::$enumClass>::py_repr},
{Py_tp_richcompare, (void*)$enumTpl<$opClass::$enumClass>::tp_richcompare},
)", &ctx);
if (attr->getEnumCombinedFlag()) {
// only bit combined enum could new instance because bitwise operation,
// others should always use singleton
os << tgfmt(R"(
e_type.tp_new = $enumTpl<$opClass::$enumClass>::py_new_combined_enum;
auto& number_method = $enumTpl<$opClass::$enumClass>::number_methods;
number_method.nb_or = $enumTpl<$opClass::$enumClass>::py_or;
number_method.nb_and = $enumTpl<$opClass::$enumClass>::py_and;
e_type.tp_as_number = &number_method;
{Py_tp_new, (void*)$enumTpl<$opClass::$enumClass>::py_new_combined_enum},
{Py_nb_or, (void*)$enumTpl<$opClass::$enumClass>::py_or},
{Py_nb_and, (void*)$enumTpl<$opClass::$enumClass>::py_and},
)", &ctx);
}
os << R"(
{0, NULL}
};)";

os << tgfmt(R"(
static PyType_Spec spec = {
// name
"megengine.core._imperative_rt.ops.$opClass.$enumClass",
// basicsize
sizeof($enumTpl<$opClass::$enumClass>),
// itemsize
0,
// flags
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HEAPTYPE,
// slots
slots
};)", &ctx);

os << tgfmt(R"(
e_type = reinterpret_cast<PyTypeObject*>(PyType_FromSpec(&spec));
)", &ctx);

os << " mgb_assert(PyType_Ready(&e_type) >= 0);\n";
for (auto&& i : {
std::pair<std::string, std::string>{"__name__", tgfmt("$enumClass", &ctx)},
{"__module__", "megengine.core._imperative_rt.ops"},
{"__qualname__", tgfmt("$opClass.$enumClass", &ctx)}}) {
os << formatv(R"(
mgb_assert(
e_type->tp_setattro(
reinterpret_cast<PyObject*>(e_type),
py::cast("{0}").release().ptr(),
py::cast("{1}").release().ptr()) >= 0);
)", i.first, i.second);
}


auto&& members = attr->getEnumMembers();
for (size_t idx = 0; idx < members.size(); ++ idx) {
os << tgfmt(R"({
PyObject* inst = e_type.tp_alloc(&e_type, 0);
PyObject* inst = e_type->tp_alloc(e_type, 0);
reinterpret_cast<$enumTpl<$opClass::$enumClass>*>(inst)->value = $opClass::$enumClass::$0;
mgb_assert(PyDict_SetItemString(e_type.tp_dict, "$0", inst) >= 0);
mgb_assert(PyDict_SetItemString(e_type->tp_dict, "$0", inst) >= 0);
$enumTpl<$opClass::$enumClass>::pyobj_insts[$1] = inst;
})", &ctx, members[idx], idx);
}
os << " PyType_Modified(&e_type);\n";
}

os << tgfmt(R"(
Py_INCREF(e_type);
mgb_assert(PyDict_SetItemString(
py_type.tp_dict, "$enumClass", reinterpret_cast<PyObject*>(&e_type)) >= 0);
py_type.tp_dict, "$enumClass", reinterpret_cast<PyObject*>(e_type)) >= 0);
)", &ctx);
os << "}\n";
return initproc;


+ 1
- 0
imperative/test/CMakeLists.txt View File

@@ -11,6 +11,7 @@ endif()

# TODO: turn python binding into a static/object library
add_executable(imperative_test ${SOURCES} ${SRCS})
add_dependencies(imperative_test mgb_opdef)
target_include_directories(imperative_test PRIVATE ${MGB_TEST_DIR}/include ../src/include ${MGB_OPDEF_OUT_DIR})

# Python binding


Loading…
Cancel
Save