@@ -7,7 +7,7 @@ | |||||
# software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
import builtins | |||||
import collections | import collections | ||||
from typing import Callable, List | from typing import Callable, List | ||||
@@ -19,6 +19,7 @@ from ...module import Module | |||||
from ...tensor import Tensor | from ...tensor import Tensor | ||||
from .module_tracer import active_module_tracer | from .module_tracer import active_module_tracer | ||||
from .node import ModuleNode, Node, NodeMixin, TensorNode | from .node import ModuleNode, Node, NodeMixin, TensorNode | ||||
from .pytree import TreeDef | |||||
class Expr: | class Expr: | ||||
@@ -28,9 +29,22 @@ class Expr: | |||||
inputs = None # type: List[Node] | inputs = None # type: List[Node] | ||||
outputs = None # type: List[Node] | outputs = None # type: List[Node] | ||||
def add_input(self, node): | |||||
self.inputs.append(node) | |||||
const_val = None # type: List[Any] | |||||
arg_def = None # type: TreeDef | |||||
def add_inputs(self, vals): | |||||
if not isinstance(vals, collections.abc.Sequence): | |||||
vals = (vals,) | |||||
for val in vals: | |||||
node = NodeMixin.get(val, None) | |||||
if isinstance(node, (TensorNode, ModuleNode)): | |||||
if node not in self.inputs: | |||||
self.inputs.append(node) | |||||
else: | |||||
assert node is None | |||||
assert type(val) in builtins.__dict__.values() | |||||
idx = len(self.inputs) + len(self.const_val) | |||||
self.const_val.append((idx, val)) | |||||
def add_outputs(self, outputs): | def add_outputs(self, outputs): | ||||
self.outputs = [] | self.outputs = [] | ||||
@@ -38,50 +52,31 @@ class Expr: | |||||
outputs = (outputs,) | outputs = (outputs,) | ||||
for i in outputs: | for i in outputs: | ||||
assert isinstance(i, RawTensor) | |||||
self.outputs.append(NodeMixin.get_wrapped_type(i)(self)) | self.outputs.append(NodeMixin.get_wrapped_type(i)(self)) | ||||
for i, node in zip(outputs, self.outputs,): | for i, node in zip(outputs, self.outputs,): | ||||
NodeMixin.wrap_safe(i, node) | NodeMixin.wrap_safe(i, node) | ||||
@classmethod | |||||
def get_args_node(cls, arg): | |||||
""" | |||||
Create nodes by ``arg``, which may be a container. | |||||
Return the same structure with arg. | |||||
If ``arg`` was not Tensor or Module, it will be stored as const. | |||||
:param arg: tensor, module or const. | |||||
""" | |||||
if isinstance(arg, (RawTensor, Module)): | |||||
if not NodeMixin.get(arg, None): | |||||
NodeMixin.wrap_safe(arg, Constant.make(arg)) | |||||
return NodeMixin.get(arg) | |||||
elif isinstance(arg, collections.abc.Sequence): | |||||
seq_cls = type(arg) | |||||
return seq_cls([Expr.get_args_node(a) for a in arg]) | |||||
def unflatten_args(self, inputs): | |||||
if self.arg_def is not None: | |||||
inputs = list(inputs) | |||||
for idx, val in self.const_val: | |||||
inputs.insert(idx, val) | |||||
args, kwargs = self.arg_def.unflatten(inputs) | |||||
return args, kwargs | |||||
else: | else: | ||||
# TODO: assert arg type | |||||
return arg # as const | |||||
return inputs, {} | |||||
@classmethod | |||||
def get_arg_value(cls, inp_node, node2value): | |||||
""" | |||||
Get values from node2value by inp_node, which may be a container. | |||||
Return the same structure with inp_node. | |||||
If ``inp_node`` was not in node2value, it is a const. | |||||
:param inp_node: nodes. | |||||
:param node2value: dict from node to tensor and module. | |||||
""" | |||||
if inp_node in node2value: | |||||
return node2value[inp_node] | |||||
elif isinstance(inp_node, collections.abc.Sequence): | |||||
seq_cls = type(inp_node) | |||||
return seq_cls([Expr.get_arg_value(i, node2value) for i in inp_node]) | |||||
else: | |||||
return inp_node | |||||
@property | |||||
def kwargs(self): | |||||
_, kwargs = self.unflatten_args(self.inputs) | |||||
return kwargs | |||||
@property | |||||
def args(self): | |||||
args, _ = self.unflatten_args(self.inputs) | |||||
return args | |||||
# expr: None (i.e. fake expression which is used to mark input) | # expr: None (i.e. fake expression which is used to mark input) | ||||
@@ -144,16 +139,8 @@ class CallMethod(Expr): | |||||
self.inputs = [ | self.inputs = [ | ||||
module, | module, | ||||
] | ] | ||||
self.const_val = [] | |||||
self.method = method | self.method = method | ||||
self.arg_names = [] | |||||
self.kwargs = {} # const kwargs | |||||
def add_input(self, node, arg_name=None): | |||||
if arg_name == "self": # FIXME: <XP> | |||||
return | |||||
self.inputs.append(node) | |||||
if arg_name is not None: | |||||
self.arg_names.append(arg_name) | |||||
@classmethod | @classmethod | ||||
def make(cls, *args, **kwargs): | def make(cls, *args, **kwargs): | ||||
@@ -162,19 +149,22 @@ class CallMethod(Expr): | |||||
return expr | return expr | ||||
def interpret(self, *inputs): | def interpret(self, *inputs): | ||||
mod = inputs[0] | |||||
args = inputs[1:] | |||||
outputs = getattr(mod, self.method)(*args, **self.kwargs) | |||||
args, kwargs = self.unflatten_args(inputs) | |||||
obj = args[0] | |||||
args = args[1:] | |||||
outputs = getattr(obj, self.method)(*args, **kwargs) | |||||
if isinstance(outputs, RawTensor): | if isinstance(outputs, RawTensor): | ||||
outputs = (outputs,) | outputs = (outputs,) | ||||
return outputs | return outputs | ||||
def __repr__(self): | def __repr__(self): | ||||
return "{} = CallMethod({}, {})({})".format( | |||||
args = ", ".join(str(i) for i in self.args[1:]) | |||||
kwargs = ", ".join("{}={}".format(k, v) for k, v in self.kwargs.items()) | |||||
return "{} = {}.{}({})".format( | |||||
", ".join(str(i) for i in self.outputs), | ", ".join(str(i) for i in self.outputs), | ||||
self.inputs[0], | self.inputs[0], | ||||
self.method, | self.method, | ||||
", ".join(str(i) for i in self.inputs[1:]), | |||||
", ".join([args, kwargs]), | |||||
) | ) | ||||
@@ -227,13 +217,8 @@ class CallFunction(Expr): | |||||
def __init__(self, func): | def __init__(self, func): | ||||
assert isinstance(func, Callable) | assert isinstance(func, Callable) | ||||
self.func = func | self.func = func | ||||
self.const_val = [] | |||||
self.inputs = [] | self.inputs = [] | ||||
self.arg_names = [] | |||||
self.kwargs = {} # const kwargs | |||||
def add_input(self, node, arg_name): | |||||
self.inputs.append(node) | |||||
self.arg_names.append(arg_name) | |||||
@classmethod | @classmethod | ||||
def make(cls, *args, **kwargs): | def make(cls, *args, **kwargs): | ||||
@@ -242,18 +227,20 @@ class CallFunction(Expr): | |||||
return expr | return expr | ||||
def interpret(self, *inputs): | def interpret(self, *inputs): | ||||
inp_dict = dict([(name, node) for node, name in zip(inputs, self.arg_names)]) | |||||
outputs = self.func(**inp_dict, **self.kwargs) | |||||
args, kwargs = self.unflatten_args(inputs) | |||||
outputs = self.func(*args, **kwargs) | |||||
outputs = ( | outputs = ( | ||||
outputs if isinstance(outputs, collections.abc.Sequence) else (outputs,) | outputs if isinstance(outputs, collections.abc.Sequence) else (outputs,) | ||||
) | ) | ||||
return outputs | return outputs | ||||
def __repr__(self): | def __repr__(self): | ||||
args = ", ".join(str(i) for i in self.args) | |||||
kwargs = ", ".join("{}={}".format(k, v) for k, v in self.kwargs.items()) | |||||
return "{} = {}({})".format( | return "{} = {}({})".format( | ||||
", ".join(str(i) for i in self.outputs), | ", ".join(str(i) for i in self.outputs), | ||||
self.func.__module__ + "." + self.func.__name__, | self.func.__module__ + "." + self.func.__name__, | ||||
", ".join(str(i) for i in self.inputs), | |||||
", ".join([args, kwargs]), | |||||
) | ) | ||||
@@ -15,6 +15,72 @@ from ...module import Module | |||||
_active_module_tracer = None | _active_module_tracer = None | ||||
BUILTIN_ARRAY_METHOD = [ | |||||
"__lt__", | |||||
"__le__", | |||||
"__gt__", | |||||
"__ge__", | |||||
"__eq__", | |||||
"__ne__", | |||||
"__neg__", | |||||
"__pos__", | |||||
"__abs__", | |||||
"__invert__", | |||||
"__round__", | |||||
"__floor__", | |||||
"__ceil__", | |||||
"__add__", | |||||
"__sub__", | |||||
"__mul__", | |||||
"__matmul__", | |||||
"__truediv__", | |||||
"__floordiv__", | |||||
"__mod__", | |||||
"__pow__", | |||||
"__lshift__", | |||||
"__rshift__", | |||||
"__and__", | |||||
"__or__", | |||||
"__xor__", | |||||
"__radd__", | |||||
"__rsub__", | |||||
"__rmul__", | |||||
"__rmatmul__", | |||||
"__rtruediv__", | |||||
"__rfloordiv__", | |||||
"__rmod__", | |||||
"__rpow__", | |||||
"__rlshift__", | |||||
"__rrshift__", | |||||
"__rand__", | |||||
"__ror__", | |||||
"__rxor__", | |||||
"__iadd__", | |||||
"__isub__", | |||||
"__imul__", | |||||
"__imatmul__", | |||||
"__itruediv__", | |||||
"__ifloordiv__", | |||||
"__imod__", | |||||
"__ipow__", | |||||
"__ilshift__", | |||||
"__irshift__", | |||||
"__iand__", | |||||
"__ior__", | |||||
"__ixor__", | |||||
"T", | |||||
"astype", | |||||
"reshape", | |||||
"_broadcast", | |||||
"transpose", | |||||
"flatten", | |||||
"sum", | |||||
"prod", | |||||
"min", | |||||
"max", | |||||
"mean", | |||||
] | |||||
def active_module_tracer(): | def active_module_tracer(): | ||||
return _active_module_tracer | return _active_module_tracer | ||||
@@ -108,9 +174,8 @@ class Patcher: | |||||
self.wrap_fn = wrap_fn | self.wrap_fn = wrap_fn | ||||
for module in self._builtin_modules: | for module in self._builtin_modules: | ||||
self.patch_module(module) | self.patch_module(module) | ||||
for cls in self._builtin_methods: | |||||
self.patch_cls(cls) | |||||
for meth in BUILTIN_ARRAY_METHOD: | |||||
self.patch_method(ArrayMethodMixin, meth, self.wrap_fn) | |||||
for i, j in self._builtin_functions: | for i, j in self._builtin_functions: | ||||
if id(i) not in self.visited_frames_ids: | if id(i) not in self.visited_frames_ids: | ||||
@@ -13,6 +13,7 @@ import numpy | |||||
from ...core._imperative_rt.core2 import Tensor as RawTensor | from ...core._imperative_rt.core2 import Tensor as RawTensor | ||||
from ...module import Module | from ...module import Module | ||||
from ...tensor import Tensor | from ...tensor import Tensor | ||||
from .pytree import TreeDef | |||||
class Node: | class Node: | ||||
@@ -58,6 +59,7 @@ class ModuleNode(Node): | |||||
module_type = Module # type: Type[Module] | module_type = Module # type: Type[Module] | ||||
graph = None | graph = None | ||||
attr_type_map = None # type: Dict[str, Type[Any]] | attr_type_map = None # type: Dict[str, Type[Any]] | ||||
arg_def = None # type: TreeDef | |||||
def __repr__(self): | def __repr__(self): | ||||
if self._name is None: | if self._name is None: | ||||
@@ -0,0 +1,80 @@ | |||||
from typing import Callable, NamedTuple | |||||
SUPPORTED_TYPE = {} | |||||
NodeType = NamedTuple("NodeType", [("flatten", Callable), ("unflatten", Callable)]) | |||||
def register_supported_type(type, flatten, unflatten): | |||||
SUPPORTED_TYPE[type] = NodeType(flatten, unflatten) | |||||
register_supported_type(list, lambda x: (x, None), lambda x, aux_data: list(x)) | |||||
register_supported_type(tuple, lambda x: (x, None), lambda x, aux_data: list(x)) | |||||
register_supported_type( | |||||
dict, lambda x: (list(x.values()), list(x.keys())), lambda x, y: dict(zip(y, x)) | |||||
) | |||||
register_supported_type( | |||||
slice, | |||||
lambda x: ([x.start, x.stop, x.step], None), | |||||
lambda x, aux_data: slice(x[0], x[1], x[2]), | |||||
) | |||||
def tree_flatten( | |||||
values, leaf_type: Callable = lambda x: type(x), is_leaf: Callable = lambda x: True | |||||
): | |||||
if type(values) not in SUPPORTED_TYPE: | |||||
assert is_leaf(values) | |||||
return [values,], LeafDef(leaf_type(values)) | |||||
rst = [] | |||||
children_defs = [] | |||||
children_values, aux_data = SUPPORTED_TYPE[type(values)].flatten(values) | |||||
for v in children_values: | |||||
v_list, treedef = tree_flatten(v, leaf_type) | |||||
rst.extend(v_list) | |||||
children_defs.append(treedef) | |||||
return rst, TreeDef(type(values), aux_data, children_defs) | |||||
class TreeDef: | |||||
def __init__(self, type, aux_data, children_defs): | |||||
self.type = type | |||||
self.aux_data = aux_data | |||||
self.children_defs = children_defs | |||||
self.num_leaves = sum(ch.num_leaves for ch in children_defs) | |||||
def unflatten(self, leaves): | |||||
assert len(leaves) == self.num_leaves | |||||
start = 0 | |||||
children = [] | |||||
for ch in self.children_defs: | |||||
children.append(ch.unflatten(leaves[start : start + ch.num_leaves])) | |||||
start += ch.num_leaves | |||||
return SUPPORTED_TYPE[self.type].unflatten(children, self.aux_data) | |||||
def __eq__(self, other): | |||||
return ( | |||||
self.type == other.type | |||||
and self.aux_data == other.aux_data | |||||
and self.num_leaves == other.num_leaves | |||||
and self.children_defs == other.children_defs | |||||
) | |||||
def __repr__(self): | |||||
return "{}[{}]".format(self.type.__name__, self.children_defs) | |||||
class LeafDef(TreeDef): | |||||
def __init__(self, type): | |||||
super().__init__(type, None, []) | |||||
self.num_leaves = 1 | |||||
def unflatten(self, leaves): | |||||
assert len(leaves) == 1 | |||||
assert isinstance(leaves[0], self.type), self.type | |||||
return leaves[0] | |||||
def __repr__(self): | |||||
return "Leaf({})".format(self.type.__name__) |
@@ -9,9 +9,11 @@ | |||||
import collections | import collections | ||||
import copy | import copy | ||||
import functools | import functools | ||||
from inspect import getmembers, isclass, ismethod | |||||
from typing import List, Type | from typing import List, Type | ||||
from ... import module as M | from ... import module as M | ||||
from ...core._imperative_rt.core2 import Tensor as RawTensor | |||||
from ...core._imperative_rt.core2 import ( | from ...core._imperative_rt.core2 import ( | ||||
is_tracing_module, | is_tracing_module, | ||||
set_module_tracing, | set_module_tracing, | ||||
@@ -28,6 +30,16 @@ from .module_tracer import ( | |||||
set_active_module_tracer, | set_active_module_tracer, | ||||
) | ) | ||||
from .node import ModuleNode, Node, NodeMixin, TensorNode | from .node import ModuleNode, Node, NodeMixin, TensorNode | ||||
from .pytree import tree_flatten | |||||
def _leaf_type(node): | |||||
if isinstance(node, RawTensor): | |||||
return (Tensor, TensorNode) | |||||
elif isinstance(node, (NodeMixin, Module)): | |||||
return (Module, ModuleNode, NodeMixin) | |||||
else: | |||||
return type(node) | |||||
class InternalGraph: | class InternalGraph: | ||||
@@ -65,9 +77,7 @@ class InternalGraph: | |||||
for n, v in zip(self._inputs, inputs): | for n, v in zip(self._inputs, inputs): | ||||
node2value[n] = v | node2value[n] = v | ||||
for expr in self._exprs: | for expr in self._exprs: | ||||
values = expr.interpret( | |||||
*list(Expr.get_arg_value(i, node2value) for i in expr.inputs) | |||||
) | |||||
values = expr.interpret(*list(node2value[i] for i in expr.inputs)) | |||||
for n, v in zip(expr.outputs, values): | for n, v in zip(expr.outputs, values): | ||||
node2value[n] = v | node2value[n] = v | ||||
return list(node2value[i] for i in self._outputs) | return list(node2value[i] for i in self._outputs) | ||||
@@ -80,37 +90,39 @@ class InternalGraph: | |||||
) | ) | ||||
def _get_meth_name(obj, func): | |||||
for cls in type(obj).mro(): | |||||
for k, v in cls.__dict__.items(): | |||||
if v == func: | |||||
return k | |||||
return None | |||||
def _wrapped_function(orig_func): | def _wrapped_function(orig_func): | ||||
@functools.wraps(orig_func) | @functools.wraps(orig_func) | ||||
def wrapped_fn(*inputs, **kwargs): | |||||
def wrapped_fn(*args, **kwargs): | |||||
if is_tracing_module(): | if is_tracing_module(): | ||||
unset_module_tracing() | unset_module_tracing() | ||||
const_kwargs = {} | |||||
arg_names = orig_func.__code__.co_varnames | |||||
if orig_func.__qualname__.split(".").__len__() > 1: | |||||
# FIXME: a robust way to distinguish method and function. <XP> | |||||
inputs, tree_def = tree_flatten((args, kwargs), leaf_type=_leaf_type) | |||||
for i in inputs: | |||||
if not NodeMixin.get(i, None): | |||||
if isinstance(i, (RawTensor, NodeMixin)): | |||||
NodeMixin.wrap_safe(i, Constant.make(i)) | |||||
meth_name = _get_meth_name(args[0], wrapped_fn) | |||||
if meth_name: | |||||
self = inputs[0] | self = inputs[0] | ||||
call_node = CallMethod.make(NodeMixin.get(self), orig_func.__name__) | |||||
call_node = CallMethod.make(NodeMixin.get(self), meth_name) | |||||
else: | else: | ||||
call_node = CallFunction.make(orig_func) | call_node = CallFunction.make(orig_func) | ||||
def add_input(inp, varname=None): | |||||
node = Expr.get_args_node(inp) | |||||
if node is not None: | |||||
call_node.add_input(node, varname) | |||||
else: | |||||
const_kwargs[varname] = inp | |||||
for ind, inp in enumerate(inputs): | |||||
add_input(inp, arg_names[ind]) | |||||
for k, v in kwargs.items(): | |||||
add_input(v, k) | |||||
call_node.kwargs = const_kwargs | |||||
outputs = orig_func(*inputs, **kwargs) | |||||
call_node.add_inputs(inputs) | |||||
call_node.arg_def = tree_def | |||||
outputs = orig_func(*args, **kwargs) | |||||
call_node.add_outputs(outputs) | call_node.add_outputs(outputs) | ||||
set_module_tracing() | set_module_tracing() | ||||
return outputs | return outputs | ||||
return orig_func(*inputs, **kwargs) | |||||
return orig_func(*args, **kwargs) | |||||
return wrapped_fn | return wrapped_fn | ||||
@@ -120,14 +132,14 @@ class TracedModuleBuilder(NodeMixin): | |||||
_mod = None # type: Module | _mod = None # type: Module | ||||
_body = None # type: InternalGraph | _body = None # type: InternalGraph | ||||
_is_builtin = None # type: bool | _is_builtin = None # type: bool | ||||
_arg_def = None # type: TreeDef | |||||
__builder_attributes__ = [ | __builder_attributes__ = [ | ||||
"_mod", | "_mod", | ||||
"_body", | "_body", | ||||
"_NodeMixin__node", | "_NodeMixin__node", | ||||
"_is_builtin", | "_is_builtin", | ||||
"_is_traced", | "_is_traced", | ||||
"build", | |||||
"_arg_def" "build", | |||||
] | ] | ||||
def __init__(self, mod): | def __init__(self, mod): | ||||
@@ -146,6 +158,7 @@ class TracedModuleBuilder(NodeMixin): | |||||
node = NodeMixin.get(self) | node = NodeMixin.get(self) | ||||
node.graph = self._body | node.graph = self._body | ||||
node.attr_type_map = {} | node.attr_type_map = {} | ||||
node.arg_def = self._arg_def | |||||
traced_module = TracedModule(node) | traced_module = TracedModule(node) | ||||
for k, v in self.__dict__.items(): | for k, v in self.__dict__.items(): | ||||
if k not in TracedModuleBuilder.__builder_attributes__: | if k not in TracedModuleBuilder.__builder_attributes__: | ||||
@@ -155,32 +168,34 @@ class TracedModuleBuilder(NodeMixin): | |||||
traced_module.m_node.attr_type_map[k] = type(v) | traced_module.m_node.attr_type_map[k] = type(v) | ||||
return traced_module | return traced_module | ||||
def __call__(self, *inputs, **kwargs): | |||||
def __call__(self, *args, **kwargs): | |||||
assert isinstance(self._mod, Module) | assert isinstance(self._mod, Module) | ||||
for arg in args: | |||||
assert isinstance(arg, RawTensor) | |||||
for k, v in kwargs.items(): | |||||
assert isinstance(v, RawTensor) | |||||
# prepare args and kwargs for inner graph | # prepare args and kwargs for inner graph | ||||
def mark_constant(x): | def mark_constant(x): | ||||
node = NodeMixin.get(x, None) | node = NodeMixin.get(x, None) | ||||
if node is None: # capture as constant | if node is None: # capture as constant | ||||
NodeMixin.wrap(x, lambda: Constant.make(x)) | NodeMixin.wrap(x, lambda: Constant.make(x)) | ||||
inputs, tree_def = tree_flatten(((self, *args), kwargs), leaf_type=_leaf_type) | |||||
if self._arg_def is None: | |||||
self._arg_def = tree_def | |||||
assert self._arg_def == tree_def | |||||
for i in inputs: | for i in inputs: | ||||
mark_constant(i) | mark_constant(i) | ||||
for k, v in kwargs.items(): | |||||
mark_constant(v) | |||||
callnode = CallMethod.make(NodeMixin.get(self)) | callnode = CallMethod.make(NodeMixin.get(self)) | ||||
def add_input(x): | |||||
callnode.add_input(NodeMixin.get(x)) | |||||
callnode.add_inputs(inputs) | |||||
for i in inputs: | |||||
add_input(i) | |||||
for k, v in kwargs.items(): | |||||
add_input(v) | |||||
callnode.arg_def = tree_def | |||||
if self._is_builtin or self._is_traced: | if self._is_builtin or self._is_traced: | ||||
unset_module_tracing() | unset_module_tracing() | ||||
outputs = self._mod(*inputs, **kwargs) | |||||
outputs = self._mod(*args, **kwargs) | |||||
set_module_tracing() | set_module_tracing() | ||||
if self._is_builtin: | if self._is_builtin: | ||||
self._body = None | self._body = None | ||||
@@ -193,23 +208,21 @@ class TracedModuleBuilder(NodeMixin): | |||||
) | ) | ||||
# prepare args and kwargs for inner graph | # prepare args and kwargs for inner graph | ||||
def wrap(x): | def wrap(x): | ||||
# wrapped = copy.copy(x) # FIXME | |||||
wrapped = x # FIXME: <XP> | |||||
wrapped = copy.copy(x) # FIXME | |||||
NodeMixin.wrap( | NodeMixin.wrap( | ||||
wrapped, | wrapped, | ||||
lambda: Input.make(type=NodeMixin.get_wrapped_type(wrapped)), | lambda: Input.make(type=NodeMixin.get_wrapped_type(wrapped)), | ||||
) | ) | ||||
return wrapped | return wrapped | ||||
args = [] | |||||
for i in inputs: | |||||
args = [self] | |||||
for i in inputs[1:]: | |||||
args.append(wrap(i)) | args.append(wrap(i)) | ||||
for k, v in kwargs.items(): | |||||
kwargs[k] = wrap(v) | |||||
args, kwargs = tree_def.unflatten(args) | |||||
active_module_tracer().patcher.auto_patch( | active_module_tracer().patcher.auto_patch( | ||||
getattr(getattr(self._mod, "forward", self._mod), "__globals__", {}) | getattr(getattr(self._mod, "forward", self._mod), "__globals__", {}) | ||||
) | ) | ||||
outputs = type(self._mod).forward(self, *args, **kwargs) | |||||
outputs = type(self._mod).forward(*args, **kwargs) | |||||
for i in ( | for i in ( | ||||
outputs if isinstance(outputs, collections.abc.Sequence) else (outputs,) | outputs if isinstance(outputs, collections.abc.Sequence) else (outputs,) | ||||
@@ -269,8 +282,10 @@ class TracedModule(Module): | |||||
super(TracedModule, self).__init__() | super(TracedModule, self).__init__() | ||||
self.m_node = node | self.m_node = node | ||||
def forward(self, *inputs): | |||||
rst = self.m_node.graph.interpret(self, *inputs) | |||||
def forward(self, *args, **kwargs): | |||||
inputs, treedef = tree_flatten(((self, *args), kwargs), leaf_type=_leaf_type) | |||||
assert treedef == self.m_node.arg_def | |||||
rst = self.m_node.graph.interpret(*inputs) | |||||
if len(rst) == 1: | if len(rst) == 1: | ||||
rst = rst[0] | rst = rst[0] | ||||
return rst | return rst | ||||
@@ -345,7 +360,6 @@ def register_as_builtin(mod_cls: Type[Module]) -> None: | |||||
def _register_all_builtin_module(): | def _register_all_builtin_module(): | ||||
from inspect import getmembers, isclass | |||||
for sub_mod in [M, M.qat, M.quantized]: | for sub_mod in [M, M.qat, M.quantized]: | ||||
for m in getmembers(sub_mod): | for m in getmembers(sub_mod): | ||||
@@ -357,7 +371,7 @@ def _register_all_builtin_module(): | |||||
module_tracer.register_as_builtin(m[1]) | module_tracer.register_as_builtin(m[1]) | ||||
def trace_module(mod: Module, *inputs: Tensor, **kwargs: Tensor) -> TracedModule: | |||||
def trace_module(mod: Module, *args: Tensor, **kwargs: Tensor) -> TracedModule: | |||||
""" | """ | ||||
Traces module ``mod`` and returns corresponding TracedModule. | Traces module ``mod`` and returns corresponding TracedModule. | ||||
@@ -375,15 +389,13 @@ def trace_module(mod: Module, *inputs: Tensor, **kwargs: Tensor) -> TracedModule | |||||
builder = TracedModuleBuilder(mod) | builder = TracedModuleBuilder(mod) | ||||
NodeMixin.wrap_safe(builder, Input.make("TopModule", ModuleNode)) | NodeMixin.wrap_safe(builder, Input.make("TopModule", ModuleNode)) | ||||
inputs, _ = tree_flatten((args, kwargs)) | |||||
for _, i in enumerate(inputs): | for _, i in enumerate(inputs): | ||||
NodeMixin.wrap_safe(i, Input.make("arg_{}".format(_))) | |||||
for k, v in kwargs.items(): | |||||
NodeMixin.wrap_safe(v, Input.make("kwarg_{}".format(k))) | |||||
builder(*inputs, **kwargs) | |||||
NodeMixin.wrap_safe( | |||||
i, Input.make("arg_{}".format(_), NodeMixin.get_wrapped_type(i)) | |||||
) | |||||
builder(*args, **kwargs) | |||||
active_module_tracer().pop_scope() | active_module_tracer().pop_scope() | ||||
return builder.build() | return builder.build() | ||||
finally: | finally: | ||||
set_active_module_tracer(None) | set_active_module_tracer(None) | ||||