Browse Source

feat(traced_module): add pytree

GitOrigin-RevId: 6c6e53521c
release-1.6
Megvii Engine Team 4 years ago
parent
commit
f2691566fd
5 changed files with 266 additions and 120 deletions
  1. +51
    -64
      imperative/python/megengine/experimental/traced_module/expr.py
  2. +68
    -3
      imperative/python/megengine/experimental/traced_module/module_tracer.py
  3. +2
    -0
      imperative/python/megengine/experimental/traced_module/node.py
  4. +80
    -0
      imperative/python/megengine/experimental/traced_module/pytree.py
  5. +65
    -53
      imperative/python/megengine/experimental/traced_module/traced_module.py

+ 51
- 64
imperative/python/megengine/experimental/traced_module/expr.py View File

@@ -7,7 +7,7 @@
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.


import builtins
import collections import collections
from typing import Callable, List from typing import Callable, List


@@ -19,6 +19,7 @@ from ...module import Module
from ...tensor import Tensor from ...tensor import Tensor
from .module_tracer import active_module_tracer from .module_tracer import active_module_tracer
from .node import ModuleNode, Node, NodeMixin, TensorNode from .node import ModuleNode, Node, NodeMixin, TensorNode
from .pytree import TreeDef




class Expr: class Expr:
@@ -28,9 +29,22 @@ class Expr:


inputs = None # type: List[Node] inputs = None # type: List[Node]
outputs = None # type: List[Node] outputs = None # type: List[Node]

def add_input(self, node):
self.inputs.append(node)
const_val = None # type: List[Any]
arg_def = None # type: TreeDef

def add_inputs(self, vals):
if not isinstance(vals, collections.abc.Sequence):
vals = (vals,)
for val in vals:
node = NodeMixin.get(val, None)
if isinstance(node, (TensorNode, ModuleNode)):
if node not in self.inputs:
self.inputs.append(node)
else:
assert node is None
assert type(val) in builtins.__dict__.values()
idx = len(self.inputs) + len(self.const_val)
self.const_val.append((idx, val))


def add_outputs(self, outputs): def add_outputs(self, outputs):
self.outputs = [] self.outputs = []
@@ -38,50 +52,31 @@ class Expr:
outputs = (outputs,) outputs = (outputs,)


for i in outputs: for i in outputs:
assert isinstance(i, RawTensor)
self.outputs.append(NodeMixin.get_wrapped_type(i)(self)) self.outputs.append(NodeMixin.get_wrapped_type(i)(self))


for i, node in zip(outputs, self.outputs,): for i, node in zip(outputs, self.outputs,):
NodeMixin.wrap_safe(i, node) NodeMixin.wrap_safe(i, node)


@classmethod
def get_args_node(cls, arg):
"""
Create nodes by ``arg``, which may be a container.
Return the same structure with arg.

If ``arg`` was not Tensor or Module, it will be stored as const.

:param arg: tensor, module or const.
"""
if isinstance(arg, (RawTensor, Module)):
if not NodeMixin.get(arg, None):
NodeMixin.wrap_safe(arg, Constant.make(arg))
return NodeMixin.get(arg)
elif isinstance(arg, collections.abc.Sequence):
seq_cls = type(arg)
return seq_cls([Expr.get_args_node(a) for a in arg])
def unflatten_args(self, inputs):
if self.arg_def is not None:
inputs = list(inputs)
for idx, val in self.const_val:
inputs.insert(idx, val)
args, kwargs = self.arg_def.unflatten(inputs)
return args, kwargs
else: else:
# TODO: assert arg type
return arg # as const
return inputs, {}


@classmethod
def get_arg_value(cls, inp_node, node2value):
"""
Get values from node2value by inp_node, which may be a container.
Return the same structure with inp_node.

If ``inp_node`` was not in node2value, it is a const.

:param inp_node: nodes.
:param node2value: dict from node to tensor and module.
"""
if inp_node in node2value:
return node2value[inp_node]
elif isinstance(inp_node, collections.abc.Sequence):
seq_cls = type(inp_node)
return seq_cls([Expr.get_arg_value(i, node2value) for i in inp_node])
else:
return inp_node
@property
def kwargs(self):
_, kwargs = self.unflatten_args(self.inputs)
return kwargs

