fix(imperative): fix refcount management on cpython opdef
refactor(mge/imperative): fix compilation for python3.6
GitOrigin-RevId: 332a516895
release-1.2
@@ -48,7 +48,7 @@ def _(op: OpDef, inputs, outputs, input_requires_grad): | |||
isinstance(op, Elemwise) and op.mode == Elemwise.Mode.ADD | |||
): | |||
grad_fn = elemwise_add_grad_fn | |||
elif isinstance(op, Reduce) and op.mode.name == "SUM": | |||
elif isinstance(op, Reduce) and op.mode == Reduce.Mode.SUM: | |||
grad_fn = reduce_sum_grad_fn | |||
else: | |||
grad_fn = default_grad_fn | |||
@@ -447,8 +447,8 @@ def _(op: OpDef, *args: VarNode): | |||
def _(op: BackwardGraph, *args: VarNode): | |||
assert args | |||
graph = args[0].graph | |||
return op.interpret( | |||
lambda op, args: apply(op, *args), graph._make_const_for_backward, args | |||
return BackwardGraph.interpret( | |||
op, lambda op, args: apply(op, *args), graph._make_const_for_backward, args | |||
) | |||
@@ -13,6 +13,7 @@ | |||
#include "megbrain/graph.h" | |||
#include "megbrain/utils/persistent_cache.h" | |||
#include "megbrain/imperative/op_def.h" | |||
#include <Python.h> | |||
#include <string> | |||
@@ -376,6 +377,32 @@ namespace detail { | |||
} | |||
}; | |||
template<> struct type_caster<mgb::imperative::OpDef> { | |||
protected: | |||
std::shared_ptr<mgb::imperative::OpDef> value; | |||
public: | |||
static constexpr auto name = _("OpDef"); | |||
operator mgb::imperative::OpDef&() { return *value; } | |||
operator const mgb::imperative::OpDef&() { return *value; } | |||
operator std::shared_ptr<mgb::imperative::OpDef>&() { return value; } | |||
operator std::shared_ptr<mgb::imperative::OpDef>&&() && { return std::move(value); } | |||
template <typename T> using cast_op_type = T; | |||
bool load(handle src, bool convert); | |||
static handle cast(const mgb::imperative::OpDef& op, return_value_policy /* policy */, handle /* parent */); | |||
static handle cast(std::shared_ptr<mgb::imperative::OpDef> op, return_value_policy policy, handle parent) { | |||
return cast(*op, policy, parent); | |||
} | |||
}; | |||
template <> struct type_caster<std::shared_ptr<mgb::imperative::OpDef>> : | |||
public type_caster<mgb::imperative::OpDef> { | |||
template <typename T> using cast_op_type = pybind11::detail::movable_cast_op_type<T>; | |||
}; | |||
} // detail | |||
} // PYBIND11_NAMESPACE | |||
@@ -106,13 +106,4 @@ void init_imperative_rt(py::module m) { | |||
}); | |||
m.def("make_backward_graph", &make_backward_graph); | |||
py::class_<OpDef, std::shared_ptr<OpDef>>(m, "OpDef") | |||
.def("ctype", [](const OpDef& opdef) { | |||
return opdef.dyn_typeinfo()->name; | |||
}) | |||
.def("__eq__", [](const OpDef& lhs, const OpDef& rhs) { | |||
return lhs.is_same(rhs); | |||
}) | |||
.def("__hash__", &OpDef::hash); | |||
} |
@@ -63,6 +63,7 @@ PYBIND11_MODULE(MODULE_NAME, m) { | |||
from .utils import * | |||
from .imperative import * | |||
from .graph import * | |||
from .ops import OpDef | |||
)", | |||
py::getattr(m, "__dict__")); | |||
@@ -16,7 +16,11 @@ | |||
#include "megbrain/imperative/ops/opr_attr.h" | |||
#include "megbrain/imperative/ops/autogen.h" | |||
#include <Python.h> | |||
#include <unordered_map> | |||
namespace py = pybind11; | |||
using namespace mgb::imperative; | |||
namespace { | |||
auto normalize_enum(const std::string& in) { | |||
@@ -28,20 +32,256 @@ auto normalize_enum(const std::string& in) { | |||
} | |||
} // anonymous namespace | |||
namespace { | |||
#define PyOp(name) Py##name | |||
#define PyOpType(name) PyOp(name)::py_type | |||
#define PyOpDefBegin(name) \ | |||
struct PyOp(name) : PyOpDef { \ | |||
using Ty = name; \ | |||
Ty& inst() { return op->cast_final_safe<Ty>(); } \ | |||
static PyTypeObject py_type; | |||
#define PyOpDefEnd(name) \ | |||
}; \ | |||
PyTypeObject PyOpType(name); | |||
#define RETURN_RICHCOMPARE(val1, val2, op) \ | |||
do { \ | |||
switch (op) { \ | |||
case Py_EQ: if ((val1) == (val2)) Py_RETURN_TRUE; Py_RETURN_FALSE; \ | |||
case Py_NE: if ((val1) != (val2)) Py_RETURN_TRUE; Py_RETURN_FALSE; \ | |||
case Py_LT: if ((val1) < (val2)) Py_RETURN_TRUE; Py_RETURN_FALSE; \ | |||
case Py_GT: if ((val1) > (val2)) Py_RETURN_TRUE; Py_RETURN_FALSE; \ | |||
case Py_LE: if ((val1) <= (val2)) Py_RETURN_TRUE; Py_RETURN_FALSE; \ | |||
case Py_GE: if ((val1) >= (val2)) Py_RETURN_TRUE; Py_RETURN_FALSE; \ | |||
default: \ | |||
Py_FatalError("Unreachable C code path reached"); \ | |||
} \ | |||
} while (0) | |||
template<typename T, typename SFINAE=void> | |||
struct pyobj_convert_generic { | |||
static T from(PyObject* obj) { | |||
// TODO: remove this guard which is used for pybind11 implicit conversion | |||
py::detail::loader_life_support guard{}; | |||
return py::cast<T>(py::handle(obj)); | |||
} | |||
template<typename U, | |||
typename = std::enable_if_t<std::is_same_v<T, std::decay_t<U>>>> | |||
static PyObject* to(U&& t) { | |||
return py::cast(std::forward<U>(t)).release().ptr(); | |||
} | |||
}; | |||
template<typename T> | |||
PyObject* py_new_generic(PyTypeObject* type, PyObject*, PyObject*) { | |||
PyObject* obj = type->tp_alloc(type, 0); | |||
T* self = reinterpret_cast<T*>(obj); | |||
if (self != NULL) { | |||
self->op = T::Ty::make(); | |||
} | |||
return obj; | |||
} | |||
template<typename T> | |||
void py_dealloc_generic(PyObject* obj) { | |||
reinterpret_cast<T*>(obj)->op.reset(); | |||
Py_TYPE(obj)->tp_free(obj); | |||
} | |||
template<typename T, typename U, U T::Ty::*attr> | |||
PyObject* py_get_generic_impl(PyObject* obj, void* /* closure */) { | |||
auto& op = reinterpret_cast<T*>(obj)->inst(); | |||
return pyobj_convert_generic<U>::to(op.*attr); | |||
} | |||
#define py_get_generic(name, attr) \ | |||
py_get_generic_impl<PyOp(name), decltype(std::declval<name>().attr), &name::attr> | |||
template<typename T, typename U, U T::Ty::*attr> | |||
int py_set_generic_impl(PyObject* obj, PyObject* value, void* /* closure */) { | |||
if (value == NULL) { | |||
PyErr_SetString(PyExc_TypeError, "Cannot delete the attribute"); | |||
return -1; | |||
} | |||
auto& op = reinterpret_cast<T*>(obj)->inst(); | |||
try { | |||
op.*attr = pyobj_convert_generic<U>::from(value); | |||
return 0; | |||
} catch(py::error_already_set& e) { | |||
e.restore(); | |||
} catch(py::builtin_exception& e) { | |||
e.set_error(); | |||
} catch(...) { | |||
PyErr_SetString(PyExc_RuntimeError, "Unknown Error"); | |||
} | |||
return -1; | |||
} | |||
#define py_set_generic(name, attr) \ | |||
py_set_generic_impl<PyOp(name), decltype(std::declval<name>().attr), &name::attr> | |||
struct PyOpDef { | |||
PyObject_HEAD | |||
std::shared_ptr<OpDef> op; | |||
static PyTypeObject py_type; | |||
static std::unordered_map<mgb::Typeinfo*, PyTypeObject*> ctype2pytype; | |||
static Py_hash_t tp_hash(PyObject *obj); | |||
static PyObject* tp_richcompare(PyObject *self, PyObject *other, int op); | |||
}; | |||
PyTypeObject PyOpType(OpDef); | |||
std::unordered_map<mgb::Typeinfo*, PyTypeObject*> PyOp(OpDef)::ctype2pytype; | |||
Py_hash_t PyOp(OpDef)::tp_hash(PyObject *obj) { | |||
return static_cast<Py_hash_t>( | |||
reinterpret_cast<PyOp(OpDef)*>(obj)->op->hash()); | |||
} | |||
PyObject* PyOp(OpDef)::tp_richcompare(PyObject *self, PyObject *other, int op) { | |||
bool same = reinterpret_cast<PyOp(OpDef)*>(self)->op->is_same( | |||
*reinterpret_cast<PyOp(OpDef)*>(other)->op); | |||
if (op == Py_EQ || op == Py_NE) { | |||
RETURN_RICHCOMPARE(same, true, op); | |||
} | |||
Py_RETURN_NOTIMPLEMENTED; | |||
} | |||
template<typename T> | |||
struct EnumWrapper { | |||
static_assert(std::is_enum_v<T>); | |||
PyObject_HEAD | |||
T value; | |||
static const char* name; | |||
static PyTypeObject type; | |||
static std::unordered_map<T, std::string> type2str; | |||
static std::unordered_map<std::string, T> str2type; | |||
EnumWrapper() = default; | |||
EnumWrapper(T v): value(v) {} | |||
EnumWrapper(std::string&& str): EnumWrapper(str2type.at(normalize_enum(str))) {} | |||
std::string to_string() const { | |||
return type2str.at(value); | |||
} | |||
static PyObject* py_repr(PyObject* self) { | |||
return pyobj_convert_generic<std::string>::to( | |||
std::string(name) + "." + reinterpret_cast<EnumWrapper*>(self)->to_string()); | |||
} | |||
static PyObject* tp_richcompare(PyObject *self, PyObject *other, int op) { | |||
T lhs = reinterpret_cast<EnumWrapper*>(self)->value, | |||
rhs = reinterpret_cast<EnumWrapper*>(other)->value; | |||
if (op == Py_EQ || op == Py_NE) { | |||
RETURN_RICHCOMPARE(lhs, rhs, op); | |||
} | |||
Py_RETURN_NOTIMPLEMENTED; | |||
} | |||
}; | |||
template<typename T> | |||
struct pyobj_convert_generic<T, | |||
std::enable_if_t<std::is_enum_v<std::decay_t<T>>>> { | |||
using Wrapper = EnumWrapper<T>; | |||
static T from(PyObject* obj) { | |||
if (PyObject_TypeCheck(obj, &Wrapper::type)) { | |||
return reinterpret_cast<Wrapper*>(obj)->value; | |||
} | |||
// try as string | |||
// TODO: type checkcd | |||
return Wrapper(pyobj_convert_generic<std::string>::from(obj)).value; | |||
} | |||
static PyObject* to(T t) { | |||
PyTypeObject* pytype = &Wrapper::type; | |||
PyObject* obj = pytype->tp_alloc(pytype, 0); | |||
reinterpret_cast<Wrapper*>(obj)->value = t; | |||
return obj; | |||
} | |||
}; | |||
void _init_py_op_def(py::module m) { | |||
auto& py_type = PyOpType(OpDef); | |||
py_type = {PyVarObject_HEAD_INIT(NULL, 0)}; | |||
py_type.tp_name = "megengine.core._imperative_rt.OpDef"; | |||
py_type.tp_basicsize = sizeof(PyOp(OpDef)); | |||
py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE; | |||
py_type.tp_doc = "OpDef"; | |||
py_type.tp_base = &PyBaseObject_Type; | |||
py_type.tp_hash = PyOp(OpDef)::tp_hash; | |||
py_type.tp_richcompare = PyOp(OpDef)::tp_richcompare; | |||
mgb_assert(PyType_Ready(&py_type) >= 0); | |||
m.add_object("OpDef", reinterpret_cast<PyObject*>(&py_type)); | |||
} | |||
/*********** begin of hand-write opdefs **************/ | |||
PyOpDefBegin(BackwardGraph) // {{ | |||
// }; | |||
PyOpDefEnd(BackwardGraph) | |||
void _init_py_backward_graph(py::module m) { | |||
using py_op = PyOp(BackwardGraph); | |||
auto& py_type = PyOpType(BackwardGraph); | |||
py_type = {PyVarObject_HEAD_INIT(NULL, 0)}; | |||
py_type.tp_name = "megengine.core._imperative_rt.ops.BackwardGraph"; | |||
py_type.tp_basicsize = sizeof(PyOp(BackwardGraph)); | |||
py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE; | |||
py_type.tp_doc = "BackwardGraph"; | |||
py_type.tp_base = &PyOpType(OpDef); | |||
py_type.tp_dealloc = py_dealloc_generic<py_op>; | |||
py_type.tp_new = py_new_generic<py_op>; | |||
mgb_assert(PyType_Ready(&py_type) >= 0); | |||
// FIXME: rewrite interpret function in cpython instead wrap directly by pybind11::cppfunction | |||
auto interpret = py::cpp_function( | |||
[](OpDef& self, py::object pyf, py::object pyc, | |||
const mgb::SmallVector<py::object>& inputs) { | |||
auto f = [pyf](OpDef& op, const mgb::SmallVector<py::object>& inputs) { | |||
return py::cast<mgb::SmallVector<py::object>>(pyf(op.shared_from_this(), inputs)); | |||
}; | |||
auto c = [pyc](const TensorPtr& tensor) { | |||
return pyc(tensor->dev_tensor()); | |||
}; | |||
return self.cast_final_safe<BackwardGraph>().graph().interpret<py::object>(f, c, inputs); | |||
}); | |||
mgb_assert(PyDict_SetItemString( | |||
py_type.tp_dict, "interpret", interpret.release().ptr()) >= 0); | |||
PyType_Modified(&py_type); | |||
m.add_object("BackwardGraph", reinterpret_cast<PyObject*>(&py_type)); | |||
mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(BackwardGraph::typeinfo(), &py_type).second); | |||
} | |||
/*********** end of hand-write opdefs **************/ | |||
// auto generated opdefs | |||
#include "opdef.cpy.inl" | |||
} // anonymous namespace | |||
namespace PYBIND11_NAMESPACE { | |||
namespace detail { | |||
bool type_caster<OpDef>::load(handle src, bool convert) { | |||
PyObject* obj = src.ptr(); | |||
if (!PyObject_TypeCheck(obj, &PyOpType(OpDef))) { | |||
return false; | |||
} | |||
value = reinterpret_cast<PyOp(OpDef)*>(obj)->op; | |||
return true; | |||
} | |||
handle type_caster<OpDef>::cast(const OpDef& op, return_value_policy, handle) { | |||
PyTypeObject* pytype; | |||
auto& c2p = PyOp(OpDef)::ctype2pytype; | |||
auto&& iter = c2p.find(op.dyn_typeinfo()); | |||
if (iter != c2p.end()) { // FIXME: should always meet this condition | |||
pytype = iter->second; | |||
} else { // which means unregistered op type, jsut make it as an opaque op type | |||
// currently, only OprAttr goes into this branch | |||
pytype = &PyOpType(OpDef); | |||
} | |||
PyObject* obj = pytype->tp_alloc(pytype, 0); | |||
mgb_assert(PyObject_TypeCheck(obj, &PyOpType(OpDef))); | |||
reinterpret_cast<PyOp(OpDef)*>(obj)->op = const_cast<OpDef&>(op).shared_from_this(); | |||
return py::handle(obj); | |||
} | |||
} // detail | |||
} // PYBIND11_NAMESPACE | |||
void init_ops(py::module m) { | |||
using namespace mgb::imperative; | |||
py::class_<BackwardGraph, std::shared_ptr<BackwardGraph>, OpDef>(m, "BackwardGraph") | |||
.def("interpret", [](BackwardGraph& self, py::object pyf, py::object pyc, | |||
const mgb::SmallVector<py::object>& inputs) { | |||
auto f = [pyf](OpDef& op, const mgb::SmallVector<py::object>& inputs) { | |||
return py::cast<mgb::SmallVector<py::object>>(pyf(op.shared_from_this(), inputs)); | |||
}; | |||
auto c = [pyc](const TensorPtr& tensor) { | |||
return pyc(tensor->dev_tensor()); | |||
}; | |||
return self.graph().interpret<py::object>(f, c, inputs); | |||
}); | |||
#include "opdef.py.inl" | |||
_init_py_op_def(m); | |||
_init_py_backward_graph(m); | |||
INIT_ALL_OP(m) | |||
} |
@@ -76,7 +76,7 @@ cg::OperatorNodeBase* apply_on_var_node( | |||
SmallVector<TensorPtr> apply_on_physical_tensor( | |||
const OpDef& def, | |||
const SmallVector<TensorPtr>& inputs) { | |||
auto opr = def.cast_final_safe<CondTake>(); | |||
auto&& opr = def.cast_final_safe<CondTake>(); | |||
mgb_assert(opr.same_type<CondTake>()); | |||
mgb_assert(inputs.size() == 2, "CondTake take 2 inputs, got %lu", | |||
inputs.size()); | |||
@@ -111,7 +111,7 @@ cg::OperatorNodeBase* param_pack_split_apply_on_var_node( | |||
SmallVector<TensorPtr> param_pack_split_apply_on_physical_tensor( | |||
const OpDef& def, | |||
const SmallVector<TensorPtr>& inputs) { | |||
auto param = def.cast_final_safe<ParamPackSplit>(); | |||
auto&& param = def.cast_final_safe<ParamPackSplit>(); | |||
mgb_assert(inputs.size() == 1, "ParamPackSplit take 1 input, got %lu", inputs.size()); | |||
auto&& inp = inputs[0]; | |||
auto&& shp = inp->layout(); | |||
@@ -27,6 +27,7 @@ struct BackwardGraphResult { | |||
}; | |||
class OpDef : public Hashable, | |||
public NonCopyableObj, | |||
public std::enable_shared_from_this<OpDef> { | |||
mutable const OpTrait* m_trait = nullptr; | |||
public: | |||
@@ -64,7 +65,7 @@ template<typename T> | |||
class OpDefImplBase : public OpDef { | |||
public: | |||
template<typename ...Args> | |||
static std::shared_ptr<OpDef> make(Args&& ...args) { | |||
static std::shared_ptr<T> make(Args&& ...args) { | |||
return std::make_shared<T>(std::forward<Args>(args)...); | |||
} | |||
}; | |||
@@ -10,5 +10,6 @@ set(LLVM_TARGET_DEFINITIONS ${MGE_IR_DIR}/ops.td) | |||
tablegen(MGB opdef.h.inl ${MGE_IR_INCLUDE_DIRS} "--gen-cpp-header") | |||
tablegen(MGB opdef.cpp.inl ${MGE_IR_INCLUDE_DIRS} "--gen-cpp-body") | |||
tablegen(MGB opdef.py.inl ${MGE_IR_INCLUDE_DIRS} "--gen-python-binding") | |||
add_custom_target(mgb_opdef ALL DEPENDS opdef.h.inl opdef.cpp.inl opdef.py.inl param_defs_tblgen) | |||
tablegen(MGB opdef.cpy.inl ${MGE_IR_INCLUDE_DIRS} "--gen-python-c-extension") | |||
add_custom_target(mgb_opdef ALL DEPENDS opdef.h.inl opdef.cpp.inl opdef.py.inl opdef.cpy.inl param_defs_tblgen) | |||
set(MGB_OPDEF_OUT_DIR ${CMAKE_CURRENT_BINARY_DIR} PARENT_SCOPE) |
@@ -11,7 +11,8 @@ enum ActionType { | |||
None, | |||
CppHeader, | |||
CppBody, | |||
Pybind | |||
Pybind, | |||
CPython | |||
}; | |||
// NOLINTNEXTLINE | |||
@@ -22,7 +23,9 @@ llvm::cl::opt<ActionType> action( | |||
clEnumValN(CppBody, "gen-cpp-body", | |||
"Generate operator cpp body"), | |||
clEnumValN(Pybind, "gen-python-binding", | |||
"Generate pybind11 python bindings"))); | |||
"Generate pybind11 python bindings"), | |||
clEnumValN(CPython, "gen-python-c-extension", | |||
"Generate python c extensions"))); | |||
using MgbAttrWrapper = mlir::tblgen::MgbAttrWrapperBase; | |||
using MgbEnumAttr = mlir::tblgen::MgbEnumAttrMixin; | |||
@@ -196,7 +199,7 @@ static void gen_op_def_c_body_single(raw_ostream &os, MgbOp& op) { | |||
formatMethImpl("hash") | |||
); | |||
os << formatv( | |||
" auto op_ = def_.cast_final_safe<{0}>();\n" | |||
" auto&& op_ = def_.cast_final_safe<{0}>();\n" | |||
" static_cast<void>(op_);\n", | |||
className | |||
); | |||
@@ -210,8 +213,8 @@ static void gen_op_def_c_body_single(raw_ostream &os, MgbOp& op) { | |||
formatMethImpl("is_same_st") | |||
); | |||
os << formatv( | |||
" auto a_ = lhs_.cast_final_safe<{0}>(),\n" | |||
" b_ = rhs_.cast_final_safe<{0}>();\n" | |||
" auto &&a_ = lhs_.cast_final_safe<{0}>(),\n" | |||
" &&b_ = rhs_.cast_final_safe<{0}>();\n" | |||
" static_cast<void>(a_);\n" | |||
" static_cast<void>(b_);\n", | |||
className | |||
@@ -237,15 +240,15 @@ static void gen_op_def_c_body_single(raw_ostream &os, MgbOp& op) { | |||
} | |||
} | |||
struct PybindContext { | |||
std::unordered_map<unsigned int, std::string> enumAlias; | |||
struct EnumContext { | |||
std::unordered_map<unsigned int, std::pair<llvm::StringRef, llvm::StringRef>> enumAlias; | |||
}; | |||
static void gen_op_def_pybind11_single(raw_ostream &os, MgbOp& op, PybindContext& ctx) { | |||
auto class_name = op.getCppClassName(); | |||
static void gen_op_def_pybind11_single(raw_ostream &os, MgbOp& op, EnumContext& ctx) { | |||
auto className = op.getCppClassName(); | |||
os << formatv( | |||
"py::class_<{0}, std::shared_ptr<{0}>, OpDef> {0}Inst(m, \"{0}\");\n\n", | |||
class_name | |||
className | |||
); | |||
for (auto&& i : op.getMgbAttributes()) { | |||
if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) { | |||
@@ -263,17 +266,17 @@ static void gen_op_def_pybind11_single(raw_ostream &os, MgbOp& op, PybindContext | |||
if (iter == enumAlias.end()) { | |||
os << formatv( | |||
"py::enum_<{0}::{1}>({0}Inst, \"{1}\")", | |||
class_name, attr->getEnumName() | |||
className, attr->getEnumName() | |||
); | |||
std::vector<std::string> body; | |||
for (auto&& i: attr->getEnumMembers()) { | |||
os << formatv( | |||
"\n .value(\"{2}\", {0}::{1}::{2})", | |||
class_name, attr->getEnumName(), i | |||
className, attr->getEnumName(), i | |||
); | |||
body.push_back(formatv( | |||
"if (str == \"{2}\") return {0}::{1}::{2};", | |||
class_name, attr->getEnumName(), i | |||
className, attr->getEnumName(), i | |||
)); | |||
} | |||
os << formatv( | |||
@@ -286,21 +289,21 @@ static void gen_op_def_pybind11_single(raw_ostream &os, MgbOp& op, PybindContext | |||
); | |||
os << formatv( | |||
"py::implicitly_convertible<std::string, {0}::{1}>();\n\n", | |||
class_name, attr->getEnumName() | |||
className, attr->getEnumName() | |||
); | |||
enumAlias.emplace(enumID, formatv( | |||
"{0}Inst.attr(\"{1}\")", class_name, attr->getEnumName() | |||
)); | |||
enumAlias.emplace(enumID, | |||
std::make_pair(className, attr->getEnumName())); | |||
} else { | |||
os << formatv( | |||
"{0}Inst.attr(\"{1}\") = {2};\n\n", | |||
class_name, attr->getEnumName(), iter->second | |||
"{0}Inst.attr(\"{1}\") = {2}Inst.attr(\"{3}\");\n\n", | |||
className, attr->getEnumName(), | |||
iter->second.first, iter->second.second | |||
); | |||
} | |||
} | |||
} | |||
// generate op class binding | |||
os << formatv("{0}Inst", class_name); | |||
os << formatv("{0}Inst", className); | |||
bool hasDefaultCtor = op.getMgbAttributes().empty(); | |||
if (!hasDefaultCtor) { | |||
os << "\n .def(py::init<"; | |||
@@ -327,12 +330,184 @@ static void gen_op_def_pybind11_single(raw_ostream &os, MgbOp& op, PybindContext | |||
for (auto &&i : op.getMgbAttributes()) { | |||
os << formatv( | |||
"\n .def_readwrite(\"{0}\", &{1}::{0})", | |||
i.name, class_name | |||
i.name, className | |||
); | |||
} | |||
os << ";\n\n"; | |||
} | |||
static void gen_op_def_python_c_extension_single(raw_ostream &os, MgbOp& op, EnumContext& ctx) { | |||
auto className = op.getCppClassName(); | |||
std::string body; | |||
// generate PyType for enum class member | |||
for (auto&& i : op.getMgbAttributes()) { | |||
if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) { | |||
unsigned int enumID; | |||
if (auto alias = llvm::dyn_cast<MgbAliasAttr>(attr)) { | |||
auto&& aliasBase = alias->getAliasBase(); | |||
enumID = | |||
llvm::cast<MgbEnumAttr>(aliasBase) | |||
.getBaseRecord()->getID(); | |||
} else { | |||
enumID = attr->getBaseRecord()->getID(); | |||
} | |||
auto&& enumAlias = ctx.enumAlias; | |||
auto&& iter = enumAlias.find(enumID); | |||
auto enumName = attr->getEnumName(); | |||
body += "{\n"; | |||
body += formatv( | |||
"auto& e_type = EnumWrapper<{0}::{1}>::type;", className, enumName | |||
); | |||
if (iter == enumAlias.end()) { | |||
os << formatv( | |||
"template<> PyTypeObject EnumWrapper<{0}::{1}>::type={{};\n", | |||
className, enumName); | |||
os << formatv( | |||
"template<> const char* EnumWrapper<{0}::{1}>::name = \"{0}.{1}\";\n", | |||
className, enumName); | |||
std::vector<std::string> pairStr; | |||
for (auto&& i: attr->getEnumMembers()) { | |||
pairStr.push_back(formatv( | |||
"{{normalize_enum(\"{2}\"), {0}::{1}::{2}}", | |||
className, enumName, i)); | |||
} | |||
os << formatv(R"( | |||
template<> std::unordered_map<std::string, {0}::{1}> | |||
EnumWrapper<{0}::{1}>::str2type = {{ | |||
{2} | |||
}; | |||
)", className, enumName, llvm::join(pairStr, ", ")); | |||
pairStr.clear(); | |||
for (auto&& i: attr->getEnumMembers()) { | |||
pairStr.push_back(formatv( | |||
"{{{0}::{1}::{2}, normalize_enum(\"{2}\")}", | |||
className, enumName, i)); | |||
} | |||
os << formatv(R"( | |||
template<> std::unordered_map<{0}::{1}, std::string> | |||
EnumWrapper<{0}::{1}>::type2str = {{ | |||
{2} | |||
}; | |||
)", className, enumName, llvm::join(pairStr, ", ")); | |||
body += formatv(R"( | |||
e_type = {{PyVarObject_HEAD_INIT(NULL, 0)}; | |||
e_type.tp_name = "megengine.core._imperative_rt.ops.{0}.{1}"; | |||
e_type.tp_basicsize = sizeof(EnumWrapper<{0}::{1}>); | |||
e_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE; | |||
e_type.tp_doc = "{0}.{1}"; | |||
e_type.tp_base = &PyBaseObject_Type; | |||
e_type.tp_repr = EnumWrapper<{0}::{1}>::py_repr; | |||
e_type.tp_richcompare = EnumWrapper<{0}::{1}>::tp_richcompare; | |||
mgb_assert(PyType_Ready(&e_type) >= 0); | |||
)", className, enumName); | |||
for (auto&& i: attr->getEnumMembers()) { | |||
body += formatv(R"({{ | |||
PyObject* inst = e_type.tp_alloc(&e_type, 0); | |||
reinterpret_cast<EnumWrapper<{0}::{1}>*>(inst)->value = {0}::{1}::{2}; | |||
mgb_assert(PyDict_SetItemString(e_type.tp_dict, "{2}", inst) >= 0); | |||
})", className, enumName, i); | |||
} | |||
enumAlias.emplace(enumID, std::make_pair(className, enumName)); | |||
} | |||
body += formatv(R"( | |||
PyType_Modified(&e_type); | |||
mgb_assert(PyDict_SetItemString( | |||
py_type.tp_dict, "{0}", reinterpret_cast<PyObject*>(&e_type)) >= 0); | |||
)", enumName); | |||
body += "}\n"; | |||
} | |||
} | |||
// generate getsetters | |||
std::vector<std::string> getsetters; | |||
for (auto &&i : op.getMgbAttributes()) { | |||
getsetters.push_back(formatv( | |||
"{{\"{1}\", py_get_generic({0}, {1}), py_set_generic({0}, {1}), \"{1}\", NULL},", | |||
className, i.name)); | |||
} | |||
// generate tp_init | |||
std::string initBody; | |||
if (!op.getMgbAttributes().empty()) { | |||
initBody += "static const char* kwlist[] = {"; | |||
llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) { | |||
initBody += formatv("\"{0}\", ", attr.name); | |||
}); | |||
initBody += "NULL};\n"; | |||
initBody += " PyObject "; | |||
std::vector<std::string> attrs; | |||
llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) { | |||
attrs.push_back(formatv("*{0} = NULL", attr.name)); | |||
}); | |||
initBody += llvm::join(attrs, ", ") + ";\n"; | |||
initBody += " if (!PyArg_ParseTupleAndKeywords(args, kwds, \"|"; | |||
initBody += std::string(op.getMgbAttributes().size(), 'O'); | |||
initBody += "\", const_cast<char**>(kwlist)"; | |||
llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) { | |||
initBody += formatv(" ,&{0}", attr.name); | |||
}); | |||
initBody += "))\n"; | |||
initBody += " return -1;\n"; | |||
llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) { | |||
initBody += formatv(R"( | |||
if ({1}) {{ | |||
try {{ | |||
reinterpret_cast<PyOp({0})*>(self)->inst().{1} = | |||
pyobj_convert_generic<decltype({0}::{1})>::from({1}); | |||
} catch(py::error_already_set& e) {{ | |||
e.restore(); | |||
return -1; | |||
} catch(py::builtin_exception& e) {{ | |||
e.set_error(); | |||
return -1; | |||
} catch(...) {{ | |||
PyErr_SetString(PyExc_RuntimeError, "Unknown Error"); | |||
return -1; | |||
} | |||
} | |||
)", className, attr.name); | |||
}); | |||
} | |||
initBody += "\n return 0;"; | |||
os << formatv(R"( | |||
PyOpDefBegin({0}) // {{ | |||
static PyGetSetDef py_getsetters[]; | |||
static int py_init(PyObject *self, PyObject *args, PyObject *kwds); | |||
// }; | |||
PyOpDefEnd({0}) | |||
PyGetSetDef PyOp({0})::py_getsetters[] = {{ | |||
{1} | |||
{{NULL} /* Sentinel */ | |||
}; | |||
int PyOp({0})::py_init(PyObject *self, PyObject *args, PyObject *kwds) {{ | |||
{2} | |||
} | |||
void _init_py_{0}(py::module m) {{ | |||
using py_op = PyOp({0}); | |||
auto& py_type = PyOpType({0}); | |||
py_type = {{PyVarObject_HEAD_INIT(NULL, 0)}; | |||
py_type.tp_name = "megengine.core._imperative_rt.ops.{0}"; | |||
py_type.tp_basicsize = sizeof(PyOp({0})); | |||
py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE; | |||
py_type.tp_doc = "{0}"; | |||
py_type.tp_base = &PyOpType(OpDef); | |||
py_type.tp_dealloc = py_dealloc_generic<py_op>; | |||
py_type.tp_new = py_new_generic<py_op>; | |||
py_type.tp_init = py_op::py_init; | |||
py_type.tp_getset = py_op::py_getsetters; | |||
mgb_assert(PyType_Ready(&py_type) >= 0); | |||
{3} | |||
PyType_Modified(&py_type); | |||
m.add_object("{0}", reinterpret_cast<PyObject*>(&py_type)); | |||
mgb_assert(PyOp(OpDef)::ctype2pytype.emplace({0}::typeinfo(), &py_type).second); | |||
} | |||
)", | |||
op.getCppClassName(), llvm::join(getsetters, "\n "), initBody, body); | |||
} | |||
static void for_each_operator(raw_ostream &os, RecordKeeper &keeper, | |||
std::function<void(raw_ostream&, MgbOp&)> callback) { | |||
auto op_base_class = keeper.getClass("Op"); | |||
@@ -360,13 +535,26 @@ static bool gen_op_def_c_body(raw_ostream &os, RecordKeeper &keeper) { | |||
} | |||
static bool gen_op_def_pybind11(raw_ostream &os, RecordKeeper &keeper) { | |||
PybindContext ctx; | |||
EnumContext ctx; | |||
using namespace std::placeholders; | |||
for_each_operator(os, keeper, | |||
std::bind(gen_op_def_pybind11_single, _1, _2, std::ref(ctx))); | |||
return false; | |||
} | |||
static bool gen_op_def_python_c_extension(raw_ostream &os, RecordKeeper &keeper) { | |||
EnumContext ctx; | |||
using namespace std::placeholders; | |||
for_each_operator(os, keeper, | |||
std::bind(gen_op_def_python_c_extension_single, _1, _2, std::ref(ctx))); | |||
os << "#define INIT_ALL_OP(m)"; | |||
for_each_operator(os, keeper, [&](raw_ostream& os, MgbOp& op) { | |||
os << formatv(" \\\n _init_py_{0}(m);", op.getCppClassName()); | |||
}); | |||
os << "\n"; | |||
return false; | |||
} | |||
int main(int argc, char **argv) { | |||
llvm::InitLLVM y(argc, argv); | |||
llvm::cl::ParseCommandLineOptions(argc, argv); | |||
@@ -379,5 +567,8 @@ int main(int argc, char **argv) { | |||
if (action == ActionType::Pybind) { | |||
return TableGenMain(argv[0], &gen_op_def_pybind11); | |||
} | |||
if (action == ActionType::CPython) { | |||
return TableGenMain(argv[0], &gen_op_def_python_c_extension); | |||
} | |||
return -1; | |||
} | |||
} |