Browse Source

feat(traced_module): add functional trace and CallMethod/Function expr

GitOrigin-RevId: ad2cdc1b61
release-1.6
Megvii Engine Team 4 years ago
parent
commit
bee305beb2
4 changed files with 338 additions and 53 deletions
  1. +105
    -27
      imperative/python/megengine/experimental/traced_module/expr.py
  2. +109
    -1
      imperative/python/megengine/experimental/traced_module/module_tracer.py
  3. +4
    -0
      imperative/python/megengine/experimental/traced_module/node.py
  4. +120
    -25
      imperative/python/megengine/experimental/traced_module/traced_module.py

+ 105
- 27
imperative/python/megengine/experimental/traced_module/expr.py View File

@@ -9,12 +9,13 @@


import collections
from typing import List
from typing import Callable, List

from ...core._imperative_rt import OpDef
from ...core._imperative_rt.core2 import Tensor as RawTensor
from ...core._imperative_rt.core2 import apply, set_module_tracing, unset_module_tracing
from ...core.ops.special import Const
from ...module import Module
from ...tensor import Tensor
from .module_tracer import active_module_tracer
from .node import ModuleNode, Node, NodeMixin, TensorNode
@@ -22,12 +23,66 @@ from .node import ModuleNode, Node, NodeMixin, TensorNode

class Expr:
"""
``Expr`` represents the operations(i.e. Call, Apply, GetAttr, Input, Constant) on ``Node``.
``Expr`` represents the operations(i.e. CallMethod, CallFunction, Apply, GetAttr, Input, Constant) on ``Node``.
"""

inputs = None # type: List[Node]
outputs = None # type: List[Node]

def add_input(self, node):
self.inputs.append(node)

def add_outputs(self, outputs):
self.outputs = []
if not isinstance(outputs, collections.Sequence):
outputs = (outputs,)

for i in outputs:
self.outputs.append(NodeMixin.get_wrapped_type(i)(self))

for i, node in zip(outputs, self.outputs,):
NodeMixin.wrap_safe(i, node)

@classmethod
def get_args_node(cls, arg):
"""
Create nodes by ``arg``, which may be a container.
Return the same structure with arg.

If ``arg`` was not Tensor or Module, it will be stored as const.

:param arg: tensor, module or const.
"""
if isinstance(arg, (RawTensor, Module)):
if not NodeMixin.get(arg, None):
NodeMixin.wrap_safe(arg, Constant.make(arg))
return NodeMixin.get(arg)
elif isinstance(arg, collections.abc.Sequence):
seq_cls = type(arg)
return seq_cls([Expr.get_args_node(a) for a in arg])
else:
# TODO: assert arg type
return arg # as const

@classmethod
def get_arg_value(cls, inp_node, node2value):
"""
Get values from node2value by inp_node, which may be a container.
Return the same structure with inp_node.

If ``inp_node`` was not in node2value, it is a const.

:param inp_node: nodes.
:param node2value: dict from node to tensor and module.
"""
if inp_node in node2value:
return node2value[inp_node]
elif isinstance(inp_node, collections.abc.Sequence):
seq_cls = type(inp_node)
return seq_cls([Expr.get_arg_value(i, node2value) for i in inp_node])
else:
return inp_node


# expr: None (i.e. fake expression which is used to mark input)
class Input(Expr):
@@ -83,23 +138,22 @@ class GetAttr(Expr):


# expr: outputs = inputs[0].__call__(*inputs[1:])
class Call(Expr):
def __init__(self, module):
assert isinstance(module, ModuleNode)
class CallMethod(Expr):
def __init__(self, module, method="__call__"):
assert isinstance(module, (TensorNode, ModuleNode))
self.inputs = [
module,
]
self.method = method
self.arg_names = []
self.kwargs = {} # const kwargs

def add_input(self, node):
def add_input(self, node, arg_name=None):
if arg_name == "self": # FIXME: <XP>
return
self.inputs.append(node)

def add_outputs(self, references):
self.outputs = []
if not isinstance(references, collections.Sequence):
references = (references,)

for i in references:
self.outputs.append(NodeMixin.get_wrapped_type(i)(self))
if arg_name is not None:
self.arg_names.append(arg_name)

@classmethod
def make(cls, *args, **kwargs):
@@ -110,15 +164,16 @@ class Call(Expr):
def interpret(self, *inputs):
mod = inputs[0]
args = inputs[1:]
outputs = mod(*args)
outputs = getattr(mod, self.method)(*args, **self.kwargs)
if isinstance(outputs, RawTensor):
outputs = (outputs,)
return outputs

