|
@@ -247,6 +247,10 @@ def _init_id2name(mod: Module, prefix: str = ""): |
|
|
class _InsertExprs: |
|
|
class _InsertExprs: |
|
|
def __init__(self, graph, expr: Optional[Expr] = None): |
|
|
def __init__(self, graph, expr: Optional[Expr] = None): |
|
|
self.graph = graph |
|
|
self.graph = graph |
|
|
|
|
|
while graph.top_graph is not None: |
|
|
|
|
|
graph = graph.top_graph |
|
|
|
|
|
assert graph.inputs[0].owner._is_top |
|
|
|
|
|
self.root_graph = graph |
|
|
self.global_scope = InternalGraph( |
|
|
self.global_scope = InternalGraph( |
|
|
graph._name, graph._prefix_name, graph._module_name |
|
|
graph._name, graph._prefix_name, graph._module_name |
|
|
) |
|
|
) |
|
@@ -256,6 +260,9 @@ class _InsertExprs: |
|
|
|
|
|
|
|
|
def __enter__(self): |
|
|
def __enter__(self): |
|
|
self.use_sym_shape = set_symbolic_shape(True) |
|
|
self.use_sym_shape = set_symbolic_shape(True) |
|
|
|
|
|
node_id, expr_id = self.root_graph._total_ids |
|
|
|
|
|
Node.set_total_id(node_id) |
|
|
|
|
|
Expr.set_total_id(expr_id) |
|
|
set_module_tracing() |
|
|
set_module_tracing() |
|
|
_set_convert_node_flag(True) |
|
|
_set_convert_node_flag(True) |
|
|
assert active_module_tracer() is None |
|
|
assert active_module_tracer() is None |
|
@@ -334,10 +341,8 @@ class _InsertExprs: |
|
|
insert_index += 1 |
|
|
insert_index += 1 |
|
|
|
|
|
|
|
|
self.graph._used_names.update(self.global_scope._used_names) |
|
|
self.graph._used_names.update(self.global_scope._used_names) |
|
|
graph = self.graph |
|
|
|
|
|
while graph.top_graph is not None: |
|
|
|
|
|
graph = graph.top_graph |
|
|
|
|
|
graph.inputs[0].owner._update_ref() |
|
|
|
|
|
|
|
|
self.root_graph._total_ids = (Node.get_total_id(), Expr.get_total_id()) |
|
|
|
|
|
self.root_graph.inputs[0].owner._update_ref() |
|
|
return True |
|
|
return True |
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -353,7 +358,8 @@ class InternalGraph: |
|
|
_exprs = None # type: List[Expr] |
|
|
_exprs = None # type: List[Expr] |
|
|
_inputs = None # type: List[Node] |
|
|
_inputs = None # type: List[Node] |
|
|
_outputs = None # type: List[Node] |
|
|
_outputs = None # type: List[Node] |
|
|
_top_graph = None |
|
|
|
|
|
|
|
|
_top_graph = None # type: InternalGraph |
|
|
|
|
|
_total_ids = None # type: List[int] |
|
|
|
|
|
|
|
|
def __init__(self, name: str = None, prefix_name: str = "", module_name: str = ""): |
|
|
def __init__(self, name: str = None, prefix_name: str = "", module_name: str = ""): |
|
|
self._exprs = [] |
|
|
self._exprs = [] |
|
@@ -704,8 +710,12 @@ class InternalGraph: |
|
|
def replace_node(self, repl_dict: Dict[Node, Node]): |
|
|
def replace_node(self, repl_dict: Dict[Node, Node]): |
|
|
while repl_dict: |
|
|
while repl_dict: |
|
|
node, repl_node = repl_dict.popitem() |
|
|
node, repl_node = repl_dict.popitem() |
|
|
|
|
|
assert type(node) == type( |
|
|
|
|
|
repl_node |
|
|
|
|
|
), "The type of {}({}) and {}({}) are not the same".format( |
|
|
|
|
|
node, type(node).__name__, repl_node, type(repl_node).__name__ |
|
|
|
|
|
) |
|
|
# check graph inputs and outputs |
|
|
# check graph inputs and outputs |
|
|
# assert node not in self.inputs, "Cannot replace inputs" |
|
|
|
|
|
for i, n in enumerate(self.outputs): |
|
|
for i, n in enumerate(self.outputs): |
|
|
if n is node: |
|
|
if n is node: |
|
|
self.outputs[i] = repl_node |
|
|
self.outputs[i] = repl_node |
|
@@ -713,7 +723,10 @@ class InternalGraph: |
|
|
# update inputs of expr in node.users |
|
|
# update inputs of expr in node.users |
|
|
graph = repl_node.top_graph |
|
|
graph = repl_node.top_graph |
|
|
assert graph is not None |
|
|
assert graph is not None |
|
|
index = graph._exprs.index(repl_node.expr) |
|
|
|
|
|
|
|
|
assert graph is self |
|
|
|
|
|
index = -1 |
|
|
|
|
|
if not isinstance(repl_node.expr, Input): |
|
|
|
|
|
index = graph._exprs.index(repl_node.expr) |
|
|
dep_exprs = self.get_dep_exprs(repl_node) |
|
|
dep_exprs = self.get_dep_exprs(repl_node) |
|
|
i = 0 |
|
|
i = 0 |
|
|
while i < len(node.users): |
|
|
while i < len(node.users): |
|
@@ -745,6 +758,13 @@ class InternalGraph: |
|
|
n.users.remove(expr) |
|
|
n.users.remove(expr) |
|
|
self._exprs.remove(expr) |
|
|
self._exprs.remove(expr) |
|
|
|
|
|
|
|
|
|
|
|
def _reset_ids(self): |
|
|
|
|
|
for total_expr_id, expr in enumerate(self.exprs()): |
|
|
|
|
|
expr._id = total_expr_id |
|
|
|
|
|
for total_node_id, node in enumerate(self.nodes()): |
|
|
|
|
|
node._id = total_node_id |
|
|
|
|
|
self._total_ids = (total_node_id + 1, total_expr_id + 1) |
|
|
|
|
|
|
|
|
def interpret(self, *inputs): |
|
|
def interpret(self, *inputs): |
|
|
node2value = {} |
|
|
node2value = {} |
|
|
end_nodes_set = set(self._end_point) |
|
|
end_nodes_set = set(self._end_point) |
|
@@ -989,6 +1009,8 @@ class TracedModuleBuilder(NodeMixin): |
|
|
) |
|
|
) |
|
|
for _, g in self._argdef_graph_map.items(): |
|
|
for _, g in self._argdef_graph_map.items(): |
|
|
g.compile() |
|
|
g.compile() |
|
|
|
|
|
if self._is_top: |
|
|
|
|
|
g._total_ids = (Node.get_total_id(), Expr.get_total_id()) |
|
|
|
|
|
|
|
|
for k, v in self.__dict__.items(): |
|
|
for k, v in self.__dict__.items(): |
|
|
if k not in TracedModuleBuilder.__builder_attributes__: |
|
|
if k not in TracedModuleBuilder.__builder_attributes__: |
|
@@ -1247,6 +1269,8 @@ class _expr_iter: |
|
|
self.recursive = recursive |
|
|
self.recursive = recursive |
|
|
|
|
|
|
|
|
def __iter__(self): |
|
|
def __iter__(self): |
|
|
|
|
|
for inp_node in self.graph.inputs: |
|
|
|
|
|
yield inp_node.expr |
|
|
for expr in self.graph._exprs: |
|
|
for expr in self.graph._exprs: |
|
|
if isinstance(expr, CallMethod) and isinstance(expr.inputs[0], ModuleNode): |
|
|
if isinstance(expr, CallMethod) and isinstance(expr.inputs[0], ModuleNode): |
|
|
yield expr |
|
|
yield expr |
|
@@ -1262,10 +1286,10 @@ class _node_iter: |
|
|
node_ids = set() |
|
|
node_ids = set() |
|
|
for expr in graph.exprs(recursive): |
|
|
for expr in graph.exprs(recursive): |
|
|
for n in expr.inputs + expr.outputs: |
|
|
for n in expr.inputs + expr.outputs: |
|
|
if n._id in node_ids: |
|
|
|
|
|
|
|
|
if id(n) in node_ids: |
|
|
continue |
|
|
continue |
|
|
nodes.append(n) |
|
|
nodes.append(n) |
|
|
node_ids.add(n._id) |
|
|
|
|
|
|
|
|
node_ids.add(id(n)) |
|
|
self.nodes = list(sorted(nodes, key=lambda x: x._id)) |
|
|
self.nodes = list(sorted(nodes, key=lambda x: x._id)) |
|
|
|
|
|
|
|
|
def __iter__(self): |
|
|
def __iter__(self): |
|
@@ -1546,6 +1570,7 @@ class TracedModule(Module): |
|
|
active_module_tracer().push_scope(new_module.graph) |
|
|
active_module_tracer().push_scope(new_module.graph) |
|
|
|
|
|
|
|
|
def _flatten_subgraph( |
|
|
def _flatten_subgraph( |
|
|
|
|
|
parent_graph: InternalGraph, |
|
|
graph: InternalGraph, |
|
|
graph: InternalGraph, |
|
|
module: Module, |
|
|
module: Module, |
|
|
call=None, |
|
|
call=None, |
|
@@ -1590,7 +1615,10 @@ class TracedModule(Module): |
|
|
if inp is call_out: |
|
|
if inp is call_out: |
|
|
expr.inputs[index] = repl_dict[out] |
|
|
expr.inputs[index] = repl_dict[out] |
|
|
repl_dict[out].users.append(expr) |
|
|
repl_dict[out].users.append(expr) |
|
|
|
|
|
|
|
|
|
|
|
if parent_graph is not None: |
|
|
|
|
|
for index, parent_out in enumerate(parent_graph._outputs): |
|
|
|
|
|
if parent_out is call_out: |
|
|
|
|
|
parent_graph._outputs[index] = repl_dict[out] |
|
|
continue |
|
|
continue |
|
|
repl_dict[out] = call.outputs[ind] |
|
|
repl_dict[out] = call.outputs[ind] |
|
|
|
|
|
|
|
@@ -1622,6 +1650,7 @@ class TracedModule(Module): |
|
|
) |
|
|
) |
|
|
exprs.extend( |
|
|
exprs.extend( |
|
|
_flatten_subgraph( |
|
|
_flatten_subgraph( |
|
|
|
|
|
graph, |
|
|
expr_graph, |
|
|
expr_graph, |
|
|
obj, |
|
|
obj, |
|
|
expr, |
|
|
expr, |
|
@@ -1643,19 +1672,10 @@ class TracedModule(Module): |
|
|
i.users.remove(call) |
|
|
i.users.remove(call) |
|
|
return exprs |
|
|
return exprs |
|
|
|
|
|
|
|
|
new_module.graph._exprs = _flatten_subgraph(new_module.graph, new_module) |
|
|
|
|
|
|
|
|
new_module.graph._exprs = _flatten_subgraph(None, new_module.graph, new_module) |
|
|
new_module.graph.compile() |
|
|
new_module.graph.compile() |
|
|
set_active_module_tracer(None) |
|
|
set_active_module_tracer(None) |
|
|
for _id, expr in enumerate(new_module.graph._exprs): |
|
|
|
|
|
expr._id = _id |
|
|
|
|
|
total_node_id = 0 |
|
|
|
|
|
for i in new_module.graph._inputs: |
|
|
|
|
|
i._id = total_node_id |
|
|
|
|
|
total_node_id += 1 |
|
|
|
|
|
for expr in new_module.graph._exprs: |
|
|
|
|
|
for o in expr.outputs: |
|
|
|
|
|
o._id = total_node_id |
|
|
|
|
|
total_node_id += 1 |
|
|
|
|
|
|
|
|
new_module.graph._reset_ids() |
|
|
return new_module |
|
|
return new_module |
|
|
|
|
|
|
|
|
def __getstate__(self): |
|
|
def __getstate__(self): |
|
@@ -1735,6 +1755,8 @@ def trace_module(mod: Module, *args: Tensor, **kwargs: Tensor) -> TracedModule: |
|
|
set_active_module_tracer( |
|
|
set_active_module_tracer( |
|
|
module_tracer(_wrapped_function, _init_id2name(mod, "self")) |
|
|
module_tracer(_wrapped_function, _init_id2name(mod, "self")) |
|
|
) |
|
|
) |
|
|
|
|
|
for cls in [Expr, Node]: |
|
|
|
|
|
cls.set_total_id(0) |
|
|
with active_module_tracer().patcher: |
|
|
with active_module_tracer().patcher: |
|
|
global_scope = InternalGraph(name="") |
|
|
global_scope = InternalGraph(name="") |
|
|
active_module_tracer().push_scope(global_scope) |
|
|
active_module_tracer().push_scope(global_scope) |
|
@@ -1750,7 +1772,9 @@ def trace_module(mod: Module, *args: Tensor, **kwargs: Tensor) -> TracedModule: |
|
|
) |
|
|
) |
|
|
builder(*args, **kwargs) |
|
|
builder(*args, **kwargs) |
|
|
active_module_tracer().pop_scope() |
|
|
active_module_tracer().pop_scope() |
|
|
return builder.build() |
|
|
|
|
|
|
|
|
traced_mod = builder.build() |
|
|
|
|
|
traced_mod.graph._reset_ids() |
|
|
|
|
|
return traced_mod |
|
|
finally: |
|
|
finally: |
|
|
set_symbolic_shape(use_sym_shape) |
|
|
set_symbolic_shape(use_sym_shape) |
|
|
set_active_module_tracer(None) |
|
|
set_active_module_tracer(None) |
|
|