@@ -7,7 +7,7 @@ | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import builtins | |||
import collections | |||
from typing import Callable, List | |||
@@ -19,6 +19,7 @@ from ...module import Module | |||
from ...tensor import Tensor | |||
from .module_tracer import active_module_tracer | |||
from .node import ModuleNode, Node, NodeMixin, TensorNode | |||
from .pytree import TreeDef | |||
class Expr: | |||
@@ -28,9 +29,22 @@ class Expr: | |||
inputs = None # type: List[Node] | |||
outputs = None # type: List[Node] | |||
def add_input(self, node): | |||
self.inputs.append(node) | |||
const_val = None # type: List[Any] | |||
arg_def = None # type: TreeDef | |||
def add_inputs(self, vals): | |||
if not isinstance(vals, collections.abc.Sequence): | |||
vals = (vals,) | |||
for val in vals: | |||
node = NodeMixin.get(val, None) | |||
if isinstance(node, (TensorNode, ModuleNode)): | |||
if node not in self.inputs: | |||
self.inputs.append(node) | |||
else: | |||
assert node is None | |||
assert type(val) in builtins.__dict__.values() | |||
idx = len(self.inputs) + len(self.const_val) | |||
self.const_val.append((idx, val)) | |||
def add_outputs(self, outputs): | |||
self.outputs = [] | |||
@@ -38,50 +52,31 @@ class Expr: | |||
outputs = (outputs,) | |||
for i in outputs: | |||
assert isinstance(i, RawTensor) | |||
self.outputs.append(NodeMixin.get_wrapped_type(i)(self)) | |||
for i, node in zip(outputs, self.outputs,): | |||
NodeMixin.wrap_safe(i, node) | |||
@classmethod | |||
def get_args_node(cls, arg): | |||
""" | |||
Create nodes by ``arg``, which may be a container. | |||
Return the same structure with arg. | |||
If ``arg`` was not Tensor or Module, it will be stored as const. | |||
:param arg: tensor, module or const. | |||
""" | |||
if isinstance(arg, (RawTensor, Module)): | |||
if not NodeMixin.get(arg, None): | |||
NodeMixin.wrap_safe(arg, Constant.make(arg)) | |||
return NodeMixin.get(arg) | |||
elif isinstance(arg, collections.abc.Sequence): | |||
seq_cls = type(arg) | |||
return seq_cls([Expr.get_args_node(a) for a in arg]) | |||
def unflatten_args(self, inputs): | |||
if self.arg_def is not None: | |||
inputs = list(inputs) | |||
for idx, val in self.const_val: | |||
inputs.insert(idx, val) | |||
args, kwargs = self.arg_def.unflatten(inputs) | |||
return args, kwargs | |||
else: | |||
# TODO: assert arg type | |||
return arg # as const | |||
return inputs, {} | |||
@classmethod | |||
def get_arg_value(cls, inp_node, node2value): | |||
""" | |||
Get values from node2value by inp_node, which may be a container. | |||
Return the same structure with inp_node. | |||
If ``inp_node`` was not in node2value, it is a const. | |||
:param inp_node: nodes. | |||
:param node2value: dict from node to tensor and module. | |||
""" | |||
if inp_node in node2value: | |||
return node2value[inp_node] | |||
elif isinstance(inp_node, collections.abc.Sequence): | |||
seq_cls = type(inp_node) | |||
return seq_cls([Expr.get_arg_value(i, node2value) for i in inp_node]) | |||
else: | |||
return inp_node | |||
@property | |||
def kwargs(self): | |||
_, kwargs = self.unflatten_args(self.inputs) | |||
return kwargs | |||
@property | |||
def args(self): | |||
args, _ = self.unflatten_args(self.inputs) | |||
return args | |||
# expr: None (i.e. fake expression which is used to mark input) | |||
@@ -144,16 +139,8 @@ class CallMethod(Expr): | |||
self.inputs = [ | |||
module, | |||
] | |||
self.const_val = [] | |||
self.method = method | |||
self.arg_names = [] | |||
self.kwargs = {} # const kwargs | |||
def add_input(self, node, arg_name=None): | |||
if arg_name == "self": # FIXME: <XP> | |||
return | |||
self.inputs.append(node) | |||
if arg_name is not None: | |||
self.arg_names.append(arg_name) | |||
@classmethod | |||
def make(cls, *args, **kwargs): | |||
@@ -162,19 +149,22 @@ class CallMethod(Expr): | |||
return expr | |||
def interpret(self, *inputs): | |||
mod = inputs[0] | |||
args = inputs[1:] | |||
outputs = getattr(mod, self.method)(*args, **self.kwargs) | |||
args, kwargs = self.unflatten_args(inputs) | |||
obj = args[0] | |||
args = args[1:] | |||
outputs = getattr(obj, self.method)(*args, **kwargs) | |||
if isinstance(outputs, RawTensor): | |||
outputs = (outputs,) | |||
return outputs | |||
def __repr__(self): | |||
return "{} = CallMethod({}, {})({})".format( | |||
args = ", ".join(str(i) for i in self.args[1:]) | |||
kwargs = ", ".join("{}={}".format(k, v) for k, v in self.kwargs.items()) | |||
return "{} = {}.{}({})".format( | |||
", ".join(str(i) for i in self.outputs), | |||
self.inputs[0], | |||
self.method, | |||
", ".join(str(i) for i in self.inputs[1:]), | |||
", ".join([args, kwargs]), | |||
) | |||
@@ -227,13 +217,8 @@ class CallFunction(Expr): | |||
def __init__(self, func): | |||
assert isinstance(func, Callable) | |||
self.func = func | |||
self.const_val = [] | |||
self.inputs = [] | |||
self.arg_names = [] | |||
self.kwargs = {} # const kwargs | |||
def add_input(self, node, arg_name): | |||
self.inputs.append(node) | |||
self.arg_names.append(arg_name) | |||
@classmethod | |||
def make(cls, *args, **kwargs): | |||
@@ -242,18 +227,20 @@ class CallFunction(Expr): | |||
return expr | |||
def interpret(self, *inputs): | |||
inp_dict = dict([(name, node) for node, name in zip(inputs, self.arg_names)]) | |||
outputs = self.func(**inp_dict, **self.kwargs) | |||
args, kwargs = self.unflatten_args(inputs) | |||
outputs = self.func(*args, **kwargs) | |||
outputs = ( | |||
outputs if isinstance(outputs, collections.abc.Sequence) else (outputs,) | |||
) | |||
return outputs | |||
def __repr__(self): | |||
args = ", ".join(str(i) for i in self.args) | |||
kwargs = ", ".join("{}={}".format(k, v) for k, v in self.kwargs.items()) | |||
return "{} = {}({})".format( | |||
", ".join(str(i) for i in self.outputs), | |||
self.func.__module__ + "." + self.func.__name__, | |||
", ".join(str(i) for i in self.inputs), | |||
", ".join([args, kwargs]), | |||
) | |||
@@ -15,6 +15,72 @@ from ...module import Module | |||
_active_module_tracer = None | |||
BUILTIN_ARRAY_METHOD = [ | |||
"__lt__", | |||
"__le__", | |||
"__gt__", | |||
"__ge__", | |||
"__eq__", | |||
"__ne__", | |||
"__neg__", | |||
"__pos__", | |||
"__abs__", | |||
"__invert__", | |||
"__round__", | |||
"__floor__", | |||
"__ceil__", | |||
"__add__", | |||
"__sub__", | |||
"__mul__", | |||
"__matmul__", | |||
"__truediv__", | |||
"__floordiv__", | |||
"__mod__", | |||
"__pow__", | |||
"__lshift__", | |||
"__rshift__", | |||
"__and__", | |||
"__or__", | |||
"__xor__", | |||
"__radd__", | |||
"__rsub__", | |||
"__rmul__", | |||
"__rmatmul__", | |||
"__rtruediv__", | |||
"__rfloordiv__", | |||
"__rmod__", | |||
"__rpow__", | |||
"__rlshift__", | |||
"__rrshift__", | |||
"__rand__", | |||
"__ror__", | |||
"__rxor__", | |||
"__iadd__", | |||
"__isub__", | |||
"__imul__", | |||
"__imatmul__", | |||
"__itruediv__", | |||
"__ifloordiv__", | |||
"__imod__", | |||
"__ipow__", | |||
"__ilshift__", | |||
"__irshift__", | |||
"__iand__", | |||
"__ior__", | |||
"__ixor__", | |||
"T", | |||
"astype", | |||
"reshape", | |||
"_broadcast", | |||
"transpose", | |||
"flatten", | |||
"sum", | |||
"prod", | |||
"min", | |||
"max", | |||
"mean", | |||
] | |||
def active_module_tracer(): | |||
return _active_module_tracer | |||
@@ -108,9 +174,8 @@ class Patcher: | |||
self.wrap_fn = wrap_fn | |||
for module in self._builtin_modules: | |||
self.patch_module(module) | |||
for cls in self._builtin_methods: | |||
self.patch_cls(cls) | |||
for meth in BUILTIN_ARRAY_METHOD: | |||
self.patch_method(ArrayMethodMixin, meth, self.wrap_fn) | |||
for i, j in self._builtin_functions: | |||
if id(i) not in self.visited_frames_ids: | |||
@@ -13,6 +13,7 @@ import numpy | |||
from ...core._imperative_rt.core2 import Tensor as RawTensor | |||
from ...module import Module | |||
from ...tensor import Tensor | |||
from .pytree import TreeDef | |||
class Node: | |||
@@ -58,6 +59,7 @@ class ModuleNode(Node): | |||
module_type = Module # type: Type[Module] | |||
graph = None | |||
attr_type_map = None # type: Dict[str, Type[Any]] | |||
arg_def = None # type: TreeDef | |||
def __repr__(self): | |||
if self._name is None: | |||
@@ -0,0 +1,80 @@ | |||
from typing import Callable, NamedTuple | |||
SUPPORTED_TYPE = {} | |||
NodeType = NamedTuple("NodeType", [("flatten", Callable), ("unflatten", Callable)]) | |||
def register_supported_type(type, flatten, unflatten): | |||
SUPPORTED_TYPE[type] = NodeType(flatten, unflatten) | |||
register_supported_type(list, lambda x: (x, None), lambda x, aux_data: list(x)) | |||
register_supported_type(tuple, lambda x: (x, None), lambda x, aux_data: list(x)) | |||
register_supported_type( | |||
dict, lambda x: (list(x.values()), list(x.keys())), lambda x, y: dict(zip(y, x)) | |||
) | |||
register_supported_type( | |||
slice, | |||
lambda x: ([x.start, x.stop, x.step], None), | |||
lambda x, aux_data: slice(x[0], x[1], x[2]), | |||
) | |||
def tree_flatten( | |||
values, leaf_type: Callable = lambda x: type(x), is_leaf: Callable = lambda x: True | |||
): | |||
if type(values) not in SUPPORTED_TYPE: | |||
assert is_leaf(values) | |||
return [values,], LeafDef(leaf_type(values)) | |||
rst = [] | |||
children_defs = [] | |||
children_values, aux_data = SUPPORTED_TYPE[type(values)].flatten(values) | |||
for v in children_values: | |||
v_list, treedef = tree_flatten(v, leaf_type) | |||
rst.extend(v_list) | |||
children_defs.append(treedef) | |||
return rst, TreeDef(type(values), aux_data, children_defs) | |||
class TreeDef: | |||
def __init__(self, type, aux_data, children_defs): | |||
self.type = type | |||
self.aux_data = aux_data | |||
self.children_defs = children_defs | |||
self.num_leaves = sum(ch.num_leaves for ch in children_defs) | |||
def unflatten(self, leaves): | |||
assert len(leaves) == self.num_leaves | |||
start = 0 | |||
children = [] | |||
for ch in self.children_defs: | |||
children.append(ch.unflatten(leaves[start : start + ch.num_leaves])) | |||
start += ch.num_leaves | |||
return SUPPORTED_TYPE[self.type].unflatten(children, self.aux_data) | |||
def __eq__(self, other): | |||
return ( | |||
self.type == other.type | |||
and self.aux_data == other.aux_data | |||
and self.num_leaves == other.num_leaves | |||
and self.children_defs == other.children_defs | |||
) | |||
def __repr__(self): | |||
return "{}[{}]".format(self.type.__name__, self.children_defs) | |||
class LeafDef(TreeDef): | |||
def __init__(self, type): | |||
super().__init__(type, None, []) | |||
self.num_leaves = 1 | |||
def unflatten(self, leaves): | |||
assert len(leaves) == 1 | |||
assert isinstance(leaves[0], self.type), self.type | |||
return leaves[0] | |||
def __repr__(self): | |||
return "Leaf({})".format(self.type.__name__) |
@@ -9,9 +9,11 @@ | |||
import collections | |||
import copy | |||
import functools | |||
from inspect import getmembers, isclass, ismethod | |||
from typing import List, Type | |||
from ... import module as M | |||
from ...core._imperative_rt.core2 import Tensor as RawTensor | |||
from ...core._imperative_rt.core2 import ( | |||
is_tracing_module, | |||
set_module_tracing, | |||
@@ -28,6 +30,16 @@ from .module_tracer import ( | |||
set_active_module_tracer, | |||
) | |||
from .node import ModuleNode, Node, NodeMixin, TensorNode | |||
from .pytree import tree_flatten | |||
def _leaf_type(node): | |||
if isinstance(node, RawTensor): | |||
return (Tensor, TensorNode) | |||
elif isinstance(node, (NodeMixin, Module)): | |||
return (Module, ModuleNode, NodeMixin) | |||
else: | |||
return type(node) | |||
class InternalGraph: | |||
@@ -65,9 +77,7 @@ class InternalGraph: | |||
for n, v in zip(self._inputs, inputs): | |||
node2value[n] = v | |||
for expr in self._exprs: | |||
values = expr.interpret( | |||
*list(Expr.get_arg_value(i, node2value) for i in expr.inputs) | |||
) | |||
values = expr.interpret(*list(node2value[i] for i in expr.inputs)) | |||
for n, v in zip(expr.outputs, values): | |||
node2value[n] = v | |||
return list(node2value[i] for i in self._outputs) | |||
@@ -80,37 +90,39 @@ class InternalGraph: | |||
) | |||
def _get_meth_name(obj, func): | |||
for cls in type(obj).mro(): | |||
for k, v in cls.__dict__.items(): | |||
if v == func: | |||
return k | |||
return None | |||
def _wrapped_function(orig_func): | |||
@functools.wraps(orig_func) | |||
def wrapped_fn(*inputs, **kwargs): | |||
def wrapped_fn(*args, **kwargs): | |||
if is_tracing_module(): | |||
unset_module_tracing() | |||
const_kwargs = {} | |||
arg_names = orig_func.__code__.co_varnames | |||
if orig_func.__qualname__.split(".").__len__() > 1: | |||
# FIXME: a robust way to distinguish method and function. <XP> | |||
inputs, tree_def = tree_flatten((args, kwargs), leaf_type=_leaf_type) | |||
for i in inputs: | |||
if not NodeMixin.get(i, None): | |||
if isinstance(i, (RawTensor, NodeMixin)): | |||
NodeMixin.wrap_safe(i, Constant.make(i)) | |||
meth_name = _get_meth_name(args[0], wrapped_fn) | |||
if meth_name: | |||
self = inputs[0] | |||
call_node = CallMethod.make(NodeMixin.get(self), orig_func.__name__) | |||
call_node = CallMethod.make(NodeMixin.get(self), meth_name) | |||
else: | |||
call_node = CallFunction.make(orig_func) | |||
def add_input(inp, varname=None): | |||
node = Expr.get_args_node(inp) | |||
if node is not None: | |||
call_node.add_input(node, varname) | |||
else: | |||
const_kwargs[varname] = inp | |||
for ind, inp in enumerate(inputs): | |||
add_input(inp, arg_names[ind]) | |||
for k, v in kwargs.items(): | |||
add_input(v, k) | |||
call_node.kwargs = const_kwargs | |||
outputs = orig_func(*inputs, **kwargs) | |||
call_node.add_inputs(inputs) | |||
call_node.arg_def = tree_def | |||
outputs = orig_func(*args, **kwargs) | |||
call_node.add_outputs(outputs) | |||
set_module_tracing() | |||
return outputs | |||
return orig_func(*inputs, **kwargs) | |||
return orig_func(*args, **kwargs) | |||
return wrapped_fn | |||
@@ -120,14 +132,14 @@ class TracedModuleBuilder(NodeMixin): | |||
_mod = None # type: Module | |||
_body = None # type: InternalGraph | |||
_is_builtin = None # type: bool | |||
_arg_def = None # type: TreeDef | |||
__builder_attributes__ = [ | |||
"_mod", | |||
"_body", | |||
"_NodeMixin__node", | |||
"_is_builtin", | |||
"_is_traced", | |||
"build", | |||
"_arg_def" "build", | |||
] | |||
def __init__(self, mod): | |||
@@ -146,6 +158,7 @@ class TracedModuleBuilder(NodeMixin): | |||
node = NodeMixin.get(self) | |||
node.graph = self._body | |||
node.attr_type_map = {} | |||
node.arg_def = self._arg_def | |||
traced_module = TracedModule(node) | |||
for k, v in self.__dict__.items(): | |||
if k not in TracedModuleBuilder.__builder_attributes__: | |||
@@ -155,32 +168,34 @@ class TracedModuleBuilder(NodeMixin): | |||
traced_module.m_node.attr_type_map[k] = type(v) | |||
return traced_module | |||
def __call__(self, *inputs, **kwargs): | |||
def __call__(self, *args, **kwargs): | |||
assert isinstance(self._mod, Module) | |||
for arg in args: | |||
assert isinstance(arg, RawTensor) | |||
for k, v in kwargs.items(): | |||
assert isinstance(v, RawTensor) | |||
# prepare args and kwargs for inner graph | |||
def mark_constant(x): | |||
node = NodeMixin.get(x, None) | |||
if node is None: # capture as constant | |||
NodeMixin.wrap(x, lambda: Constant.make(x)) | |||
inputs, tree_def = tree_flatten(((self, *args), kwargs), leaf_type=_leaf_type) | |||
if self._arg_def is None: | |||
self._arg_def = tree_def | |||
assert self._arg_def == tree_def | |||
for i in inputs: | |||
mark_constant(i) | |||
for k, v in kwargs.items(): | |||
mark_constant(v) | |||
callnode = CallMethod.make(NodeMixin.get(self)) | |||
def add_input(x): | |||
callnode.add_input(NodeMixin.get(x)) | |||
callnode.add_inputs(inputs) | |||
for i in inputs: | |||
add_input(i) | |||
for k, v in kwargs.items(): | |||
add_input(v) | |||
callnode.arg_def = tree_def | |||
if self._is_builtin or self._is_traced: | |||
unset_module_tracing() | |||
outputs = self._mod(*inputs, **kwargs) | |||
outputs = self._mod(*args, **kwargs) | |||
set_module_tracing() | |||
if self._is_builtin: | |||
self._body = None | |||
@@ -193,23 +208,21 @@ class TracedModuleBuilder(NodeMixin): | |||
) | |||
# prepare args and kwargs for inner graph | |||
def wrap(x): | |||
# wrapped = copy.copy(x) # FIXME | |||
wrapped = x # FIXME: <XP> | |||
wrapped = copy.copy(x) # FIXME | |||
NodeMixin.wrap( | |||
wrapped, | |||
lambda: Input.make(type=NodeMixin.get_wrapped_type(wrapped)), | |||
) | |||
return wrapped | |||
args = [] | |||
for i in inputs: | |||
args = [self] | |||
for i in inputs[1:]: | |||
args.append(wrap(i)) | |||
for k, v in kwargs.items(): | |||
kwargs[k] = wrap(v) | |||
args, kwargs = tree_def.unflatten(args) | |||
active_module_tracer().patcher.auto_patch( | |||
getattr(getattr(self._mod, "forward", self._mod), "__globals__", {}) | |||
) | |||
outputs = type(self._mod).forward(self, *args, **kwargs) | |||
outputs = type(self._mod).forward(*args, **kwargs) | |||
for i in ( | |||
outputs if isinstance(outputs, collections.abc.Sequence) else (outputs,) | |||
@@ -269,8 +282,10 @@ class TracedModule(Module): | |||
super(TracedModule, self).__init__() | |||
self.m_node = node | |||
def forward(self, *inputs): | |||
rst = self.m_node.graph.interpret(self, *inputs) | |||
def forward(self, *args, **kwargs): | |||
inputs, treedef = tree_flatten(((self, *args), kwargs), leaf_type=_leaf_type) | |||
assert treedef == self.m_node.arg_def | |||
rst = self.m_node.graph.interpret(*inputs) | |||
if len(rst) == 1: | |||
rst = rst[0] | |||
return rst | |||
@@ -345,7 +360,6 @@ def register_as_builtin(mod_cls: Type[Module]) -> None: | |||
def _register_all_builtin_module(): | |||
from inspect import getmembers, isclass | |||
for sub_mod in [M, M.qat, M.quantized]: | |||
for m in getmembers(sub_mod): | |||
@@ -357,7 +371,7 @@ def _register_all_builtin_module(): | |||
module_tracer.register_as_builtin(m[1]) | |||
def trace_module(mod: Module, *inputs: Tensor, **kwargs: Tensor) -> TracedModule: | |||
def trace_module(mod: Module, *args: Tensor, **kwargs: Tensor) -> TracedModule: | |||
""" | |||
Traces module ``mod`` and returns corresponding TracedModule. | |||
@@ -375,15 +389,13 @@ def trace_module(mod: Module, *inputs: Tensor, **kwargs: Tensor) -> TracedModule | |||
builder = TracedModuleBuilder(mod) | |||
NodeMixin.wrap_safe(builder, Input.make("TopModule", ModuleNode)) | |||
inputs, _ = tree_flatten((args, kwargs)) | |||
for _, i in enumerate(inputs): | |||
NodeMixin.wrap_safe(i, Input.make("arg_{}".format(_))) | |||
for k, v in kwargs.items(): | |||
NodeMixin.wrap_safe(v, Input.make("kwarg_{}".format(k))) | |||
builder(*inputs, **kwargs) | |||
NodeMixin.wrap_safe( | |||
i, Input.make("arg_{}".format(_), NodeMixin.get_wrapped_type(i)) | |||
) | |||
builder(*args, **kwargs) | |||
active_module_tracer().pop_scope() | |||
return builder.build() | |||
finally: | |||
set_active_module_tracer(None) | |||