def __repr__(self):
return "{} = Call({})({})".format(
return "{} = CallMethod({}, {})({})".format(
", ".join(str(i) for i in self.outputs),
self.inputs[0],
self.method,
", ".join(str(i) for i in self.inputs[1:]),
)

@@ -132,17 +187,6 @@ class Apply(Expr):
self.opdef = opdef
self.inputs = []

def add_input(self, node):
self.inputs.append(node)

def add_outputs(self, references):
self.outputs = []
if not isinstance(references, collections.Sequence):
references = (references,)

for i in references:
self.outputs.append(NodeMixin.get_wrapped_type(i)(self))

@classmethod
def make(cls, *args, **kwargs):
expr = cls(*args, **kwargs)
@@ -179,6 +223,40 @@ class Apply(Expr):
return list(outputs)


class CallFunction(Expr):
def __init__(self, func):
assert isinstance(func, Callable)
self.func = func
self.inputs = []
self.arg_names = []
self.kwargs = {} # const kwargs

def add_input(self, node, arg_name):
self.inputs.append(node)
self.arg_names.append(arg_name)

@classmethod
def make(cls, *args, **kwargs):
expr = cls(*args, **kwargs)
active_module_tracer().current_scope().insert(expr)
return expr

def interpret(self, *inputs):
inp_dict = dict([(name, node) for node, name in zip(inputs, self.arg_names)])
outputs = self.func(**inp_dict, **self.kwargs)
outputs = (
outputs if isinstance(outputs, collections.abc.Sequence) else (outputs,)
)
return outputs

def __repr__(self):
return "{} = {}({})".format(
", ".join(str(i) for i in self.outputs),
self.func.__module__ + "." + self.func.__name__,
", ".join(str(i) for i in self.inputs),
)


# expr outputs = self.value
class Constant(Expr):
value = None


+ 109
- 1
imperative/python/megengine/experimental/traced_module/module_tracer.py View File

@@ -6,7 +6,11 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import collections

from ... import Tensor
from ... import functional as F
from ...core.tensor.array_method import ArrayMethodMixin
from ...module import Module

_active_module_tracer = None
@@ -23,12 +27,14 @@ def set_active_module_tracer(tracer):

class module_tracer:

# builtin types
_opaque_types = set()

_active_scopes = None

def __init__(self):
def __init__(self, wrap_fn):
self._active_scopes = []
self.patcher = Patcher(wrap_fn)

@classmethod
def register_as_builtin(cls, mod):
@@ -50,3 +56,105 @@ class module_tracer:
if self._active_scopes:
return self._active_scopes[-1]
return None


class PatchedFn:
frame_dict = None
name = None
origin_fn = None

def __init__(self, frame_dict, name):
self.frame_dict = frame_dict
self.name = name
self.origin_fn = (
self.frame_dict[name]
if isinstance(frame_dict, collections.abc.Mapping)
else getattr(frame_dict, name)
)

def set_func(self, func):
if isinstance(self.frame_dict, collections.abc.Mapping):
self.frame_dict[self.name] = func
else:
setattr(self.frame_dict, self.name, func)


class Patcher:

patched_fn_ids = set()
_builtin_functions = []
_builtin_modules = [
F,
F.distributed,
F.elemwise,
F.inplace,
F.loss,
F.math,
F.metric,
F.nn,
F.quantized,
F.tensor,
F.utils,
F.vision,
]
_builtin_methods = [
Tensor,
ArrayMethodMixin,
]

def __init__(self, wrap_fn):
self.patched_fn = []
self.visited_frames_ids = set()
self.wrap_fn = wrap_fn
for module in self._builtin_modules:
self.patch_module(module)

for cls in self._builtin_methods:
self.patch_cls(cls)

for i, j in self._builtin_functions:
if id(i) not in self.visited_frames_ids:
self.patch_function(i, j, self.wrap_fn)

def patch_function(self, frame_dict, fn, wrap_fn):
patched_fn = PatchedFn(frame_dict, fn)
self.patched_fn_ids.add(id(patched_fn.origin_fn))
patched_fn.set_func(wrap_fn(patched_fn.origin_fn))
self.patched_fn.append(patched_fn)

def patch_method(self, cls, name, wrap_fn):
self.patch_function(cls, name, wrap_fn)

def patch_cls(self, cls):
import inspect

if id(cls) not in self.visited_frames_ids:
for k, v in cls.__dict__.items():
if inspect.isfunction(v) and not k.startswith("_"):
self.patch_function(cls, k, self.wrap_fn)
self.visited_frames_ids.add(id(cls))

def patch_module(self, module):
import inspect

if id(module.__dict__) not in self.visited_frames_ids:
for k, v in module.__dict__.items():
if inspect.isfunction(v) and not k.startswith("_"):
self.patch_function(module.__dict__, k, self.wrap_fn)
self.visited_frames_ids.add(id(module.__dict__))

def auto_patch(self, frame_dict):
if id(frame_dict) not in self.visited_frames_ids:
for k, v in frame_dict.items():
if id(v) in self.patched_fn_ids:
self.patch_function(frame_dict, k, self.wrap_fn)
self.visited_frames_ids.add(id(frame_dict))

def __enter__(self):
return self

def __exit__(self, type, vlaue, trace):
while self.patched_fn:
pf = self.patched_fn.pop()
pf.set_func(pf.origin_fn)
self.visited_frames_ids.clear()

+ 4
- 0
imperative/python/megengine/experimental/traced_module/node.py View File

@@ -34,6 +34,10 @@ class Node:
Node.__total_id += 1
self._name = name

def __setstate__(self, d):
self.__dict__ = d
Node.__total_id = max(Node.__total_id, self._id) + 1

def __repr__(self):
if self._name is None:
return "%{}".format(self._id)


+ 120
- 25
imperative/python/megengine/experimental/traced_module/traced_module.py View File

@@ -8,14 +8,25 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import collections
import copy
import functools
from typing import List, Type

from ... import module as M
from ...core._imperative_rt.core2 import set_module_tracing, unset_module_tracing
from ...core._imperative_rt.core2 import (
is_tracing_module,
set_module_tracing,
unset_module_tracing,
)
from ...core.tensor.array_method import ArrayMethodMixin
from ...module import Module
from ...tensor import Tensor
from .expr import Apply, Call, Constant, Expr, GetAttr, Input
from .module_tracer import active_module_tracer, module_tracer, set_active_module_tracer
from .expr import Apply, CallFunction, CallMethod, Constant, Expr, GetAttr, Input
from .module_tracer import (
Patcher,
active_module_tracer,
module_tracer,
set_active_module_tracer,
)
from .node import ModuleNode, Node, NodeMixin, TensorNode


@@ -54,7 +65,9 @@ class InternalGraph:
for n, v in zip(self._inputs, inputs):
node2value[n] = v
for expr in self._exprs:
values = expr.interpret(*list(node2value[i] for i in expr.inputs))
values = expr.interpret(
*list(Expr.get_arg_value(i, node2value) for i in expr.inputs)
)
for n, v in zip(expr.outputs, values):
node2value[n] = v
return list(node2value[i] for i in self._outputs)
@@ -67,6 +80,41 @@ class InternalGraph:
)


