GitOrigin-RevId: 9ecf6f2c5b
tags/v1.8.0
@@ -92,7 +92,6 @@ BUILTIN_TENSOR_WRAP_METHOD = [ | |||||
"dtype", | "dtype", | ||||
"grad", | "grad", | ||||
"item", | "item", | ||||
"name", | |||||
"ndim", | "ndim", | ||||
"numpy", | "numpy", | ||||
"qparams", | "qparams", | ||||
@@ -152,6 +151,11 @@ class module_tracer: | |||||
return self._active_scopes[-1] | return self._active_scopes[-1] | ||||
return None | return None | ||||
def top_scope(self): | |||||
if self._active_scopes: | |||||
return self._active_scopes[0] | |||||
return None | |||||
class NotExist: | class NotExist: | ||||
pass | pass | ||||
@@ -180,6 +180,25 @@ def _tensor_to_node(tensors): | |||||
return nodes | return nodes | ||||
def _name_setter(node: Node, new_name: str): | |||||
surgery_mode = _set_graph_surgery_mode(False) | |||||
graph = active_module_tracer().current_scope() | |||||
if node.top_graph is not None: | |||||
top_graph = active_module_tracer().top_scope() | |||||
if node is top_graph._namespace.used_names.get(node._name, None): | |||||
graph = top_graph | |||||
else: | |||||
graph = node.top_graph | |||||
assert ( | |||||
graph._namespace.used_names.get(new_name, None) is None | |||||
), "The name(%s) is already in use. Please try a different one again." % (new_name) | |||||
graph._namespace.unassociate_name_with_obj(node) | |||||
node._name = graph._namespace.create_unique_name(new_name, node) | |||||
_set_graph_surgery_mode(surgery_mode) | |||||
def _wrap_method_to_tensor_node(): | def _wrap_method_to_tensor_node(): | ||||
def _any_method(name, func): | def _any_method(name, func): | ||||
def _any(*args, **kwargs): | def _any(*args, **kwargs): | ||||
@@ -213,6 +232,10 @@ def _wrap_method_to_tensor_node(): | |||||
else: | else: | ||||
patch.set_func(_any_method(method, patch.origin_fn)) | patch.set_func(_any_method(method, patch.origin_fn)) | ||||
tensor_method_patch.append(patch) | tensor_method_patch.append(patch) | ||||
patch = PatchedFn(Node, "name") | |||||
patch.set_func(property(patch.origin_fn.fget, _name_setter)) | |||||
tensor_method_patch.append(patch) | |||||
return tensor_method_patch | return tensor_method_patch | ||||
@@ -377,6 +377,33 @@ def test_set_node_name(): | |||||
rename("output") | rename("output") | ||||
np.testing.assert_equal(str(graph.outputs[0]), "output") | np.testing.assert_equal(str(graph.outputs[0]), "output") | ||||
def add_1(x): | |||||
x = x + 1 | |||||
x.name = "func_add_1" | |||||
return x | |||||
class ModuleAdd_3(M.Module): | |||||
def forward(self, x): | |||||
x = x + 1 | |||||
x.name = "module_add_1" | |||||
x = x + 2 | |||||
return x | |||||
setattr(traced_module, "add_3", ModuleAdd_3()) | |||||
self = graph.inputs[0] | |||||
with graph.insert_exprs(): | |||||
x = output_node + 1 | |||||
x.name = "_add_1" | |||||
x = add_1(x) | |||||
x = self.add_3(x) | |||||
graph.replace_node({output_node: x}) | |||||
graph.compile() | |||||
assert "_add_1" in graph._namespace.used_names | |||||
assert "func_add_1" in graph._namespace.used_names | |||||
assert "module_add_1" in traced_module.add_3.graph._namespace.used_names | |||||
def test_set_graph_name(): | def test_set_graph_name(): | ||||
traced_module, x, expect = _init_module() | traced_module, x, expect = _init_module() | ||||