|
|
@@ -544,11 +544,11 @@ class InternalGraph: |
|
|
|
graph = self.top_graph |
|
|
|
assert graph is not None or mod._is_top, "The parent graph cannot be None." |
|
|
|
if graph is not None: |
|
|
|
assert new_name not in self._namespace.used_names, ( |
|
|
|
assert graph._namespace.used_names.get(new_name, None) is None, ( |
|
|
|
"The name(%s) is already in use. Please try a different one again." |
|
|
|
% (new_name) |
|
|
|
) |
|
|
|
new_name = self._namespace.create_unique_name(new_name, self) |
|
|
|
new_name = graph._namespace.create_unique_name(new_name, self) |
|
|
|
self._name = new_name |
|
|
|
|
|
|
|
@property |
|
|
@@ -1032,21 +1032,33 @@ class InternalGraph: |
|
|
|
n.inputs[idx] = repl_node |
|
|
|
|
|
|
|
def _merge_getattr_expr(self): |
|
|
|
getattr_nodes_map = dict() |
|
|
|
for expr in self._exprs: |
|
|
|
if not isinstance(expr, GetAttr): |
|
|
|
continue |
|
|
|
attr_name = get_suffix_name(self.qualname, expr.outputs[0].qualname) |
|
|
|
assert attr_name, '"{}" is not a prefix of "{}"'.format( |
|
|
|
self.qualname, expr.outputs[0].qualname |
|
|
|
) |
|
|
|
if attr_name in getattr_nodes_map: |
|
|
|
base_node = getattr_nodes_map[attr_name] |
|
|
|
getattr_nodes_map = dict() # Dcit[(Node, str), Node] |
|
|
|
node_to_attrname = dict() # Dict[Node, (Node, Str)] |
|
|
|
for expr in filter(lambda x: isinstance(x, GetAttr), self._exprs): |
|
|
|
base_node, attr_name = expr.inputs[0], expr.name |
|
|
|
if expr.inputs[0] in node_to_attrname: |
|
|
|
base_node, base_name = node_to_attrname[expr.inputs[0]] |
|
|
|
attr_name = "{}.{}".format(base_name, expr.name) |
|
|
|
|
|
|
|
if get_suffix_name(self.qualname, expr.outputs[0].qualname) != attr_name: |
|
|
|
expected_qualname = base_node.qualname + "." + attr_name |
|
|
|
logger.warning( |
|
|
|
"{}.qualname expects {}, got {} actually. You can re-trace this " |
|
|
|
"TracedModel to make the name correct.".format( |
|
|
|
expr.outputs[0], expected_qualname, expr.outputs[0].qualname |
|
|
|
) |
|
|
|
) |
|
|
|
expr.outputs[0]._qualname = expected_qualname |
|
|
|
|
|
|
|
key = (base_node, attr_name) |
|
|
|
node_to_attrname[expr.outputs[0]] = key |
|
|
|
if key in getattr_nodes_map: |
|
|
|
existed_node = getattr_nodes_map[key] |
|
|
|
repl_node = expr.outputs[0] |
|
|
|
for expr in repl_node.users: |
|
|
|
base_node.users.append(expr) |
|
|
|
existed_node.users.append(expr) |
|
|
|
idx = expr.inputs.index(repl_node) |
|
|
|
expr.inputs[idx] = base_node |
|
|
|
expr.inputs[idx] = existed_node |
|
|
|
repl_node.users = [] |
|
|
|
else: |
|
|
|
if attr_name != expr.name: |
|
|
@@ -1054,7 +1066,7 @@ class InternalGraph: |
|
|
|
expr.inputs[0].users.remove(expr) |
|
|
|
self.inputs[0].users.append(expr) |
|
|
|
expr.inputs[0] = self.inputs[0] |
|
|
|
getattr_nodes_map[attr_name] = expr.outputs[0] |
|
|
|
getattr_nodes_map[key] = expr.outputs[0] |
|
|
|
|
|
|
|
def compile(self): |
|
|
|
r"""Delete unused expr.""" |
|
|
|