|
@@ -68,6 +68,7 @@ from .expr import ( |
|
|
is_call_tensor_method, |
|
|
is_call_tensor_method, |
|
|
is_constant, |
|
|
is_constant, |
|
|
is_getattr, |
|
|
is_getattr, |
|
|
|
|
|
is_input, |
|
|
) |
|
|
) |
|
|
from .fake_quant import FakeQuantize as TM_FakeQuant |
|
|
from .fake_quant import FakeQuantize as TM_FakeQuant |
|
|
from .module_tracer import ( |
|
|
from .module_tracer import ( |
|
@@ -342,13 +343,19 @@ class NameSpace: |
|
|
self.qualname = qualname |
|
|
self.qualname = qualname |
|
|
self._used_names = {} |
|
|
self._used_names = {} |
|
|
|
|
|
|
|
|
def create_unique_name(self, name: str) -> str: |
|
|
|
|
|
|
|
|
def create_unique_name(self, name: str, node: Any = None) -> str: |
|
|
assert isinstance(name, str), "The name must be a string" |
|
|
assert isinstance(name, str), "The name must be a string" |
|
|
|
|
|
|
|
|
|
|
|
if name in self._used_names and self._used_names[name] is node: |
|
|
|
|
|
return name |
|
|
|
|
|
|
|
|
name = re.sub("[^0-9a-zA-Z_]+", "_", name) |
|
|
name = re.sub("[^0-9a-zA-Z_]+", "_", name) |
|
|
if name[0].isdigit(): |
|
|
if name[0].isdigit(): |
|
|
name = "_{}".format(name) |
|
|
name = "_{}".format(name) |
|
|
|
|
|
|
|
|
while name in self._used_names or _is_builtin_name(name): |
|
|
|
|
|
|
|
|
while ( |
|
|
|
|
|
name in self._used_names and self._used_names[name] is not None |
|
|
|
|
|
) or _is_builtin_name(name): |
|
|
match = re.match(r"(.*)_(\d+)$", name) |
|
|
match = re.match(r"(.*)_(\d+)$", name) |
|
|
if match is None: |
|
|
if match is None: |
|
|
name = name + "_1" |
|
|
name = name + "_1" |
|
@@ -357,6 +364,10 @@ class NameSpace: |
|
|
name = "{}_{}".format(base, int(num) + 1) |
|
|
name = "{}_{}".format(base, int(num) + 1) |
|
|
|
|
|
|
|
|
self._used_names.setdefault(name) |
|
|
self._used_names.setdefault(name) |
|
|
|
|
|
|
|
|
|
|
|
if node is not None: |
|
|
|
|
|
self.associate_name_with_obj(name, node) |
|
|
|
|
|
|
|
|
return name |
|
|
return name |
|
|
|
|
|
|
|
|
def auto_naming_for_outputs(self, expr: Expr): |
|
|
def auto_naming_for_outputs(self, expr: Expr): |
|
@@ -384,7 +395,7 @@ class NameSpace: |
|
|
qualname = "{}.{}".format(expr.inputs[0].qualname, expr.name) |
|
|
qualname = "{}.{}".format(expr.inputs[0].qualname, expr.name) |
|
|
name = get_suffix_name(self.qualname, qualname) |
|
|
name = get_suffix_name(self.qualname, qualname) |
|
|
_add_suffix = lambda x: x |
|
|
_add_suffix = lambda x: x |
|
|
elif is_constant(expr): |
|
|
|
|
|
|
|
|
elif is_constant(expr) or is_input(expr): |
|
|
name = ( |
|
|
name = ( |
|
|
expr.name if expr.name else "const_" + type(expr.value).__name__.lower() |
|
|
expr.name if expr.name else "const_" + type(expr.value).__name__.lower() |
|
|
) |
|
|
) |
|
@@ -392,16 +403,25 @@ class NameSpace: |
|
|
_add_suffix = lambda x: x |
|
|
_add_suffix = lambda x: x |
|
|
|
|
|
|
|
|
for node in expr.outputs: |
|
|
for node in expr.outputs: |
|
|
if node._name == "" or node._name in self.used_names: |
|
|
|
|
|
assert _add_suffix(name) == name or isinstance(node, TensorNode) |
|
|
|
|
|
node._name = self.create_unique_name(_add_suffix(name)) |
|
|
|
|
|
|
|
|
cur_name = node._name if node._name else _add_suffix(name) |
|
|
|
|
|
node._name = self.create_unique_name(cur_name, node) |
|
|
if node._qualname == "": |
|
|
if node._qualname == "": |
|
|
node._qualname = qualname |
|
|
node._qualname = qualname |
|
|
assert get_suffix_name(self.qualname, qualname) |
|
|
|
|
|
|
|
|
assert get_suffix_name(self.qualname, qualname) is not None |
|
|
|
|
|
|
|
|
def merge(self, other: "NameSpace"): |
|
|
def merge(self, other: "NameSpace"): |
|
|
self._used_names.update(other.used_names) |
|
|
self._used_names.update(other.used_names) |
|
|
|
|
|
|
|
|
|
|
|
def associate_name_with_obj(self, name: str, node: Node): |
|
|
|
|
|
assert name in self.used_names |
|
|
|
|
|
assert self.used_names[name] is None, "The name(%s) is already in use" % (name) |
|
|
|
|
|
self._used_names[name] = node |
|
|
|
|
|
|
|
|
|
|
|
def unassociate_name_with_obj(self, node: Node): |
|
|
|
|
|
assert node.name in self.used_names |
|
|
|
|
|
assert self.used_names[node.name] is node |
|
|
|
|
|
self._used_names[node.name] = None |
|
|
|
|
|
|
|
|
@property |
|
|
@property |
|
|
def used_names(self): |
|
|
def used_names(self): |
|
|
return self._used_names |
|
|
return self._used_names |
|
@@ -487,7 +507,7 @@ class InternalGraph: |
|
|
"The name(%s) is already in use. Please try a different one again." |
|
|
"The name(%s) is already in use. Please try a different one again." |
|
|
% (new_name) |
|
|
% (new_name) |
|
|
) |
|
|
) |
|
|
new_name = self._namespace.create_unique_name(new_name) |
|
|
|
|
|
|
|
|
new_name = self._namespace.create_unique_name(new_name, self) |
|
|
self._name = new_name |
|
|
self._name = new_name |
|
|
|
|
|
|
|
|
@property |
|
|
@property |
|
@@ -726,6 +746,7 @@ class InternalGraph: |
|
|
node = Input( |
|
|
node = Input( |
|
|
type=TensorNode, name=name, qualname="%s.[%s]" % (self._qualname, name) |
|
|
type=TensorNode, name=name, qualname="%s.[%s]" % (self._qualname, name) |
|
|
).outputs[0] |
|
|
).outputs[0] |
|
|
|
|
|
self._namespace.associate_name_with_obj(node.name, node) |
|
|
node.shape = val.shape |
|
|
node.shape = val.shape |
|
|
node.dtype = val.dtype |
|
|
node.dtype = val.dtype |
|
|
return node |
|
|
return node |
|
@@ -764,9 +785,11 @@ class InternalGraph: |
|
|
assert moudle._is_top, "add_input_node only supports top graph" |
|
|
assert moudle._is_top, "add_input_node only supports top graph" |
|
|
|
|
|
|
|
|
def create_node(name=None): |
|
|
def create_node(name=None): |
|
|
|
|
|
name = self._namespace.create_unique_name(name) |
|
|
node = Input( |
|
|
node = Input( |
|
|
type=TensorNode, name=name, qualname="%s.[%s]" % (self._qualname, name) |
|
|
type=TensorNode, name=name, qualname="%s.[%s]" % (self._qualname, name) |
|
|
).outputs[0] |
|
|
).outputs[0] |
|
|
|
|
|
self._namespace.associate_name_with_obj(node.name, node) |
|
|
node.shape = shape |
|
|
node.shape = shape |
|
|
node.dtype = dtype |
|
|
node.dtype = dtype |
|
|
return node |
|
|
return node |
|
@@ -774,7 +797,7 @@ class InternalGraph: |
|
|
org_argdef = list(moudle.argdef_graph_map.keys())[0] |
|
|
org_argdef = list(moudle.argdef_graph_map.keys())[0] |
|
|
|
|
|
|
|
|
args, kwargs = org_argdef.unflatten(self._inputs) |
|
|
args, kwargs = org_argdef.unflatten(self._inputs) |
|
|
formal_inp_node = create_node(self._namespace.create_unique_name(name)) |
|
|
|
|
|
|
|
|
formal_inp_node = create_node(name) |
|
|
inputs, tree_def = tree_flatten( |
|
|
inputs, tree_def = tree_flatten( |
|
|
((*args, formal_inp_node), kwargs), |
|
|
((*args, formal_inp_node), kwargs), |
|
|
is_const_leaf=lambda x: not isinstance(x, (TensorNode, ModuleNode)), |
|
|
is_const_leaf=lambda x: not isinstance(x, (TensorNode, ModuleNode)), |
|
@@ -1006,6 +1029,8 @@ class InternalGraph: |
|
|
for n in expr.inputs: |
|
|
for n in expr.inputs: |
|
|
n.users.remove(expr) |
|
|
n.users.remove(expr) |
|
|
self._exprs.remove(expr) |
|
|
self._exprs.remove(expr) |
|
|
|
|
|
for n in expr.outputs: |
|
|
|
|
|
self._namespace.unassociate_name_with_obj(n) |
|
|
|
|
|
|
|
|
def _reset_ids(self): |
|
|
def _reset_ids(self): |
|
|
for total_expr_id, expr in enumerate(self.exprs()): |
|
|
for total_expr_id, expr in enumerate(self.exprs()): |
|
@@ -1014,6 +1039,11 @@ class InternalGraph: |
|
|
node._id = total_node_id |
|
|
node._id = total_node_id |
|
|
self._total_ids = (total_node_id + 1, total_expr_id + 1) |
|
|
self._total_ids = (total_node_id + 1, total_expr_id + 1) |
|
|
|
|
|
|
|
|
|
|
|
def _re_associate_name(self): |
|
|
|
|
|
self._namespace.used_names.clear() |
|
|
|
|
|
for node in self.nodes(False): |
|
|
|
|
|
node._name = self._namespace.create_unique_name(node.name, node) |
|
|
|
|
|
|
|
|
def interpret(self, *inputs): |
|
|
def interpret(self, *inputs): |
|
|
node2value = {} |
|
|
node2value = {} |
|
|
end_nodes_set = set(self._end_point) |
|
|
end_nodes_set = set(self._end_point) |
|
@@ -1108,6 +1138,8 @@ class InternalGraph: |
|
|
if n._qualname: |
|
|
if n._qualname: |
|
|
qualname = "{}.{}".format(qualname, n._qualname) |
|
|
qualname = "{}.{}".format(qualname, n._qualname) |
|
|
n._qualname = qualname |
|
|
n._qualname = qualname |
|
|
|
|
|
self._namespace = NameSpace(self._name, self._qualname) |
|
|
|
|
|
self._re_associate_name() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_meth_name(obj, func): |
|
|
def _get_meth_name(obj, func): |
|
@@ -1372,6 +1404,7 @@ class TracedModuleBuilder(NodeMixin): |
|
|
continue |
|
|
continue |
|
|
for g in mod.argdef_graph_map.values(): |
|
|
for g in mod.argdef_graph_map.values(): |
|
|
replace_qualname(g) |
|
|
replace_qualname(g) |
|
|
|
|
|
g._namespace.qualname = g.qualname |
|
|
for n in g.nodes(False): |
|
|
for n in g.nodes(False): |
|
|
replace_qualname(n) |
|
|
replace_qualname(n) |
|
|
else: |
|
|
else: |
|
@@ -1383,6 +1416,7 @@ class TracedModuleBuilder(NodeMixin): |
|
|
name=parent_graph._namespace.create_unique_name(module_qualname), |
|
|
name=parent_graph._namespace.create_unique_name(module_qualname), |
|
|
qualname=module_qualname, |
|
|
qualname=module_qualname, |
|
|
) |
|
|
) |
|
|
|
|
|
parent_graph._namespace.associate_name_with_obj(self._body.name, self._body) |
|
|
active_module_tracer().push_scope(self._body) |
|
|
active_module_tracer().push_scope(self._body) |
|
|
# rebind self to new input node |
|
|
# rebind self to new input node |
|
|
|
|
|
|
|
@@ -1552,6 +1586,7 @@ class _expr_iter: |
|
|
def __init__(self, graph: InternalGraph, recursive: bool = True): |
|
|
def __init__(self, graph: InternalGraph, recursive: bool = True): |
|
|
self.graph = graph |
|
|
self.graph = graph |
|
|
self.recursive = recursive |
|
|
self.recursive = recursive |
|
|
|
|
|
self._visited_graph = set() |
|
|
|
|
|
|
|
|
def __iter__(self): |
|
|
def __iter__(self): |
|
|
for inp_node in self.graph.inputs: |
|
|
for inp_node in self.graph.inputs: |
|
@@ -1559,8 +1594,13 @@ class _expr_iter: |
|
|
for expr in self.graph._exprs: |
|
|
for expr in self.graph._exprs: |
|
|
if isinstance(expr, CallMethod) and isinstance(expr.inputs[0], ModuleNode): |
|
|
if isinstance(expr, CallMethod) and isinstance(expr.inputs[0], ModuleNode): |
|
|
yield expr |
|
|
yield expr |
|
|
if self.recursive and expr.graph is not None: |
|
|
|
|
|
|
|
|
if ( |
|
|
|
|
|
self.recursive |
|
|
|
|
|
and expr.graph is not None |
|
|
|
|
|
and id(expr.graph) not in self._visited_graph |
|
|
|
|
|
): |
|
|
yield from expr.graph.exprs(self.recursive) |
|
|
yield from expr.graph.exprs(self.recursive) |
|
|
|
|
|
self._visited_graph.add(id(expr.graph)) |
|
|
else: |
|
|
else: |
|
|
yield expr |
|
|
yield expr |
|
|
|
|
|
|
|
@@ -1570,12 +1610,11 @@ class _node_iter: |
|
|
nodes = [] |
|
|
nodes = [] |
|
|
node_ids = set() |
|
|
node_ids = set() |
|
|
for expr in graph.exprs(recursive): |
|
|
for expr in graph.exprs(recursive): |
|
|
for n in expr.inputs + expr.outputs: |
|
|
|
|
|
if id(n) in node_ids: |
|
|
|
|
|
continue |
|
|
|
|
|
|
|
|
for n in expr.outputs: |
|
|
|
|
|
assert id(n) not in node_ids |
|
|
nodes.append(n) |
|
|
nodes.append(n) |
|
|
node_ids.add(id(n)) |
|
|
node_ids.add(id(n)) |
|
|
self.nodes = list(sorted(nodes, key=lambda x: x._id)) |
|
|
|
|
|
|
|
|
self.nodes = nodes |
|
|
|
|
|
|
|
|
def __iter__(self): |
|
|
def __iter__(self): |
|
|
for node in self.nodes: |
|
|
for node in self.nodes: |
|
@@ -2076,10 +2115,12 @@ class TracedModule(Module): |
|
|
|
|
|
|
|
|
if parent_graph is not None: |
|
|
if parent_graph is not None: |
|
|
for node in expr.outputs: |
|
|
for node in expr.outputs: |
|
|
if node in rename_blacklist: |
|
|
|
|
|
continue |
|
|
|
|
|
name = "{}_{}".format(prefix_name, node._name) |
|
|
|
|
|
node._name = parent_graph._namespace.create_unique_name(name) |
|
|
|
|
|
|
|
|
name = node._name |
|
|
|
|
|
if node not in rename_blacklist: |
|
|
|
|
|
name = "{}_{}".format(prefix_name, name) |
|
|
|
|
|
node._name = parent_graph._namespace.create_unique_name( |
|
|
|
|
|
name, node |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
exprs.append(expr) |
|
|
exprs.append(expr) |
|
|
|
|
|
|
|
@@ -2092,6 +2133,7 @@ class TracedModule(Module): |
|
|
new_module.graph._exprs = _flatten_subgraph( |
|
|
new_module.graph._exprs = _flatten_subgraph( |
|
|
None, new_module.graph, None, new_module |
|
|
None, new_module.graph, None, new_module |
|
|
) |
|
|
) |
|
|
|
|
|
new_module.graph._re_associate_name() |
|
|
new_module.graph.compile() |
|
|
new_module.graph.compile() |
|
|
new_module.graph._reset_ids() |
|
|
new_module.graph._reset_ids() |
|
|
return new_module |
|
|
return new_module |
|
|