@property
def args(self):
args, _ = self.unflatten_args(self.inputs)
return args




# expr: None (i.e. fake expression which is used to mark input) # expr: None (i.e. fake expression which is used to mark input)
@@ -144,16 +139,8 @@ class CallMethod(Expr):
self.inputs = [ self.inputs = [
module, module,
] ]
self.const_val = []
self.method = method self.method = method
self.arg_names = []
self.kwargs = {} # const kwargs

def add_input(self, node, arg_name=None):
if arg_name == "self": # FIXME: <XP>
return
self.inputs.append(node)
if arg_name is not None:
self.arg_names.append(arg_name)


@classmethod @classmethod
def make(cls, *args, **kwargs): def make(cls, *args, **kwargs):
@@ -162,19 +149,22 @@ class CallMethod(Expr):
return expr return expr


def interpret(self, *inputs): def interpret(self, *inputs):
mod = inputs[0]
args = inputs[1:]
outputs = getattr(mod, self.method)(*args, **self.kwargs)
args, kwargs = self.unflatten_args(inputs)
obj = args[0]
args = args[1:]
outputs = getattr(obj, self.method)(*args, **kwargs)
if isinstance(outputs, RawTensor): if isinstance(outputs, RawTensor):
outputs = (outputs,) outputs = (outputs,)
return outputs return outputs


def __repr__(self): def __repr__(self):
return "{} = CallMethod({}, {})({})".format(
args = ", ".join(str(i) for i in self.args[1:])
kwargs = ", ".join("{}={}".format(k, v) for k, v in self.kwargs.items())
return "{} = {}.{}({})".format(
", ".join(str(i) for i in self.outputs), ", ".join(str(i) for i in self.outputs),
self.inputs[0], self.inputs[0],
self.method, self.method,
", ".join(str(i) for i in self.inputs[1:]),
", ".join([args, kwargs]),
) )




@@ -227,13 +217,8 @@ class CallFunction(Expr):
def __init__(self, func): def __init__(self, func):
assert isinstance(func, Callable) assert isinstance(func, Callable)
self.func = func self.func = func
self.const_val = []
self.inputs = [] self.inputs = []
self.arg_names = []
self.kwargs = {} # const kwargs

def add_input(self, node, arg_name):
self.inputs.append(node)
self.arg_names.append(arg_name)


@classmethod @classmethod
def make(cls, *args, **kwargs): def make(cls, *args, **kwargs):
@@ -242,18 +227,20 @@ class CallFunction(Expr):
return expr return expr


def interpret(self, *inputs): def interpret(self, *inputs):
inp_dict = dict([(name, node) for node, name in zip(inputs, self.arg_names)])
outputs = self.func(**inp_dict, **self.kwargs)
args, kwargs = self.unflatten_args(inputs)
outputs = self.func(*args, **kwargs)
outputs = ( outputs = (
outputs if isinstance(outputs, collections.abc.Sequence) else (outputs,) outputs if isinstance(outputs, collections.abc.Sequence) else (outputs,)
) )
return outputs return outputs


def __repr__(self): def __repr__(self):
args = ", ".join(str(i) for i in self.args)
kwargs = ", ".join("{}={}".format(k, v) for k, v in self.kwargs.items())
return "{} = {}({})".format( return "{} = {}({})".format(
", ".join(str(i) for i in self.outputs), ", ".join(str(i) for i in self.outputs),
self.func.__module__ + "." + self.func.__name__, self.func.__module__ + "." + self.func.__name__,
", ".join(str(i) for i in self.inputs),
", ".join([args, kwargs]),
) )






+ 68
- 3
imperative/python/megengine/experimental/traced_module/module_tracer.py View File

@@ -15,6 +15,72 @@ from ...module import Module


_active_module_tracer = None _active_module_tracer = None