def _wrapped_function(orig_func):
@functools.wraps(orig_func)
def wrapped_fn(*inputs, **kwargs):
if is_tracing_module():
unset_module_tracing()
const_kwargs = {}
arg_names = orig_func.__code__.co_varnames
if orig_func.__qualname__.split(".").__len__() > 1:
# FIXME: a robust way to distinguish method and function. <XP>
self = inputs[0]
call_node = CallMethod.make(NodeMixin.get(self), orig_func.__name__)
else:
call_node = CallFunction.make(orig_func)

def add_input(inp, varname=None):
node = Expr.get_args_node(inp)
if node is not None:
call_node.add_input(node, varname)
else:
const_kwargs[varname] = inp

for ind, inp in enumerate(inputs):
add_input(inp, arg_names[ind])
for k, v in kwargs.items():
add_input(v, k)
call_node.kwargs = const_kwargs
outputs = orig_func(*inputs, **kwargs)
call_node.add_outputs(outputs)
set_module_tracing()
return outputs
return orig_func(*inputs, **kwargs)

return wrapped_fn


class TracedModuleBuilder(NodeMixin):

_mod = None # type: Module
@@ -120,7 +168,7 @@ class TracedModuleBuilder(NodeMixin):
mark_constant(i)
for k, v in kwargs.items():
mark_constant(v)
callnode = Call.make(NodeMixin.get(self))
callnode = CallMethod.make(NodeMixin.get(self))

