Browse Source

feat(traced_module): let CallFunction own graph

GitOrigin-RevId: 66cdbca7e5
release-1.6
Megvii Engine Team 3 years ago
parent
commit
4bb253695b
4 changed files with 115 additions and 65 deletions
  1. +13
    -2
      imperative/python/megengine/experimental/traced_module/expr.py
  2. +6
    -2
      imperative/python/megengine/experimental/traced_module/node.py
  3. +33
    -5
      imperative/python/megengine/experimental/traced_module/pytree.py
  4. +63
    -56
      imperative/python/megengine/experimental/traced_module/traced_module.py

+ 13
- 2
imperative/python/megengine/experimental/traced_module/expr.py View File

@@ -17,7 +17,7 @@ from ...core._imperative_rt.core2 import apply, set_module_tracing, unset_module
from ...core.ops.special import Const
from ...module import Module
from ...tensor import Tensor
from .module_tracer import active_module_tracer
from .module_tracer import active_module_tracer, module_tracer
from .node import ModuleNode, Node, NodeMixin, TensorNode
from .pytree import TreeDef

@@ -148,6 +148,15 @@ class CallMethod(Expr):
active_module_tracer().current_scope().insert(expr)
return expr

@property
def graph(self):
if isinstance(self.inputs[0], ModuleNode):
m_node = self.inputs[0]
if m_node.argdef_graph_map:
assert self.arg_def in m_node.argdef_graph_map
return m_node.argdef_graph_map[self.arg_def]
return None

def interpret(self, *inputs):
args, kwargs = self.unflatten_args(inputs)
obj = args[0]
@@ -252,7 +261,9 @@ class Constant(Expr):
_constant_cache = {}

def __init__(self, c):
# TODO: type check, since not all types should be captured as constant
assert isinstance(c, (RawTensor, Module))
if isinstance(c, Module):
assert module_tracer.is_builtin(c)
self.value = c
self.inputs = []
node_cls = NodeMixin.get_wrapped_type(c)


+ 6
- 2
imperative/python/megengine/experimental/traced_module/node.py View File

@@ -57,9 +57,13 @@ class ModuleNode(Node):
"""

module_type = Module # type: Type[Module]
graph = None
attr_type_map = None # type: Dict[str, Type[Any]]
arg_def = None # type: TreeDef
argdef_graph_map = None # type: Dict[Treedef, "InternalGraph"]

def __init__(self, expr: "Expr", name: str = None):
super().__init__(expr, name)
self.attr_type_map = {}
self.argdef_graph_map = {}

def __repr__(self):
if self._name is None:


+ 33
- 5
imperative/python/megengine/experimental/traced_module/pytree.py View File

@@ -25,7 +25,7 @@ def _dict_flatten(inp):
for key, value in sorted(inp.items()):
results.append(value)
aux_data.append(key)
return results, aux_data
return results, tuple(aux_data)


def _dict_unflatten(inps, aux_data):
@@ -43,16 +43,23 @@ register_supported_type(


def tree_flatten(
values, leaf_type: Callable = lambda x: type(x), is_leaf: Callable = lambda x: True
values,
leaf_type: Callable = lambda x: type(x),
is_leaf: Callable = lambda _: True,
is_const_leaf: Callable = lambda _: False,
):
if type(values) not in SUPPORTED_TYPE:
assert is_leaf(values)
return [values,], LeafDef(leaf_type(values))
node = LeafDef(leaf_type(values))
if is_const_leaf(values):
node.const_val = values
return [values,], node

rst = []
children_defs = []
children_values, aux_data = SUPPORTED_TYPE[type(values)].flatten(values)
for v in children_values:
v_list, treedef = tree_flatten(v, leaf_type)
v_list, treedef = tree_flatten(v, leaf_type, is_leaf, is_const_leaf)
rst.extend(v_list)
children_defs.append(treedef)

@@ -75,6 +82,18 @@ class TreeDef:
start += ch.num_leaves
return SUPPORTED_TYPE[self.type].unflatten(children, self.aux_data)

def __hash__(self):
return hash(
tuple(
[
self.type,
self.aux_data,
self.num_leaves,
tuple([hash(x) for x in self.children_defs]),
]
)
)

def __eq__(self, other):
return (
self.type == other.type
@@ -93,11 +112,20 @@ class LeafDef(TreeDef):
type = (type,)
super().__init__(type, None, [])
self.num_leaves = 1
self.const_val = None

def unflatten(self, leaves):
assert len(leaves) == 1
assert isinstance(leaves[0], self.type), self.type
return leaves[0]

def __eq__(self, other):
return self.type == other.type and self.const_val == other.const_val

def __hash__(self):
return hash(tuple([self.type, self.const_val]))

def __repr__(self):
return "Leaf({})".format(", ".join(t.__name__ for t in self.type))
return "Leaf({}[{}])".format(
", ".join(t.__name__ for t in self.type), self.const_val
)

+ 63
- 56
imperative/python/megengine/experimental/traced_module/traced_module.py View File

@@ -42,6 +42,12 @@ def _leaf_type(node):
return type(node)


def _is_const_leaf(node):
if isinstance(node, (RawTensor, NodeMixin, Module)):
return False
return True


class InternalGraph:
"""
``InternalGraph`` is a graph consist of ``Node`` and ``Expr``, it is used to represent the execution procedure of Module's forward method.
@@ -72,6 +78,10 @@ class InternalGraph:
def outputs(self):
return self._outputs

