Browse Source

fix(traced_module): clear node after trace module

GitOrigin-RevId: f7f6024034
tags/v1.8.0
Megvii Engine Team 3 years ago
parent
commit
355782aecb
4 changed files with 27 additions and 1 deletions
  1. +1
    -0
      imperative/python/megengine/traced_module/expr.py
  2. +13
    -0
      imperative/python/megengine/traced_module/module_tracer.py
  3. +5
    -0
      imperative/python/megengine/traced_module/node.py
  4. +8
    -1
      imperative/python/megengine/traced_module/traced_module.py

+ 1
- 0
imperative/python/megengine/traced_module/expr.py View File

@@ -763,6 +763,7 @@ class Constant(Expr):
current_graph = active_module_tracer().current_scope()
current_graph._namespace.auto_naming_for_outputs(expr)
current_graph._insert(expr)
active_module_tracer().current_constant_cache().append(expr.value)
return expr.outputs[0]

def interpret(self, *inputs):


+ 13
- 0
imperative/python/megengine/traced_module/module_tracer.py View File

@@ -131,6 +131,7 @@ class module_tracer:
self._active_scopes = []
self.checker = TracedModuleChecker(self)
self.patcher = Patcher(wrap_fn)
self._activate_constant_cache = []

@classmethod
def register_as_builtin(cls, mod):
@@ -145,16 +146,28 @@ class module_tracer:
def push_scope(self, scope):
self._active_scopes.append(scope)
self.checker.push_scope()
self._activate_constant_cache.append([])


def pop_scope(self):
self._active_scopes.pop()
self.checker.pop_scope()
cache = self._activate_constant_cache.pop()
for obj in cache:
if hasattr(obj, "_NodeMixin__node"):
delattr(obj, "_NodeMixin__node")


def current_scope(self):
if self._active_scopes:
return self._active_scopes[-1]
return None

def current_constant_cache(self):
if self._activate_constant_cache:
return self._activate_constant_cache[-1]
return None

def top_scope(self):
if self._active_scopes:
return self._active_scopes[0]


+ 5
- 0
imperative/python/megengine/traced_module/node.py View File

@@ -380,6 +380,11 @@ class NodeMixin(abc.ABC):
value._record_wrapped_nodes(node)

@classmethod
def clear_node(cls, value):
if hasattr(value, "_NodeMixin__node"):
delattr(value, "_NodeMixin__node")

@classmethod
def get(cls, value, *default):
return getattr(value, "_NodeMixin__node", *default)



+ 8
- 1
imperative/python/megengine/traced_module/traced_module.py View File

@@ -1980,7 +1980,10 @@ class TracedModule(Module):
assert (
treedef in self.argdef_graph_map
), "support input args kwargs format: \n{}, but get: \n{}".format(
"\n ".join("forward({})".format(i._args_kwargs_repr()) for i in self.argdef_graph_map.keys()),
"\n ".join(
"forward({})".format(i._args_kwargs_repr())
for i in self.argdef_graph_map.keys()
),
treedef._args_kwargs_repr(),
)
inputs = filter(
@@ -2514,3 +2517,7 @@ def trace_module(
set_symbolic_shape(use_sym_shape)
set_active_module_tracer(None)
unset_module_tracing()
for t in mod.tensors(recursive=True):
NodeMixin.clear_node(t)
for t in inputs:
NodeMixin.clear_node(t)

Loading…
Cancel
Save