diff --git a/imperative/python/megengine/traced_module/expr.py b/imperative/python/megengine/traced_module/expr.py index cc486b0a..19624482 100644 --- a/imperative/python/megengine/traced_module/expr.py +++ b/imperative/python/megengine/traced_module/expr.py @@ -606,7 +606,8 @@ class Apply(Expr): def apply_module_trace_hook(cls, opdef, *inputs): for i in inputs: node = NodeMixin.get(i, None) - assert node is not None + if node is None: # capture as constant + NodeMixin.wrap_safe(i, Constant.make(i)) if isinstance(opdef, FakeQuant): inp_nodes = [NodeMixin.get(inputs[0])] @@ -627,7 +628,6 @@ class Apply(Expr): unset_module_tracing() outputs = apply(opdef, *inputs) - outputs = list(map(Tensor, outputs)) set_module_tracing() apply_node.add_outputs(outputs) @@ -741,12 +741,8 @@ class Constant(Expr): assert isinstance(c, (RawTensor, Module)) if isinstance(c, Module): assert module_tracer.is_builtin(c) or c.is_qat - if isinstance(c, RawTensor): - if is_tracing_module(): - unset_module_tracing() - c = Tensor(c) - set_module_tracing() - else: + if type(c) is RawTensor: + with _exclude_from_trace(): c = Tensor(c) self.value = c self.name = name diff --git a/imperative/python/src/module_trace.h b/imperative/python/src/module_trace.h index 410900ec..eee35d6c 100644 --- a/imperative/python/src/module_trace.h +++ b/imperative/python/src/module_trace.h @@ -52,6 +52,12 @@ public: } } + void enable() { m_enabled = 1; } + + void disable() { m_enabled = 0; } + + bool enabled() const { return m_enabled; } + ValueRef unwrap(ValueRef value) override { return value; } std::string name() const override { return "ModuleTraceTransformation"; } diff --git a/imperative/python/src/tensor.cpp b/imperative/python/src/tensor.cpp index 4d7b6767..8cede64a 100644 --- a/imperative/python/src/tensor.cpp +++ b/imperative/python/src/tensor.cpp @@ -219,17 +219,19 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) { PyObject* TensorWrapper::module_trace_info() { if (auto module_trace_info = module_trace_info_map.try_get(m_tensor->data())) { - return module_trace_info->inc_ref().ptr(); - } else { - PyErr_SetString( - PyExc_AttributeError, - "Has no attribute named \'_NodeMixin__node\', please " - "set it first"); - return nullptr; + if (module_trace_info->ptr()) { + return module_trace_info->inc_ref().ptr(); + } } + PyErr_SetString( + PyExc_AttributeError, + "Has no attribute named \'_NodeMixin__node\', please " + "set it first"); + return nullptr; } void TensorWrapper::set_module_trace_info(PyObject* obj) { + // TODO: erase when obj == nullptr module_trace_info_map[m_tensor->data()] = py::reinterpret_borrow(obj); } @@ -1031,29 +1033,23 @@ void init_tensor(py::module m) { static py::function module_trace_hook; - static std::shared_ptr module_trace_transformation; - static int module_tracing = 0; - - m.def("set_module_tracing", [=] { + static auto get_module_trace = [] { + static std::shared_ptr module_trace_transformation; if (!module_trace_transformation) { mgb_assert(module_trace_hook); module_trace_transformation = std::make_shared(module_trace_hook); - } - if (++module_tracing == 1) { - transformations.register_at( + transformations.register_at( module_trace_transformation); } - }); + return module_trace_transformation; + }; - m.def("unset_module_tracing", [=] { - if (--module_tracing == 0) { - transformations.unregister( - module_trace_transformation); - } - }); + m.def("set_module_tracing", [=] { get_module_trace()->enable(); }); + + m.def("unset_module_tracing", [=] { get_module_trace()->disable(); }); - m.def("is_tracing_module", [=] { return module_tracing > 0; }); + m.def("is_tracing_module", [=] { return get_module_trace()->enabled(); }); m.def("set_module_trace_hook", [](py::function function) { module_trace_hook = function; }); diff --git a/imperative/python/test/unit/traced_module/test_trace_module.py b/imperative/python/test/unit/traced_module/test_trace_module.py index d3baf153..43d3f492 100644 --- a/imperative/python/test/unit/traced_module/test_trace_module.py +++ b/imperative/python/test/unit/traced_module/test_trace_module.py @@ -5,9 +5,11 @@ import numpy as np import megengine.functional as F import megengine.module as M from megengine import Tensor -from megengine.module.module import Module +from megengine.core._imperative_rt.core2 import apply +from megengine.core.ops import builtin +from megengine.module import Module from megengine.traced_module import TracedModule, enable_expr_checker, trace_module -from megengine.traced_module.expr import CallFunction +from megengine.traced_module.expr import Apply, CallFunction, Constant class MyModule1(M.Module): @@ -133,3 +135,25 @@ def test_trace_module(): tm6 = trace_module(MyModule5(), a, b) assert tm6.m1.argspec is None assert tm6.m1._is_top is False + + +def test_trace_module_2(): + class Model(M.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + out = x.shape + out = apply(builtin.Elemwise(mode="ADD"), out, Tensor(1)) + return out + + traced_model = trace_module(Model(), Tensor(([1,]))) + + assert isinstance(traced_model.graph._exprs[0], Apply) and isinstance( + traced_model.graph._exprs[0].opdef, builtin.GetVarShape + ) + assert isinstance(traced_model.graph._exprs[1], Constant) + assert isinstance(traced_model.graph._exprs[2], Apply) and isinstance( + traced_model.graph._exprs[2].opdef, builtin.Elemwise + ) + assert int(traced_model(Tensor([1, 2]))[0]) == 3