BUILTIN_ARRAY_METHOD = [
"__lt__",
"__le__",
"__gt__",
"__ge__",
"__eq__",
"__ne__",
"__neg__",
"__pos__",
"__abs__",
"__invert__",
"__round__",
"__floor__",
"__ceil__",
"__add__",
"__sub__",
"__mul__",
"__matmul__",
"__truediv__",
"__floordiv__",
"__mod__",
"__pow__",
"__lshift__",
"__rshift__",
"__and__",
"__or__",
"__xor__",
"__radd__",
"__rsub__",
"__rmul__",
"__rmatmul__",
"__rtruediv__",
"__rfloordiv__",
"__rmod__",
"__rpow__",
"__rlshift__",
"__rrshift__",
"__rand__",
"__ror__",
"__rxor__",
"__iadd__",
"__isub__",
"__imul__",
"__imatmul__",
"__itruediv__",
"__ifloordiv__",
"__imod__",
"__ipow__",
"__ilshift__",
"__irshift__",
"__iand__",
"__ior__",
"__ixor__",
"T",
"astype",
"reshape",
"_broadcast",
"transpose",
"flatten",
"sum",
"prod",
"min",
"max",
"mean",
]



def active_module_tracer(): def active_module_tracer():
return _active_module_tracer return _active_module_tracer
@@ -108,9 +174,8 @@ class Patcher:
self.wrap_fn = wrap_fn self.wrap_fn = wrap_fn
for module in self._builtin_modules: for module in self._builtin_modules:
self.patch_module(module) self.patch_module(module)

for cls in self._builtin_methods:
self.patch_cls(cls)
for meth in BUILTIN_ARRAY_METHOD:
self.patch_method(ArrayMethodMixin, meth, self.wrap_fn)


for i, j in self._builtin_functions: for i, j in self._builtin_functions:
if id(i) not in self.visited_frames_ids: if id(i) not in self.visited_frames_ids:


+ 2
- 0
imperative/python/megengine/experimental/traced_module/node.py View File

@@ -13,6 +13,7 @@ import numpy
from ...core._imperative_rt.core2 import Tensor as RawTensor from ...core._imperative_rt.core2 import Tensor as RawTensor
from ...module import Module from ...module import Module
from ...tensor import Tensor from ...tensor import Tensor
from .pytree import TreeDef




class Node: class Node:
@@ -58,6 +59,7 @@ class ModuleNode(Node):
module_type = Module # type: Type[Module] module_type = Module # type: Type[Module]
graph = None graph = None
attr_type_map = None # type: Dict[str, Type[Any]] attr_type_map = None # type: Dict[str, Type[Any]]
arg_def = None # type: TreeDef


def __repr__(self): def __repr__(self):
if self._name is None: if self._name is None:


+ 80
- 0
imperative/python/megengine/experimental/traced_module/pytree.py View File

@@ -0,0 +1,80 @@
from typing import Callable, NamedTuple

SUPPORTED_TYPE = {}

NodeType = NamedTuple("NodeType", [("flatten", Callable), ("unflatten", Callable)])


def register_supported_type(type, flatten, unflatten):
SUPPORTED_TYPE[type] = NodeType(flatten, unflatten)


register_supported_type(list, lambda x: (x, None), lambda x, aux_data: list(x))
register_supported_type(tuple, lambda x: (x, None), lambda x, aux_data: list(x))
register_supported_type(
dict, lambda x: (list(x.values()), list(x.keys())), lambda x, y: dict(zip(y, x))
)
register_supported_type(
slice,
lambda x: ([x.start, x.stop, x.step], None),
lambda x, aux_data: slice(x[0], x[1], x[2]),
)


def tree_flatten(
values, leaf_type: Callable = lambda x: type(x), is_leaf: Callable = lambda x: True
):
if type(values) not in SUPPORTED_TYPE:
assert is_leaf(values)
return [values,], LeafDef(leaf_type(values))
rst = []
children_defs = []
children_values, aux_data = SUPPORTED_TYPE[type(values)].flatten(values)
for v in children_values:
v_list, treedef = tree_flatten(v, leaf_type)
rst.extend(v_list)
children_defs.append(treedef)

return rst, TreeDef(type(values), aux_data, children_defs)


