GitOrigin-RevId: 5b45bded1d
release-1.6
@@ -0,0 +1,7 @@ | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
@@ -0,0 +1,34 @@ | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
from typing import Dict | |||||
from ...core._imperative_rt import OpDef | |||||
from ...core.ops import builtin | |||||
from ...version import __version__ | |||||
OPDEF_PARAM_LOADER = {} | |||||
def get_opdef_state(obj: OpDef) -> Dict: | |||||
state = obj.__getstate__() | |||||
state["type"] = type(obj) | |||||
state["version"] = __version__ | |||||
return state | |||||
def load_opdef_from_state(state: Dict) -> OpDef: | |||||
assert "type" in state and issubclass(state["type"], OpDef) | |||||
assert "version" in state | |||||
opdef_type = state.pop("type") | |||||
if opdef_type in OPDEF_PARAM_LOADER: | |||||
loader = OPDEF_PARAM_LOADER[opdef_type] | |||||
state = loader(state) | |||||
state.pop("version") | |||||
opdef_obj = opdef_type() | |||||
opdef_obj.__setstate__(state) | |||||
return opdef_obj |
@@ -0,0 +1,41 @@ | |||||
/** | |||||
* \file imperative/python/src/module_trace.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "./module_trace.h" | |||||
#include "./helper.h" // include op pybind11 caster | |||||
namespace py = pybind11; | |||||
namespace mgb::imperative::python { | |||||
apply_result_t apply_module_trace(ApplyContext& ctx) { | |||||
apply_result_t outputs; | |||||
auto args = py::tuple(ctx.nargs + 1); | |||||
args[0] = py::cast(ctx.op); | |||||
for (size_t i = 0; i < ctx.nargs; i++) { | |||||
args[i + 1] = TensorWrapper::make(ctx.args[i]->shared_from_this()); | |||||
} | |||||
auto pyout = PyObject_Call(cpp_apply_module_trace, args.ptr(), nullptr); | |||||
if (!pyout) throw py::error_already_set(); | |||||
auto ret = py::reinterpret_steal<py::object>(pyout); | |||||
// assumption: python function always returns PyList | |||||
auto tup = py::reinterpret_borrow<py::list>(ret); | |||||
for (auto i = 0; i < tup.size(); i++) { | |||||
auto tw = TensorWrapper::try_cast(tup[i].ptr()); | |||||
outputs.emplace_back(tw->m_tensor); | |||||
} | |||||
return outputs; | |||||
} | |||||
} // namespace mgb::imperative::python |
@@ -0,0 +1,20 @@ | |||||
/** | |||||
* \file imperative/python/src/module_trace.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include "./tensor.h" | |||||
namespace mgb::imperative::python { | |||||
apply_result_t apply_module_trace(ApplyContext& ctx); | |||||
} // namespace mgb::imperative::python |
@@ -88,6 +88,19 @@ PyObject* py_new_generic(PyTypeObject* type, PyObject*, PyObject*) { | |||||
return obj; | return obj; | ||||
} | } | ||||
template<typename T, typename SNIFAE=void> | |||||
struct serialization { | |||||
static T load(py::object obj) { | |||||
return py::cast<T>(obj); | |||||
} | |||||
template<typename U, | |||||
typename = std::enable_if_t<std::is_same_v<T, std::decay_t<U>>>> | |||||
static py::object dump(U&& t) { | |||||
return py::cast(std::forward<U>(t)); | |||||
} | |||||
}; | |||||
template<typename T> | template<typename T> | ||||
void py_dealloc_generic(PyObject* obj) { | void py_dealloc_generic(PyObject* obj) { | ||||
reinterpret_cast<T*>(obj)->op.reset(); | reinterpret_cast<T*>(obj)->op.reset(); | ||||
@@ -127,6 +140,13 @@ struct PyOpDef { | |||||
static PyGetSetDef py_getsetters[]; | static PyGetSetDef py_getsetters[]; | ||||
static Py_hash_t tp_hash(PyObject *obj); | static Py_hash_t tp_hash(PyObject *obj); | ||||
static PyObject* tp_richcompare(PyObject *self, PyObject *other, int op); | static PyObject* tp_richcompare(PyObject *self, PyObject *other, int op); | ||||
static PyObject* py_repr(PyObject* self) { | |||||
return py::cast( | |||||
reinterpret_cast<PyOpDef*>(self)->op->make_name()) | |||||
.release() | |||||
.ptr(); | |||||
} | |||||
}; | }; | ||||
PyTypeObject PyOpType(OpDef); | PyTypeObject PyOpType(OpDef); | ||||
std::unordered_map<mgb::Typeinfo*, PyTypeObject*> PyOp(OpDef)::ctype2pytype; | std::unordered_map<mgb::Typeinfo*, PyTypeObject*> PyOp(OpDef)::ctype2pytype; | ||||
@@ -191,6 +211,13 @@ struct EnumWrapper { | |||||
std::string(name) + "." + reinterpret_cast<EnumWrapper*>(self)->to_string()) | std::string(name) + "." + reinterpret_cast<EnumWrapper*>(self)->to_string()) | ||||
.release().ptr(); | .release().ptr(); | ||||
} | } | ||||
static PyObject* py_dump(PyObject* self) { | |||||
return py::cast(reinterpret_cast<EnumWrapper*>(self)->to_string()) | |||||
.release() | |||||
.ptr(); | |||||
} | |||||
static PyObject* tp_richcompare(PyObject *self, PyObject *other, int op) { | static PyObject* tp_richcompare(PyObject *self, PyObject *other, int op) { | ||||
if (op == Py_EQ || op == Py_NE) { | if (op == Py_EQ || op == Py_NE) { | ||||
T lhs, rhs; | T lhs, rhs; | ||||
@@ -279,6 +306,19 @@ struct BitCombinedEnumWrapper { | |||||
reinterpret_cast<BitCombinedEnumWrapper*>(self)->to_string()) | reinterpret_cast<BitCombinedEnumWrapper*>(self)->to_string()) | ||||
.release().ptr(); | .release().ptr(); | ||||
} | } | ||||
static PyObject* py_dump(PyObject* self) { | |||||
std::vector<std::string> result; | |||||
auto value = reinterpret_cast<BitCombinedEnumWrapper*>(self)->value; | |||||
uint32_t value_int = static_cast<uint32_t>(value); | |||||
for (uint32_t i = 0; i < 32; i++) { | |||||
if (value_int >> i & 1) { | |||||
result.push_back(members[i]); | |||||
} | |||||
} | |||||
return py::tuple(py::cast(result)).release().ptr(); | |||||
} | |||||
static PyObject* py_or(PyObject* self, PyObject* other) { | static PyObject* py_or(PyObject* self, PyObject* other) { | ||||
if(!(self->ob_type == other->ob_type)){ | if(!(self->ob_type == other->ob_type)){ | ||||
return PyErr_Format( | return PyErr_Format( | ||||
@@ -326,6 +366,24 @@ struct BitCombinedEnumWrapper { | |||||
return false; | return false; | ||||
} | } | ||||
} | } | ||||
if (py::isinstance<py::tuple>(src)) { | |||||
auto params = py::cast<std::vector<std::string>>(src); | |||||
bool first = true; | |||||
for (auto s : params){ | |||||
auto&& iter = mem2value.find(normalize_enum(s)); | |||||
if (iter != mem2value.end()) { | |||||
if (first) { | |||||
value = iter->second; | |||||
first = false; | |||||
} else { | |||||
value |= iter->second; | |||||
} | |||||
} else { | |||||
return false; | |||||
} | |||||
} | |||||
return true; | |||||
} | |||||
if (py::isinstance<py::int_>(obj)) { | if (py::isinstance<py::int_>(obj)) { | ||||
auto v = py::cast<std::underlying_type_t<T>>(src); | auto v = py::cast<std::underlying_type_t<T>>(src); | ||||
if(v > EnumTrait<T>::max) { | if(v > EnumTrait<T>::max) { | ||||
@@ -351,6 +409,25 @@ struct BitCombinedEnumWrapper { | |||||
} | } | ||||
}; | }; | ||||
template<typename T> | |||||
struct serialization<T, | |||||
std::enable_if_t<std::is_enum_v<std::decay_t<T>>>> { | |||||
static T load(py::object obj) { | |||||
auto caster = pybind11::detail::type_caster<T>(); | |||||
if (caster.load(obj, true)) { | |||||
return caster; | |||||
} else { | |||||
PyErr_SetString(PyExc_RuntimeError, | |||||
"load faild \n"); | |||||
return caster; | |||||
} | |||||
} | |||||
static py::object dump(T t) { | |||||
return py::cast(t).attr("dump")(); | |||||
} | |||||
}; | |||||
void _init_py_op_def(py::module m) { | void _init_py_op_def(py::module m) { | ||||
using py_op = PyOp(OpDef); | using py_op = PyOp(OpDef); | ||||
auto& py_type = PyOpType(OpDef); | auto& py_type = PyOpType(OpDef); | ||||
@@ -363,6 +440,7 @@ void _init_py_op_def(py::module m) { | |||||
py_type.tp_hash = PyOp(OpDef)::tp_hash; | py_type.tp_hash = PyOp(OpDef)::tp_hash; | ||||
py_type.tp_richcompare = PyOp(OpDef)::tp_richcompare; | py_type.tp_richcompare = PyOp(OpDef)::tp_richcompare; | ||||
py_type.tp_getset = py_op::py_getsetters; | py_type.tp_getset = py_op::py_getsetters; | ||||
py_type.tp_repr = py_op::py_repr; | |||||
mgb_assert(PyType_Ready(&py_type) >= 0); | mgb_assert(PyType_Ready(&py_type) >= 0); | ||||
m.add_object("OpDef", reinterpret_cast<PyObject*>(&py_type)); | m.add_object("OpDef", reinterpret_cast<PyObject*>(&py_type)); | ||||
} | } | ||||
@@ -451,18 +451,11 @@ public: | |||||
template<typename... Args> | template<typename... Args> | ||||
static PyObject* cnew(Args&&... args) { | static PyObject* cnew(Args&&... args) { | ||||
auto* pytype = type().operator->(); | auto* pytype = type().operator->(); | ||||
auto* self = pytype->tp_alloc(pytype, 0); | |||||
auto* inst = reinterpret_cast<wrap_t*>(self)->inst(); | |||||
if constexpr (has_vectorcall && tp_vectorcall::valid) { | |||||
reinterpret_cast<wrap_t*>(self)->vectorcall_slot = &tp_vectorcall::template impl<>; | |||||
} | |||||
new(inst) T(std::forward<Args>(args)...); | |||||
return self; | |||||
return cnew_with_type(pytype, std::forward<Args>(args)...); | |||||
} | } | ||||
template<typename... Args> | template<typename... Args> | ||||
static PyObject* cnew_with_type(PyTypeObject* pytype, Args&&... args) { | static PyObject* cnew_with_type(PyTypeObject* pytype, Args&&... args) { | ||||
auto* self = pytype->tp_alloc(pytype, 0); | auto* self = pytype->tp_alloc(pytype, 0); | ||||
auto* inst = reinterpret_cast<wrap_t*>(self)->inst(); | auto* inst = reinterpret_cast<wrap_t*>(self)->inst(); | ||||
if constexpr (has_vectorcall && tp_vectorcall::valid) { | if constexpr (has_vectorcall && tp_vectorcall::valid) { | ||||
@@ -20,6 +20,7 @@ | |||||
#include "./tensor.h" | #include "./tensor.h" | ||||
#include "./grad.h" | #include "./grad.h" | ||||
#include "./trace.h" | #include "./trace.h" | ||||
#include "./module_trace.h" | |||||
#include "./common.h" | #include "./common.h" | ||||
#include "./numpy_dtypes.h" | #include "./numpy_dtypes.h" | ||||
#include "./graph_rt.h" | #include "./graph_rt.h" | ||||
@@ -41,6 +42,7 @@ interpreter::Interpreter::Channel* interpreter_for_py; | |||||
PyObject *cpp_apply_with_tracing, *cpp_apply_const_with_tracing; | PyObject *cpp_apply_with_tracing, *cpp_apply_const_with_tracing; | ||||
PyObject *cpp_apply_backward_varnode; | PyObject *cpp_apply_backward_varnode; | ||||
PyObject *cpp_apply_module_trace; | |||||
std::shared_ptr<Tensor> make_const(imperative::TensorPtr value) { | std::shared_ptr<Tensor> make_const(imperative::TensorPtr value) { | ||||
if (!(ApplyContext::global_enable & Tensor::Flags::TRACE)) { | if (!(ApplyContext::global_enable & Tensor::Flags::TRACE)) { | ||||
@@ -70,6 +72,7 @@ std::shared_ptr<Tensor> make_const(imperative::TensorPtr value) { | |||||
REGISTE_APPLY_FUNC(cpp_apply_with_tracing) | REGISTE_APPLY_FUNC(cpp_apply_with_tracing) | ||||
REGISTE_APPLY_FUNC(cpp_apply_const_with_tracing) | REGISTE_APPLY_FUNC(cpp_apply_const_with_tracing) | ||||
REGISTE_APPLY_FUNC(cpp_apply_backward_varnode) | REGISTE_APPLY_FUNC(cpp_apply_backward_varnode) | ||||
REGISTE_APPLY_FUNC(cpp_apply_module_trace) | |||||
#undef REGISTE_APPLY_FUNC | #undef REGISTE_APPLY_FUNC | ||||
@@ -79,6 +82,14 @@ Tensor::flags_t ApplyContext::global_enable = 0; | |||||
void set_tracing() { ApplyContext::global_enable |= Tensor::Flags::TRACE; } | void set_tracing() { ApplyContext::global_enable |= Tensor::Flags::TRACE; } | ||||
void unset_tracing() { ApplyContext::global_enable &= ~Tensor::Flags::TRACE; } | void unset_tracing() { ApplyContext::global_enable &= ~Tensor::Flags::TRACE; } | ||||
void set_module_tracing() { ApplyContext::global_enable |= Tensor::Flags::MODULE_TRACE; } | |||||
void unset_module_tracing() { ApplyContext::global_enable &= ~Tensor::Flags::MODULE_TRACE; } | |||||
bool is_tracing_module() { | |||||
return ApplyContext::global_enable & Tensor::Flags::MODULE_TRACE; | |||||
} | |||||
bool skip_tracing = false; | bool skip_tracing = false; | ||||
apply_result_t apply(ApplyContext& ctx) { | apply_result_t apply(ApplyContext& ctx) { | ||||
@@ -117,6 +128,11 @@ apply_result_t apply(ApplyContext& ctx) { | |||||
return ret; | return ret; | ||||
} | } | ||||
if (flags & Tensor::Flags::MODULE_TRACE) { | |||||
return apply_module_trace(ctx); | |||||
} | |||||
if (flags & Tensor::Flags::TRACE) { | if (flags & Tensor::Flags::TRACE) { | ||||
return apply_trace(ctx); | return apply_trace(ctx); | ||||
} else { | } else { | ||||
@@ -310,6 +326,21 @@ REGISTE_TENSORWRAPPER_FUNC(bool, recording) | |||||
#undef REGISTE_TENSORWRAPPER_FUNC | #undef REGISTE_TENSORWRAPPER_FUNC | ||||
PyObject* TensorWrapper::module_trace_info() { | |||||
if (!m_tensor->m_module_trace_info.ptr()) { | |||||
PyErr_SetString(PyExc_AttributeError, | |||||
"Has no attribute named \'_NodeMixin__node\', please " | |||||
"set it first"); | |||||
return nullptr; | |||||
} | |||||
return m_tensor->m_module_trace_info.inc_ref().ptr(); | |||||
} | |||||
void TensorWrapper::set_module_trace_info(PyObject* obj) { | |||||
m_tensor->m_module_trace_info = py::reinterpret_borrow<py::object>(obj); | |||||
} | |||||
#define REGISTE_TENSORWRAPPER_PYOBJECT_FUNC(member) \ | #define REGISTE_TENSORWRAPPER_PYOBJECT_FUNC(member) \ | ||||
PyObject* TensorWrapper::member() { \ | PyObject* TensorWrapper::member() { \ | ||||
@@ -495,7 +526,9 @@ void TensorWrapper::reset(PyObject* tensor) { | |||||
} | } | ||||
std::string user_custom_name = m_tensor->user_custom_name; | std::string user_custom_name = m_tensor->user_custom_name; | ||||
std::string automatic_name = m_tensor->automatic_name; | std::string automatic_name = m_tensor->automatic_name; | ||||
auto module_trace_info = m_tensor->m_module_trace_info; | |||||
m_tensor = t->m_tensor; | m_tensor = t->m_tensor; | ||||
m_tensor->m_module_trace_info = module_trace_info; | |||||
m_tensor->user_custom_name = user_custom_name; | m_tensor->user_custom_name = user_custom_name; | ||||
m_tensor->automatic_name = automatic_name; | m_tensor->automatic_name = automatic_name; | ||||
} | } | ||||
@@ -856,6 +889,7 @@ void init_tensor(py::module m) { | |||||
.def_getset<&TensorWrapper::trace_mixin_info, &TensorWrapper::set_trace_mixin_info>("_trace_mixin_info") | .def_getset<&TensorWrapper::trace_mixin_info, &TensorWrapper::set_trace_mixin_info>("_trace_mixin_info") | ||||
.def_getset<&TensorWrapper::user_custom_name, &TensorWrapper::set_user_custom_name>("c_name") | .def_getset<&TensorWrapper::user_custom_name, &TensorWrapper::set_user_custom_name>("c_name") | ||||
.def_getset<&TensorWrapper::automatic_name, &TensorWrapper::set_automatic_name>("_name") | .def_getset<&TensorWrapper::automatic_name, &TensorWrapper::set_automatic_name>("_name") | ||||
.def_getset<&TensorWrapper::module_trace_info, &TensorWrapper::set_module_trace_info>("_NodeMixin__node") | |||||
.finalize(); | .finalize(); | ||||
if (!tensor_type) throw py::error_already_set(); | if (!tensor_type) throw py::error_already_set(); | ||||
py::setattr(m, "Tensor", tensor_type); | py::setattr(m, "Tensor", tensor_type); | ||||
@@ -998,7 +1032,7 @@ void init_tensor(py::module m) { | |||||
m.def("set_cpp_apply_with_tracing", &set_cpp_apply_with_tracing); | m.def("set_cpp_apply_with_tracing", &set_cpp_apply_with_tracing); | ||||
m.def("set_cpp_apply_const_with_tracing", &set_cpp_apply_const_with_tracing); | m.def("set_cpp_apply_const_with_tracing", &set_cpp_apply_const_with_tracing); | ||||
m.def("set_cpp_apply_backward_varnode", &set_cpp_apply_backward_varnode); | m.def("set_cpp_apply_backward_varnode", &set_cpp_apply_backward_varnode); | ||||
m.def("set_cpp_apply_module_trace", &set_cpp_apply_module_trace); | |||||
m.attr("skip_tracing") = &skip_tracing; | m.attr("skip_tracing") = &skip_tracing; | ||||
py::class_<SharedHandle>(m, "SharedHandle") | py::class_<SharedHandle>(m, "SharedHandle") | ||||
@@ -1016,6 +1050,9 @@ void init_tensor(py::module m) { | |||||
m.def("set_allow_higher_order_directive", [](bool value){ | m.def("set_allow_higher_order_directive", [](bool value){ | ||||
GradKey::allow_higher_order_directive = value; | GradKey::allow_higher_order_directive = value; | ||||
}); | }); | ||||
m.def("set_module_tracing", &set_module_tracing); | |||||
m.def("unset_module_tracing", &unset_module_tracing); | |||||
m.def("is_tracing_module", &is_tracing_module); | |||||
} | } | ||||
#undef MGE_PY_INTERFACE | #undef MGE_PY_INTERFACE | ||||
@@ -96,6 +96,7 @@ struct Tensor : std::enable_shared_from_this<Tensor>, NonCopyableObj { | |||||
static constexpr flags_t SCALAR = 1; | static constexpr flags_t SCALAR = 1; | ||||
static constexpr flags_t GRAD = 1 << 1; | static constexpr flags_t GRAD = 1 << 1; | ||||
static constexpr flags_t TRACE = 1 << 2; | static constexpr flags_t TRACE = 1 << 2; | ||||
static constexpr flags_t MODULE_TRACE = 1 << 3; | |||||
}; | }; | ||||
flags_t m_flags = 0; | flags_t m_flags = 0; | ||||
@@ -106,6 +107,7 @@ struct Tensor : std::enable_shared_from_this<Tensor>, NonCopyableObj { | |||||
std::string user_custom_name; | std::string user_custom_name; | ||||
std::string automatic_name; | std::string automatic_name; | ||||
cg::VarNode* m_var; | cg::VarNode* m_var; | ||||
pybind11::object m_module_trace_info; | |||||
using Handle = interpreter::Interpreter::Handle; | using Handle = interpreter::Interpreter::Handle; | ||||
@@ -158,10 +160,10 @@ struct TensorWrapper { | |||||
using wrap_t = pyext17::wrap<TensorWrapper>; | using wrap_t = pyext17::wrap<TensorWrapper>; | ||||
friend wrap_t; | friend wrap_t; | ||||
inline static TensorWrapper* cast(PyObject* op) {return reinterpret_cast<wrap_t*>(op)->inst();} | |||||
inline static TensorWrapper* try_cast(PyObject* op) { | |||||
if (!wrap_t::type().isinstance(op)) return nullptr; | |||||
return cast(op); | |||||
inline static TensorWrapper* cast(PyObject* obj) {return reinterpret_cast<wrap_t*>(obj)->inst();} | |||||
inline static TensorWrapper* try_cast(PyObject* obj) { | |||||
if (!wrap_t::type().isinstance(obj)) return nullptr; | |||||
return cast(obj); | |||||
} | } | ||||
inline ObjectPtr<TensorWrapper, pybind11::handle> self() {return wrap_t::pycast(this);} | inline ObjectPtr<TensorWrapper, pybind11::handle> self() {return wrap_t::pycast(this);} | ||||
@@ -206,6 +208,8 @@ struct TensorWrapper { | |||||
void set_compiled_info(PyObject *); | void set_compiled_info(PyObject *); | ||||
PyObject* trace_mixin_info(); | PyObject* trace_mixin_info(); | ||||
void set_trace_mixin_info(PyObject *); | void set_trace_mixin_info(PyObject *); | ||||
PyObject* module_trace_info(); | |||||
void set_module_trace_info(PyObject *); | |||||
PyObject* user_custom_name(); | PyObject* user_custom_name(); | ||||
void set_user_custom_name(PyObject *); | void set_user_custom_name(PyObject *); | ||||
PyObject* automatic_name(); | PyObject* automatic_name(); | ||||
@@ -331,6 +335,7 @@ void init_tensor(pybind11::module); | |||||
extern PyObject *cpp_apply_with_tracing; | extern PyObject *cpp_apply_with_tracing; | ||||
extern PyObject *cpp_apply_backward_varnode; | extern PyObject *cpp_apply_backward_varnode; | ||||
extern PyObject *cpp_apply_module_trace; | |||||
} // namespace mgb::imperative::python | } // namespace mgb::imperative::python | ||||
@@ -14,6 +14,11 @@ import numpy as np | |||||
import megengine as mge | import megengine as mge | ||||
from megengine import Parameter, Tensor | from megengine import Parameter, Tensor | ||||
from megengine.core.ops import builtin | |||||
from megengine.experimental.traced_module.serialization import ( | |||||
get_opdef_state, | |||||
load_opdef_from_state, | |||||
) | |||||
def test_tensor_serialization(): | def test_tensor_serialization(): | ||||
@@ -86,3 +91,25 @@ def test_compatibility(): | |||||
test_old_tensor("tensor_v1_1.mge") | test_old_tensor("tensor_v1_1.mge") | ||||
test_old_tensor("tensor_v1_2.mge") | test_old_tensor("tensor_v1_2.mge") | ||||
def test_opdef_serialization(): | |||||
with TemporaryFile() as f: | |||||
x = builtin.Elemwise(mode="Add") | |||||
pickle.dump(get_opdef_state(x), f) | |||||
f.seek(0) | |||||
load_x = load_opdef_from_state(pickle.load(f)) | |||||
assert x == load_x | |||||
with TemporaryFile() as f: | |||||
x = builtin.Convolution(stride_h=9, compute_mode="float32") | |||||
x.strategy = ( | |||||
builtin.Convolution.Strategy.PROFILE | |||||
| builtin.Convolution.Strategy.HEURISTIC | |||||
| builtin.Convolution.Strategy.REPRODUCIBLE | |||||
) | |||||
pickle.dump(get_opdef_state(x), f) | |||||
f.seek(0) | |||||
load_x = load_opdef_from_state(pickle.load(f)) | |||||
assert x.strategy == load_x.strategy | |||||
assert x == load_x |
@@ -34,6 +34,7 @@ private: | |||||
void emit_class(); | void emit_class(); | ||||
void emit_py_init(); | void emit_py_init(); | ||||
void emit_py_getsetters(); | void emit_py_getsetters(); | ||||
void emit_py_methods(); | |||||
Initproc emit_initproc(); | Initproc emit_initproc(); | ||||
MgbOp& op; | MgbOp& op; | ||||
@@ -133,9 +134,16 @@ void $0(PyTypeObject& py_type) { | |||||
if (firstOccur) { | if (firstOccur) { | ||||
os << tgfmt(R"( | os << tgfmt(R"( | ||||
static PyMethodDef tp_methods[] = { | |||||
{const_cast<char*>("dump"), (PyCFunction)$enumTpl<$opClass::$enumClass>::py_dump, METH_NOARGS, NULL}, | |||||
{NULL} /* Sentinel */ | |||||
}; | |||||
)", &ctx); | |||||
os << tgfmt(R"( | |||||
static PyType_Slot slots[] = { | static PyType_Slot slots[] = { | ||||
{Py_tp_repr, (void*)$enumTpl<$opClass::$enumClass>::py_repr}, | {Py_tp_repr, (void*)$enumTpl<$opClass::$enumClass>::py_repr}, | ||||
{Py_tp_richcompare, (void*)$enumTpl<$opClass::$enumClass>::tp_richcompare}, | {Py_tp_richcompare, (void*)$enumTpl<$opClass::$enumClass>::tp_richcompare}, | ||||
{Py_tp_methods, tp_methods}, | |||||
)", &ctx); | )", &ctx); | ||||
if (attr->getEnumCombinedFlag()) { | if (attr->getEnumCombinedFlag()) { | ||||
// only bit combined enum could new instance because bitwise operation, | // only bit combined enum could new instance because bitwise operation, | ||||
@@ -212,17 +220,62 @@ Initproc OpDefEmitter::emit() { | |||||
emit_class(); | emit_class(); | ||||
emit_py_init(); | emit_py_init(); | ||||
emit_py_getsetters(); | emit_py_getsetters(); | ||||
emit_py_methods(); | |||||
return emit_initproc(); | return emit_initproc(); | ||||
} | } | ||||
void OpDefEmitter::emit_class() { | void OpDefEmitter::emit_class() { | ||||
auto&& className = op.getCppClassName(); | |||||
std::string method_defs; | |||||
std::vector<std::string> body; | |||||
llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) { | |||||
body.push_back(formatv(R"( | |||||
{{"{0}", serialization<decltype(opdef.{0})>::dump(opdef.{0})})" | |||||
, attr.name)); | |||||
}); | |||||
method_defs += formatv(R"( | |||||
static PyObject* getstate(PyObject* self, PyObject*) {{ | |||||
auto& opdef = reinterpret_cast<PyOp({0})*>(self)->inst(); | |||||
static_cast<void>(opdef); | |||||
std::unordered_map<std::string, py::object> state {{ | |||||
{1} | |||||
}; | |||||
return py::cast(state).release().ptr(); | |||||
})", className, llvm::join(body, ",")); | |||||
body.clear(); | |||||
llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) { | |||||
body.push_back(formatv(R"( | |||||
{{ | |||||
auto&& iter = state.find("{0}"); | |||||
if (iter != state.end()) { | |||||
opdef.{0} = serialization<decltype(opdef.{0})>::load(iter->second); | |||||
} | |||||
})", attr.name)); | |||||
}); | |||||
method_defs += formatv(R"( | |||||
static PyObject* setstate(PyObject* self, PyObject* args) {{ | |||||
PyObject* dict = PyTuple_GetItem(args, 0); | |||||
if (!dict) return NULL; | |||||
auto state = py::cast<std::unordered_map<std::string, py::object>>(dict); | |||||
auto& opdef = reinterpret_cast<PyOp({0})*>(self)->inst(); | |||||
static_cast<void>(opdef); | |||||
{1} | |||||
Py_RETURN_NONE; | |||||
})", className, llvm::join(body, "\n")); | |||||
os << tgfmt(R"( | os << tgfmt(R"( | ||||
PyOpDefBegin($_self) // { | PyOpDefBegin($_self) // { | ||||
static PyGetSetDef py_getsetters[]; | static PyGetSetDef py_getsetters[]; | ||||
static PyMethodDef tp_methods[]; | |||||
$0 | |||||
static int py_init(PyObject *self, PyObject *args, PyObject *kwds); | static int py_init(PyObject *self, PyObject *args, PyObject *kwds); | ||||
// }; | // }; | ||||
PyOpDefEnd($_self) | PyOpDefEnd($_self) | ||||
)", &ctx); | |||||
)", &ctx, method_defs); | |||||
} | } | ||||
void OpDefEmitter::emit_py_init() { | void OpDefEmitter::emit_py_init() { | ||||
@@ -302,6 +355,33 @@ PyGetSetDef PyOp($_self)::py_getsetters[] = { | |||||
)", &ctx, llvm::join(llvm::map_range(op.getMgbAttributes(), f), "\n ")); | )", &ctx, llvm::join(llvm::map_range(op.getMgbAttributes(), f), "\n ")); | ||||
} | } | ||||
void OpDefEmitter::emit_py_methods(){ | |||||
// generate methods | |||||
std::string method_defs; | |||||
std::vector<std::string> method_items; | |||||
{ | |||||
auto&& className = op.getCppClassName(); | |||||
// generate getstate | |||||
method_items.push_back(formatv( | |||||
"{{const_cast<char*>(\"__getstate__\"), PyOp({0})::getstate, METH_NOARGS, \"{0} getstate\"},", | |||||
className)); | |||||
// generate setstate | |||||
method_items.push_back(formatv( | |||||
"{{const_cast<char*>(\"__setstate__\"), PyOp({0})::setstate, METH_VARARGS, \"{0} setstate\"},", | |||||
className)); | |||||
} | |||||
os << tgfmt(R"( | |||||
PyMethodDef PyOp($_self)::tp_methods[] = { | |||||
$0 | |||||
{NULL} /* Sentinel */ | |||||
}; | |||||
)", &ctx, llvm::join(method_items, "\n ")); | |||||
} | |||||
Initproc OpDefEmitter::emit_initproc() { | Initproc OpDefEmitter::emit_initproc() { | ||||
std::string initproc = formatv("_init_py_{0}", op.getCppClassName()); | std::string initproc = formatv("_init_py_{0}", op.getCppClassName()); | ||||
std::string subclass_init_call; | std::string subclass_init_call; | ||||
@@ -321,6 +401,7 @@ void $0(py::module m) { | |||||
py_type.tp_dealloc = py_dealloc_generic<py_op>; | py_type.tp_dealloc = py_dealloc_generic<py_op>; | ||||
py_type.tp_new = py_new_generic<py_op>; | py_type.tp_new = py_new_generic<py_op>; | ||||
py_type.tp_init = py_op::py_init; | py_type.tp_init = py_op::py_init; | ||||
py_type.tp_methods = py_op::tp_methods; | |||||
py_type.tp_getset = py_op::py_getsetters; | py_type.tp_getset = py_op::py_getsetters; | ||||
mgb_assert(PyType_Ready(&py_type) >= 0); | mgb_assert(PyType_Ready(&py_type) >= 0); | ||||
$1 | $1 | ||||