Browse Source

chore(imperative): remove unnecessary function template

GitOrigin-RevId: 8dd2f8c308
tags/v1.3.0
Megvii Engine Team 4 years ago
parent
commit
5a1f913435
2 changed files with 61 additions and 75 deletions
  1. +41
    -38
      imperative/python/src/ops.cpp
  2. +20
    -37
      imperative/tablegen/autogen.cpp

+ 41
- 38
imperative/python/src/ops.cpp View File

@@ -33,6 +33,18 @@ auto normalize_enum(const std::string& in) {
}
} // anonymous namespace

#define CATCH_ALL(RETVAL) \
catch(py::error_already_set& e) { \
e.restore(); \
return RETVAL; \
} catch(py::builtin_exception& e) { \
e.set_error(); \
return RETVAL; \
} catch(std::exception& e) { \
PyErr_SetString(PyExc_RuntimeError, e.what()); \
return RETVAL; \
} \

namespace {
#define PyOp(name) Py##name
#define PyOpType(name) PyOp(name)::py_type
@@ -99,14 +111,6 @@ PyObject* py_get_generic_impl(PyObject* obj, void* /* closure */) {
#define py_get_generic(name, attr) \
py_get_generic_impl<PyOp(name), decltype(std::declval<name>().attr), &name::attr>

template<typename T>
PyObject* py_get_scope_impl(PyObject* obj, void* /* closure */) {
// T: PyOpXXX inst(): return XXX in opdef.h.inl
auto& op = reinterpret_cast<T*>(obj)->inst();
return pyobj_convert_generic<std::string>::to(op.scope());
}
#define py_get_scope(class) py_get_scope_impl<PyOp(class)>

template<typename T, typename U, U T::Ty::*attr>
int py_set_generic_impl(PyObject* obj, PyObject* value, void* /* closure */) {
if (value == NULL) {
@@ -116,51 +120,46 @@ int py_set_generic_impl(PyObject* obj, PyObject* value, void* /* closure */) {
auto& op = reinterpret_cast<T*>(obj)->inst();
try {
op.*attr = pyobj_convert_generic<U>::from(value);
return 0;
} catch(py::error_already_set& e) {
e.restore();
} catch(py::builtin_exception& e) {
e.set_error();
} catch(...) {
PyErr_SetString(PyExc_RuntimeError, "Unknown Error");
}
return -1;
} CATCH_ALL(-1)
return 0;
}
#define py_set_generic(name, attr) \
py_set_generic_impl<PyOp(name), decltype(std::declval<name>().attr), &name::attr>

template<typename T>
int py_set_scope_impl(PyObject* obj, PyObject* value, void* /* closure */) {
if (value == NULL) {
PyErr_SetString(PyExc_TypeError, "Cannot delete the attribute");
return -1;
}
auto& op = reinterpret_cast<T*>(obj)->inst();
try {
op.set_scope(pyobj_convert_generic<std::string>::from(value));
return 0;
} catch(py::error_already_set& e) {
e.restore();
} catch(py::builtin_exception& e) {
e.set_error();
} catch(...) {
PyErr_SetString(PyExc_RuntimeError, "Unknown Error");
}
return -1;
}
#define py_set_scope(class) py_set_scope_impl<PyOp(class)>

struct PyOpDef {
PyObject_HEAD
std::shared_ptr<OpDef> op;
static PyTypeObject py_type;
static std::unordered_map<mgb::Typeinfo*, PyTypeObject*> ctype2pytype;
static PyGetSetDef py_getsetters[];
static Py_hash_t tp_hash(PyObject *obj);
static PyObject* tp_richcompare(PyObject *self, PyObject *other, int op);
};
PyTypeObject PyOpType(OpDef);
std::unordered_map<mgb::Typeinfo*, PyTypeObject*> PyOp(OpDef)::ctype2pytype;

PyObject* py_get_scope(PyObject* obj, void* /* closure */) {
return pyobj_convert_generic<std::string>::to(
reinterpret_cast<PyOp(OpDef)*>(obj)->op->scope());
}

int py_set_scope(PyObject* obj, PyObject* value, void* /* closure */) {
if (value == NULL) {
PyErr_SetString(PyExc_TypeError, "Cannot delete the attribute");
return -1;
}
try {
reinterpret_cast<PyOp(OpDef)*>(obj)->op
->set_scope(pyobj_convert_generic<std::string>::from(value));
} CATCH_ALL(-1)
return 0;
}

PyGetSetDef PyOp(OpDef)::py_getsetters[] = {
{const_cast<char*>("scope"), py_get_scope, py_set_scope, "scope", NULL},
{NULL}
};

Py_hash_t PyOp(OpDef)::tp_hash(PyObject *obj) {
return static_cast<Py_hash_t>(
reinterpret_cast<PyOp(OpDef)*>(obj)->op->hash());
@@ -225,6 +224,7 @@ struct pyobj_convert_generic<T,
};

void _init_py_op_def(py::module m) {
using py_op = PyOp(OpDef);
auto& py_type = PyOpType(OpDef);
py_type = {PyVarObject_HEAD_INIT(NULL, 0)};
py_type.tp_name = "megengine.core._imperative_rt.OpDef";
@@ -234,6 +234,7 @@ void _init_py_op_def(py::module m) {
py_type.tp_base = &PyBaseObject_Type;
py_type.tp_hash = PyOp(OpDef)::tp_hash;
py_type.tp_richcompare = PyOp(OpDef)::tp_richcompare;
py_type.tp_getset = py_op::py_getsetters;
mgb_assert(PyType_Ready(&py_type) >= 0);
m.add_object("OpDef", reinterpret_cast<PyObject*>(&py_type));
}
@@ -309,6 +310,8 @@ void _init_py_op_base(py::module m) {
// auto generated opdefs
#include "opdef.cpy.inl"

#undef CATCH_ALL

} // anonymous namespace

namespace PYBIND11_NAMESPACE {


+ 20
- 37
imperative/tablegen/autogen.cpp View File

@@ -485,52 +485,44 @@ EnumWrapper<{0}::{1}>::type2str = {{
className, i.name));
}

getsetters.push_back(formatv(
"{{\"scope\", py_get_scope({0}), py_set_scope({0}), \"scope\", NULL},",
className));

// generate tp_init
std::string initBody;
if (!op.getMgbAttributes().empty()) {
initBody += "static const char* kwlist[] = {";

std::vector<llvm::StringRef> attr_name_list;
llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) {
initBody += formatv("\"{0}\", ", attr.name);
attr_name_list.push_back(attr.name);
});
attr_name_list.push_back("scope");

llvm::for_each(attr_name_list, [&](auto&& attr) {
initBody += formatv("\"{0}\", ", attr);
});
initBody += "\"scope\", ";
initBody += "NULL};\n";
initBody += " PyObject ";
std::vector<std::string> attrs;
llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) {
attrs.push_back(formatv("*{0} = NULL", attr.name));
std::vector<std::string> attr_init;
llvm::for_each(attr_name_list, [&](auto&& attr) {
attr_init.push_back(formatv("*{0} = NULL", attr));
});
initBody += llvm::join(attrs, ", ") + ";\n";
initBody += " PyObject *scope = NULL;\n";
initBody += llvm::join(attr_init, ", ") + ";\n";
initBody += " if (!PyArg_ParseTupleAndKeywords(args, kwds, \"|";
// an extra slot created for name
initBody += std::string(op.getMgbAttributes().size() + 1, 'O');
initBody += std::string(attr_name_list.size(), 'O');
initBody += "\", const_cast<char**>(kwlist)";
llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) {
initBody += formatv(", &{0}", attr.name);
llvm::for_each(attr_name_list, [&](auto&& attr) {
initBody += formatv(", &{0}", attr);
});
initBody += ", &scope";
initBody += "))\n";
initBody += " return -1;\n";

llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) {
initBody += formatv(R"(
if ({1}) {{
try {{
reinterpret_cast<PyOp({0})*>(self)->inst().{1} =
pyobj_convert_generic<decltype({0}::{1})>::from({1});
} catch(py::error_already_set& e) {{
e.restore();
return -1;
} catch(py::builtin_exception& e) {{
e.set_error();
return -1;
} catch(...) {{
PyErr_SetString(PyExc_RuntimeError, "Unknown Error");
return -1;
}
} CATCH_ALL(-1)
}
)", className, attr.name);
});
@@ -538,18 +530,9 @@ EnumWrapper<{0}::{1}>::type2str = {{
initBody += formatv(R"(
if (scope) {{
try {{
reinterpret_cast<PyOp({0})*>(self)->inst().set_scope(
pyobj_convert_generic<std::string>::from(scope));
} catch(py::error_already_set& e) {{
e.restore();
return -1;
} catch(py::builtin_exception& e) {{
e.set_error();
return -1;
} catch(...) {{
PyErr_SetString(PyExc_RuntimeError, "Unknown Error");
return -1;
}
reinterpret_cast<PyOp(OpDef)*>(self)->op
->set_scope(pyobj_convert_generic<std::string>::from(scope));
} CATCH_ALL(-1)
}
)", className);



Loading…
Cancel
Save