Browse Source

feat(traced_module): add functional trace and CallMethod/Function expr

GitOrigin-RevId: ad2cdc1b61
release-1.6
Megvii Engine Team 4 years ago
parent
commit
bee305beb2
4 changed files with 338 additions and 53 deletions
  1. +105
    -27
      imperative/python/megengine/experimental/traced_module/expr.py
  2. +109
    -1
      imperative/python/megengine/experimental/traced_module/module_tracer.py
  3. +4
    -0
      imperative/python/megengine/experimental/traced_module/node.py
  4. +120
    -25
      imperative/python/megengine/experimental/traced_module/traced_module.py

+ 105
- 27
imperative/python/megengine/experimental/traced_module/expr.py View File

@@ -9,12 +9,13 @@




import collections import collections
from typing import List
from typing import Callable, List


from ...core._imperative_rt import OpDef from ...core._imperative_rt import OpDef
from ...core._imperative_rt.core2 import Tensor as RawTensor from ...core._imperative_rt.core2 import Tensor as RawTensor
from ...core._imperative_rt.core2 import apply, set_module_tracing, unset_module_tracing from ...core._imperative_rt.core2 import apply, set_module_tracing, unset_module_tracing
from ...core.ops.special import Const from ...core.ops.special import Const
from ...module import Module
from ...tensor import Tensor from ...tensor import Tensor
from .module_tracer import active_module_tracer from .module_tracer import active_module_tracer
from .node import ModuleNode, Node, NodeMixin, TensorNode from .node import ModuleNode, Node, NodeMixin, TensorNode
@@ -22,12 +23,66 @@ from .node import ModuleNode, Node, NodeMixin, TensorNode


class Expr: class Expr:
""" """
``Expr`` represents the operations(i.e. Call, Apply, GetAttr, Input, Constant) on ``Node``.
``Expr`` represents the operations(i.e. CallMethod, CallFunction, Apply, GetAttr, Input, Constant) on ``Node``.
""" """


inputs = None # type: List[Node] inputs = None # type: List[Node]
outputs = None # type: List[Node] outputs = None # type: List[Node]


def add_input(self, node):
self.inputs.append(node)

def add_outputs(self, outputs):
self.outputs = []
if not isinstance(outputs, collections.Sequence):
outputs = (outputs,)

for i in outputs:
self.outputs.append(NodeMixin.get_wrapped_type(i)(self))

for i, node in zip(outputs, self.outputs,):
NodeMixin.wrap_safe(i, node)

@classmethod
def get_args_node(cls, arg):
"""
Create nodes by ``arg``, which may be a container.
Return the same structure with arg.

If ``arg`` was not Tensor or Module, it will be stored as const.

:param arg: tensor, module or const.
"""
if isinstance(arg, (RawTensor, Module)):
if not NodeMixin.get(arg, None):
NodeMixin.wrap_safe(arg, Constant.make(arg))
return NodeMixin.get(arg)
elif isinstance(arg, collections.abc.Sequence):
seq_cls = type(arg)
return seq_cls([Expr.get_args_node(a) for a in arg])
else:
# TODO: assert arg type
return arg # as const

@classmethod
def get_arg_value(cls, inp_node, node2value):
"""
Get values from node2value by inp_node, which may be a container.
Return the same structure with inp_node.

If ``inp_node`` was not in node2value, it is a const.

:param inp_node: nodes.
:param node2value: dict from node to tensor and module.
"""
if inp_node in node2value:
return node2value[inp_node]
elif isinstance(inp_node, collections.abc.Sequence):
seq_cls = type(inp_node)
return seq_cls([Expr.get_arg_value(i, node2value) for i in inp_node])
else:
return inp_node



# expr: None (i.e. fake expression which is used to mark input) # expr: None (i.e. fake expression which is used to mark input)
class Input(Expr): class Input(Expr):
@@ -83,23 +138,22 @@ class GetAttr(Expr):




