Browse Source

feat(traced_module): add pytree

GitOrigin-RevId: 6c6e53521c
release-1.6
Megvii Engine Team 4 years ago
parent
commit
f2691566fd
5 changed files with 266 additions and 120 deletions
  1. +51
    -64
      imperative/python/megengine/experimental/traced_module/expr.py
  2. +68
    -3
      imperative/python/megengine/experimental/traced_module/module_tracer.py
  3. +2
    -0
      imperative/python/megengine/experimental/traced_module/node.py
  4. +80
    -0
      imperative/python/megengine/experimental/traced_module/pytree.py
  5. +65
    -53
      imperative/python/megengine/experimental/traced_module/traced_module.py

+ 51
- 64
imperative/python/megengine/experimental/traced_module/expr.py View File

@@ -7,7 +7,7 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

import builtins
import collections
from typing import Callable, List

@@ -19,6 +19,7 @@ from ...module import Module
from ...tensor import Tensor
from .module_tracer import active_module_tracer
from .node import ModuleNode, Node, NodeMixin, TensorNode
from .pytree import TreeDef


class Expr:
@@ -28,9 +29,22 @@ class Expr:

inputs = None # type: List[Node]
outputs = None # type: List[Node]

def add_input(self, node):
self.inputs.append(node)
const_val = None # type: List[Any]
arg_def = None # type: TreeDef

def add_inputs(self, vals):
if not isinstance(vals, collections.abc.Sequence):
vals = (vals,)
for val in vals:
node = NodeMixin.get(val, None)
if isinstance(node, (TensorNode, ModuleNode)):
if node not in self.inputs:
self.inputs.append(node)
else:
assert node is None
assert type(val) in builtins.__dict__.values()
idx = len(self.inputs) + len(self.const_val)
self.const_val.append((idx, val))

def add_outputs(self, outputs):
self.outputs = []
@@ -38,50 +52,31 @@ class Expr:
outputs = (outputs,)

for i in outputs:
assert isinstance(i, RawTensor)
self.outputs.append(NodeMixin.get_wrapped_type(i)(self))

for i, node in zip(outputs, self.outputs,):
NodeMixin.wrap_safe(i, node)

@classmethod
def get_args_node(cls, arg):
"""
Create nodes by ``arg``, which may be a container.
Return the same structure with arg.

If ``arg`` was not Tensor or Module, it will be stored as const.

:param arg: tensor, module or const.
"""
if isinstance(arg, (RawTensor, Module)):
if not NodeMixin.get(arg, None):
NodeMixin.wrap_safe(arg, Constant.make(arg))
return NodeMixin.get(arg)
elif isinstance(arg, collections.abc.Sequence):
seq_cls = type(arg)
return seq_cls([Expr.get_args_node(a) for a in arg])
def unflatten_args(self, inputs):
if self.arg_def is not None:
inputs = list(inputs)
for idx, val in self.const_val:
inputs.insert(idx, val)
args, kwargs = self.arg_def.unflatten(inputs)
return args, kwargs
else:
# TODO: assert arg type
return arg # as const
return inputs, {}

@classmethod
def get_arg_value(cls, inp_node, node2value):
"""
Get values from node2value by inp_node, which may be a container.
Return the same structure with inp_node.

If ``inp_node`` was not in node2value, it is a const.

:param inp_node: nodes.
:param node2value: dict from node to tensor and module.
"""
if inp_node in node2value:
return node2value[inp_node]
elif isinstance(inp_node, collections.abc.Sequence):
seq_cls = type(inp_node)
return seq_cls([Expr.get_arg_value(i, node2value) for i in inp_node])
else:
return inp_node
@property
def kwargs(self):
_, kwargs = self.unflatten_args(self.inputs)
return kwargs

@property
def args(self):
args, _ = self.unflatten_args(self.inputs)
return args


# expr: None (i.e. fake expression which is used to mark input)
@@ -144,16 +139,8 @@ class CallMethod(Expr):
self.inputs = [
module,
]
self.const_val = []
self.method = method
self.arg_names = []
self.kwargs = {} # const kwargs