class TreeDef:
def __init__(self, type, aux_data, children_defs):
self.type = type
self.aux_data = aux_data
self.children_defs = children_defs
self.num_leaves = sum(ch.num_leaves for ch in children_defs)

def unflatten(self, leaves):
assert len(leaves) == self.num_leaves
start = 0
children = []
for ch in self.children_defs:
children.append(ch.unflatten(leaves[start : start + ch.num_leaves]))
start += ch.num_leaves
return SUPPORTED_TYPE[self.type].unflatten(children, self.aux_data)

def __eq__(self, other):
return (
self.type == other.type
and self.aux_data == other.aux_data
and self.num_leaves == other.num_leaves
and self.children_defs == other.children_defs
)

def __repr__(self):
return "{}[{}]".format(self.type.__name__, self.children_defs)


class LeafDef(TreeDef):
def __init__(self, type):
super().__init__(type, None, [])
self.num_leaves = 1

def unflatten(self, leaves):
assert len(leaves) == 1
assert isinstance(leaves[0], self.type), self.type
return leaves[0]

def __repr__(self):
return "Leaf({})".format(self.type.__name__)

+ 65
- 53
imperative/python/megengine/experimental/traced_module/traced_module.py View File

@@ -9,9 +9,11 @@
import collections import collections
import copy import copy
import functools import functools
from inspect import getmembers, isclass, ismethod
from typing import List, Type from typing import List, Type


from ... import module as M from ... import module as M
from ...core._imperative_rt.core2 import Tensor as RawTensor
from ...core._imperative_rt.core2 import ( from ...core._imperative_rt.core2 import (
is_tracing_module, is_tracing_module,
set_module_tracing, set_module_tracing,
@@ -28,6 +30,16 @@ from .module_tracer import (
set_active_module_tracer, set_active_module_tracer,
) )
from .node import ModuleNode, Node, NodeMixin, TensorNode from .node import ModuleNode, Node, NodeMixin, TensorNode
from .pytree import tree_flatten


def _leaf_type(node):
if isinstance(node, RawTensor):
return (Tensor, TensorNode)
elif isinstance(node, (NodeMixin, Module)):
return (Module, ModuleNode, NodeMixin)
else:
return type(node)




class InternalGraph: class InternalGraph:
@@ -65,9 +77,7 @@ class InternalGraph:
for n, v in zip(self._inputs, inputs): for n, v in zip(self._inputs, inputs):
node2value[n] = v node2value[n] = v
for expr in self._exprs: for expr in self._exprs:
values = expr.interpret(
*list(Expr.get_arg_value(i, node2value) for i in expr.inputs)
)
values = expr.interpret(*list(node2value[i] for i in expr.inputs))
for n, v in zip(expr.outputs, values): for n, v in zip(expr.outputs, values):
node2value[n] = v node2value[n] = v
return list(node2value[i] for i in self._outputs) return list(node2value[i] for i in self._outputs)
@@ -80,37 +90,39 @@ class InternalGraph:
) )




def _get_meth_name(obj, func):
for cls in type(obj).mro():
for k, v in cls.__dict__.items():
if v == func:
return k
return None


def _wrapped_function(orig_func): def _wrapped_function(orig_func):
@functools.wraps(orig_func) @functools.wraps(orig_func)
def wrapped_fn(*inputs, **kwargs):
def wrapped_fn(*args, **kwargs):
if is_tracing_module(): if is_tracing_module():
unset_module_tracing() unset_module_tracing()
const_kwargs = {}
arg_names = orig_func.__code__.co_varnames
if orig_func.__qualname__.split(".").__len__() > 1:
# FIXME: a robust way to distinguish method and function. <XP>
inputs, tree_def = tree_flatten((args, kwargs), leaf_type=_leaf_type)
for i in inputs:
if not NodeMixin.get(i, None):
if isinstance(i, (RawTensor, NodeMixin)):
NodeMixin.wrap_safe(i, Constant.make(i))
meth_name = _get_meth_name(args[0], wrapped_fn)
if meth_name:
self = inputs[0] self = inputs[0]
call_node = CallMethod.make(NodeMixin.get(self), orig_func.__name__)
call_node = CallMethod.make(NodeMixin.get(self), meth_name)
else: else:
call_node = CallFunction.make(orig_func) call_node = CallFunction.make(orig_func)