@property
def exprs(self):
return _expr_list(self)

def add_input(self, i):
self._inputs.append(i)

@@ -111,7 +121,9 @@ def _wrapped_function(orig_func):
def wrapped_fn(*args, **kwargs):
if is_tracing_module():
unset_module_tracing()
inputs, tree_def = tree_flatten((args, kwargs), leaf_type=_leaf_type)
inputs, tree_def = tree_flatten(
(args, kwargs), leaf_type=_leaf_type, is_const_leaf=_is_const_leaf
)
for i in inputs:
if not NodeMixin.get(i, None):
if isinstance(i, (RawTensor, NodeMixin)):
@@ -140,21 +152,18 @@ class TracedModuleBuilder(NodeMixin):
_mod = None # type: Module
_body = None # type: InternalGraph
_is_builtin = None # type: bool
_arg_def = None # type: TreeDef
__builder_attributes__ = [
"_mod",
"_body",
"_NodeMixin__node",
"_is_builtin",
"_is_traced",
"_arg_def" "build",
"build",
]

def __init__(self, mod):
def __init__(self, mod, is_top_module=False):
super(TracedModuleBuilder, self).__init__()
self._mod = mod
self._body = InternalGraph()
self._is_traced = False
self._body = None
self._is_builtin = module_tracer.is_builtin(mod)

def build(self):
@@ -164,9 +173,6 @@ class TracedModuleBuilder(NodeMixin):
return self._mod
else:
node = NodeMixin.get(self)
node.graph = self._body
node.attr_type_map = {}
node.arg_def = self._arg_def
traced_module = TracedModule(node)
for k, v in self.__dict__.items():
if k not in TracedModuleBuilder.__builder_attributes__:
@@ -178,21 +184,15 @@ class TracedModuleBuilder(NodeMixin):

def __call__(self, *args, **kwargs):
assert isinstance(self._mod, Module)
for arg in args:
assert isinstance(arg, RawTensor)

for k, v in kwargs.items():
assert isinstance(v, RawTensor)
# prepare args and kwargs for inner graph
def mark_constant(x):
node = NodeMixin.get(x, None)
if node is None: # capture as constant
NodeMixin.wrap(x, lambda: Constant.make(x))

inputs, tree_def = tree_flatten(((self, *args), kwargs), leaf_type=_leaf_type)
if self._arg_def is None:
self._arg_def = tree_def
assert self._arg_def == tree_def
inputs, tree_def = tree_flatten(
((self, *args), kwargs), leaf_type=_leaf_type, is_const_leaf=_is_const_leaf
)
for i in inputs:
mark_constant(i)
callnode = CallMethod.make(NodeMixin.get(self))
@@ -201,13 +201,14 @@ class TracedModuleBuilder(NodeMixin):

callnode.arg_def = tree_def

if self._is_builtin or self._is_traced:
if self._is_builtin:
unset_module_tracing()
outputs = self._mod(*args, **kwargs)
set_module_tracing()
if self._is_builtin:
self._body = None
else:
self._body = InternalGraph()
active_module_tracer().push_scope(self._body)
# rebind self to new input node
orig_self = NodeMixin.get(self)
@@ -238,11 +239,12 @@ class TracedModuleBuilder(NodeMixin):
active_module_tracer().current_scope().add_output(NodeMixin.get(i))

NodeMixin.wrap_safe(self, orig_self)
self._is_traced = True
active_module_tracer().pop_scope()

# rebind output to outer graph
callnode.add_outputs(outputs)
self_node = NodeMixin.get(self)
self_node.argdef_graph_map[callnode.arg_def] = self._body
return outputs

def __getattr__(self, name):
@@ -280,24 +282,23 @@ class TracedModuleBuilder(NodeMixin):


class _expr_list:
def __init__(self, module: "TracedModule"):
self.module = module
def __init__(self, graph: InternalGraph):
self.graph = graph