# expr: outputs = inputs[0].__call__(*inputs[1:]) # expr: outputs = inputs[0].__call__(*inputs[1:])
class Call(Expr):
def __init__(self, module):
assert isinstance(module, ModuleNode)
class CallMethod(Expr):
def __init__(self, module, method="__call__"):
assert isinstance(module, (TensorNode, ModuleNode))
self.inputs = [ self.inputs = [
module, module,
] ]
self.method = method
self.arg_names = []
self.kwargs = {} # const kwargs


def add_input(self, node):
def add_input(self, node, arg_name=None):
if arg_name == "self": # FIXME: <XP>
return
self.inputs.append(node) self.inputs.append(node)

def add_outputs(self, references):
self.outputs = []
if not isinstance(references, collections.Sequence):
references = (references,)

for i in references:
self.outputs.append(NodeMixin.get_wrapped_type(i)(self))
if arg_name is not None:
self.arg_names.append(arg_name)


@classmethod @classmethod
def make(cls, *args, **kwargs): def make(cls, *args, **kwargs):
@@ -110,15 +164,16 @@ class Call(Expr):
def interpret(self, *inputs): def interpret(self, *inputs):
mod = inputs[0] mod = inputs[0]
args = inputs[1:] args = inputs[1:]
outputs = mod(*args)
outputs = getattr(mod, self.method)(*args, **self.kwargs)
if isinstance(outputs, RawTensor): if isinstance(outputs, RawTensor):
outputs = (outputs,) outputs = (outputs,)
return outputs return outputs


def __repr__(self): def __repr__(self):
return "{} = Call({})({})".format(
return "{} = CallMethod({}, {})({})".format(
", ".join(str(i) for i in self.outputs), ", ".join(str(i) for i in self.outputs),
self.inputs[0], self.inputs[0],
self.method,
", ".join(str(i) for i in self.inputs[1:]), ", ".join(str(i) for i in self.inputs[1:]),
) )


@@ -132,17 +187,6 @@ class Apply(Expr):
self.opdef = opdef self.opdef = opdef
self.inputs = [] self.inputs = []


def add_input(self, node):
self.inputs.append(node)

def add_outputs(self, references):
self.outputs = []
if not isinstance(references, collections.Sequence):
references = (references,)

for i in references:
self.outputs.append(NodeMixin.get_wrapped_type(i)(self))

@classmethod @classmethod
def make(cls, *args, **kwargs): def make(cls, *args, **kwargs):
expr = cls(*args, **kwargs) expr = cls(*args, **kwargs)
@@ -179,6 +223,40 @@ class Apply(Expr):
return list(outputs) return list(outputs)




class CallFunction(Expr):
def __init__(self, func):
assert isinstance(func, Callable)
self.func = func
self.inputs = []
self.arg_names = []
self.kwargs = {} # const kwargs

def add_input(self, node, arg_name):
self.inputs.append(node)
self.arg_names.append(arg_name)

@classmethod
def make(cls, *args, **kwargs):
expr = cls(*args, **kwargs)
active_module_tracer().current_scope().insert(expr)
return expr

def interpret(self, *inputs):
inp_dict = dict([(name, node) for node, name in zip(inputs, self.arg_names)])
outputs = self.func(**inp_dict, **self.kwargs)
outputs = (
outputs if isinstance(outputs, collections.abc.Sequence) else (outputs,)
)
return outputs

def __repr__(self):
return "{} = {}({})".format(
", ".join(str(i) for i in self.outputs),
self.func.__module__ + "." + self.func.__name__,
", ".join(str(i) for i in self.inputs),
)


# expr outputs = self.value # expr outputs = self.value
class Constant(Expr): class Constant(Expr):
value = None value = None


+ 109
- 1
imperative/python/megengine/experimental/traced_module/module_tracer.py View File

@@ -6,7 +6,11 @@
# Unless required by applicable law or agreed to in writing, # Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import collections


from ... import Tensor
from ... import functional as F
from ...core.tensor.array_method import ArrayMethodMixin
from ...module import Module from ...module import Module


_active_module_tracer = None _active_module_tracer = None
@@ -23,12 +27,14 @@ def set_active_module_tracer(tracer):


class module_tracer: class module_tracer:


# builtin types
_opaque_types = set() _opaque_types = set()