def add_input(inp, varname=None):
node = Expr.get_args_node(inp)
if node is not None:
call_node.add_input(node, varname)
else:
const_kwargs[varname] = inp

for ind, inp in enumerate(inputs):
add_input(inp, arg_names[ind])
for k, v in kwargs.items():
add_input(v, k)
call_node.kwargs = const_kwargs
outputs = orig_func(*inputs, **kwargs)
call_node.add_inputs(inputs)

call_node.arg_def = tree_def
outputs = orig_func(*args, **kwargs)
call_node.add_outputs(outputs) call_node.add_outputs(outputs)
set_module_tracing() set_module_tracing()
return outputs return outputs
return orig_func(*inputs, **kwargs)
return orig_func(*args, **kwargs)


return wrapped_fn return wrapped_fn


@@ -120,14 +132,14 @@ class TracedModuleBuilder(NodeMixin):
_mod = None # type: Module _mod = None # type: Module
_body = None # type: InternalGraph _body = None # type: InternalGraph
_is_builtin = None # type: bool _is_builtin = None # type: bool
_arg_def = None # type: TreeDef
__builder_attributes__ = [ __builder_attributes__ = [
"_mod", "_mod",
"_body", "_body",
"_NodeMixin__node", "_NodeMixin__node",
"_is_builtin", "_is_builtin",
"_is_traced", "_is_traced",
"build",
"_arg_def" "build",
] ]


def __init__(self, mod): def __init__(self, mod):
@@ -146,6 +158,7 @@ class TracedModuleBuilder(NodeMixin):
node = NodeMixin.get(self) node = NodeMixin.get(self)
node.graph = self._body node.graph = self._body
node.attr_type_map = {} node.attr_type_map = {}
node.arg_def = self._arg_def
traced_module = TracedModule(node) traced_module = TracedModule(node)
for k, v in self.__dict__.items(): for k, v in self.__dict__.items():
if k not in TracedModuleBuilder.__builder_attributes__: if k not in TracedModuleBuilder.__builder_attributes__:
@@ -155,32 +168,34 @@ class TracedModuleBuilder(NodeMixin):
traced_module.m_node.attr_type_map[k] = type(v) traced_module.m_node.attr_type_map[k] = type(v)
return traced_module return traced_module


def __call__(self, *inputs, **kwargs):
def __call__(self, *args, **kwargs):
assert isinstance(self._mod, Module) assert isinstance(self._mod, Module)
for arg in args:
assert isinstance(arg, RawTensor)


for k, v in kwargs.items():
assert isinstance(v, RawTensor)
# prepare args and kwargs for inner graph # prepare args and kwargs for inner graph
def mark_constant(x): def mark_constant(x):
node = NodeMixin.get(x, None) node = NodeMixin.get(x, None)
if node is None: # capture as constant if node is None: # capture as constant
NodeMixin.wrap(x, lambda: Constant.make(x)) NodeMixin.wrap(x, lambda: Constant.make(x))


inputs, tree_def = tree_flatten(((self, *args), kwargs), leaf_type=_leaf_type)
if self._arg_def is None:
self._arg_def = tree_def
assert self._arg_def == tree_def
for i in inputs: for i in inputs:
mark_constant(i) mark_constant(i)
for k, v in kwargs.items():
mark_constant(v)
callnode = CallMethod.make(NodeMixin.get(self)) callnode = CallMethod.make(NodeMixin.get(self))


def add_input(x):
callnode.add_input(NodeMixin.get(x))
callnode.add_inputs(inputs)


for i in inputs:
add_input(i)
for k, v in kwargs.items():
add_input(v)
callnode.arg_def = tree_def