def __iter__(self):
graph = self.module.m_node.graph
for expr in graph._exprs:
for expr in self.graph._exprs:
if isinstance(expr, CallMethod) and isinstance(expr.inputs[0], ModuleNode):
yield expr
assert isinstance(expr.inputs[0].expr, GetAttr)
(obj,) = expr.inputs[0].expr.interpret(self.module)
if isinstance(obj, TracedModule):
yield from obj.exprs
yield expr
if expr.graph is not None:
yield from expr.graph.exprs
else:
yield expr


class TracedModule(Module):
"""
`TracedModule` is the Module created by tracing normal module. It owns a ModuleNode(m_node), and will interpret the m_node.graph when it is called.
`TracedModule` is the Module created by tracing normal module. It owns a ModuleNode(m_node). `TracedModule` can not be called directly. It can be
interpreted by CallMethod Expr.
"""

m_node = None # type: ModuleNode
@@ -307,21 +308,24 @@ class TracedModule(Module):
self.m_node = node

def forward(self, *args, **kwargs):
inputs, treedef = tree_flatten(((self, *args), kwargs), leaf_type=_leaf_type)
assert treedef == self.m_node.arg_def
rst = self.m_node.graph.interpret(*inputs)
if len(rst) == 1:
rst = rst[0]
return rst
inputs, treedef = tree_flatten(
((self, *args), kwargs), _leaf_type, is_const_leaf=_is_const_leaf
)
assert treedef in self.m_node.argdef_graph_map
inputs = [i for i in inputs if isinstance(i, (Module, RawTensor))]
outputs = self.m_node.argdef_graph_map[treedef].interpret(*inputs)
if len(outputs) == 1:
return outputs[0]
return outputs

@property
def exprs(self):
"""
Get all ``Expr`` s recursively.
def graph(self):
assert len(self.m_node.argdef_graph_map) == 1
return list(self.m_node.argdef_graph_map.values())[0]

:return: Iterator[Expr]
"""
return _expr_list(self)
@property
def exprs(self):
return self.graph.exprs

def flatten(self):
"""
@@ -331,24 +335,26 @@ class TracedModule(Module):
"""
new_module = copy.deepcopy(self)

def _flatten_submodule(module, call=None):
if not isinstance(module, TracedModule):
call.inputs[0] = module
return (call,)

def _flatten_subgraph(graph, module, call=None):
if graph is None:
assert not isinstance(module, TracedModule)
const = Constant(module)
modulenode = const.outputs[0]
modulenode.module_type = type(module)
call.inputs[0] = modulenode
return [const, call]
exprs = []

graph = module.m_node.graph
for expr in graph._exprs:

# replace inputs for submodule's expr
for idx, inp in enumerate(expr.inputs):
if call and inp in graph._inputs:
expr.inputs[idx] = call.inputs[idx]
inp_idx = graph._inputs.index(inp)
expr.inputs[idx] = call.inputs[inp_idx]
# replace outputs for submodule's expr
for idx, outp in enumerate(expr.outputs):
if call and outp in graph._outputs:
expr.outputs[idx] = call.outputs[idx]
oup_idx = graph._outputs.index(outp)
expr.outputs[idx] = call.outputs[oup_idx]

if isinstance(expr, GetAttr):
# replace GetAttr with Constant
@@ -356,12 +362,13 @@ class TracedModule(Module):
const = Constant(getattr(module, expr.name))
const.outputs = expr.outputs
exprs.append(const)

elif isinstance(expr, CallMethod):
obj_node = expr.inputs[0]
if isinstance(obj_node, ModuleNode):
assert isinstance(expr.inputs[0].expr, GetAttr)
(obj,) = expr.inputs[0].expr.interpret(module)
exprs.extend(_flatten_submodule(obj, expr))
exprs.extend(_flatten_subgraph(expr.graph, obj, expr))
else:
exprs.append(expr)
else:
@@ -369,7 +376,7 @@ class TracedModule(Module):

return exprs

new_module.m_node.graph._exprs = _flatten_submodule(new_module)
new_module.graph._exprs = _flatten_subgraph(new_module.graph, new_module)

return new_module

@@ -421,7 +428,7 @@ def trace_module(mod: Module, *args: Tensor, **kwargs: Tensor) -> TracedModule:
global_scope = InternalGraph()
active_module_tracer().push_scope(global_scope)

builder = TracedModuleBuilder(mod)
builder = TracedModuleBuilder(mod, True)
NodeMixin.wrap_safe(builder, Input.make("TopModule", ModuleNode))
inputs, _ = tree_flatten((args, kwargs))
for _, i in enumerate(inputs):


Loading…
Cancel
Save