_active_scopes = None _active_scopes = None


def __init__(self):
def __init__(self, wrap_fn):
self._active_scopes = [] self._active_scopes = []
self.patcher = Patcher(wrap_fn)


@classmethod @classmethod
def register_as_builtin(cls, mod): def register_as_builtin(cls, mod):
@@ -50,3 +56,105 @@ class module_tracer:
if self._active_scopes: if self._active_scopes:
return self._active_scopes[-1] return self._active_scopes[-1]
return None return None


class PatchedFn:
frame_dict = None
name = None
origin_fn = None

def __init__(self, frame_dict, name):
self.frame_dict = frame_dict
self.name = name
self.origin_fn = (
self.frame_dict[name]
if isinstance(frame_dict, collections.abc.Mapping)
else getattr(frame_dict, name)
)

def set_func(self, func):
if isinstance(self.frame_dict, collections.abc.Mapping):
self.frame_dict[self.name] = func
else:
setattr(self.frame_dict, self.name, func)


class Patcher:

patched_fn_ids = set()
_builtin_functions = []
_builtin_modules = [
F,
F.distributed,
F.elemwise,
F.inplace,
F.loss,
F.math,
F.metric,
F.nn,
F.quantized,
F.tensor,
F.utils,
F.vision,
]
_builtin_methods = [
Tensor,
ArrayMethodMixin,
]

def __init__(self, wrap_fn):
self.patched_fn = []
self.visited_frames_ids = set()
self.wrap_fn = wrap_fn
for module in self._builtin_modules:
self.patch_module(module)

for cls in self._builtin_methods:
self.patch_cls(cls)

for i, j in self._builtin_functions:
if id(i) not in self.visited_frames_ids:
self.patch_function(i, j, self.wrap_fn)

def patch_function(self, frame_dict, fn, wrap_fn):
patched_fn = PatchedFn(frame_dict, fn)
self.patched_fn_ids.add(id(patched_fn.origin_fn))
patched_fn.set_func(wrap_fn(patched_fn.origin_fn))
self.patched_fn.append(patched_fn)

def patch_method(self, cls, name, wrap_fn):
self.patch_function(cls, name, wrap_fn)

def patch_cls(self, cls):
import inspect

if id(cls) not in self.visited_frames_ids:
for k, v in cls.__dict__.items():
if inspect.isfunction(v) and not k.startswith("_"):
self.patch_function(cls, k, self.wrap_fn)
self.visited_frames_ids.add(id(cls))

def patch_module(self, module):
import inspect

if id(module.__dict__) not in self.visited_frames_ids:
for k, v in module.__dict__.items():
if inspect.isfunction(v) and not k.startswith("_"):
self.patch_function(module.__dict__, k, self.wrap_fn)
self.visited_frames_ids.add(id(module.__dict__))

def auto_patch(self, frame_dict):
if id(frame_dict) not in self.visited_frames_ids:
for k, v in frame_dict.items():
if id(v) in self.patched_fn_ids:
self.patch_function(frame_dict, k, self.wrap_fn)
self.visited_frames_ids.add(id(frame_dict))

def __enter__(self):
return self

def __exit__(self, type, vlaue, trace):
while self.patched_fn:
pf = self.patched_fn.pop()
pf.set_func(pf.origin_fn)
self.visited_frames_ids.clear()

+ 4
- 0
imperative/python/megengine/experimental/traced_module/node.py View File

@@ -34,6 +34,10 @@ class Node:
Node.__total_id += 1 Node.__total_id += 1
self._name = name self._name = name


def __setstate__(self, d):
self.__dict__ = d
Node.__total_id = max(Node.__total_id, self._id) + 1

def __repr__(self): def __repr__(self):
if self._name is None: if self._name is None:
return "%{}".format(self._id) return "%{}".format(self._id)


+ 120
- 25
imperative/python/megengine/experimental/traced_module/traced_module.py View File

@@ -8,14 +8,25 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import collections import collections
import copy import copy
import functools
from typing import List, Type from typing import List, Type