if self._is_builtin or self._is_traced: if self._is_builtin or self._is_traced:
unset_module_tracing() unset_module_tracing()
outputs = self._mod(*inputs, **kwargs)
outputs = self._mod(*args, **kwargs)
set_module_tracing() set_module_tracing()
if self._is_builtin: if self._is_builtin:
self._body = None self._body = None
@@ -193,23 +208,21 @@ class TracedModuleBuilder(NodeMixin):
) )
# prepare args and kwargs for inner graph # prepare args and kwargs for inner graph
def wrap(x): def wrap(x):
# wrapped = copy.copy(x) # FIXME
wrapped = x # FIXME: <XP>
wrapped = copy.copy(x) # FIXME
NodeMixin.wrap( NodeMixin.wrap(
wrapped, wrapped,
lambda: Input.make(type=NodeMixin.get_wrapped_type(wrapped)), lambda: Input.make(type=NodeMixin.get_wrapped_type(wrapped)),
) )
return wrapped return wrapped


args = []
for i in inputs:
args = [self]
for i in inputs[1:]:
args.append(wrap(i)) args.append(wrap(i))
for k, v in kwargs.items():
kwargs[k] = wrap(v)
args, kwargs = tree_def.unflatten(args)
active_module_tracer().patcher.auto_patch( active_module_tracer().patcher.auto_patch(
getattr(getattr(self._mod, "forward", self._mod), "__globals__", {}) getattr(getattr(self._mod, "forward", self._mod), "__globals__", {})
) )
outputs = type(self._mod).forward(self, *args, **kwargs)
outputs = type(self._mod).forward(*args, **kwargs)


for i in ( for i in (
outputs if isinstance(outputs, collections.abc.Sequence) else (outputs,) outputs if isinstance(outputs, collections.abc.Sequence) else (outputs,)
@@ -269,8 +282,10 @@ class TracedModule(Module):
super(TracedModule, self).__init__() super(TracedModule, self).__init__()
self.m_node = node self.m_node = node


def forward(self, *inputs):
rst = self.m_node.graph.interpret(self, *inputs)
def forward(self, *args, **kwargs):
inputs, treedef = tree_flatten(((self, *args), kwargs), leaf_type=_leaf_type)
assert treedef == self.m_node.arg_def
rst = self.m_node.graph.interpret(*inputs)
if len(rst) == 1: if len(rst) == 1:
rst = rst[0] rst = rst[0]
return rst return rst
@@ -345,7 +360,6 @@ def register_as_builtin(mod_cls: Type[Module]) -> None:




def _register_all_builtin_module(): def _register_all_builtin_module():
from inspect import getmembers, isclass


for sub_mod in [M, M.qat, M.quantized]: for sub_mod in [M, M.qat, M.quantized]:
for m in getmembers(sub_mod): for m in getmembers(sub_mod):
@@ -357,7 +371,7 @@ def _register_all_builtin_module():
module_tracer.register_as_builtin(m[1]) module_tracer.register_as_builtin(m[1])




def trace_module(mod: Module, *inputs: Tensor, **kwargs: Tensor) -> TracedModule:
def trace_module(mod: Module, *args: Tensor, **kwargs: Tensor) -> TracedModule:
""" """
Traces module ``mod`` and returns corresponding TracedModule. Traces module ``mod`` and returns corresponding TracedModule.


@@ -375,15 +389,13 @@ def trace_module(mod: Module, *inputs: Tensor, **kwargs: Tensor) -> TracedModule


builder = TracedModuleBuilder(mod) builder = TracedModuleBuilder(mod)
NodeMixin.wrap_safe(builder, Input.make("TopModule", ModuleNode)) NodeMixin.wrap_safe(builder, Input.make("TopModule", ModuleNode))
inputs, _ = tree_flatten((args, kwargs))
for _, i in enumerate(inputs): for _, i in enumerate(inputs):
NodeMixin.wrap_safe(i, Input.make("arg_{}".format(_)))
for k, v in kwargs.items():
NodeMixin.wrap_safe(v, Input.make("kwarg_{}".format(k)))

builder(*inputs, **kwargs)
NodeMixin.wrap_safe(
i, Input.make("arg_{}".format(_), NodeMixin.get_wrapped_type(i))
)
builder(*args, **kwargs)
active_module_tracer().pop_scope() active_module_tracer().pop_scope()

return builder.build() return builder.build()
finally: finally:
set_active_module_tracer(None) set_active_module_tracer(None)


Loading…
Cancel
Save