GitOrigin-RevId: ce11fe5e09
tags/v1.8.0
@@ -606,7 +606,8 @@ class Apply(Expr): | |||
def apply_module_trace_hook(cls, opdef, *inputs): | |||
for i in inputs: | |||
node = NodeMixin.get(i, None) | |||
assert node is not None | |||
if node is None: # capture as constant | |||
NodeMixin.wrap_safe(i, Constant.make(i)) | |||
if isinstance(opdef, FakeQuant): | |||
inp_nodes = [NodeMixin.get(inputs[0])] | |||
@@ -627,7 +628,6 @@ class Apply(Expr): | |||
unset_module_tracing() | |||
outputs = apply(opdef, *inputs) | |||
outputs = list(map(Tensor, outputs)) | |||
set_module_tracing() | |||
apply_node.add_outputs(outputs) | |||
@@ -741,12 +741,8 @@ class Constant(Expr): | |||
assert isinstance(c, (RawTensor, Module)) | |||
if isinstance(c, Module): | |||
assert module_tracer.is_builtin(c) or c.is_qat | |||
if isinstance(c, RawTensor): | |||
if is_tracing_module(): | |||
unset_module_tracing() | |||
c = Tensor(c) | |||
set_module_tracing() | |||
else: | |||
if type(c) is RawTensor: | |||
with _exclude_from_trace(): | |||
c = Tensor(c) | |||
self.value = c | |||
self.name = name | |||
@@ -52,6 +52,12 @@ public: | |||
} | |||
} | |||
void enable() { m_enabled = 1; } | |||
void disable() { m_enabled = 0; } | |||
bool enabled() const { return m_enabled; } | |||
ValueRef unwrap(ValueRef value) override { return value; } | |||
std::string name() const override { return "ModuleTraceTransformation"; } | |||
@@ -219,17 +219,19 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) { | |||
PyObject* TensorWrapper::module_trace_info() { | |||
if (auto module_trace_info = module_trace_info_map.try_get(m_tensor->data())) { | |||
return module_trace_info->inc_ref().ptr(); | |||
} else { | |||
PyErr_SetString( | |||
PyExc_AttributeError, | |||
"Has no attribute named \'_NodeMixin__node\', please " | |||
"set it first"); | |||
return nullptr; | |||
if (module_trace_info->ptr()) { | |||
return module_trace_info->inc_ref().ptr(); | |||
} | |||
} | |||
PyErr_SetString( | |||
PyExc_AttributeError, | |||
"Has no attribute named \'_NodeMixin__node\', please " | |||
"set it first"); | |||
return nullptr; | |||
} | |||
void TensorWrapper::set_module_trace_info(PyObject* obj) { | |||
// TODO: erase when obj == nullptr | |||
module_trace_info_map[m_tensor->data()] = py::reinterpret_borrow<py::object>(obj); | |||
} | |||
@@ -1031,29 +1033,23 @@ void init_tensor(py::module m) { | |||
static py::function module_trace_hook; | |||
static std::shared_ptr<ModuleTraceTransformation> module_trace_transformation; | |||
static int module_tracing = 0; | |||
m.def("set_module_tracing", [=] { | |||
static auto get_module_trace = [] { | |||
static std::shared_ptr<ModuleTraceTransformation> module_trace_transformation; | |||
if (!module_trace_transformation) { | |||
mgb_assert(module_trace_hook); | |||
module_trace_transformation = | |||
std::make_shared<ModuleTraceTransformation>(module_trace_hook); | |||
} | |||
if (++module_tracing == 1) { | |||
transformations.register_at<TransformationManager::ModuleTrace>( | |||
transformations.register_at<Segment::ModuleTrace>( | |||
module_trace_transformation); | |||
} | |||
}); | |||
return module_trace_transformation; | |||
}; | |||
m.def("unset_module_tracing", [=] { | |||
if (--module_tracing == 0) { | |||
transformations.unregister<TransformationManager::ModuleTrace>( | |||
module_trace_transformation); | |||
} | |||
}); | |||
m.def("set_module_tracing", [=] { get_module_trace()->enable(); }); | |||
m.def("unset_module_tracing", [=] { get_module_trace()->disable(); }); | |||
m.def("is_tracing_module", [=] { return module_tracing > 0; }); | |||
m.def("is_tracing_module", [=] { return get_module_trace()->enabled(); }); | |||
m.def("set_module_trace_hook", | |||
[](py::function function) { module_trace_hook = function; }); | |||
@@ -5,9 +5,11 @@ import numpy as np | |||
import megengine.functional as F | |||
import megengine.module as M | |||
from megengine import Tensor | |||
from megengine.module.module import Module | |||
from megengine.core._imperative_rt.core2 import apply | |||
from megengine.core.ops import builtin | |||
from megengine.module import Module | |||
from megengine.traced_module import TracedModule, enable_expr_checker, trace_module | |||
from megengine.traced_module.expr import CallFunction | |||
from megengine.traced_module.expr import Apply, CallFunction, Constant | |||
class MyModule1(M.Module): | |||
@@ -133,3 +135,25 @@ def test_trace_module(): | |||
tm6 = trace_module(MyModule5(), a, b) | |||
assert tm6.m1.argspec is None | |||
assert tm6.m1._is_top is False | |||
def test_trace_module_2(): | |||
class Model(M.Module): | |||
def __init__(self): | |||
super().__init__() | |||
def forward(self, x): | |||
out = x.shape | |||
out = apply(builtin.Elemwise(mode="ADD"), out, Tensor(1)) | |||
return out | |||
traced_model = trace_module(Model(), Tensor(([1,]))) | |||
assert isinstance(traced_model.graph._exprs[0], Apply) and isinstance( | |||
traced_model.graph._exprs[0].opdef, builtin.GetVarShape | |||
) | |||
assert isinstance(traced_model.graph._exprs[1], Constant) | |||
assert isinstance(traced_model.graph._exprs[2], Apply) and isinstance( | |||
traced_model.graph._exprs[2].opdef, builtin.Elemwise | |||
) | |||
assert int(traced_model(Tensor([1, 2]))[0]) == 3 |