from ... import module as M from ... import module as M
from ...core._imperative_rt.core2 import set_module_tracing, unset_module_tracing
from ...core._imperative_rt.core2 import (
is_tracing_module,
set_module_tracing,
unset_module_tracing,
)
from ...core.tensor.array_method import ArrayMethodMixin
from ...module import Module from ...module import Module
from ...tensor import Tensor from ...tensor import Tensor
from .expr import Apply, Call, Constant, Expr, GetAttr, Input
from .module_tracer import active_module_tracer, module_tracer, set_active_module_tracer
from .expr import Apply, CallFunction, CallMethod, Constant, Expr, GetAttr, Input
from .module_tracer import (
Patcher,
active_module_tracer,
module_tracer,
set_active_module_tracer,
)
from .node import ModuleNode, Node, NodeMixin, TensorNode from .node import ModuleNode, Node, NodeMixin, TensorNode




@@ -54,7 +65,9 @@ class InternalGraph:
for n, v in zip(self._inputs, inputs): for n, v in zip(self._inputs, inputs):
node2value[n] = v node2value[n] = v
for expr in self._exprs: for expr in self._exprs:
values = expr.interpret(*list(node2value[i] for i in expr.inputs))
values = expr.interpret(
*list(Expr.get_arg_value(i, node2value) for i in expr.inputs)
)
for n, v in zip(expr.outputs, values): for n, v in zip(expr.outputs, values):
node2value[n] = v node2value[n] = v
return list(node2value[i] for i in self._outputs) return list(node2value[i] for i in self._outputs)
@@ -67,6 +80,41 @@ class InternalGraph:
) )




def _wrapped_function(orig_func):
@functools.wraps(orig_func)
def wrapped_fn(*inputs, **kwargs):
if is_tracing_module():
unset_module_tracing()
const_kwargs = {}
arg_names = orig_func.__code__.co_varnames
if orig_func.__qualname__.split(".").__len__() > 1:
# FIXME: a robust way to distinguish method and function. <XP>
self = inputs[0]
call_node = CallMethod.make(NodeMixin.get(self), orig_func.__name__)
else:
call_node = CallFunction.make(orig_func)

def add_input(inp, varname=None):
node = Expr.get_args_node(inp)
if node is not None:
call_node.add_input(node, varname)
else:
const_kwargs[varname] = inp

for ind, inp in enumerate(inputs):
add_input(inp, arg_names[ind])
for k, v in kwargs.items():
add_input(v, k)
call_node.kwargs = const_kwargs
outputs = orig_func(*inputs, **kwargs)
call_node.add_outputs(outputs)
set_module_tracing()
return outputs
return orig_func(*inputs, **kwargs)

return wrapped_fn


class TracedModuleBuilder(NodeMixin): class TracedModuleBuilder(NodeMixin):


_mod = None # type: Module _mod = None # type: Module
@@ -120,7 +168,7 @@ class TracedModuleBuilder(NodeMixin):
mark_constant(i) mark_constant(i)
for k, v in kwargs.items(): for k, v in kwargs.items():
mark_constant(v) mark_constant(v)
callnode = Call.make(NodeMixin.get(self))
callnode = CallMethod.make(NodeMixin.get(self))