def add_input(x):
callnode.add_input(NodeMixin.get(x))
@@ -145,7 +193,8 @@ class TracedModuleBuilder(NodeMixin):
)
# prepare args and kwargs for inner graph
def wrap(x):
wrapped = copy.copy(x) # FIXME
# wrapped = copy.copy(x) # FIXME
wrapped = x # FIXME: <XP>
NodeMixin.wrap(
wrapped,
lambda: Input.make(type=NodeMixin.get_wrapped_type(wrapped)),
@@ -157,7 +206,9 @@ class TracedModuleBuilder(NodeMixin):
args.append(wrap(i))
for k, v in kwargs.items():
kwargs[k] = wrap(v)

active_module_tracer().patcher.auto_patch(
getattr(getattr(self._mod, "forward", self._mod), "__globals__", {})
)
outputs = type(self._mod).forward(self, *args, **kwargs)

for i in (
@@ -171,11 +222,6 @@ class TracedModuleBuilder(NodeMixin):

# rebind output to outer graph
callnode.add_outputs(outputs)
for i, node in zip(
outputs if isinstance(outputs, collections.abc.Sequence) else (outputs,),
callnode.outputs,
):
NodeMixin.wrap_safe(i, node)
return outputs

def __getattr__(self, name):
@@ -229,6 +275,55 @@ class TracedModule(Module):
rst = rst[0]
return rst

@property
def all_exprs(self):
"""
Visit all ``Expr``s in the graph recursively.

:return: List[Expr]
"""

in_nodes = [i.expr for i in self.m_node.graph._inputs if not i is self]

def _flatten_submodule(module, call=None):
if not isinstance(module, TracedModule):
call.inputs[0] = module
return (call,)

exprs = []

graph = module.m_node.graph
for expr in graph._exprs:

# replace inputs for submodule's expr
for idx, inp in enumerate(expr.inputs):
if call and inp in graph._inputs:
expr.inputs[idx] = call.inputs[idx]
# replace outputs for submodule's expr
for idx, outp in enumerate(expr.outputs):
if call and outp in graph._outputs:
expr.outputs[idx] = call.outputs[idx]

if isinstance(expr, GetAttr):
# replace GetAttr with Constant
if isinstance(expr.outputs[0], TensorNode):
const = Constant(getattr(module, expr.name))
const.outputs = expr.outputs
exprs.append(const)
elif isinstance(expr, CallMethod):
obj_node = expr.inputs[0]
if isinstance(obj_node, ModuleNode):
(obj,) = expr.inputs[0].expr.interpret(module)
exprs.extend(_flatten_submodule(obj, expr))
else:
exprs.append(expr)
else:
exprs.append(expr)

return exprs

return in_nodes + _flatten_submodule(self)

def __getstate__(self):
d = self.__dict__
for k in Module.__dict__:
@@ -273,23 +368,23 @@ def trace_module(mod: Module, *inputs: Tensor, **kwargs: Tensor) -> TracedModule
assert active_module_tracer() is None
try:
set_module_tracing()
set_active_module_tracer(module_tracer())
global_scope = InternalGraph()
active_module_tracer().push_scope(global_scope)
set_active_module_tracer(module_tracer(_wrapped_function))
with active_module_tracer().patcher:
global_scope = InternalGraph()
active_module_tracer().push_scope(global_scope)

builder = TracedModuleBuilder(mod)
NodeMixin.wrap_safe(builder, Input.make("TopModule", ModuleNode))
builder = TracedModuleBuilder(mod)
NodeMixin.wrap_safe(builder, Input.make("TopModule", ModuleNode))

for _, i in enumerate(inputs):
NodeMixin.wrap_safe(i, Input.make("arg_{}".format(_)))
for k, v in kwargs.items():
NodeMixin.wrap_safe(v, Input.make("kwarg_{}".format(k)))
for _, i in enumerate(inputs):
NodeMixin.wrap_safe(i, Input.make("arg_{}".format(_)))
for k, v in kwargs.items():
NodeMixin.wrap_safe(v, Input.make("kwarg_{}".format(k)))

builder(*inputs, **kwargs)
active_module_tracer().pop_scope()
builder(*inputs, **kwargs)
active_module_tracer().pop_scope()

return builder.build()
return builder.build()
finally:
set_active_module_tracer(None)
unset_module_tracing()

Loading…
Cancel
Save