GitOrigin-RevId: f7f6024034
tags/v1.8.0
@@ -763,6 +763,7 @@ class Constant(Expr): | |||||
current_graph = active_module_tracer().current_scope() | current_graph = active_module_tracer().current_scope() | ||||
current_graph._namespace.auto_naming_for_outputs(expr) | current_graph._namespace.auto_naming_for_outputs(expr) | ||||
current_graph._insert(expr) | current_graph._insert(expr) | ||||
active_module_tracer().current_constant_cache().append(expr.value) | |||||
return expr.outputs[0] | return expr.outputs[0] | ||||
def interpret(self, *inputs): | def interpret(self, *inputs): | ||||
@@ -131,6 +131,7 @@ class module_tracer: | |||||
self._active_scopes = [] | self._active_scopes = [] | ||||
self.checker = TracedModuleChecker(self) | self.checker = TracedModuleChecker(self) | ||||
self.patcher = Patcher(wrap_fn) | self.patcher = Patcher(wrap_fn) | ||||
self._activate_constant_cache = [] | |||||
@classmethod | @classmethod | ||||
def register_as_builtin(cls, mod): | def register_as_builtin(cls, mod): | ||||
@@ -145,16 +146,28 @@ class module_tracer: | |||||
def push_scope(self, scope): | def push_scope(self, scope): | ||||
self._active_scopes.append(scope) | self._active_scopes.append(scope) | ||||
self.checker.push_scope() | self.checker.push_scope() | ||||
self._activate_constant_cache.append([]) | |||||
def pop_scope(self): | def pop_scope(self): | ||||
self._active_scopes.pop() | self._active_scopes.pop() | ||||
self.checker.pop_scope() | self.checker.pop_scope() | ||||
cache = self._activate_constant_cache.pop() | |||||
for obj in cache: | |||||
if hasattr(obj, "_NodeMixin__node"): | |||||
delattr(obj, "_NodeMixin__node") | |||||
def current_scope(self): | def current_scope(self): | ||||
if self._active_scopes: | if self._active_scopes: | ||||
return self._active_scopes[-1] | return self._active_scopes[-1] | ||||
return None | return None | ||||
def current_constant_cache(self): | |||||
if self._activate_constant_cache: | |||||
return self._activate_constant_cache[-1] | |||||
return None | |||||
def top_scope(self): | def top_scope(self): | ||||
if self._active_scopes: | if self._active_scopes: | ||||
return self._active_scopes[0] | return self._active_scopes[0] | ||||
@@ -380,6 +380,11 @@ class NodeMixin(abc.ABC): | |||||
value._record_wrapped_nodes(node) | value._record_wrapped_nodes(node) | ||||
@classmethod | @classmethod | ||||
def clear_node(cls, value): | |||||
if hasattr(value, "_NodeMixin__node"): | |||||
delattr(value, "_NodeMixin__node") | |||||
@classmethod | |||||
def get(cls, value, *default): | def get(cls, value, *default): | ||||
return getattr(value, "_NodeMixin__node", *default) | return getattr(value, "_NodeMixin__node", *default) | ||||
@@ -1980,7 +1980,10 @@ class TracedModule(Module): | |||||
assert ( | assert ( | ||||
treedef in self.argdef_graph_map | treedef in self.argdef_graph_map | ||||
), "support input args kwargs format: \n{}, but get: \n{}".format( | ), "support input args kwargs format: \n{}, but get: \n{}".format( | ||||
"\n ".join("forward({})".format(i._args_kwargs_repr()) for i in self.argdef_graph_map.keys()), | |||||
"\n ".join( | |||||
"forward({})".format(i._args_kwargs_repr()) | |||||
for i in self.argdef_graph_map.keys() | |||||
), | |||||
treedef._args_kwargs_repr(), | treedef._args_kwargs_repr(), | ||||
) | ) | ||||
inputs = filter( | inputs = filter( | ||||
@@ -2514,3 +2517,7 @@ def trace_module( | |||||
set_symbolic_shape(use_sym_shape) | set_symbolic_shape(use_sym_shape) | ||||
set_active_module_tracer(None) | set_active_module_tracer(None) | ||||
unset_module_tracing() | unset_module_tracing() | ||||
for t in mod.tensors(recursive=True): | |||||
NodeMixin.clear_node(t) | |||||
for t in inputs: | |||||
NodeMixin.clear_node(t) |