def add_input(x): def add_input(x):
callnode.add_input(NodeMixin.get(x)) callnode.add_input(NodeMixin.get(x))
@@ -145,7 +193,8 @@ class TracedModuleBuilder(NodeMixin):
) )
# prepare args and kwargs for inner graph # prepare args and kwargs for inner graph
def wrap(x): def wrap(x):
wrapped = copy.copy(x) # FIXME
# wrapped = copy.copy(x) # FIXME
wrapped = x # FIXME: <XP>
NodeMixin.wrap( NodeMixin.wrap(
wrapped, wrapped,
lambda: Input.make(type=NodeMixin.get_wrapped_type(wrapped)), lambda: Input.make(type=NodeMixin.get_wrapped_type(wrapped)),
@@ -157,7 +206,9 @@ class TracedModuleBuilder(NodeMixin):
args.append(wrap(i)) args.append(wrap(i))
for k, v in kwargs.items(): for k, v in kwargs.items():
kwargs[k] = wrap(v) kwargs[k] = wrap(v)

active_module_tracer().patcher.auto_patch(
getattr(getattr(self._mod, "forward", self._mod), "__globals__", {})
)
outputs = type(self._mod).forward(self, *args, **kwargs) outputs = type(self._mod).forward(self, *args, **kwargs)


for i in ( for i in (
@@ -171,11 +222,6 @@ class TracedModuleBuilder(NodeMixin):


# rebind output to outer graph # rebind output to outer graph
callnode.add_outputs(outputs) callnode.add_outputs(outputs)
for i, node in zip(
outputs if isinstance(outputs, collections.abc.Sequence) else (outputs,),
callnode.outputs,
):
NodeMixin.wrap_safe(i, node)
return outputs return outputs


def __getattr__(self, name): def __getattr__(self, name):
@@ -229,6 +275,55 @@ class TracedModule(Module):
rst = rst[0] rst = rst[0]
return rst return rst


@property
def all_exprs(self):
"""
Visit all ``Expr``s in the graph recursively.

:return: List[Expr]
"""

in_nodes = [i.expr for i in self.m_node.graph._inputs if not i is self]

def _flatten_submodule(module, call=None):
if not isinstance(module, TracedModule):
call.inputs[0] = module
return (call,)

exprs = []

graph = module.m_node.graph
for expr in graph._exprs:

# replace inputs for submodule's expr
for idx, inp in enumerate(expr.inputs):
if call and inp in graph._inputs:
expr.inputs[idx] = call.inputs[idx]
# replace outputs for submodule's expr
for idx, outp in enumerate(expr.outputs):
if call and outp in graph._outputs:
expr.outputs[idx] = call.outputs[idx]

if isinstance(expr, GetAttr):
# replace GetAttr with Constant
if isinstance(expr.outputs[0], TensorNode):
const = Constant(getattr(module, expr.name))
const.outputs = expr.outputs
exprs.append(const)
elif isinstance(expr, CallMethod):
obj_node = expr.inputs[0]
if isinstance(obj_node, ModuleNode):
(obj,) = expr.inputs[0].expr.interpret(module)
exprs.extend(_flatten_submodule(obj, expr))
else:
exprs.append(expr)
else:
exprs.append(expr)

return exprs

return in_nodes + _flatten_submodule(self)

def __getstate__(self): def __getstate__(self):
d = self.__dict__ d = self.__dict__
for k in Module.__dict__: for k in Module.__dict__:
@@ -273,23 +368,23 @@ def trace_module(mod: Module, *inputs: Tensor, **kwargs: Tensor) -> TracedModule
assert active_module_tracer() is None assert active_module_tracer() is None
try: try:
set_module_tracing() set_module_tracing()
set_active_module_tracer(module_tracer())
global_scope = InternalGraph()
active_module_tracer().push_scope(global_scope)
set_active_module_tracer(module_tracer(_wrapped_function))
with active_module_tracer().patcher:
global_scope = InternalGraph()
active_module_tracer().push_scope(global_scope)


builder = TracedModuleBuilder(mod)
NodeMixin.wrap_safe(builder, Input.make("TopModule", ModuleNode))
builder = TracedModuleBuilder(mod)
NodeMixin.wrap_safe(builder, Input.make("TopModule", ModuleNode))


for _, i in enumerate(inputs):
NodeMixin.wrap_safe(i, Input.make("arg_{}".format(_)))
for k, v in kwargs.items():
NodeMixin.wrap_safe(v, Input.make("kwarg_{}".format(k)))
for _, i in enumerate(inputs):
NodeMixin.wrap_safe(i, Input.make("arg_{}".format(_)))
for k, v in kwargs.items():
NodeMixin.wrap_safe(v, Input.make("kwarg_{}".format(k)))


builder(*inputs, **kwargs)
active_module_tracer().pop_scope()
builder(*inputs, **kwargs)
active_module_tracer().pop_scope()


return builder.build()
return builder.build()
finally: finally:
set_active_module_tracer(None) set_active_module_tracer(None)
unset_module_tracing() unset_module_tracing()

Loading…
Cancel
Save