def add_input(self, node, arg_name=None):
if arg_name == "self": # FIXME: <XP>
return
self.inputs.append(node)
if arg_name is not None:
self.arg_names.append(arg_name)

@classmethod
def make(cls, *args, **kwargs):
@@ -162,19 +149,22 @@ class CallMethod(Expr):
return expr

def interpret(self, *inputs):
mod = inputs[0]
args = inputs[1:]
outputs = getattr(mod, self.method)(*args, **self.kwargs)
args, kwargs = self.unflatten_args(inputs)
obj = args[0]
args = args[1:]
outputs = getattr(obj, self.method)(*args, **kwargs)
if isinstance(outputs, RawTensor):
outputs = (outputs,)
return outputs

def __repr__(self):
return "{} = CallMethod({}, {})({})".format(
args = ", ".join(str(i) for i in self.args[1:])
kwargs = ", ".join("{}={}".format(k, v) for k, v in self.kwargs.items())
return "{} = {}.{}({})".format(
", ".join(str(i) for i in self.outputs),
self.inputs[0],
self.method,
", ".join(str(i) for i in self.inputs[1:]),
", ".join([args, kwargs]),
)


@@ -227,13 +217,8 @@ class CallFunction(Expr):
def __init__(self, func):
assert isinstance(func, Callable)
self.func = func
self.const_val = []
self.inputs = []
self.arg_names = []
self.kwargs = {} # const kwargs

def add_input(self, node, arg_name):
self.inputs.append(node)
self.arg_names.append(arg_name)

@classmethod
def make(cls, *args, **kwargs):
@@ -242,18 +227,20 @@ class CallFunction(Expr):
return expr

def interpret(self, *inputs):
inp_dict = dict([(name, node) for node, name in zip(inputs, self.arg_names)])
outputs = self.func(**inp_dict, **self.kwargs)
args, kwargs = self.unflatten_args(inputs)
outputs = self.func(*args, **kwargs)
outputs = (
outputs if isinstance(outputs, collections.abc.Sequence) else (outputs,)
)
return outputs

def __repr__(self):
args = ", ".join(str(i) for i in self.args)
kwargs = ", ".join("{}={}".format(k, v) for k, v in self.kwargs.items())
return "{} = {}({})".format(
", ".join(str(i) for i in self.outputs),
self.func.__module__ + "." + self.func.__name__,
", ".join(str(i) for i in self.inputs),
", ".join([args, kwargs]),
)




+ 68
- 3
imperative/python/megengine/experimental/traced_module/module_tracer.py View File

@@ -15,6 +15,72 @@ from ...module import Module

_active_module_tracer = None

BUILTIN_ARRAY_METHOD = [
"__lt__",
"__le__",
"__gt__",
"__ge__",
"__eq__",
"__ne__",
"__neg__",
"__pos__",
"__abs__",
"__invert__",
"__round__",
"__floor__",
"__ceil__",
"__add__",
"__sub__",
"__mul__",
"__matmul__",
"__truediv__",
"__floordiv__",
"__mod__",
"__pow__",
"__lshift__",
"__rshift__",
"__and__",
"__or__",
"__xor__",
"__radd__",
"__rsub__",
"__rmul__",
"__rmatmul__",
"__rtruediv__",
"__rfloordiv__",
"__rmod__",
"__rpow__",
"__rlshift__",
"__rrshift__",
"__rand__",
"__ror__",
"__rxor__",
"__iadd__",
"__isub__",
"__imul__",
"__imatmul__",
"__itruediv__",
"__ifloordiv__",
"__imod__",
"__ipow__",
"__ilshift__",
"__irshift__",
"__iand__",
"__ior__",
"__ixor__",
"T",
"astype",
"reshape",
"_broadcast",
"transpose",
"flatten",
"sum",
"prod",
"min",
"max",
"mean",
]


def active_module_tracer():
return _active_module_tracer
@@ -108,9 +174,8 @@ class Patcher:
self.wrap_fn = wrap_fn
for module in self._builtin_modules:
self.patch_module(module)

for cls in self._builtin_methods:
self.patch_cls(cls)
for meth in BUILTIN_ARRAY_METHOD:
self.patch_method(ArrayMethodMixin, meth, self.wrap_fn)

for i, j in self._builtin_functions:
if id(i) not in self.visited_frames_ids:


+ 2
- 0
imperative/python/megengine/experimental/traced_module/node.py View File

@@ -13,6 +13,7 @@ import numpy
from ...core._imperative_rt.core2 import Tensor as RawTensor
from ...module import Module
from ...tensor import Tensor
from .pytree import TreeDef


class Node:
@@ -58,6 +59,7 @@ class ModuleNode(Node):
module_type = Module # type: Type[Module]
graph = None
attr_type_map = None # type: Dict[str, Type[Any]]
arg_def = None # type: TreeDef

def __repr__(self):
if self._name is None:


+ 80
- 0
imperative/python/megengine/experimental/traced_module/pytree.py View File

@@ -0,0 +1,80 @@
from typing import Callable, NamedTuple

SUPPORTED_TYPE = {}

NodeType = NamedTuple("NodeType", [("flatten", Callable), ("unflatten", Callable)])


def register_supported_type(type, flatten, unflatten):
SUPPORTED_TYPE[type] = NodeType(flatten, unflatten)


register_supported_type(list, lambda x: (x, None), lambda x, aux_data: list(x))
register_supported_type(tuple, lambda x: (x, None), lambda x, aux_data: list(x))
register_supported_type(
dict, lambda x: (list(x.values()), list(x.keys())), lambda x, y: dict(zip(y, x))
)
register_supported_type(
slice,
lambda x: ([x.start, x.stop, x.step], None),
lambda x, aux_data: slice(x[0], x[1], x[2]),
)


def tree_flatten(
values, leaf_type: Callable = lambda x: type(x), is_leaf: Callable = lambda x: True
):
if type(values) not in SUPPORTED_TYPE:
assert is_leaf(values)
return [values,], LeafDef(leaf_type(values))
rst = []
children_defs = []
children_values, aux_data = SUPPORTED_TYPE[type(values)].flatten(values)
for v in children_values:
v_list, treedef = tree_flatten(v, leaf_type)
rst.extend(v_list)
children_defs.append(treedef)

return rst, TreeDef(type(values), aux_data, children_defs)


class TreeDef:
def __init__(self, type, aux_data, children_defs):
self.type = type
self.aux_data = aux_data
self.children_defs = children_defs
self.num_leaves = sum(ch.num_leaves for ch in children_defs)

def unflatten(self, leaves):
assert len(leaves) == self.num_leaves
start = 0
children = []
for ch in self.children_defs:
children.append(ch.unflatten(leaves[start : start + ch.num_leaves]))
start += ch.num_leaves
return SUPPORTED_TYPE[self.type].unflatten(children, self.aux_data)

def __eq__(self, other):
return (
self.type == other.type
and self.aux_data == other.aux_data
and self.num_leaves == other.num_leaves
and self.children_defs == other.children_defs
)

def __repr__(self):
return "{}[{}]".format(self.type.__name__, self.children_defs)


class LeafDef(TreeDef):
def __init__(self, type):
super().__init__(type, None, [])
self.num_leaves = 1

def unflatten(self, leaves):
assert len(leaves) == 1
assert isinstance(leaves[0], self.type), self.type
return leaves[0]

def __repr__(self):
return "Leaf({})".format(self.type.__name__)

+ 65
- 53
imperative/python/megengine/experimental/traced_module/traced_module.py View File

@@ -9,9 +9,11 @@
import collections
import copy
import functools
from inspect import getmembers, isclass, ismethod
from typing import List, Type

from ... import module as M
from ...core._imperative_rt.core2 import Tensor as RawTensor
from ...core._imperative_rt.core2 import (
is_tracing_module,
set_module_tracing,
@@ -28,6 +30,16 @@ from .module_tracer import (
set_active_module_tracer,
)
from .node import ModuleNode, Node, NodeMixin, TensorNode
from .pytree import tree_flatten


def _leaf_type(node):
if isinstance(node, RawTensor):
return (Tensor, TensorNode)
elif isinstance(node, (NodeMixin, Module)):
return (Module, ModuleNode, NodeMixin)
else:
return type(node)


class InternalGraph:
@@ -65,9 +77,7 @@ class InternalGraph:
for n, v in zip(self._inputs, inputs):
node2value[n] = v
for expr in self._exprs:
values = expr.interpret(
*list(Expr.get_arg_value(i, node2value) for i in expr.inputs)
)
values = expr.interpret(*list(node2value[i] for i in expr.inputs))
for n, v in zip(expr.outputs, values):
node2value[n] = v
return list(node2value[i] for i in self._outputs)
@@ -80,37 +90,39 @@ class InternalGraph:
)


def _get_meth_name(obj, func):
for cls in type(obj).mro():
for k, v in cls.__dict__.items():
if v == func:
return k
return None


def _wrapped_function(orig_func):
@functools.wraps(orig_func)
def wrapped_fn(*inputs, **kwargs):
def wrapped_fn(*args, **kwargs):
if is_tracing_module():
unset_module_tracing()
const_kwargs = {}
arg_names = orig_func.__code__.co_varnames
if orig_func.__qualname__.split(".").__len__() > 1:
# FIXME: a robust way to distinguish method and function. <XP>
inputs, tree_def = tree_flatten((args, kwargs), leaf_type=_leaf_type)
for i in inputs:
if not NodeMixin.get(i, None):
if isinstance(i, (RawTensor, NodeMixin)):
NodeMixin.wrap_safe(i, Constant.make(i))
meth_name = _get_meth_name(args[0], wrapped_fn)
if meth_name:
self = inputs[0]
call_node = CallMethod.make(NodeMixin.get(self), orig_func.__name__)
call_node = CallMethod.make(NodeMixin.get(self), meth_name)
else:
call_node = CallFunction.make(orig_func)

def add_input(inp, varname=None):
node = Expr.get_args_node(inp)
if node is not None:
call_node.add_input(node, varname)
else:
const_kwargs[varname] = inp

for ind, inp in enumerate(inputs):
add_input(inp, arg_names[ind])
for k, v in kwargs.items():
add_input(v, k)
call_node.kwargs = const_kwargs
outputs = orig_func(*inputs, **kwargs)
call_node.add_inputs(inputs)

call_node.arg_def = tree_def
outputs = orig_func(*args, **kwargs)
call_node.add_outputs(outputs)
set_module_tracing()
return outputs
return orig_func(*inputs, **kwargs)
return orig_func(*args, **kwargs)

return wrapped_fn

@@ -120,14 +132,14 @@ class TracedModuleBuilder(NodeMixin):
_mod = None # type: Module
_body = None # type: InternalGraph
_is_builtin = None # type: bool
_arg_def = None # type: TreeDef
__builder_attributes__ = [
"_mod",
"_body",
"_NodeMixin__node",
"_is_builtin",
"_is_traced",
"build",
"_arg_def" "build",
]

def __init__(self, mod):
@@ -146,6 +158,7 @@ class TracedModuleBuilder(NodeMixin):
node = NodeMixin.get(self)
node.graph = self._body
node.attr_type_map = {}
node.arg_def = self._arg_def
traced_module = TracedModule(node)
for k, v in self.__dict__.items():
if k not in TracedModuleBuilder.__builder_attributes__:
@@ -155,32 +168,34 @@ class TracedModuleBuilder(NodeMixin):
traced_module.m_node.attr_type_map[k] = type(v)
return traced_module

def __call__(self, *inputs, **kwargs):
def __call__(self, *args, **kwargs):
assert isinstance(self._mod, Module)
for arg in args:
assert isinstance(arg, RawTensor)

for k, v in kwargs.items():
assert isinstance(v, RawTensor)
# prepare args and kwargs for inner graph
def mark_constant(x):
node = NodeMixin.get(x, None)
if node is None: # capture as constant
NodeMixin.wrap(x, lambda: Constant.make(x))

inputs, tree_def = tree_flatten(((self, *args), kwargs), leaf_type=_leaf_type)
if self._arg_def is None:
self._arg_def = tree_def
assert self._arg_def == tree_def
for i in inputs:
mark_constant(i)
for k, v in kwargs.items():
mark_constant(v)
callnode = CallMethod.make(NodeMixin.get(self))

def add_input(x):
callnode.add_input(NodeMixin.get(x))
callnode.add_inputs(inputs)

for i in inputs:
add_input(i)
for k, v in kwargs.items():
add_input(v)
callnode.arg_def = tree_def

if self._is_builtin or self._is_traced:
unset_module_tracing()
outputs = self._mod(*inputs, **kwargs)
outputs = self._mod(*args, **kwargs)
set_module_tracing()
if self._is_builtin:
self._body = None
@@ -193,23 +208,21 @@ class TracedModuleBuilder(NodeMixin):
)
# prepare args and kwargs for inner graph
def wrap(x):
# wrapped = copy.copy(x) # FIXME
wrapped = x # FIXME: <XP>
wrapped = copy.copy(x) # FIXME
NodeMixin.wrap(
wrapped,
lambda: Input.make(type=NodeMixin.get_wrapped_type(wrapped)),
)
return wrapped

