|
@@ -9,12 +9,10 @@ |
|
|
import collections |
|
|
import collections |
|
|
import copy |
|
|
import copy |
|
|
import functools |
|
|
import functools |
|
|
|
|
|
import inspect |
|
|
import weakref |
|
|
import weakref |
|
|
from inspect import getmembers, isclass, ismethod |
|
|
from inspect import getmembers, isclass, ismethod |
|
|
from typing import Callable, Dict, Iterable, List, Sequence, Type |
|
|
|
|
|
|
|
|
|
|
|
import numpy as np |
|
|
|
|
|
from numpy.lib.arraysetops import isin |
|
|
|
|
|
|
|
|
from typing import Callable, Dict, Iterable, List, Optional, Sequence, Type, Union |
|
|
|
|
|
|
|
|
from ... import functional as F |
|
|
from ... import functional as F |
|
|
from ... import get_logger |
|
|
from ... import get_logger |
|
@@ -43,9 +41,9 @@ logger = get_logger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _leaf_type(node): |
|
|
def _leaf_type(node): |
|
|
if isinstance(node, RawTensor): |
|
|
|
|
|
|
|
|
if isinstance(node, (RawTensor, TensorNode)): |
|
|
return (Tensor, TensorNode) |
|
|
return (Tensor, TensorNode) |
|
|
elif isinstance(node, (NodeMixin, Module)): |
|
|
|
|
|
|
|
|
elif isinstance(node, (NodeMixin, Module, ModuleNode)): |
|
|
return (Module, ModuleNode, NodeMixin) |
|
|
return (Module, ModuleNode, NodeMixin) |
|
|
else: |
|
|
else: |
|
|
return type(node) |
|
|
return type(node) |
|
@@ -64,6 +62,50 @@ def _is_const_leaf(node): |
|
|
return True |
|
|
return True |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def wrap_tensors(tensors: Tensor, nodes: TensorNode): |
|
|
|
|
|
inp_tensors = copy.deepcopy(tensors) |
|
|
|
|
|
inp_tensors, inp_def_v = tree_flatten( |
|
|
|
|
|
inp_tensors, leaf_type=_leaf_type, is_const_leaf=_is_const_leaf |
|
|
|
|
|
) |
|
|
|
|
|
inp_nodes, inp_def_n = tree_flatten( |
|
|
|
|
|
nodes, leaf_type=_leaf_type, is_const_leaf=_is_const_leaf |
|
|
|
|
|
) |
|
|
|
|
|
for v, n in zip(inp_tensors, inp_nodes): |
|
|
|
|
|
if isinstance(n, TensorNode) and isinstance(v, Tensor): |
|
|
|
|
|
NodeMixin.wrap_safe(v, n) |
|
|
|
|
|
return inp_def_v.unflatten(inp_tensors) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _InsertExprs: |
|
|
|
|
|
def __init__(self, graph, expr: Optional[Expr] = None, after: bool = True): |
|
|
|
|
|
self.graph = graph |
|
|
|
|
|
self.global_scope = InternalGraph() |
|
|
|
|
|
self.expr = expr |
|
|
|
|
|
self.after = after |
|
|
|
|
|
|
|
|
|
|
|
def __enter__(self): |
|
|
|
|
|
self.use_sym_shape = set_symbolic_shape(True) |
|
|
|
|
|
set_module_tracing() |
|
|
|
|
|
assert active_module_tracer() is None |
|
|
|
|
|
set_active_module_tracer(module_tracer(_wrapped_function)) |
|
|
|
|
|
active_module_tracer().patcher.__enter__() |
|
|
|
|
|
active_module_tracer().push_scope(self.global_scope) |
|
|
|
|
|
|
|
|
|
|
|
def __exit__(self, ty, va, tr): |
|
|
|
|
|
set_symbolic_shape(self.use_sym_shape) |
|
|
|
|
|
unset_module_tracing() |
|
|
|
|
|
active_module_tracer().patcher.__exit__(ty, va, tr) |
|
|
|
|
|
set_active_module_tracer(None) |
|
|
|
|
|
index = len(self.graph._exprs) if self.after else 0 |
|
|
|
|
|
if self.expr is not None: |
|
|
|
|
|
index = self.graph._exprs.index(self.expr) |
|
|
|
|
|
if self.after: |
|
|
|
|
|
index += 1 |
|
|
|
|
|
for expr in self.global_scope._exprs: |
|
|
|
|
|
self.graph._exprs.insert(index, expr) |
|
|
|
|
|
index += 1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class InternalGraph: |
|
|
class InternalGraph: |
|
|
""" |
|
|
""" |
|
|
``InternalGraph`` is a graph consist of ``Node`` and ``Expr``, it is used to represent the execution procedure of Module's forward method. |
|
|
``InternalGraph`` is a graph consist of ``Node`` and ``Expr``, it is used to represent the execution procedure of Module's forward method. |
|
@@ -95,14 +137,28 @@ class InternalGraph: |
|
|
return self._outputs |
|
|
return self._outputs |
|
|
|
|
|
|
|
|
@property |
|
|
@property |
|
|
def exprs(self): |
|
|
|
|
|
|
|
|
def expr_filter(self): |
|
|
return ExprFilter(_expr_iter(self)) |
|
|
return ExprFilter(_expr_iter(self)) |
|
|
|
|
|
|
|
|
def get_call_function(self, func: Callable = None): |
|
|
|
|
|
return self.exprs.call_function(func) |
|
|
|
|
|
|
|
|
@property |
|
|
|
|
|
def node_filter(self): |
|
|
|
|
|
return NodeFilter(_node_iter(self)) |
|
|
|
|
|
|
|
|
|
|
|
def get_function_by_type(self, func: Callable = None): |
|
|
|
|
|
return self.expr_filter.call_function(func) |
|
|
|
|
|
|
|
|
|
|
|
def get_method_by_type(self, method: str = None): |
|
|
|
|
|
return self.expr_filter.call_method(method) |
|
|
|
|
|
|
|
|
def get_call_method(self, method: str = None): |
|
|
|
|
|
return self.exprs.call_method(method) |
|
|
|
|
|
|
|
|
def get_expr_by_id(self, expr_id: List[int] = None): |
|
|
|
|
|
return self.expr_filter.expr_id(expr_id) |
|
|
|
|
|
|
|
|
|
|
|
def get_module_by_type(self, module_cls: Module): |
|
|
|
|
|
assert issubclass(module_cls, Module) |
|
|
|
|
|
return self.node_filter.type(module_cls, ModuleNode) |
|
|
|
|
|
|
|
|
|
|
|
def get_node_by_id(self, node_id: List[int] = None): |
|
|
|
|
|
return self.node_filter.node_id(node_id) |
|
|
|
|
|
|
|
|
def add_input(self, i): |
|
|
def add_input(self, i): |
|
|
self._inputs.append(i) |
|
|
self._inputs.append(i) |
|
@@ -124,7 +180,6 @@ class InternalGraph: |
|
|
for idx, o in enumerate(self._outputs): |
|
|
for idx, o in enumerate(self._outputs): |
|
|
if o in repl_dict: |
|
|
if o in repl_dict: |
|
|
self._outputs[idx] = repl_dict[o] |
|
|
self._outputs[idx] = repl_dict[o] |
|
|
self._outputs[idx].expr = node.expr |
|
|
|
|
|
|
|
|
|
|
|
for expr in self._exprs: |
|
|
for expr in self._exprs: |
|
|
|
|
|
|
|
@@ -135,83 +190,283 @@ class InternalGraph: |
|
|
for idx, o in enumerate(expr.outputs): |
|
|
for idx, o in enumerate(expr.outputs): |
|
|
if o in repl_dict: |
|
|
if o in repl_dict: |
|
|
expr.outputs[idx] = repl_dict[o] |
|
|
expr.outputs[idx] = repl_dict[o] |
|
|
|
|
|
expr.outputs[idx].expr = expr |
|
|
|
|
|
|
|
|
def get_dep_exprs(self, nodes: Sequence[Node]) -> List[Expr]: |
|
|
def get_dep_exprs(self, nodes: Sequence[Node]) -> List[Expr]: |
|
|
if not isinstance(nodes, Sequence): |
|
|
if not isinstance(nodes, Sequence): |
|
|
nodes = (nodes,) |
|
|
nodes = (nodes,) |
|
|
ret = list() |
|
|
ret = list() |
|
|
queue = list(nodes) |
|
|
queue = list(nodes) |
|
|
|
|
|
visited_queue = list() |
|
|
while queue: |
|
|
while queue: |
|
|
node = queue.pop() |
|
|
node = queue.pop() |
|
|
|
|
|
visited_queue.append(node) |
|
|
|
|
|
|
|
|
expr = node.expr |
|
|
expr = node.expr |
|
|
|
|
|
|
|
|
if expr not in ret: |
|
|
if expr not in ret: |
|
|
ret.append(expr) |
|
|
ret.append(expr) |
|
|
|
|
|
|
|
|
for i in expr.inputs: |
|
|
for i in expr.inputs: |
|
|
if i not in queue: |
|
|
|
|
|
|
|
|
if i not in queue and i not in visited_queue: |
|
|
queue.append(i) |
|
|
queue.append(i) |
|
|
return ret |
|
|
return ret |
|
|
|
|
|
|
|
|
def insert_call_function(self, func: Callable, nodes: Sequence[Node]): |
|
|
|
|
|
if not isinstance(nodes, Sequence): |
|
|
|
|
|
nodes = [nodes] |
|
|
|
|
|
assert isinstance(func, Callable) |
|
|
|
|
|
for i in nodes: |
|
|
|
|
|
assert isinstance( |
|
|
|
|
|
i, TensorNode |
|
|
|
|
|
), "CallFunction only accept TensorNode as inputs" |
|
|
|
|
|
|
|
|
def reset_inputs(self, *args, **kwargs): |
|
|
|
|
|
forma_mnode = self.inputs[0] |
|
|
|
|
|
actual_mnodes = forma_mnode.actual_mnode |
|
|
|
|
|
call_nodes = [] |
|
|
|
|
|
for n in actual_mnodes: |
|
|
|
|
|
for c_expr in n.users: |
|
|
|
|
|
if isinstance(c_expr, CallMethod) and c_expr.method == "__call__": |
|
|
|
|
|
call_nodes.append((c_expr, n)) |
|
|
|
|
|
|
|
|
expr = CallFunction(func) |
|
|
|
|
|
expr.inputs = nodes |
|
|
|
|
|
|
|
|
moudle = forma_mnode.owner |
|
|
|
|
|
assert moudle._is_top, "reset_inputs only support the top-level graph" |
|
|
|
|
|
|
|
|
|
|
|
inputs, tree_def = tree_flatten( |
|
|
|
|
|
((moudle, *args), kwargs), |
|
|
|
|
|
leaf_type=_leaf_type, |
|
|
|
|
|
is_const_leaf=_is_const_leaf, |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
for i in nodes: |
|
|
|
|
|
i.users.append(expr) |
|
|
|
|
|
|
|
|
def create_node(val: Tensor): |
|
|
|
|
|
node = Input(type=TensorNode).outputs[0] |
|
|
|
|
|
node.shape = val.shape |
|
|
|
|
|
node.dtype = val.dtype |
|
|
|
|
|
return node |
|
|
|
|
|
|
|
|
idx = max(self._exprs.index(i.expr) for i in nodes) + 1 |
|
|
|
|
|
self._exprs.insert(idx, expr) |
|
|
|
|
|
|
|
|
formal_node_inputs = [ |
|
|
|
|
|
forma_mnode, |
|
|
|
|
|
] |
|
|
|
|
|
|
|
|
|
|
|
org_argdef = list(moudle.argdef_graph_map.keys())[0] |
|
|
|
|
|
if call_nodes: |
|
|
|
|
|
org_argdef = call_nodes[0][0].arg_def |
|
|
|
|
|
|
|
|
|
|
|
for v in inputs[1:]: |
|
|
|
|
|
assert isinstance(v, RawTensor) |
|
|
|
|
|
formal_node_inputs.append(create_node(v)) |
|
|
|
|
|
|
|
|
|
|
|
actual_nodes = [] |
|
|
|
|
|
for e, n in call_nodes: |
|
|
|
|
|
e.arg_def = tree_def |
|
|
|
|
|
actual_node_inputs = [ |
|
|
|
|
|
n, |
|
|
|
|
|
] |
|
|
|
|
|
for v in inputs[1:]: |
|
|
|
|
|
actual_node_inputs.append(create_node(v)) |
|
|
|
|
|
|
|
|
|
|
|
for org_n in e.inputs: |
|
|
|
|
|
org_n.users.pop(e) |
|
|
|
|
|
|
|
|
|
|
|
e.inputs[:] = actual_node_inputs |
|
|
|
|
|
e.const_val = [] |
|
|
|
|
|
actual_nodes.append(actual_node_inputs[1:]) |
|
|
|
|
|
|
|
|
|
|
|
self._inputs[:] = formal_node_inputs |
|
|
|
|
|
moudle.argdef_graph_map[tree_def] = moudle.argdef_graph_map.pop(org_argdef) |
|
|
|
|
|
moudle.argdef_outdef_map[tree_def] = moudle.argdef_outdef_map.pop(org_argdef) |
|
|
|
|
|
|
|
|
|
|
|
# return formal_node_inputs[1:], actual_nodes |
|
|
|
|
|
return formal_node_inputs[1:] |
|
|
|
|
|
|
|
|
|
|
|
def add_input_node(self, shape, dtype="float32"): |
|
|
|
|
|
forma_mnode = self.inputs[0] |
|
|
|
|
|
actual_mnodes = forma_mnode.actual_mnode |
|
|
|
|
|
|
|
|
|
|
|
moudle = forma_mnode.owner |
|
|
|
|
|
assert moudle._is_top, "add_input_node only support the top-level graph" |
|
|
|
|
|
|
|
|
|
|
|
call_nodes = [] |
|
|
|
|
|
for n in actual_mnodes: |
|
|
|
|
|
for c_expr in n.users: |
|
|
|
|
|
if isinstance(c_expr, CallMethod) and c_expr.method == "__call__": |
|
|
|
|
|
call_nodes.append(c_expr) |
|
|
|
|
|
|
|
|
|
|
|
def create_node(is_input: bool = True): |
|
|
|
|
|
if is_input: |
|
|
|
|
|
node = Input(type=TensorNode).outputs[0] |
|
|
|
|
|
else: |
|
|
|
|
|
node = TensorNode(expr=None) |
|
|
|
|
|
node.shape = shape |
|
|
|
|
|
node.dtype = dtype |
|
|
|
|
|
return node |
|
|
|
|
|
|
|
|
fake_inp_val = tuple(F.zeros(shape=i.shape, dtype=i.dtype) for i in nodes) |
|
|
|
|
|
fake_out_val = func(*fake_inp_val) |
|
|
|
|
|
|
|
|
org_argdef = list(moudle.argdef_graph_map.keys())[0] |
|
|
|
|
|
|
|
|
def create_node(val: Tensor): |
|
|
|
|
|
|
|
|
if call_nodes: |
|
|
|
|
|
org_argdef = call_nodes[0].arg_def |
|
|
|
|
|
|
|
|
|
|
|
args, kwargs = org_argdef.unflatten(self._inputs) |
|
|
|
|
|
formal_inp_node = create_node(True) |
|
|
|
|
|
inputs, tree_def = tree_flatten( |
|
|
|
|
|
((*args, formal_inp_node), kwargs), |
|
|
|
|
|
leaf_type=_leaf_type, |
|
|
|
|
|
is_const_leaf=lambda x: not isinstance(x, (TensorNode, ModuleNode)), |
|
|
|
|
|
) |
|
|
|
|
|
self._inputs[:] = inputs[:] |
|
|
|
|
|
|
|
|
|
|
|
actual_inp_nodes = [] |
|
|
|
|
|
for e in call_nodes: |
|
|
|
|
|
args, kwargs = e.unflatten_args(e.inputs) |
|
|
|
|
|
args = args + (create_node(False),) |
|
|
|
|
|
inputs, tree_def = tree_flatten( |
|
|
|
|
|
(args, kwargs), |
|
|
|
|
|
leaf_type=_leaf_type, |
|
|
|
|
|
is_const_leaf=lambda x: not isinstance(x, (TensorNode, ModuleNode)), |
|
|
|
|
|
) |
|
|
|
|
|
e.inputs[:] = inputs[:] |
|
|
|
|
|
e.arg_def = tree_def |
|
|
|
|
|
actual_inp_nodes.append(args[-1]) |
|
|
|
|
|
|
|
|
|
|
|
moudle.argdef_graph_map[tree_def] = moudle.argdef_graph_map.pop(org_argdef) |
|
|
|
|
|
moudle.argdef_outdef_map[tree_def] = moudle.argdef_outdef_map.pop(org_argdef) |
|
|
|
|
|
|
|
|
|
|
|
# return formal_inp_node, actual_inp_nodes |
|
|
|
|
|
return formal_inp_node |
|
|
|
|
|
|
|
|
|
|
|
def reset_outputs(self, outputs): |
|
|
|
|
|
outputs, out_def = tree_flatten( |
|
|
|
|
|
outputs, leaf_type=_leaf_type, is_leaf=lambda x: isinstance(x, TensorNode), |
|
|
|
|
|
) |
|
|
|
|
|
forma_mnode = self.inputs[0] |
|
|
|
|
|
|
|
|
|
|
|
moudle = forma_mnode.owner |
|
|
|
|
|
assert moudle._is_top, "reset_outputs only support the top-level graph" |
|
|
|
|
|
|
|
|
|
|
|
actual_mnodes = forma_mnode.actual_mnode |
|
|
|
|
|
call_nodes = [] |
|
|
|
|
|
for n in actual_mnodes: |
|
|
|
|
|
for c_expr in n.users: |
|
|
|
|
|
if isinstance(c_expr, CallMethod) and c_expr.method == "__call__": |
|
|
|
|
|
call_nodes.append((c_expr)) |
|
|
|
|
|
|
|
|
|
|
|
def create_node(val: TensorNode, expr: Expr): |
|
|
node = TensorNode(expr) |
|
|
node = TensorNode(expr) |
|
|
node.shape = val.shape |
|
|
node.shape = val.shape |
|
|
node.dtype = val.dtype |
|
|
node.dtype = val.dtype |
|
|
return node |
|
|
return node |
|
|
|
|
|
|
|
|
out_nodes = list(create_node(i) for i in fake_out_val) |
|
|
|
|
|
expr.outputs = out_nodes |
|
|
|
|
|
|
|
|
tree_def = list(moudle.argdef_graph_map.keys())[0] |
|
|
|
|
|
if call_nodes: |
|
|
|
|
|
tree_def = call_nodes[0].arg_def |
|
|
|
|
|
|
|
|
return out_nodes |
|
|
|
|
|
|
|
|
actual_nodes = [] |
|
|
|
|
|
for e in call_nodes: |
|
|
|
|
|
actual_node_outputs = [] |
|
|
|
|
|
for v in outputs: |
|
|
|
|
|
actual_node_outputs.append(create_node(v, e)) |
|
|
|
|
|
e.outputs[:] = actual_node_outputs |
|
|
|
|
|
e.out_def = out_def |
|
|
|
|
|
actual_nodes.append(actual_node_outputs) |
|
|
|
|
|
|
|
|
def insert_call_method(self, target, method, args): |
|
|
|
|
|
if not isinstance(args, Sequence): |
|
|
|
|
|
args = [args] |
|
|
|
|
|
assert isinstance(target, (TensorNode, ModuleNode)) |
|
|
|
|
|
assert isinstance(method, str) |
|
|
|
|
|
for i in args: |
|
|
|
|
|
assert isinstance(i, TensorNode) |
|
|
|
|
|
|
|
|
self._outputs[:] = outputs |
|
|
|
|
|
moudle.argdef_outdef_map[tree_def] = out_def |
|
|
|
|
|
|
|
|
expr = CallMethod(method) |
|
|
|
|
|
expr.inputs = [target, *args] |
|
|
|
|
|
|
|
|
return actual_nodes |
|
|
|
|
|
|
|
|
if isinstance(target, TensorNode): |
|
|
|
|
|
fake_target_val = F.zeros(shape=target.shape, dtype=target.dtype) |
|
|
|
|
|
fake_inp_val = tuple(F.zeros(shape=i.shape, dtype=i.dtype) for i in args) |
|
|
|
|
|
fake_out_val = getattr(fake_target_val, method)(fake_inp_val) |
|
|
|
|
|
|
|
|
def add_output_node(self, node: TensorNode): |
|
|
|
|
|
forma_mnode = self.inputs[0] |
|
|
|
|
|
|
|
|
def create_node(val: Tensor): |
|
|
|
|
|
node = TensorNode(expr) |
|
|
|
|
|
node.shape = val.shape |
|
|
|
|
|
node.dtype = val.dtype |
|
|
|
|
|
return node |
|
|
|
|
|
|
|
|
moudle = forma_mnode.owner |
|
|
|
|
|
assert moudle._is_top, "add_output_node only support the top-level graph" |
|
|
|
|
|
|
|
|
out_nodes = list(create_node(i) for i in fake_out_val) |
|
|
|
|
|
expr.outputs = out_nodes |
|
|
|
|
|
else: |
|
|
|
|
|
raise NotImplementedError() |
|
|
|
|
|
|
|
|
actual_mnodes = forma_mnode.actual_mnode |
|
|
|
|
|
call_nodes = [] |
|
|
|
|
|
|
|
|
|
|
|
for n in actual_mnodes: |
|
|
|
|
|
for c_expr in n.users: |
|
|
|
|
|
if isinstance(c_expr, CallMethod) and c_expr.method == "__call__": |
|
|
|
|
|
call_nodes.append((c_expr)) |
|
|
|
|
|
|
|
|
|
|
|
def create_node(val: TensorNode, expr: Expr): |
|
|
|
|
|
node = TensorNode(expr) |
|
|
|
|
|
node.shape = val.shape |
|
|
|
|
|
node.dtype = val.dtype |
|
|
|
|
|
return node |
|
|
|
|
|
|
|
|
|
|
|
tree_def = list(moudle.argdef_graph_map.keys())[0] |
|
|
|
|
|
if call_nodes: |
|
|
|
|
|
tree_def = call_nodes[0].arg_def |
|
|
|
|
|
|
|
|
|
|
|
org_out_def = moudle.argdef_outdef_map[tree_def] |
|
|
|
|
|
org_outs = org_out_def.unflatten(self._outputs) |
|
|
|
|
|
outputs, out_def = tree_flatten( |
|
|
|
|
|
(org_outs, node), |
|
|
|
|
|
leaf_type=_leaf_type, |
|
|
|
|
|
is_leaf=lambda x: isinstance(x, TensorNode), |
|
|
|
|
|
) |
|
|
|
|
|
self._outputs[:] = outputs |
|
|
|
|
|
|
|
|
|
|
|
actual_out_nodes = [] |
|
|
|
|
|
for e in call_nodes: |
|
|
|
|
|
actual_node = create_node(node, e) |
|
|
|
|
|
org_outs = org_out_def.unflatten(e.outputs) |
|
|
|
|
|
outputs, out_def = tree_flatten( |
|
|
|
|
|
(org_outs, actual_node), |
|
|
|
|
|
leaf_type=_leaf_type, |
|
|
|
|
|
is_leaf=lambda x: isinstance(x, TensorNode), |
|
|
|
|
|
) |
|
|
|
|
|
e.outputs[:] = outputs |
|
|
|
|
|
e.out_def = out_def |
|
|
|
|
|
actual_out_nodes.append(actual_node) |
|
|
|
|
|
|
|
|
|
|
|
moudle.argdef_outdef_map[tree_def] = out_def |
|
|
|
|
|
|
|
|
|
|
|
return actual_out_nodes |
|
|
|
|
|
|
|
|
|
|
|
def insert_function(self, func: Callable, *args, **kwargs): |
|
|
|
|
|
assert isinstance(func, Callable) |
|
|
|
|
|
|
|
|
|
|
|
inp_nodes, inp_def = tree_flatten( |
|
|
|
|
|
(args, kwargs), leaf_type=_leaf_type, is_const_leaf=_is_const_leaf |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
insert_idx = -1 |
|
|
|
|
|
for i in inp_nodes: |
|
|
|
|
|
if isinstance(i, TensorNode) and i.expr in self._exprs: |
|
|
|
|
|
insert_idx = max(insert_idx, self._exprs.index(i.expr)) |
|
|
|
|
|
|
|
|
|
|
|
fake_inp_val = list( |
|
|
|
|
|
F.zeros(shape=i.shape, dtype=i.dtype) if isinstance(i, TensorNode) else i |
|
|
|
|
|
for i in inp_nodes |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
for v, n in zip(fake_inp_val, inp_nodes): |
|
|
|
|
|
if isinstance(n, TensorNode): |
|
|
|
|
|
NodeMixin.wrap_safe(v, n) |
|
|
|
|
|
|
|
|
|
|
|
fake_args, fake_kwargs = inp_def.unflatten(fake_inp_val) |
|
|
|
|
|
|
|
|
|
|
|
insert_point = self.insert_exprs_before() |
|
|
|
|
|
if insert_idx != -1: |
|
|
|
|
|
insert_point = self.insert_exprs_after(self._exprs[insert_idx]) |
|
|
|
|
|
|
|
|
|
|
|
with insert_point: |
|
|
|
|
|
rst = func(*fake_args, **fake_kwargs) |
|
|
|
|
|
|
|
|
|
|
|
if rst is None: |
|
|
|
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
outputs, out_def = tree_flatten(rst, leaf_type=_leaf_type, is_leaf=_is_leaf) |
|
|
|
|
|
node_outputs = [] |
|
|
|
|
|
for out in outputs: |
|
|
|
|
|
assert isinstance(out, RawTensor) |
|
|
|
|
|
node_outputs.append(NodeMixin.get(out, None)) |
|
|
|
|
|
|
|
|
|
|
|
node_outputs = out_def.unflatten(node_outputs) |
|
|
|
|
|
return node_outputs |
|
|
|
|
|
|
|
|
|
|
|
def insert_exprs_after(self, expr: Optional[Expr] = None): |
|
|
|
|
|
if expr is not None: |
|
|
|
|
|
assert expr.top_graph == self, "Expr to insert after is not in graph." |
|
|
|
|
|
return _InsertExprs(self, expr, after=True) |
|
|
|
|
|
|
|
|
return out_nodes |
|
|
|
|
|
|
|
|
def insert_exprs_before(self, expr: Optional[Expr] = None): |
|
|
|
|
|
if expr is not None: |
|
|
|
|
|
assert expr.top_graph == self, "Expr to insert before is not in graph." |
|
|
|
|
|
return _InsertExprs(self, expr, after=False) |
|
|
|
|
|
|
|
|
def replace_node(self, repl_dict: Dict[Node, Node]): |
|
|
def replace_node(self, repl_dict: Dict[Node, Node]): |
|
|
while repl_dict: |
|
|
while repl_dict: |
|
@@ -246,7 +501,7 @@ class InternalGraph: |
|
|
i = 0 |
|
|
i = 0 |
|
|
while i < len(self._exprs): |
|
|
while i < len(self._exprs): |
|
|
expr = self._exprs[i] |
|
|
expr = self._exprs[i] |
|
|
if expr in dep_exprs: |
|
|
|
|
|
|
|
|
if expr in dep_exprs or expr._disable_remove: |
|
|
i += 1 |
|
|
i += 1 |
|
|
continue |
|
|
continue |
|
|
for n in expr.inputs: |
|
|
for n in expr.inputs: |
|
@@ -267,7 +522,7 @@ class InternalGraph: |
|
|
def __repr__(self): |
|
|
def __repr__(self): |
|
|
return "InternalGraph ({}) {{\n\t{}\n\treturn {}\n}}".format( |
|
|
return "InternalGraph ({}) {{\n\t{}\n\treturn {}\n}}".format( |
|
|
", ".join(str(i) for i in self._inputs), |
|
|
", ".join(str(i) for i in self._inputs), |
|
|
"\n\t".join(str(i) for i in self._exprs), |
|
|
|
|
|
|
|
|
"\n\t".join("{}".format(str(i)) for i in self._exprs), |
|
|
", ".join(str(i) for i in self._outputs), |
|
|
", ".join(str(i) for i in self._outputs), |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
@@ -293,7 +548,7 @@ def _wrapped_function(orig_func): |
|
|
if not NodeMixin.get(i, None): |
|
|
if not NodeMixin.get(i, None): |
|
|
if isinstance(i, (RawTensor, NodeMixin)): |
|
|
if isinstance(i, (RawTensor, NodeMixin)): |
|
|
NodeMixin.wrap_safe(i, Constant.make(i)) |
|
|
NodeMixin.wrap_safe(i, Constant.make(i)) |
|
|
meth_name = _get_meth_name(args[0], wrapped_fn) |
|
|
|
|
|
|
|
|
meth_name = _get_meth_name(args[0], wrapped_fn) if args else None |
|
|
if meth_name: |
|
|
if meth_name: |
|
|
self = inputs[0] |
|
|
self = inputs[0] |
|
|
if meth_name == "__new__": |
|
|
if meth_name == "__new__": |
|
@@ -316,10 +571,19 @@ def _wrapped_function(orig_func): |
|
|
call_node.add_inputs(inputs) |
|
|
call_node.add_inputs(inputs) |
|
|
|
|
|
|
|
|
call_node.arg_def = tree_def |
|
|
call_node.arg_def = tree_def |
|
|
outputs = orig_func(*args, **kwargs) |
|
|
|
|
|
|
|
|
rst = orig_func(*args, **kwargs) |
|
|
|
|
|
if meth_name == "__setitem__": |
|
|
|
|
|
rst = self |
|
|
|
|
|
if rst is not None: |
|
|
|
|
|
outputs, out_def = tree_flatten( |
|
|
|
|
|
rst, leaf_type=_leaf_type, is_leaf=_is_leaf |
|
|
|
|
|
) |
|
|
|
|
|
call_node.out_def = out_def |
|
|
|
|
|
else: |
|
|
|
|
|
outputs = None |
|
|
call_node.add_outputs(outputs) |
|
|
call_node.add_outputs(outputs) |
|
|
set_module_tracing() |
|
|
set_module_tracing() |
|
|
return outputs |
|
|
|
|
|
|
|
|
return rst |
|
|
return orig_func(*args, **kwargs) |
|
|
return orig_func(*args, **kwargs) |
|
|
|
|
|
|
|
|
return wrapped_fn |
|
|
return wrapped_fn |
|
@@ -349,6 +613,7 @@ class TracedModuleBuilder(NodeMixin): |
|
|
super(TracedModuleBuilder, self).__init__() |
|
|
super(TracedModuleBuilder, self).__init__() |
|
|
self._mod = mod |
|
|
self._mod = mod |
|
|
self._body = None |
|
|
self._body = None |
|
|
|
|
|
self._is_top = is_top_module |
|
|
self._is_builtin = module_tracer.is_builtin(mod) |
|
|
self._is_builtin = module_tracer.is_builtin(mod) |
|
|
self._argdef_graph_map = {} |
|
|
self._argdef_graph_map = {} |
|
|
self._argdef_outdef_map = {} |
|
|
self._argdef_outdef_map = {} |
|
@@ -362,7 +627,7 @@ class TracedModuleBuilder(NodeMixin): |
|
|
return self._mod |
|
|
return self._mod |
|
|
else: |
|
|
else: |
|
|
traced_module = TracedModule( |
|
|
traced_module = TracedModule( |
|
|
self._argdef_graph_map, self._argdef_outdef_map |
|
|
|
|
|
|
|
|
self._is_top, self._argdef_graph_map, self._argdef_outdef_map |
|
|
) |
|
|
) |
|
|
for _, g in self._argdef_graph_map.items(): |
|
|
for _, g in self._argdef_graph_map.items(): |
|
|
g.compile() |
|
|
g.compile() |
|
@@ -408,8 +673,8 @@ class TracedModuleBuilder(NodeMixin): |
|
|
self._body = None |
|
|
self._body = None |
|
|
else: |
|
|
else: |
|
|
self_node = None |
|
|
self_node = None |
|
|
if self._body: |
|
|
|
|
|
self_node = self._body.inputs[0] |
|
|
|
|
|
|
|
|
if tree_def in self._argdef_graph_map: |
|
|
|
|
|
self_node = self._argdef_graph_map[tree_def].inputs[0] |
|
|
self._body = InternalGraph() |
|
|
self._body = InternalGraph() |
|
|
active_module_tracer().push_scope(self._body) |
|
|
active_module_tracer().push_scope(self._body) |
|
|
# rebind self to new input node |
|
|
# rebind self to new input node |
|
@@ -446,7 +711,7 @@ class TracedModuleBuilder(NodeMixin): |
|
|
outputs if isinstance(outputs, collections.abc.Sequence) else (outputs,) |
|
|
outputs if isinstance(outputs, collections.abc.Sequence) else (outputs,) |
|
|
): |
|
|
): |
|
|
active_module_tracer().current_scope().add_output(NodeMixin.get(i)) |
|
|
active_module_tracer().current_scope().add_output(NodeMixin.get(i)) |
|
|
|
|
|
|
|
|
|
|
|
NodeMixin.get(self, None).actual_mnode.append(orig_self) |
|
|
NodeMixin.wrap_safe(self, orig_self) |
|
|
NodeMixin.wrap_safe(self, orig_self) |
|
|
for arg, node in zip(inputs[1:], origin_inp_node): |
|
|
for arg, node in zip(inputs[1:], origin_inp_node): |
|
|
if node: |
|
|
if node: |
|
@@ -454,6 +719,7 @@ class TracedModuleBuilder(NodeMixin): |
|
|
active_module_tracer().pop_scope() |
|
|
active_module_tracer().pop_scope() |
|
|
|
|
|
|
|
|
# rebind output to outer graph |
|
|
# rebind output to outer graph |
|
|
|
|
|
callnode.out_def = out_def |
|
|
callnode.add_outputs(outputs) |
|
|
callnode.add_outputs(outputs) |
|
|
self._argdef_graph_map[callnode.arg_def] = self._body |
|
|
self._argdef_graph_map[callnode.arg_def] = self._body |
|
|
self._argdef_outdef_map[callnode.arg_def] = out_def |
|
|
self._argdef_outdef_map[callnode.arg_def] = out_def |
|
@@ -512,31 +778,44 @@ class _expr_iter: |
|
|
if isinstance(expr, CallMethod) and isinstance(expr.inputs[0], ModuleNode): |
|
|
if isinstance(expr, CallMethod) and isinstance(expr.inputs[0], ModuleNode): |
|
|
yield expr |
|
|
yield expr |
|
|
if expr.graph is not None: |
|
|
if expr.graph is not None: |
|
|
yield from expr.graph.exprs |
|
|
|
|
|
|
|
|
yield from expr.graph.expr_filter |
|
|
else: |
|
|
else: |
|
|
yield expr |
|
|
yield expr |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ExprFilter: |
|
|
|
|
|
|
|
|
class _node_iter: |
|
|
|
|
|
def __init__(self, graph: InternalGraph) -> None: |
|
|
|
|
|
nodes = [] |
|
|
|
|
|
node_ids = set() |
|
|
|
|
|
for expr in graph.expr_filter: |
|
|
|
|
|
for n in expr.inputs + expr.outputs: |
|
|
|
|
|
if n._id in node_ids: |
|
|
|
|
|
continue |
|
|
|
|
|
nodes.append(n) |
|
|
|
|
|
node_ids.add(n._id) |
|
|
|
|
|
self.nodes = list(sorted(nodes, key=lambda x: x._id)) |
|
|
|
|
|
|
|
|
|
|
|
def __iter__(self): |
|
|
|
|
|
for node in self.nodes: |
|
|
|
|
|
yield node |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BaseFilter: |
|
|
def __init__(self, expr_iter: Iterable): |
|
|
def __init__(self, expr_iter: Iterable): |
|
|
self._iter = expr_iter |
|
|
self._iter = expr_iter |
|
|
|
|
|
|
|
|
def __iter__(self): |
|
|
def __iter__(self): |
|
|
return iter(self._iter) |
|
|
return iter(self._iter) |
|
|
|
|
|
|
|
|
def call_function(self, func): |
|
|
|
|
|
return ExprFilterCallFunction(self, func) |
|
|
|
|
|
|
|
|
|
|
|
def call_method(self, method): |
|
|
|
|
|
return ExprFilterCallMethod(self, method) |
|
|
|
|
|
|
|
|
|
|
|
def as_list(self): |
|
|
def as_list(self): |
|
|
return list(self) |
|
|
return list(self) |
|
|
|
|
|
|
|
|
def as_dict(self): |
|
|
def as_dict(self): |
|
|
raise NotImplementedError("need key") |
|
|
|
|
|
|
|
|
return collections.OrderedDict((i._id, i) for i in self) |
|
|
|
|
|
|
|
|
def as_unique(self): |
|
|
def as_unique(self): |
|
|
|
|
|
rst = self.as_list() |
|
|
|
|
|
assert len(rst) == 1, "{} elements found".format(len(rst)) |
|
|
(expr,) = self |
|
|
(expr,) = self |
|
|
return expr |
|
|
return expr |
|
|
|
|
|
|
|
@@ -544,17 +823,65 @@ class ExprFilter: |
|
|
return sum(1 for _ in self) |
|
|
return sum(1 for _ in self) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ExprFilter(BaseFilter): |
|
|
|
|
|
def call_function(self, func): |
|
|
|
|
|
return ExprFilterCallFunction(self, func) |
|
|
|
|
|
|
|
|
|
|
|
def call_method(self, method): |
|
|
|
|
|
return ExprFilterCallMethod(self, method) |
|
|
|
|
|
|
|
|
|
|
|
def expr_id(self, expr_id: List[int]): |
|
|
|
|
|
return ExprFilterExprId(self, expr_id) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class NodeFilter(BaseFilter): |
|
|
|
|
|
def type(self, owner_type, node_type): |
|
|
|
|
|
return NodeFilterType(self, owner_type, node_type) |
|
|
|
|
|
|
|
|
|
|
|
def node_id(self, node_id: List[int]): |
|
|
|
|
|
return NodeFilterNodeId(self, node_id) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class NodeFilterType(NodeFilter): |
|
|
|
|
|
def __init__(self, expr_iter, owner_type, node_type): |
|
|
|
|
|
super().__init__(expr_iter) |
|
|
|
|
|
self.owner_type = owner_type |
|
|
|
|
|
self.node_type = node_type |
|
|
|
|
|
|
|
|
|
|
|
def __iter__(self): |
|
|
|
|
|
for node in self._iter: |
|
|
|
|
|
if not isinstance(node, self.node_type): |
|
|
|
|
|
continue |
|
|
|
|
|
if not hasattr(node, "owner"): |
|
|
|
|
|
continue |
|
|
|
|
|
if isinstance(node.owner, self.owner_type): |
|
|
|
|
|
yield node |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class NodeFilterNodeId(NodeFilter): |
|
|
|
|
|
def __init__(self, expr_iter, node_id: List[int]): |
|
|
|
|
|
super().__init__(expr_iter) |
|
|
|
|
|
if not isinstance(node_id, Sequence): |
|
|
|
|
|
node_id = [node_id] |
|
|
|
|
|
self.node_id = node_id |
|
|
|
|
|
|
|
|
|
|
|
def __iter__(self): |
|
|
|
|
|
for node in self._iter: |
|
|
|
|
|
if node._id in self.node_id: |
|
|
|
|
|
yield node |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ExprFilterCallFunction(ExprFilter): |
|
|
class ExprFilterCallFunction(ExprFilter): |
|
|
def __init__(self, expr_iter, func: Callable = None): |
|
|
def __init__(self, expr_iter, func: Callable = None): |
|
|
super().__init__(expr_iter) |
|
|
super().__init__(expr_iter) |
|
|
self.func = func |
|
|
self.func = func |
|
|
|
|
|
|
|
|
def __iter__(self): |
|
|
def __iter__(self): |
|
|
for i in self._iter: |
|
|
|
|
|
if not isinstance(i, CallFunction): |
|
|
|
|
|
|
|
|
for expr in self._iter: |
|
|
|
|
|
if not isinstance(expr, CallFunction): |
|
|
continue |
|
|
continue |
|
|
if self.func is None or i.func == self.func: |
|
|
|
|
|
yield i |
|
|
|
|
|
|
|
|
if self.func is None or expr.func == self.func: |
|
|
|
|
|
yield expr |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ExprFilterCallMethod(ExprFilter): |
|
|
class ExprFilterCallMethod(ExprFilter): |
|
@@ -563,11 +890,24 @@ class ExprFilterCallMethod(ExprFilter): |
|
|
self.method = method |
|
|
self.method = method |
|
|
|
|
|
|
|
|
def __iter__(self): |
|
|
def __iter__(self): |
|
|
for i in self._iter: |
|
|
|
|
|
if not isinstance(i, CallMethod): |
|
|
|
|
|
|
|
|
for expr in self._iter: |
|
|
|
|
|
if not isinstance(expr, CallMethod): |
|
|
continue |
|
|
continue |
|
|
if self.method is None or i.method == self.method: |
|
|
|
|
|
yield i |
|
|
|
|
|
|
|
|
if self.method is None or expr.method == self.method: |
|
|
|
|
|
yield expr |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ExprFilterExprId(ExprFilter): |
|
|
|
|
|
def __init__(self, expr_iter, expr_id: List[int]): |
|
|
|
|
|
super().__init__(expr_iter) |
|
|
|
|
|
if not isinstance(expr_id, Sequence): |
|
|
|
|
|
expr_id = [expr_id] |
|
|
|
|
|
self.expr_id = expr_id |
|
|
|
|
|
|
|
|
|
|
|
def __iter__(self): |
|
|
|
|
|
for expr in self._iter: |
|
|
|
|
|
if expr._id in self.expr_id: |
|
|
|
|
|
yield expr |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TracedModule(Module): |
|
|
class TracedModule(Module): |
|
@@ -579,10 +919,11 @@ class TracedModule(Module): |
|
|
argdef_graph_map = None |
|
|
argdef_graph_map = None |
|
|
argdef_outdef_map = None |
|
|
argdef_outdef_map = None |
|
|
|
|
|
|
|
|
def __init__(self, argdef_graph_map, argdef_outdef_map): |
|
|
|
|
|
|
|
|
def __init__(self, is_top, argdef_graph_map, argdef_outdef_map): |
|
|
super(TracedModule, self).__init__() |
|
|
super(TracedModule, self).__init__() |
|
|
self.argdef_graph_map = argdef_graph_map |
|
|
self.argdef_graph_map = argdef_graph_map |
|
|
self.argdef_outdef_map = argdef_outdef_map |
|
|
self.argdef_outdef_map = argdef_outdef_map |
|
|
|
|
|
self._is_top = is_top |
|
|
|
|
|
|
|
|
def forward(self, *args, **kwargs): |
|
|
def forward(self, *args, **kwargs): |
|
|
inputs, treedef = tree_flatten( |
|
|
inputs, treedef = tree_flatten( |
|
@@ -598,29 +939,58 @@ class TracedModule(Module): |
|
|
return outputs |
|
|
return outputs |
|
|
|
|
|
|
|
|
@property |
|
|
@property |
|
|
def graph(self): |
|
|
|
|
|
self._update_modulenode_ref() |
|
|
|
|
|
|
|
|
def graph(self) -> InternalGraph: |
|
|
|
|
|
if self._is_top: |
|
|
|
|
|
self._update_ref() |
|
|
assert len(self.argdef_graph_map) == 1 |
|
|
assert len(self.argdef_graph_map) == 1 |
|
|
return list(self.argdef_graph_map.values())[0] |
|
|
return list(self.argdef_graph_map.values())[0] |
|
|
|
|
|
|
|
|
def _update_modulenode_ref(self): |
|
|
|
|
|
for _, graph in self.argdef_graph_map.items(): |
|
|
|
|
|
|
|
|
def _update_ref(self, actual_node_map: Union[Dict] = None): |
|
|
|
|
|
for inp_def, graph in self.argdef_graph_map.items(): |
|
|
|
|
|
for n in graph._inputs + graph.outputs: |
|
|
|
|
|
n._top_graph = weakref.ref(graph) |
|
|
graph._inputs[0]._owner = weakref.ref(self) |
|
|
graph._inputs[0]._owner = weakref.ref(self) |
|
|
|
|
|
graph._inputs[0].actual_mnode = [] |
|
|
|
|
|
if actual_node_map is not None and inp_def in actual_node_map.keys(): |
|
|
|
|
|
graph._inputs[0].actual_mnode = actual_node_map[inp_def] |
|
|
node2obj = {} |
|
|
node2obj = {} |
|
|
|
|
|
next_actual_node_map = collections.defaultdict( |
|
|
|
|
|
lambda: collections.defaultdict(list) |
|
|
|
|
|
) |
|
|
node2obj[graph._inputs[0]] = self |
|
|
node2obj[graph._inputs[0]] = self |
|
|
for expr in graph._exprs: |
|
|
for expr in graph._exprs: |
|
|
|
|
|
for n in expr.inputs + expr.outputs: |
|
|
|
|
|
n._top_graph = weakref.ref(graph) |
|
|
|
|
|
expr._top_graph = weakref.ref(graph) |
|
|
if isinstance(expr, GetAttr) and isinstance( |
|
|
if isinstance(expr, GetAttr) and isinstance( |
|
|
expr.outputs[0], ModuleNode |
|
|
expr.outputs[0], ModuleNode |
|
|
): |
|
|
): |
|
|
obj = getattr(node2obj[expr.inputs[0]], expr.name) |
|
|
obj = getattr(node2obj[expr.inputs[0]], expr.name) |
|
|
expr.outputs[0]._owner = weakref.ref(obj) |
|
|
expr.outputs[0]._owner = weakref.ref(obj) |
|
|
node2obj[expr.outputs[0]] = obj |
|
|
node2obj[expr.outputs[0]] = obj |
|
|
if isinstance(obj, TracedModule): |
|
|
|
|
|
obj._update_modulenode_ref() |
|
|
|
|
|
|
|
|
if isinstance(expr, Constant) and isinstance( |
|
|
|
|
|
expr.outputs[0], ModuleNode |
|
|
|
|
|
): |
|
|
|
|
|
obj = expr.value |
|
|
|
|
|
expr.outputs[0]._owner = weakref.ref(obj) |
|
|
|
|
|
node2obj[expr.outputs[0]] = obj |
|
|
|
|
|
if ( |
|
|
|
|
|
isinstance(expr, CallMethod) |
|
|
|
|
|
and expr.method == "__call__" |
|
|
|
|
|
and isinstance(expr.inputs[0], ModuleNode) |
|
|
|
|
|
): |
|
|
|
|
|
obj = node2obj[expr.inputs[0]] |
|
|
|
|
|
if expr.arg_def is not None: |
|
|
|
|
|
next_actual_node_map[obj][expr.arg_def].append(expr.inputs[0]) |
|
|
|
|
|
|
|
|
@property |
|
|
|
|
|
def exprs(self): |
|
|
|
|
|
return self.graph.exprs |
|
|
|
|
|
|
|
|
for obj in node2obj.values(): |
|
|
|
|
|
if obj is self: |
|
|
|
|
|
continue |
|
|
|
|
|
mnode_map = None |
|
|
|
|
|
if obj in next_actual_node_map.keys(): |
|
|
|
|
|
mnode_map = next_actual_node_map[obj] |
|
|
|
|
|
if isinstance(obj, TracedModule): |
|
|
|
|
|
obj._update_ref(mnode_map) |
|
|
|
|
|
|
|
|
def flatten(self): |
|
|
def flatten(self): |
|
|
""" |
|
|
""" |
|
@@ -644,13 +1014,21 @@ class TracedModule(Module): |
|
|
node2obj[graph._inputs[0]] = module |
|
|
node2obj[graph._inputs[0]] = module |
|
|
if call: |
|
|
if call: |
|
|
node2obj[call.inputs[0]] = module |
|
|
node2obj[call.inputs[0]] = module |
|
|
|
|
|
repl_dict = dict(zip(graph._inputs, call.inputs)) |
|
|
|
|
|
for ind, out in enumerate(graph.outputs): |
|
|
|
|
|
if isinstance(out.expr, Input): |
|
|
|
|
|
assert out in repl_dict |
|
|
|
|
|
call_out = call.outputs[ind] |
|
|
|
|
|
for expr in call.outputs[ind].users: |
|
|
|
|
|
for index, inp in enumerate(expr.inputs): |
|
|
|
|
|
if inp is call_out: |
|
|
|
|
|
expr.inputs[index] = repl_dict[out] |
|
|
|
|
|
|
|
|
|
|
|
continue |
|
|
|
|
|
repl_dict[out] = call.outputs[ind] |
|
|
|
|
|
|
|
|
|
|
|
graph._replace_inputs_outputs(repl_dict) |
|
|
for expr in graph._exprs: |
|
|
for expr in graph._exprs: |
|
|
# replace inputs for submodule's exprx |
|
|
|
|
|
if call: |
|
|
|
|
|
repl_dict = dict( |
|
|
|
|
|
zip(graph._inputs + graph._outputs, call.inputs + call.outputs) |
|
|
|
|
|
) |
|
|
|
|
|
graph._replace_inputs_outputs(repl_dict) |
|
|
|
|
|
|
|
|
|
|
|
if isinstance(expr, GetAttr): |
|
|
if isinstance(expr, GetAttr): |
|
|
# replace GetAttr with Constant |
|
|
# replace GetAttr with Constant |
|
@@ -715,6 +1093,21 @@ def register_as_builtin(mod_cls: Type[Module]) -> None: |
|
|
module_tracer.register_as_builtin(mod_cls) |
|
|
module_tracer.register_as_builtin(mod_cls) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def wrap(func: Union[Callable]): |
|
|
|
|
|
assert callable(func) |
|
|
|
|
|
if hasattr(func, "__code__"): |
|
|
|
|
|
assert not isinstance(func, str) |
|
|
|
|
|
fn_name = func.__code__.co_name |
|
|
|
|
|
currentframe = inspect.currentframe() |
|
|
|
|
|
assert currentframe is not None |
|
|
|
|
|
f = currentframe.f_back |
|
|
|
|
|
assert f is not None |
|
|
|
|
|
if f.f_code.co_name != "<module>": |
|
|
|
|
|
raise NotImplementedError("wrap must be called at the top level of a module") |
|
|
|
|
|
Patcher._builtin_functions.append((f.f_globals, fn_name)) |
|
|
|
|
|
return func |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _register_all_builtin_module(): |
|
|
def _register_all_builtin_module(): |
|
|
|
|
|
|
|
|
for sub_mod in [M, M.qat, M.quantized]: |
|
|
for sub_mod in [M, M.qat, M.quantized]: |
|
@@ -749,6 +1142,7 @@ def trace_module(mod: Module, *args: Tensor, **kwargs: Tensor) -> TracedModule: |
|
|
NodeMixin.wrap_safe(builder, Input.make("TopModule", ModuleNode)) |
|
|
NodeMixin.wrap_safe(builder, Input.make("TopModule", ModuleNode)) |
|
|
inputs, _ = tree_flatten((args, kwargs), is_const_leaf=_is_const_leaf) |
|
|
inputs, _ = tree_flatten((args, kwargs), is_const_leaf=_is_const_leaf) |
|
|
for _, i in enumerate(inputs): |
|
|
for _, i in enumerate(inputs): |
|
|
|
|
|
assert isinstance(i, Tensor), "not support " |
|
|
if isinstance(i, RawTensor): |
|
|
if isinstance(i, RawTensor): |
|
|
NodeMixin.wrap_safe( |
|
|
NodeMixin.wrap_safe( |
|
|
i, Input.make("arg_{}".format(_), NodeMixin.get_wrapped_type(i)) |
|
|
i, Input.make("arg_{}".format(_), NodeMixin.get_wrapped_type(i)) |
|
|