fix(imperative): fix refcount management on cpython opdef
refactor(mge/imperative): fix compilation for python3.6
GitOrigin-RevId: 332a516895
release-1.2
@@ -48,7 +48,7 @@ def _(op: OpDef, inputs, outputs, input_requires_grad): | |||||
isinstance(op, Elemwise) and op.mode == Elemwise.Mode.ADD | isinstance(op, Elemwise) and op.mode == Elemwise.Mode.ADD | ||||
): | ): | ||||
grad_fn = elemwise_add_grad_fn | grad_fn = elemwise_add_grad_fn | ||||
elif isinstance(op, Reduce) and op.mode.name == "SUM": | |||||
elif isinstance(op, Reduce) and op.mode == Reduce.Mode.SUM: | |||||
grad_fn = reduce_sum_grad_fn | grad_fn = reduce_sum_grad_fn | ||||
else: | else: | ||||
grad_fn = default_grad_fn | grad_fn = default_grad_fn | ||||
@@ -447,8 +447,8 @@ def _(op: OpDef, *args: VarNode): | |||||
def _(op: BackwardGraph, *args: VarNode): | def _(op: BackwardGraph, *args: VarNode): | ||||
assert args | assert args | ||||
graph = args[0].graph | graph = args[0].graph | ||||
return op.interpret( | |||||
lambda op, args: apply(op, *args), graph._make_const_for_backward, args | |||||
return BackwardGraph.interpret( | |||||
op, lambda op, args: apply(op, *args), graph._make_const_for_backward, args | |||||
) | ) | ||||
@@ -13,6 +13,7 @@ | |||||
#include "megbrain/graph.h" | #include "megbrain/graph.h" | ||||
#include "megbrain/utils/persistent_cache.h" | #include "megbrain/utils/persistent_cache.h" | ||||
#include "megbrain/imperative/op_def.h" | |||||
#include <Python.h> | #include <Python.h> | ||||
#include <string> | #include <string> | ||||
@@ -376,6 +377,32 @@ namespace detail { | |||||
} | } | ||||
}; | }; | ||||
template<> struct type_caster<mgb::imperative::OpDef> { | |||||
protected: | |||||
std::shared_ptr<mgb::imperative::OpDef> value; | |||||
public: | |||||
static constexpr auto name = _("OpDef"); | |||||
operator mgb::imperative::OpDef&() { return *value; } | |||||
operator const mgb::imperative::OpDef&() { return *value; } | |||||
operator std::shared_ptr<mgb::imperative::OpDef>&() { return value; } | |||||
operator std::shared_ptr<mgb::imperative::OpDef>&&() && { return std::move(value); } | |||||
template <typename T> using cast_op_type = T; | |||||
bool load(handle src, bool convert); | |||||
static handle cast(const mgb::imperative::OpDef& op, return_value_policy /* policy */, handle /* parent */); | |||||
static handle cast(std::shared_ptr<mgb::imperative::OpDef> op, return_value_policy policy, handle parent) { | |||||
return cast(*op, policy, parent); | |||||
} | |||||
}; | |||||
template <> struct type_caster<std::shared_ptr<mgb::imperative::OpDef>> : | |||||
public type_caster<mgb::imperative::OpDef> { | |||||
template <typename T> using cast_op_type = pybind11::detail::movable_cast_op_type<T>; | |||||
}; | |||||
} // detail | } // detail | ||||
} // PYBIND11_NAMESPACE | } // PYBIND11_NAMESPACE | ||||
@@ -106,13 +106,4 @@ void init_imperative_rt(py::module m) { | |||||
}); | }); | ||||
m.def("make_backward_graph", &make_backward_graph); | m.def("make_backward_graph", &make_backward_graph); | ||||
py::class_<OpDef, std::shared_ptr<OpDef>>(m, "OpDef") | |||||
.def("ctype", [](const OpDef& opdef) { | |||||
return opdef.dyn_typeinfo()->name; | |||||
}) | |||||
.def("__eq__", [](const OpDef& lhs, const OpDef& rhs) { | |||||
return lhs.is_same(rhs); | |||||
}) | |||||
.def("__hash__", &OpDef::hash); | |||||
} | } |
@@ -63,6 +63,7 @@ PYBIND11_MODULE(MODULE_NAME, m) { | |||||
from .utils import * | from .utils import * | ||||
from .imperative import * | from .imperative import * | ||||
from .graph import * | from .graph import * | ||||
from .ops import OpDef | |||||
)", | )", | ||||
py::getattr(m, "__dict__")); | py::getattr(m, "__dict__")); | ||||
@@ -16,7 +16,11 @@ | |||||
#include "megbrain/imperative/ops/opr_attr.h" | #include "megbrain/imperative/ops/opr_attr.h" | ||||
#include "megbrain/imperative/ops/autogen.h" | #include "megbrain/imperative/ops/autogen.h" | ||||
#include <Python.h> | |||||
#include <unordered_map> | |||||
namespace py = pybind11; | namespace py = pybind11; | ||||
using namespace mgb::imperative; | |||||
namespace { | namespace { | ||||
auto normalize_enum(const std::string& in) { | auto normalize_enum(const std::string& in) { | ||||
@@ -28,20 +32,256 @@ auto normalize_enum(const std::string& in) { | |||||
} | } | ||||
} // anonymous namespace | } // anonymous namespace | ||||
namespace { | |||||
#define PyOp(name) Py##name | |||||
#define PyOpType(name) PyOp(name)::py_type | |||||
#define PyOpDefBegin(name) \ | |||||
struct PyOp(name) : PyOpDef { \ | |||||
using Ty = name; \ | |||||
Ty& inst() { return op->cast_final_safe<Ty>(); } \ | |||||
static PyTypeObject py_type; | |||||
#define PyOpDefEnd(name) \ | |||||
}; \ | |||||
PyTypeObject PyOpType(name); | |||||
#define RETURN_RICHCOMPARE(val1, val2, op) \ | |||||
do { \ | |||||
switch (op) { \ | |||||
case Py_EQ: if ((val1) == (val2)) Py_RETURN_TRUE; Py_RETURN_FALSE; \ | |||||
case Py_NE: if ((val1) != (val2)) Py_RETURN_TRUE; Py_RETURN_FALSE; \ | |||||
case Py_LT: if ((val1) < (val2)) Py_RETURN_TRUE; Py_RETURN_FALSE; \ | |||||
case Py_GT: if ((val1) > (val2)) Py_RETURN_TRUE; Py_RETURN_FALSE; \ | |||||
case Py_LE: if ((val1) <= (val2)) Py_RETURN_TRUE; Py_RETURN_FALSE; \ | |||||
case Py_GE: if ((val1) >= (val2)) Py_RETURN_TRUE; Py_RETURN_FALSE; \ | |||||
default: \ | |||||
Py_FatalError("Unreachable C code path reached"); \ | |||||
} \ | |||||
} while (0) | |||||
template<typename T, typename SFINAE=void> | |||||
struct pyobj_convert_generic { | |||||
static T from(PyObject* obj) { | |||||
// TODO: remove this guard which is used for pybind11 implicit conversion | |||||
py::detail::loader_life_support guard{}; | |||||
return py::cast<T>(py::handle(obj)); | |||||
} | |||||
template<typename U, | |||||
typename = std::enable_if_t<std::is_same_v<T, std::decay_t<U>>>> | |||||
static PyObject* to(U&& t) { | |||||
return py::cast(std::forward<U>(t)).release().ptr(); | |||||
} | |||||
}; | |||||
template<typename T> | |||||
PyObject* py_new_generic(PyTypeObject* type, PyObject*, PyObject*) { | |||||
PyObject* obj = type->tp_alloc(type, 0); | |||||
T* self = reinterpret_cast<T*>(obj); | |||||
if (self != NULL) { | |||||
self->op = T::Ty::make(); | |||||
} | |||||
return obj; | |||||
} | |||||
template<typename T> | |||||
void py_dealloc_generic(PyObject* obj) { | |||||
reinterpret_cast<T*>(obj)->op.reset(); | |||||
Py_TYPE(obj)->tp_free(obj); | |||||
} | |||||
template<typename T, typename U, U T::Ty::*attr> | |||||
PyObject* py_get_generic_impl(PyObject* obj, void* /* closure */) { | |||||
auto& op = reinterpret_cast<T*>(obj)->inst(); | |||||
return pyobj_convert_generic<U>::to(op.*attr); | |||||
} | |||||
#define py_get_generic(name, attr) \ | |||||
py_get_generic_impl<PyOp(name), decltype(std::declval<name>().attr), &name::attr> | |||||
template<typename T, typename U, U T::Ty::*attr> | |||||
int py_set_generic_impl(PyObject* obj, PyObject* value, void* /* closure */) { | |||||
if (value == NULL) { | |||||
PyErr_SetString(PyExc_TypeError, "Cannot delete the attribute"); | |||||
return -1; | |||||
} | |||||
auto& op = reinterpret_cast<T*>(obj)->inst(); | |||||
try { | |||||
op.*attr = pyobj_convert_generic<U>::from(value); | |||||
return 0; | |||||
} catch(py::error_already_set& e) { | |||||
e.restore(); | |||||
} catch(py::builtin_exception& e) { | |||||
e.set_error(); | |||||
} catch(...) { | |||||
PyErr_SetString(PyExc_RuntimeError, "Unknown Error"); | |||||
} | |||||
return -1; | |||||
} | |||||
#define py_set_generic(name, attr) \ | |||||
py_set_generic_impl<PyOp(name), decltype(std::declval<name>().attr), &name::attr> | |||||
struct PyOpDef { | |||||
PyObject_HEAD | |||||
std::shared_ptr<OpDef> op; | |||||
static PyTypeObject py_type; | |||||
static std::unordered_map<mgb::Typeinfo*, PyTypeObject*> ctype2pytype; | |||||
static Py_hash_t tp_hash(PyObject *obj); | |||||
static PyObject* tp_richcompare(PyObject *self, PyObject *other, int op); | |||||
}; | |||||
PyTypeObject PyOpType(OpDef); | |||||
std::unordered_map<mgb::Typeinfo*, PyTypeObject*> PyOp(OpDef)::ctype2pytype; | |||||
Py_hash_t PyOp(OpDef)::tp_hash(PyObject *obj) { | |||||
return static_cast<Py_hash_t>( | |||||
reinterpret_cast<PyOp(OpDef)*>(obj)->op->hash()); | |||||
} | |||||
PyObject* PyOp(OpDef)::tp_richcompare(PyObject *self, PyObject *other, int op) { | |||||
bool same = reinterpret_cast<PyOp(OpDef)*>(self)->op->is_same( | |||||
*reinterpret_cast<PyOp(OpDef)*>(other)->op); | |||||
if (op == Py_EQ || op == Py_NE) { | |||||
RETURN_RICHCOMPARE(same, true, op); | |||||
} | |||||
Py_RETURN_NOTIMPLEMENTED; | |||||
} | |||||
template<typename T> | |||||
struct EnumWrapper { | |||||
static_assert(std::is_enum_v<T>); | |||||
PyObject_HEAD | |||||
T value; | |||||
static const char* name; | |||||
static PyTypeObject type; | |||||
static std::unordered_map<T, std::string> type2str; | |||||
static std::unordered_map<std::string, T> str2type; | |||||
EnumWrapper() = default; | |||||
EnumWrapper(T v): value(v) {} | |||||
EnumWrapper(std::string&& str): EnumWrapper(str2type.at(normalize_enum(str))) {} | |||||
std::string to_string() const { | |||||
return type2str.at(value); | |||||
} | |||||
static PyObject* py_repr(PyObject* self) { | |||||
return pyobj_convert_generic<std::string>::to( | |||||
std::string(name) + "." + reinterpret_cast<EnumWrapper*>(self)->to_string()); | |||||
} | |||||
static PyObject* tp_richcompare(PyObject *self, PyObject *other, int op) { | |||||
T lhs = reinterpret_cast<EnumWrapper*>(self)->value, | |||||
rhs = reinterpret_cast<EnumWrapper*>(other)->value; | |||||
if (op == Py_EQ || op == Py_NE) { | |||||
RETURN_RICHCOMPARE(lhs, rhs, op); | |||||
} | |||||
Py_RETURN_NOTIMPLEMENTED; | |||||
} | |||||
}; | |||||
template<typename T> | |||||
struct pyobj_convert_generic<T, | |||||
std::enable_if_t<std::is_enum_v<std::decay_t<T>>>> { | |||||
using Wrapper = EnumWrapper<T>; | |||||
static T from(PyObject* obj) { | |||||
if (PyObject_TypeCheck(obj, &Wrapper::type)) { | |||||
return reinterpret_cast<Wrapper*>(obj)->value; | |||||
} | |||||
// try as string | |||||
// TODO: type checkcd | |||||
return Wrapper(pyobj_convert_generic<std::string>::from(obj)).value; | |||||
} | |||||
static PyObject* to(T t) { | |||||
PyTypeObject* pytype = &Wrapper::type; | |||||
PyObject* obj = pytype->tp_alloc(pytype, 0); | |||||
reinterpret_cast<Wrapper*>(obj)->value = t; | |||||
return obj; | |||||
} | |||||
}; | |||||
void _init_py_op_def(py::module m) { | |||||
auto& py_type = PyOpType(OpDef); | |||||
py_type = {PyVarObject_HEAD_INIT(NULL, 0)}; | |||||
py_type.tp_name = "megengine.core._imperative_rt.OpDef"; | |||||
py_type.tp_basicsize = sizeof(PyOp(OpDef)); | |||||
py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE; | |||||
py_type.tp_doc = "OpDef"; | |||||
py_type.tp_base = &PyBaseObject_Type; | |||||
py_type.tp_hash = PyOp(OpDef)::tp_hash; | |||||
py_type.tp_richcompare = PyOp(OpDef)::tp_richcompare; | |||||
mgb_assert(PyType_Ready(&py_type) >= 0); | |||||
m.add_object("OpDef", reinterpret_cast<PyObject*>(&py_type)); | |||||
} | |||||
/*********** begin of hand-write opdefs **************/ | |||||
PyOpDefBegin(BackwardGraph) // {{ | |||||
// }; | |||||
PyOpDefEnd(BackwardGraph) | |||||
void _init_py_backward_graph(py::module m) { | |||||
using py_op = PyOp(BackwardGraph); | |||||
auto& py_type = PyOpType(BackwardGraph); | |||||
py_type = {PyVarObject_HEAD_INIT(NULL, 0)}; | |||||
py_type.tp_name = "megengine.core._imperative_rt.ops.BackwardGraph"; | |||||
py_type.tp_basicsize = sizeof(PyOp(BackwardGraph)); | |||||
py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE; | |||||
py_type.tp_doc = "BackwardGraph"; | |||||
py_type.tp_base = &PyOpType(OpDef); | |||||
py_type.tp_dealloc = py_dealloc_generic<py_op>; | |||||
py_type.tp_new = py_new_generic<py_op>; | |||||
mgb_assert(PyType_Ready(&py_type) >= 0); | |||||
// FIXME: rewrite interpret function in cpython instead wrap directly by pybind11::cppfunction | |||||
auto interpret = py::cpp_function( | |||||
[](OpDef& self, py::object pyf, py::object pyc, | |||||
const mgb::SmallVector<py::object>& inputs) { | |||||
auto f = [pyf](OpDef& op, const mgb::SmallVector<py::object>& inputs) { | |||||
return py::cast<mgb::SmallVector<py::object>>(pyf(op.shared_from_this(), inputs)); | |||||
}; | |||||
auto c = [pyc](const TensorPtr& tensor) { | |||||
return pyc(tensor->dev_tensor()); | |||||
}; | |||||
return self.cast_final_safe<BackwardGraph>().graph().interpret<py::object>(f, c, inputs); | |||||
}); | |||||
mgb_assert(PyDict_SetItemString( | |||||
py_type.tp_dict, "interpret", interpret.release().ptr()) >= 0); | |||||
PyType_Modified(&py_type); | |||||
m.add_object("BackwardGraph", reinterpret_cast<PyObject*>(&py_type)); | |||||
mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(BackwardGraph::typeinfo(), &py_type).second); | |||||
} | |||||
/*********** end of hand-write opdefs **************/ | |||||
// auto generated opdefs | |||||
#include "opdef.cpy.inl" | |||||
} // anonymous namespace | |||||
namespace PYBIND11_NAMESPACE { | |||||
namespace detail { | |||||
bool type_caster<OpDef>::load(handle src, bool convert) { | |||||
PyObject* obj = src.ptr(); | |||||
if (!PyObject_TypeCheck(obj, &PyOpType(OpDef))) { | |||||
return false; | |||||
} | |||||
value = reinterpret_cast<PyOp(OpDef)*>(obj)->op; | |||||
return true; | |||||
} | |||||
handle type_caster<OpDef>::cast(const OpDef& op, return_value_policy, handle) { | |||||
PyTypeObject* pytype; | |||||
auto& c2p = PyOp(OpDef)::ctype2pytype; | |||||
auto&& iter = c2p.find(op.dyn_typeinfo()); | |||||
if (iter != c2p.end()) { // FIXME: should always meet this condition | |||||
pytype = iter->second; | |||||
} else { // which means unregistered op type, jsut make it as an opaque op type | |||||
// currently, only OprAttr goes into this branch | |||||
pytype = &PyOpType(OpDef); | |||||
} | |||||
PyObject* obj = pytype->tp_alloc(pytype, 0); | |||||
mgb_assert(PyObject_TypeCheck(obj, &PyOpType(OpDef))); | |||||
reinterpret_cast<PyOp(OpDef)*>(obj)->op = const_cast<OpDef&>(op).shared_from_this(); | |||||
return py::handle(obj); | |||||
} | |||||
} // detail | |||||
} // PYBIND11_NAMESPACE | |||||
void init_ops(py::module m) { | void init_ops(py::module m) { | ||||
using namespace mgb::imperative; | |||||
py::class_<BackwardGraph, std::shared_ptr<BackwardGraph>, OpDef>(m, "BackwardGraph") | |||||
.def("interpret", [](BackwardGraph& self, py::object pyf, py::object pyc, | |||||
const mgb::SmallVector<py::object>& inputs) { | |||||
auto f = [pyf](OpDef& op, const mgb::SmallVector<py::object>& inputs) { | |||||
return py::cast<mgb::SmallVector<py::object>>(pyf(op.shared_from_this(), inputs)); | |||||
}; | |||||
auto c = [pyc](const TensorPtr& tensor) { | |||||
return pyc(tensor->dev_tensor()); | |||||
}; | |||||
return self.graph().interpret<py::object>(f, c, inputs); | |||||
}); | |||||
#include "opdef.py.inl" | |||||
_init_py_op_def(m); | |||||
_init_py_backward_graph(m); | |||||
INIT_ALL_OP(m) | |||||
} | } |
@@ -76,7 +76,7 @@ cg::OperatorNodeBase* apply_on_var_node( | |||||
SmallVector<TensorPtr> apply_on_physical_tensor( | SmallVector<TensorPtr> apply_on_physical_tensor( | ||||
const OpDef& def, | const OpDef& def, | ||||
const SmallVector<TensorPtr>& inputs) { | const SmallVector<TensorPtr>& inputs) { | ||||
auto opr = def.cast_final_safe<CondTake>(); | |||||
auto&& opr = def.cast_final_safe<CondTake>(); | |||||
mgb_assert(opr.same_type<CondTake>()); | mgb_assert(opr.same_type<CondTake>()); | ||||
mgb_assert(inputs.size() == 2, "CondTake take 2 inputs, got %lu", | mgb_assert(inputs.size() == 2, "CondTake take 2 inputs, got %lu", | ||||
inputs.size()); | inputs.size()); | ||||
@@ -111,7 +111,7 @@ cg::OperatorNodeBase* param_pack_split_apply_on_var_node( | |||||
SmallVector<TensorPtr> param_pack_split_apply_on_physical_tensor( | SmallVector<TensorPtr> param_pack_split_apply_on_physical_tensor( | ||||
const OpDef& def, | const OpDef& def, | ||||
const SmallVector<TensorPtr>& inputs) { | const SmallVector<TensorPtr>& inputs) { | ||||
auto param = def.cast_final_safe<ParamPackSplit>(); | |||||
auto&& param = def.cast_final_safe<ParamPackSplit>(); | |||||
mgb_assert(inputs.size() == 1, "ParamPackSplit take 1 input, got %lu", inputs.size()); | mgb_assert(inputs.size() == 1, "ParamPackSplit take 1 input, got %lu", inputs.size()); | ||||
auto&& inp = inputs[0]; | auto&& inp = inputs[0]; | ||||
auto&& shp = inp->layout(); | auto&& shp = inp->layout(); | ||||
@@ -27,6 +27,7 @@ struct BackwardGraphResult { | |||||
}; | }; | ||||
class OpDef : public Hashable, | class OpDef : public Hashable, | ||||
public NonCopyableObj, | |||||
public std::enable_shared_from_this<OpDef> { | public std::enable_shared_from_this<OpDef> { | ||||
mutable const OpTrait* m_trait = nullptr; | mutable const OpTrait* m_trait = nullptr; | ||||
public: | public: | ||||
@@ -64,7 +65,7 @@ template<typename T> | |||||
class OpDefImplBase : public OpDef { | class OpDefImplBase : public OpDef { | ||||
public: | public: | ||||
template<typename ...Args> | template<typename ...Args> | ||||
static std::shared_ptr<OpDef> make(Args&& ...args) { | |||||
static std::shared_ptr<T> make(Args&& ...args) { | |||||
return std::make_shared<T>(std::forward<Args>(args)...); | return std::make_shared<T>(std::forward<Args>(args)...); | ||||
} | } | ||||
}; | }; | ||||
@@ -10,5 +10,6 @@ set(LLVM_TARGET_DEFINITIONS ${MGE_IR_DIR}/ops.td) | |||||
tablegen(MGB opdef.h.inl ${MGE_IR_INCLUDE_DIRS} "--gen-cpp-header") | tablegen(MGB opdef.h.inl ${MGE_IR_INCLUDE_DIRS} "--gen-cpp-header") | ||||
tablegen(MGB opdef.cpp.inl ${MGE_IR_INCLUDE_DIRS} "--gen-cpp-body") | tablegen(MGB opdef.cpp.inl ${MGE_IR_INCLUDE_DIRS} "--gen-cpp-body") | ||||
tablegen(MGB opdef.py.inl ${MGE_IR_INCLUDE_DIRS} "--gen-python-binding") | tablegen(MGB opdef.py.inl ${MGE_IR_INCLUDE_DIRS} "--gen-python-binding") | ||||
add_custom_target(mgb_opdef ALL DEPENDS opdef.h.inl opdef.cpp.inl opdef.py.inl param_defs_tblgen) | |||||
tablegen(MGB opdef.cpy.inl ${MGE_IR_INCLUDE_DIRS} "--gen-python-c-extension") | |||||
add_custom_target(mgb_opdef ALL DEPENDS opdef.h.inl opdef.cpp.inl opdef.py.inl opdef.cpy.inl param_defs_tblgen) | |||||
set(MGB_OPDEF_OUT_DIR ${CMAKE_CURRENT_BINARY_DIR} PARENT_SCOPE) | set(MGB_OPDEF_OUT_DIR ${CMAKE_CURRENT_BINARY_DIR} PARENT_SCOPE) |
@@ -11,7 +11,8 @@ enum ActionType { | |||||
None, | None, | ||||
CppHeader, | CppHeader, | ||||
CppBody, | CppBody, | ||||
Pybind | |||||
Pybind, | |||||
CPython | |||||
}; | }; | ||||
// NOLINTNEXTLINE | // NOLINTNEXTLINE | ||||
@@ -22,7 +23,9 @@ llvm::cl::opt<ActionType> action( | |||||
clEnumValN(CppBody, "gen-cpp-body", | clEnumValN(CppBody, "gen-cpp-body", | ||||
"Generate operator cpp body"), | "Generate operator cpp body"), | ||||
clEnumValN(Pybind, "gen-python-binding", | clEnumValN(Pybind, "gen-python-binding", | ||||
"Generate pybind11 python bindings"))); | |||||
"Generate pybind11 python bindings"), | |||||
clEnumValN(CPython, "gen-python-c-extension", | |||||
"Generate python c extensions"))); | |||||
using MgbAttrWrapper = mlir::tblgen::MgbAttrWrapperBase; | using MgbAttrWrapper = mlir::tblgen::MgbAttrWrapperBase; | ||||
using MgbEnumAttr = mlir::tblgen::MgbEnumAttrMixin; | using MgbEnumAttr = mlir::tblgen::MgbEnumAttrMixin; | ||||
@@ -196,7 +199,7 @@ static void gen_op_def_c_body_single(raw_ostream &os, MgbOp& op) { | |||||
formatMethImpl("hash") | formatMethImpl("hash") | ||||
); | ); | ||||
os << formatv( | os << formatv( | ||||
" auto op_ = def_.cast_final_safe<{0}>();\n" | |||||
" auto&& op_ = def_.cast_final_safe<{0}>();\n" | |||||
" static_cast<void>(op_);\n", | " static_cast<void>(op_);\n", | ||||
className | className | ||||
); | ); | ||||
@@ -210,8 +213,8 @@ static void gen_op_def_c_body_single(raw_ostream &os, MgbOp& op) { | |||||
formatMethImpl("is_same_st") | formatMethImpl("is_same_st") | ||||
); | ); | ||||
os << formatv( | os << formatv( | ||||
" auto a_ = lhs_.cast_final_safe<{0}>(),\n" | |||||
" b_ = rhs_.cast_final_safe<{0}>();\n" | |||||
" auto &&a_ = lhs_.cast_final_safe<{0}>(),\n" | |||||
" &&b_ = rhs_.cast_final_safe<{0}>();\n" | |||||
" static_cast<void>(a_);\n" | " static_cast<void>(a_);\n" | ||||
" static_cast<void>(b_);\n", | " static_cast<void>(b_);\n", | ||||
className | className | ||||
@@ -237,15 +240,15 @@ static void gen_op_def_c_body_single(raw_ostream &os, MgbOp& op) { | |||||
} | } | ||||
} | } | ||||
struct PybindContext { | |||||
std::unordered_map<unsigned int, std::string> enumAlias; | |||||
struct EnumContext { | |||||
std::unordered_map<unsigned int, std::pair<llvm::StringRef, llvm::StringRef>> enumAlias; | |||||
}; | }; | ||||
static void gen_op_def_pybind11_single(raw_ostream &os, MgbOp& op, PybindContext& ctx) { | |||||
auto class_name = op.getCppClassName(); | |||||
static void gen_op_def_pybind11_single(raw_ostream &os, MgbOp& op, EnumContext& ctx) { | |||||
auto className = op.getCppClassName(); | |||||
os << formatv( | os << formatv( | ||||
"py::class_<{0}, std::shared_ptr<{0}>, OpDef> {0}Inst(m, \"{0}\");\n\n", | "py::class_<{0}, std::shared_ptr<{0}>, OpDef> {0}Inst(m, \"{0}\");\n\n", | ||||
class_name | |||||
className | |||||
); | ); | ||||
for (auto&& i : op.getMgbAttributes()) { | for (auto&& i : op.getMgbAttributes()) { | ||||
if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) { | if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) { | ||||
@@ -263,17 +266,17 @@ static void gen_op_def_pybind11_single(raw_ostream &os, MgbOp& op, PybindContext | |||||
if (iter == enumAlias.end()) { | if (iter == enumAlias.end()) { | ||||
os << formatv( | os << formatv( | ||||
"py::enum_<{0}::{1}>({0}Inst, \"{1}\")", | "py::enum_<{0}::{1}>({0}Inst, \"{1}\")", | ||||
class_name, attr->getEnumName() | |||||
className, attr->getEnumName() | |||||
); | ); | ||||
std::vector<std::string> body; | std::vector<std::string> body; | ||||
for (auto&& i: attr->getEnumMembers()) { | for (auto&& i: attr->getEnumMembers()) { | ||||
os << formatv( | os << formatv( | ||||
"\n .value(\"{2}\", {0}::{1}::{2})", | "\n .value(\"{2}\", {0}::{1}::{2})", | ||||
class_name, attr->getEnumName(), i | |||||
className, attr->getEnumName(), i | |||||
); | ); | ||||
body.push_back(formatv( | body.push_back(formatv( | ||||
"if (str == \"{2}\") return {0}::{1}::{2};", | "if (str == \"{2}\") return {0}::{1}::{2};", | ||||
class_name, attr->getEnumName(), i | |||||
className, attr->getEnumName(), i | |||||
)); | )); | ||||
} | } | ||||
os << formatv( | os << formatv( | ||||
@@ -286,21 +289,21 @@ static void gen_op_def_pybind11_single(raw_ostream &os, MgbOp& op, PybindContext | |||||
); | ); | ||||
os << formatv( | os << formatv( | ||||
"py::implicitly_convertible<std::string, {0}::{1}>();\n\n", | "py::implicitly_convertible<std::string, {0}::{1}>();\n\n", | ||||
class_name, attr->getEnumName() | |||||
className, attr->getEnumName() | |||||
); | ); | ||||
enumAlias.emplace(enumID, formatv( | |||||
"{0}Inst.attr(\"{1}\")", class_name, attr->getEnumName() | |||||
)); | |||||
enumAlias.emplace(enumID, | |||||
std::make_pair(className, attr->getEnumName())); | |||||
} else { | } else { | ||||
os << formatv( | os << formatv( | ||||
"{0}Inst.attr(\"{1}\") = {2};\n\n", | |||||
class_name, attr->getEnumName(), iter->second | |||||
"{0}Inst.attr(\"{1}\") = {2}Inst.attr(\"{3}\");\n\n", | |||||
className, attr->getEnumName(), | |||||
iter->second.first, iter->second.second | |||||
); | ); | ||||
} | } | ||||
} | } | ||||
} | } | ||||
// generate op class binding | // generate op class binding | ||||
os << formatv("{0}Inst", class_name); | |||||
os << formatv("{0}Inst", className); | |||||
bool hasDefaultCtor = op.getMgbAttributes().empty(); | bool hasDefaultCtor = op.getMgbAttributes().empty(); | ||||
if (!hasDefaultCtor) { | if (!hasDefaultCtor) { | ||||
os << "\n .def(py::init<"; | os << "\n .def(py::init<"; | ||||
@@ -327,12 +330,184 @@ static void gen_op_def_pybind11_single(raw_ostream &os, MgbOp& op, PybindContext | |||||
for (auto &&i : op.getMgbAttributes()) { | for (auto &&i : op.getMgbAttributes()) { | ||||
os << formatv( | os << formatv( | ||||
"\n .def_readwrite(\"{0}\", &{1}::{0})", | "\n .def_readwrite(\"{0}\", &{1}::{0})", | ||||
i.name, class_name | |||||
i.name, className | |||||
); | ); | ||||
} | } | ||||
os << ";\n\n"; | os << ";\n\n"; | ||||
} | } | ||||
static void gen_op_def_python_c_extension_single(raw_ostream &os, MgbOp& op, EnumContext& ctx) { | |||||
auto className = op.getCppClassName(); | |||||
std::string body; | |||||
// generate PyType for enum class member | |||||
for (auto&& i : op.getMgbAttributes()) { | |||||
if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) { | |||||
unsigned int enumID; | |||||
if (auto alias = llvm::dyn_cast<MgbAliasAttr>(attr)) { | |||||
auto&& aliasBase = alias->getAliasBase(); | |||||
enumID = | |||||
llvm::cast<MgbEnumAttr>(aliasBase) | |||||
.getBaseRecord()->getID(); | |||||
} else { | |||||
enumID = attr->getBaseRecord()->getID(); | |||||
} | |||||
auto&& enumAlias = ctx.enumAlias; | |||||
auto&& iter = enumAlias.find(enumID); | |||||
auto enumName = attr->getEnumName(); | |||||
body += "{\n"; | |||||
body += formatv( | |||||
"auto& e_type = EnumWrapper<{0}::{1}>::type;", className, enumName | |||||
); | |||||
if (iter == enumAlias.end()) { | |||||
os << formatv( | |||||
"template<> PyTypeObject EnumWrapper<{0}::{1}>::type={{};\n", | |||||
className, enumName); | |||||
os << formatv( | |||||
"template<> const char* EnumWrapper<{0}::{1}>::name = \"{0}.{1}\";\n", | |||||
className, enumName); | |||||
std::vector<std::string> pairStr; | |||||
for (auto&& i: attr->getEnumMembers()) { | |||||
pairStr.push_back(formatv( | |||||
"{{normalize_enum(\"{2}\"), {0}::{1}::{2}}", | |||||
className, enumName, i)); | |||||
} | |||||
os << formatv(R"( | |||||
template<> std::unordered_map<std::string, {0}::{1}> | |||||
EnumWrapper<{0}::{1}>::str2type = {{ | |||||
{2} | |||||
}; | |||||
)", className, enumName, llvm::join(pairStr, ", ")); | |||||
pairStr.clear(); | |||||
for (auto&& i: attr->getEnumMembers()) { | |||||
pairStr.push_back(formatv( | |||||
"{{{0}::{1}::{2}, normalize_enum(\"{2}\")}", | |||||
className, enumName, i)); | |||||
} | |||||
os << formatv(R"( | |||||
template<> std::unordered_map<{0}::{1}, std::string> | |||||
EnumWrapper<{0}::{1}>::type2str = {{ | |||||
{2} | |||||
}; | |||||
)", className, enumName, llvm::join(pairStr, ", ")); | |||||
body += formatv(R"( | |||||
e_type = {{PyVarObject_HEAD_INIT(NULL, 0)}; | |||||
e_type.tp_name = "megengine.core._imperative_rt.ops.{0}.{1}"; | |||||
e_type.tp_basicsize = sizeof(EnumWrapper<{0}::{1}>); | |||||
e_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE; | |||||
e_type.tp_doc = "{0}.{1}"; | |||||
e_type.tp_base = &PyBaseObject_Type; | |||||
e_type.tp_repr = EnumWrapper<{0}::{1}>::py_repr; | |||||
e_type.tp_richcompare = EnumWrapper<{0}::{1}>::tp_richcompare; | |||||
mgb_assert(PyType_Ready(&e_type) >= 0); | |||||
)", className, enumName); | |||||
for (auto&& i: attr->getEnumMembers()) { | |||||
body += formatv(R"({{ | |||||
PyObject* inst = e_type.tp_alloc(&e_type, 0); | |||||
reinterpret_cast<EnumWrapper<{0}::{1}>*>(inst)->value = {0}::{1}::{2}; | |||||
mgb_assert(PyDict_SetItemString(e_type.tp_dict, "{2}", inst) >= 0); | |||||
})", className, enumName, i); | |||||
} | |||||
enumAlias.emplace(enumID, std::make_pair(className, enumName)); | |||||
} | |||||
body += formatv(R"( | |||||
PyType_Modified(&e_type); | |||||
mgb_assert(PyDict_SetItemString( | |||||
py_type.tp_dict, "{0}", reinterpret_cast<PyObject*>(&e_type)) >= 0); | |||||
)", enumName); | |||||
body += "}\n"; | |||||
} | |||||
} | |||||
// generate getsetters | |||||
std::vector<std::string> getsetters; | |||||
for (auto &&i : op.getMgbAttributes()) { | |||||
getsetters.push_back(formatv( | |||||
"{{\"{1}\", py_get_generic({0}, {1}), py_set_generic({0}, {1}), \"{1}\", NULL},", | |||||
className, i.name)); | |||||
} | |||||
// generate tp_init | |||||
std::string initBody; | |||||
if (!op.getMgbAttributes().empty()) { | |||||
initBody += "static const char* kwlist[] = {"; | |||||
llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) { | |||||
initBody += formatv("\"{0}\", ", attr.name); | |||||
}); | |||||
initBody += "NULL};\n"; | |||||
initBody += " PyObject "; | |||||
std::vector<std::string> attrs; | |||||
llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) { | |||||
attrs.push_back(formatv("*{0} = NULL", attr.name)); | |||||
}); | |||||
initBody += llvm::join(attrs, ", ") + ";\n"; | |||||
initBody += " if (!PyArg_ParseTupleAndKeywords(args, kwds, \"|"; | |||||
initBody += std::string(op.getMgbAttributes().size(), 'O'); | |||||
initBody += "\", const_cast<char**>(kwlist)"; | |||||
llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) { | |||||
initBody += formatv(" ,&{0}", attr.name); | |||||
}); | |||||
initBody += "))\n"; | |||||
initBody += " return -1;\n"; | |||||
llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) { | |||||
initBody += formatv(R"( | |||||
if ({1}) {{ | |||||
try {{ | |||||
reinterpret_cast<PyOp({0})*>(self)->inst().{1} = | |||||
pyobj_convert_generic<decltype({0}::{1})>::from({1}); | |||||
} catch(py::error_already_set& e) {{ | |||||
e.restore(); | |||||
return -1; | |||||
} catch(py::builtin_exception& e) {{ | |||||
e.set_error(); | |||||
return -1; | |||||
} catch(...) {{ | |||||
PyErr_SetString(PyExc_RuntimeError, "Unknown Error"); | |||||
return -1; | |||||
} | |||||
} | |||||
)", className, attr.name); | |||||
}); | |||||
} | |||||
initBody += "\n return 0;"; | |||||
os << formatv(R"( | |||||
PyOpDefBegin({0}) // {{ | |||||
static PyGetSetDef py_getsetters[]; | |||||
static int py_init(PyObject *self, PyObject *args, PyObject *kwds); | |||||
// }; | |||||
PyOpDefEnd({0}) | |||||
PyGetSetDef PyOp({0})::py_getsetters[] = {{ | |||||
{1} | |||||
{{NULL} /* Sentinel */ | |||||
}; | |||||
int PyOp({0})::py_init(PyObject *self, PyObject *args, PyObject *kwds) {{ | |||||
{2} | |||||
} | |||||
void _init_py_{0}(py::module m) {{ | |||||
using py_op = PyOp({0}); | |||||
auto& py_type = PyOpType({0}); | |||||
py_type = {{PyVarObject_HEAD_INIT(NULL, 0)}; | |||||
py_type.tp_name = "megengine.core._imperative_rt.ops.{0}"; | |||||
py_type.tp_basicsize = sizeof(PyOp({0})); | |||||
py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE; | |||||
py_type.tp_doc = "{0}"; | |||||
py_type.tp_base = &PyOpType(OpDef); | |||||
py_type.tp_dealloc = py_dealloc_generic<py_op>; | |||||
py_type.tp_new = py_new_generic<py_op>; | |||||
py_type.tp_init = py_op::py_init; | |||||
py_type.tp_getset = py_op::py_getsetters; | |||||
mgb_assert(PyType_Ready(&py_type) >= 0); | |||||
{3} | |||||
PyType_Modified(&py_type); | |||||
m.add_object("{0}", reinterpret_cast<PyObject*>(&py_type)); | |||||
mgb_assert(PyOp(OpDef)::ctype2pytype.emplace({0}::typeinfo(), &py_type).second); | |||||
} | |||||
)", | |||||
op.getCppClassName(), llvm::join(getsetters, "\n "), initBody, body); | |||||
} | |||||
static void for_each_operator(raw_ostream &os, RecordKeeper &keeper, | static void for_each_operator(raw_ostream &os, RecordKeeper &keeper, | ||||
std::function<void(raw_ostream&, MgbOp&)> callback) { | std::function<void(raw_ostream&, MgbOp&)> callback) { | ||||
auto op_base_class = keeper.getClass("Op"); | auto op_base_class = keeper.getClass("Op"); | ||||
@@ -360,13 +535,26 @@ static bool gen_op_def_c_body(raw_ostream &os, RecordKeeper &keeper) { | |||||
} | } | ||||
static bool gen_op_def_pybind11(raw_ostream &os, RecordKeeper &keeper) { | static bool gen_op_def_pybind11(raw_ostream &os, RecordKeeper &keeper) { | ||||
PybindContext ctx; | |||||
EnumContext ctx; | |||||
using namespace std::placeholders; | using namespace std::placeholders; | ||||
for_each_operator(os, keeper, | for_each_operator(os, keeper, | ||||
std::bind(gen_op_def_pybind11_single, _1, _2, std::ref(ctx))); | std::bind(gen_op_def_pybind11_single, _1, _2, std::ref(ctx))); | ||||
return false; | return false; | ||||
} | } | ||||
static bool gen_op_def_python_c_extension(raw_ostream &os, RecordKeeper &keeper) { | |||||
EnumContext ctx; | |||||
using namespace std::placeholders; | |||||
for_each_operator(os, keeper, | |||||
std::bind(gen_op_def_python_c_extension_single, _1, _2, std::ref(ctx))); | |||||
os << "#define INIT_ALL_OP(m)"; | |||||
for_each_operator(os, keeper, [&](raw_ostream& os, MgbOp& op) { | |||||
os << formatv(" \\\n _init_py_{0}(m);", op.getCppClassName()); | |||||
}); | |||||
os << "\n"; | |||||
return false; | |||||
} | |||||
int main(int argc, char **argv) { | int main(int argc, char **argv) { | ||||
llvm::InitLLVM y(argc, argv); | llvm::InitLLVM y(argc, argv); | ||||
llvm::cl::ParseCommandLineOptions(argc, argv); | llvm::cl::ParseCommandLineOptions(argc, argv); | ||||
@@ -379,5 +567,8 @@ int main(int argc, char **argv) { | |||||
if (action == ActionType::Pybind) { | if (action == ActionType::Pybind) { | ||||
return TableGenMain(argv[0], &gen_op_def_pybind11); | return TableGenMain(argv[0], &gen_op_def_pybind11); | ||||
} | } | ||||
if (action == ActionType::CPython) { | |||||
return TableGenMain(argv[0], &gen_op_def_python_c_extension); | |||||
} | |||||
return -1; | return -1; | ||||
} | |||||
} |