Browse Source

feat(mge/traced_module): support to modify the name of Node during graph surgery

GitOrigin-RevId: 9ecf6f2c5b
tags/v1.8.0
Megvii Engine Team 3 years ago
parent
commit
3ff5ca5ffe
3 changed files with 55 additions and 1 deletions
  1. +5
    -1
      imperative/python/megengine/traced_module/module_tracer.py
  2. +23
    -0
      imperative/python/megengine/traced_module/traced_module.py
  3. +27
    -0
      imperative/python/test/unit/traced_module/test_modification.py

+ 5
- 1
imperative/python/megengine/traced_module/module_tracer.py View File

@@ -92,7 +92,6 @@ BUILTIN_TENSOR_WRAP_METHOD = [
"dtype",
"grad",
"item",
"name",
"ndim",
"numpy",
"qparams",
@@ -152,6 +151,11 @@ class module_tracer:
return self._active_scopes[-1]
return None

def top_scope(self):
if self._active_scopes:
return self._active_scopes[0]
return None


class NotExist:
pass


+ 23
- 0
imperative/python/megengine/traced_module/traced_module.py View File

@@ -180,6 +180,25 @@ def _tensor_to_node(tensors):
return nodes


def _name_setter(node: Node, new_name: str):
surgery_mode = _set_graph_surgery_mode(False)
graph = active_module_tracer().current_scope()

if node.top_graph is not None:
top_graph = active_module_tracer().top_scope()
if node is top_graph._namespace.used_names.get(node._name, None):
graph = top_graph
else:
graph = node.top_graph

assert (
graph._namespace.used_names.get(new_name, None) is None
), "The name(%s) is already in use. Please try a different one again." % (new_name)
graph._namespace.unassociate_name_with_obj(node)
node._name = graph._namespace.create_unique_name(new_name, node)
_set_graph_surgery_mode(surgery_mode)


def _wrap_method_to_tensor_node():
def _any_method(name, func):
def _any(*args, **kwargs):
@@ -213,6 +232,10 @@ def _wrap_method_to_tensor_node():
else:
patch.set_func(_any_method(method, patch.origin_fn))
tensor_method_patch.append(patch)

patch = PatchedFn(Node, "name")
patch.set_func(property(patch.origin_fn.fget, _name_setter))
tensor_method_patch.append(patch)
return tensor_method_patch




+ 27
- 0
imperative/python/test/unit/traced_module/test_modification.py View File

@@ -377,6 +377,33 @@ def test_set_node_name():
rename("output")
np.testing.assert_equal(str(graph.outputs[0]), "output")

def add_1(x):
x = x + 1
x.name = "func_add_1"
return x

class ModuleAdd_3(M.Module):
def forward(self, x):
x = x + 1
x.name = "module_add_1"
x = x + 2
return x

setattr(traced_module, "add_3", ModuleAdd_3())

self = graph.inputs[0]
with graph.insert_exprs():
x = output_node + 1
x.name = "_add_1"
x = add_1(x)
x = self.add_3(x)
graph.replace_node({output_node: x})
graph.compile()

assert "_add_1" in graph._namespace.used_names
assert "func_add_1" in graph._namespace.used_names
assert "module_add_1" in traced_module.add_3.graph._namespace.used_names


def test_set_graph_name():
traced_module, x, expect = _init_module()


Loading…
Cancel
Save