args = []
for i in inputs:
args = [self]
for i in inputs[1:]:
args.append(wrap(i))
for k, v in kwargs.items():
kwargs[k] = wrap(v)
args, kwargs = tree_def.unflatten(args)
active_module_tracer().patcher.auto_patch(
getattr(getattr(self._mod, "forward", self._mod), "__globals__", {})
)
outputs = type(self._mod).forward(self, *args, **kwargs)
outputs = type(self._mod).forward(*args, **kwargs)

for i in (
outputs if isinstance(outputs, collections.abc.Sequence) else (outputs,)
@@ -269,8 +282,10 @@ class TracedModule(Module):
super(TracedModule, self).__init__()
self.m_node = node

def forward(self, *inputs):
rst = self.m_node.graph.interpret(self, *inputs)
def forward(self, *args, **kwargs):
inputs, treedef = tree_flatten(((self, *args), kwargs), leaf_type=_leaf_type)
assert treedef == self.m_node.arg_def
rst = self.m_node.graph.interpret(*inputs)
if len(rst) == 1:
rst = rst[0]
return rst
@@ -345,7 +360,6 @@ def register_as_builtin(mod_cls: Type[Module]) -> None:


def _register_all_builtin_module():
from inspect import getmembers, isclass

for sub_mod in [M, M.qat, M.quantized]:
for m in getmembers(sub_mod):
@@ -357,7 +371,7 @@ def _register_all_builtin_module():
module_tracer.register_as_builtin(m[1])


def trace_module(mod: Module, *inputs: Tensor, **kwargs: Tensor) -> TracedModule:
def trace_module(mod: Module, *args: Tensor, **kwargs: Tensor) -> TracedModule:
"""
Traces module ``mod`` and returns corresponding TracedModule.

@@ -375,15 +389,13 @@ def trace_module(mod: Module, *inputs: Tensor, **kwargs: Tensor) -> TracedModule

builder = TracedModuleBuilder(mod)
NodeMixin.wrap_safe(builder, Input.make("TopModule", ModuleNode))
inputs, _ = tree_flatten((args, kwargs))
for _, i in enumerate(inputs):
NodeMixin.wrap_safe(i, Input.make("arg_{}".format(_)))
for k, v in kwargs.items():
NodeMixin.wrap_safe(v, Input.make("kwarg_{}".format(k)))

builder(*inputs, **kwargs)
NodeMixin.wrap_safe(
i, Input.make("arg_{}".format(_), NodeMixin.get_wrapped_type(i))
)
builder(*args, **kwargs)
active_module_tracer().pop_scope()

return builder.build()
finally:
set_active_module_tracer(None)


Loading…
Cancel
Save