GitOrigin-RevId: 468a996bdd
HuaHua404-patch-1
@@ -30,12 +30,22 @@ private: | |||||
} | } | ||||
public: | public: | ||||
inline static WeakKeyMap<ValueWeakRef, py::object> module_trace_info_map; | |||||
ModuleTraceTransformation(py::function hook_fn) : m_hook_fn(hook_fn) {} | ModuleTraceTransformation(py::function hook_fn) : m_hook_fn(hook_fn) {} | ||||
ValueRefList apply_transformation( | ValueRefList apply_transformation( | ||||
const Operator& op, Span<ValueRef> inputs) override { | const Operator& op, Span<ValueRef> inputs) override { | ||||
if (op.is<ApplyOp>() && m_enabled > 0) { | if (op.is<ApplyOp>() && m_enabled > 0) { | ||||
auto outputs = apply_module_trace_hook(op.cast<ApplyOp>().op(), inputs); | auto outputs = apply_module_trace_hook(op.cast<ApplyOp>().op(), inputs); | ||||
return outputs; | return outputs; | ||||
} else if (op.is<RenameValue>()) { | |||||
auto outputs = imperative::apply(op, inputs); | |||||
if (auto module_trace_info = module_trace_info_map.try_get(inputs[0])) { | |||||
if (module_trace_info->ptr()) { | |||||
auto node = module_trace_info.value(); | |||||
module_trace_info_map[outputs[0]] = module_trace_info.value(); | |||||
} | |||||
} | |||||
return outputs; | |||||
} else { | } else { | ||||
return imperative::apply(op, inputs); | return imperative::apply(op, inputs); | ||||
} | } | ||||
@@ -47,10 +47,6 @@ namespace views = ranges::views; | |||||
namespace mgb::imperative::python { | namespace mgb::imperative::python { | ||||
namespace { | |||||
WeakKeyMap<ValueWeakRef, py::object> module_trace_info_map; | |||||
} // namespace | |||||
interpreter::Interpreter::Channel* interpreter_for_py = nullptr; | interpreter::Interpreter::Channel* interpreter_for_py = nullptr; | ||||
PyTypeObject* py_tensor_type = nullptr; | PyTypeObject* py_tensor_type = nullptr; | ||||
PyTypeObject* py_varnode_type = nullptr; | PyTypeObject* py_varnode_type = nullptr; | ||||
@@ -594,7 +590,9 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) { | |||||
} | } | ||||
PyObject* TensorWrapper::module_trace_info() { | PyObject* TensorWrapper::module_trace_info() { | ||||
if (auto module_trace_info = module_trace_info_map.try_get(m_tensor->data())) { | |||||
if (auto module_trace_info = | |||||
ModuleTraceTransformation::module_trace_info_map.try_get( | |||||
m_tensor->data())) { | |||||
if (module_trace_info->ptr()) { | if (module_trace_info->ptr()) { | ||||
return module_trace_info->inc_ref().ptr(); | return module_trace_info->inc_ref().ptr(); | ||||
} | } | ||||
@@ -608,7 +606,8 @@ PyObject* TensorWrapper::module_trace_info() { | |||||
void TensorWrapper::set_module_trace_info(PyObject* obj) { | void TensorWrapper::set_module_trace_info(PyObject* obj) { | ||||
// TODO: erase when obj == nullptr | // TODO: erase when obj == nullptr | ||||
module_trace_info_map[m_tensor->data()] = py::reinterpret_borrow<py::object>(obj); | |||||
ModuleTraceTransformation::module_trace_info_map[m_tensor->data()] = | |||||
py::reinterpret_borrow<py::object>(obj); | |||||
} | } | ||||
void TensorWrapper::_set_format(PyObject* dest) { | void TensorWrapper::_set_format(PyObject* dest) { | ||||
@@ -620,6 +619,7 @@ void TensorWrapper::_set_format(PyObject* dest) { | |||||
void TensorWrapper::_set_name(PyObject* dest) { | void TensorWrapper::_set_name(PyObject* dest) { | ||||
auto py_dest = py::reinterpret_borrow<py::object>(dest); | auto py_dest = py::reinterpret_borrow<py::object>(dest); | ||||
auto name = py_dest.cast<std::string>(); | auto name = py_dest.cast<std::string>(); | ||||
m_tensor->set_name(name); | m_tensor->set_name(name); | ||||
} | } | ||||
@@ -9,7 +9,7 @@ from megengine.core._imperative_rt.core2 import apply | |||||
from megengine.core.ops import builtin | from megengine.core.ops import builtin | ||||
from megengine.module import Module | from megengine.module import Module | ||||
from megengine.traced_module import TracedModule, enable_expr_checker, trace_module | from megengine.traced_module import TracedModule, enable_expr_checker, trace_module | ||||
from megengine.traced_module.expr import Apply, CallFunction, Constant | |||||
from megengine.traced_module.expr import Apply, CallFunction, CallMethod, Constant | |||||
class MyModule1(M.Module): | class MyModule1(M.Module): | ||||
@@ -59,6 +59,14 @@ class MyModule4(M.Module): | |||||
return self.add(x, y) | return self.add(x, y) | ||||
class MyModule5(M.Module): | |||||
def forward(self, x): | |||||
a = x + x | |||||
b = x * a | |||||
b.name = "result" | |||||
return b | |||||
def test_trace_module(): | def test_trace_module(): | ||||
enable_expr_checker() | enable_expr_checker() | ||||
x = Tensor(1) | x = Tensor(1) | ||||
@@ -157,3 +165,9 @@ def test_trace_module_2(): | |||||
traced_model.graph._exprs[2].opdef, builtin.Elemwise | traced_model.graph._exprs[2].opdef, builtin.Elemwise | ||||
) | ) | ||||
assert int(traced_model(Tensor([1, 2]))[0]) == 3 | assert int(traced_model(Tensor([1, 2]))[0]) == 3 | ||||
def test_rename(): | |||||
model = MyModule5() | |||||
tm_model = trace_module(model, Tensor(1)) | |||||
assert isinstance(tm_model.graph.outputs[0].expr, CallMethod) |