|
|
@@ -180,6 +180,25 @@ def _tensor_to_node(tensors): |
|
|
|
return nodes |
|
|
|
|
|
|
|
|
|
|
|
def _name_setter(node: Node, new_name: str): |
|
|
|
surgery_mode = _set_graph_surgery_mode(False) |
|
|
|
graph = active_module_tracer().current_scope() |
|
|
|
|
|
|
|
if node.top_graph is not None: |
|
|
|
top_graph = active_module_tracer().top_scope() |
|
|
|
if node is top_graph._namespace.used_names.get(node._name, None): |
|
|
|
graph = top_graph |
|
|
|
else: |
|
|
|
graph = node.top_graph |
|
|
|
|
|
|
|
assert ( |
|
|
|
graph._namespace.used_names.get(new_name, None) is None |
|
|
|
), "The name(%s) is already in use. Please try a different one again." % (new_name) |
|
|
|
graph._namespace.unassociate_name_with_obj(node) |
|
|
|
node._name = graph._namespace.create_unique_name(new_name, node) |
|
|
|
_set_graph_surgery_mode(surgery_mode) |
|
|
|
|
|
|
|
|
|
|
|
def _wrap_method_to_tensor_node(): |
|
|
|
def _any_method(name, func): |
|
|
|
def _any(*args, **kwargs): |
|
|
@@ -213,6 +232,10 @@ def _wrap_method_to_tensor_node(): |
|
|
|
else: |
|
|
|
patch.set_func(_any_method(method, patch.origin_fn)) |
|
|
|
tensor_method_patch.append(patch) |
|
|
|
|
|
|
|
patch = PatchedFn(Node, "name") |
|
|
|
patch.set_func(property(patch.origin_fn.fget, _name_setter)) |
|
|
|
tensor_method_patch.append(patch) |
|
|
|
return tensor_method_patch |
|
|
|
|
|
|
|
|
|
|
|