Browse Source

refactor(traced_module): let TracedModule own argdef_graph_map

GitOrigin-RevId: 80d685b9a3
release-1.6
Megvii Engine Team 3 years ago
parent
commit
f88bd3ae33
6 changed files with 177 additions and 67 deletions
  1. +9
    -8
      imperative/python/megengine/experimental/traced_module/expr.py
  2. +3
    -0
      imperative/python/megengine/experimental/traced_module/module_tracer.py
  3. +25
    -7
      imperative/python/megengine/experimental/traced_module/node.py
  4. +138
    -52
      imperative/python/megengine/experimental/traced_module/traced_module.py
  5. +1
    -0
      imperative/python/test/integration/test_converge.py
  6. +1
    -0
      imperative/python/test/integration/test_converge_with_gradient_clip.py

+ 9
- 8
imperative/python/megengine/experimental/traced_module/expr.py View File

@@ -9,6 +9,7 @@

import builtins
import collections
import copy
import inspect
from typing import Callable, List

@@ -46,7 +47,7 @@ class Expr:
idx = len(self.inputs) + len(self.const_val)
self.const_val.append((idx, val))

def add_outputs(self, outputs, check_inplace=True):
def add_outputs(self, outputs):
self.outputs = []
if outputs is not None:
if not isinstance(outputs, collections.Sequence):
@@ -54,10 +55,7 @@ class Expr:

for i in outputs:
assert isinstance(i, RawTensor)
node = NodeMixin.get(i, None) if check_inplace else None
self.outputs.append(
node if node else NodeMixin.get_wrapped_type(i)(self)
)
self.outputs.append(NodeMixin.get_wrapped_type(i)(self))

for i, node in zip(outputs, self.outputs,):
NodeMixin.wrap_safe(i, node)
@@ -165,9 +163,12 @@ class CallMethod(Expr):
def graph(self):
if isinstance(self.inputs[0], ModuleNode):
m_node = self.inputs[0]
if m_node.argdef_graph_map:
assert self.arg_def in m_node.argdef_graph_map
return m_node.argdef_graph_map[self.arg_def]
if (
hasattr(m_node.owner, "argdef_graph_map")
and m_node.owner.argdef_graph_map
):
assert self.arg_def in m_node.owner.argdef_graph_map
return m_node.owner.argdef_graph_map[self.arg_def]
return None

def interpret(self, *inputs):


+ 3
- 0
imperative/python/megengine/experimental/traced_module/module_tracer.py View File

@@ -184,6 +184,9 @@ class Patcher:
if id(i) not in self.visited_frames_ids:
self.patch_function(i, j, self.wrap_fn)

for m in module_tracer._opaque_types:
self.auto_patch(getattr(getattr(m, "forward", m), "__globals__", {}))

def patch_function(self, frame_dict, fn, wrap_fn):
patched_fn = PatchedFn(frame_dict, fn)
self.patched_fn_ids.add(id(patched_fn.origin_fn))


+ 25
- 7
imperative/python/megengine/experimental/traced_module/node.py View File

@@ -6,6 +6,8 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import abc
import weakref
from typing import Any, Dict, List, Tuple, Type

import numpy
@@ -58,15 +60,10 @@ class ModuleNode(Node):
"""

module_type = Module # type: Type[Module]
attr_type_map = None # type: Dict[str, Type[Any]]
argdef_graph_map = None # type: Dict[Treedef, "InternalGraph"]
argdef_outdef_map = None # type: Dict[Treedef, Treedef]
_owner = None # type: weakref.ReferenceType

def __init__(self, expr: "Expr", name: str = None):
super().__init__(expr, name)
self.attr_type_map = {}
self.argdef_graph_map = {}
self.argdef_outdef_map = {}

def __repr__(self):
if self._name is None:
@@ -74,6 +71,15 @@ class ModuleNode(Node):
else:
return "%{}({})".format(self._name, self.module_type.__name__)

def __getstate__(self):
d = self.__dict__
d.pop("_owner", None)
return d

@property
def owner(self):
return self._owner()


class TensorNode(Node):
"""
@@ -90,9 +96,14 @@ class TensorNode(Node):
return "%{}(Tensor)".format(self._name)


class NodeMixin:
class NodeMixin(abc.ABC):
__node = None

@abc.abstractmethod
def _record_wrapped_nodes(self, node):
# record the nodes which had been bound to this NodeMixin
pass

@classmethod
def wrap(cls, value, node):
if isinstance(value, (NodeMixin, RawTensor)):
@@ -102,15 +113,20 @@ class NodeMixin:
node.shape = (
value._tuple_shape if isinstance(value, Tensor) else value.shape
)
if isinstance(value, NodeMixin):
value._record_wrapped_nodes(node)
setattr(value, "_NodeMixin__node", node)
else:
assert callable(node)
n = node()
assert isinstance(n, Node)
if isinstance(value, RawTensor):
n.dtype = value.dtype
n.shape = (
value._tuple_shape if isinstance(value, Tensor) else value.shape
)
if isinstance(value, NodeMixin):
value._record_wrapped_nodes(n)
setattr(value, "_NodeMixin__node", n)

@classmethod
@@ -122,6 +138,8 @@ class NodeMixin:
value._tuple_shape if isinstance(value, Tensor) else value.shape
)
setattr(value, "_NodeMixin__node", node)
if isinstance(value, NodeMixin):
value._record_wrapped_nodes(node)

@classmethod
def get(cls, value, *default):


+ 138
- 52
imperative/python/megengine/experimental/traced_module/traced_module.py View File

@@ -9,6 +9,7 @@
import collections
import copy
import functools
import weakref
from inspect import getmembers, isclass, ismethod
from typing import Callable, Dict, Iterable, List, Sequence, Type

@@ -51,7 +52,9 @@ def _leaf_type(node):


def _is_leaf(node):
assert isinstance(node, RawTensor), type(node)
assert isinstance(node, RawTensor), "doesn't support {} in return values".format(
type(node)
)
return isinstance(node, RawTensor)


@@ -107,6 +110,32 @@ class InternalGraph:
def add_output(self, o):
self._outputs.append(o)

def _replace_inputs_outputs(self, repl_dict):

for node, repl_node in repl_dict.items():
assert node in self._inputs or node in self._outputs
for i in node.users:
if i not in repl_node.users:
repl_node.users.append(i)

for idx, i in enumerate(self._inputs):
if i in repl_dict:
self._inputs[idx] = repl_dict[i]
for idx, o in enumerate(self._outputs):
if o in repl_dict:
self._outputs[idx] = repl_dict[o]
self._outputs[idx].expr = node.expr

for expr in self._exprs:

for idx, i in enumerate(expr.inputs):
if i in repl_dict:
expr.inputs[idx] = repl_dict[i]

for idx, o in enumerate(expr.outputs):
if o in repl_dict:
expr.outputs[idx] = repl_dict[o]

def get_dep_exprs(self, nodes: Sequence[Node]) -> List[Expr]:
if not isinstance(nodes, Sequence):
nodes = (nodes,)
@@ -117,6 +146,7 @@ class InternalGraph:
expr = node.expr
if expr not in ret:
ret.append(expr)

for i in expr.inputs:
if i not in queue:
queue.append(i)
@@ -287,10 +317,7 @@ def _wrapped_function(orig_func):

call_node.arg_def = tree_def
outputs = orig_func(*args, **kwargs)
if meth_name == "__new__":
call_node.add_outputs(outputs, False)
else:
call_node.add_outputs(outputs)
call_node.add_outputs(outputs)
set_module_tracing()
return outputs
return orig_func(*args, **kwargs)
@@ -303,12 +330,19 @@ class TracedModuleBuilder(NodeMixin):
_mod = None # type: Module
_body = None # type: InternalGraph
_is_builtin = None # type: bool
_argdef_graph_map = None # type: Dict[Treedef, "InternalGraph"]
_argdef_outdef_map = None # type: Dict[Treedef, Treedef]
nodes = None

__builder_attributes__ = [
"_mod",
"_body",
"_NodeMixin__node",
"_is_builtin",
"build",
"_argdef_graph_map",
"_argdef_outdef_map",
"nodes",
]

def __init__(self, mod, is_top_module=False):
@@ -316,23 +350,36 @@ class TracedModuleBuilder(NodeMixin):
self._mod = mod
self._body = None
self._is_builtin = module_tracer.is_builtin(mod)
self._argdef_graph_map = {}
self._argdef_outdef_map = {}
self.nodes = set()

def build(self):
if self._is_builtin:
node = NodeMixin.get(self)
node.module_type = type(self._mod)
for node in self.nodes:
node.module_type = type(self._mod)
# node._owner = weakref.ref(self._mod)
return self._mod
else:
node = NodeMixin.get(self)
traced_module = TracedModule(node)
traced_module = TracedModule(
self._argdef_graph_map, self._argdef_outdef_map
)
for _, g in self._argdef_graph_map.items():
g.compile()
# for node in self.nodes:
# node._owner = weakref.ref(traced_module)

for k, v in self.__dict__.items():
if k not in TracedModuleBuilder.__builder_attributes__:
if isinstance(v, TracedModuleBuilder):
v = v.build()
setattr(traced_module, k, v)
traced_module.m_node.attr_type_map[k] = type(v)
return traced_module

def _record_wrapped_nodes(self, node):
self.nodes.add(node)

def __call__(self, *args, **kwargs):
assert isinstance(self._mod, Module)
# prepare args and kwargs for inner graph
@@ -360,19 +407,30 @@ class TracedModuleBuilder(NodeMixin):
if self._is_builtin:
self._body = None
else:
self_node = None
if self._body:
self_node = self._body.inputs[0]
self._body = InternalGraph()
active_module_tracer().push_scope(self._body)
# rebind self to new input node
orig_self = NodeMixin.get(self)
NodeMixin.wrap_safe(
self, Input.make("self", NodeMixin.get_wrapped_type(self))
)
if self_node:
NodeMixin.wrap_safe(self, self_node)
active_module_tracer().current_scope().add_input(self_node)
else:
NodeMixin.wrap_safe(
self,
self_node
if self_node
else Input.make("self", NodeMixin.get_wrapped_type(self)),
)
origin_inp_node = [NodeMixin.get(i, None) for i in inputs[1:]]
# prepare args and kwargs for inner graph
def wrap(x):
NodeMixin.wrap(
x, lambda: Input.make(type=NodeMixin.get_wrapped_type(x)),
)
if isinstance(x, (RawTensor, NodeMixin)):
NodeMixin.wrap(
x, lambda: Input.make(type=NodeMixin.get_wrapped_type(x)),
)
return x

args = [self]
@@ -397,9 +455,8 @@ class TracedModuleBuilder(NodeMixin):

# rebind output to outer graph
callnode.add_outputs(outputs)
self_node = NodeMixin.get(self)
self_node.argdef_graph_map[callnode.arg_def] = self._body
self_node.argdef_outdef_map[callnode.arg_def] = out_def
self._argdef_graph_map[callnode.arg_def] = self._body
self._argdef_outdef_map[callnode.arg_def] = out_def
return rst

def __getattr__(self, name):
@@ -424,8 +481,8 @@ class TracedModuleBuilder(NodeMixin):
else:
wrapped = super().__getattribute__(name)
if name in self._mod.__dict__:
if not NodeMixin.get(wrapped, None):
assert not self._is_builtin
assert not self._is_builtin
if isinstance(wrapped, (NodeMixin, RawTensor)):
NodeMixin.wrap(
wrapped,
lambda: GetAttr.make(
@@ -434,14 +491,15 @@ class TracedModuleBuilder(NodeMixin):
type=NodeMixin.get_wrapped_type(wrapped),
),
)
"""
else:
node = NodeMixin.get(wrapped)
expr = GetAttr.make(
NodeMixin.get(self),
name,
type=NodeMixin.get_wrapped_type(wrapped),
).expr
expr.outputs[0] = node
expr = node.expr
assert isinstance(expr, GetAttr)
if expr not in active_module_tracer().current_scope()._exprs:
active_module_tracer().current_scope().insert(expr)
"""
return wrapped


@@ -514,33 +572,51 @@ class ExprFilterCallMethod(ExprFilter):

class TracedModule(Module):
"""
`TracedModule` is the Module created by tracing normal module. It owns a ModuleNode(m_node). `TracedModule` can not be called directly. It can be
interpreted by CallMethod Expr.
`TracedModule` is the Module created by tracing normal module. It owns an argdef to graph(InternalGraph) map. The forward method of `TracedModule` will get a graph from `argdef_graph_map` according to the argdef of input args/kwargs and interpret it.
"""

m_node = None # type: ModuleNode
# m_node = None # type: ModuleNode
argdef_graph_map = None
argdef_outdef_map = None

def __init__(self, node):
def __init__(self, argdef_graph_map, argdef_outdef_map):
super(TracedModule, self).__init__()
self.m_node = node
self.argdef_graph_map = argdef_graph_map
self.argdef_outdef_map = argdef_outdef_map

def forward(self, *args, **kwargs):
inputs, treedef = tree_flatten(
((self, *args), kwargs), _leaf_type, is_const_leaf=_is_const_leaf
)
assert treedef in self.m_node.argdef_graph_map
assert treedef in self.argdef_graph_map
inputs = filter(
lambda i: isinstance(i, (Module, TracedModuleBuilder, RawTensor)), inputs
) # allow TracedModuleBuilder for retrace.
outputs = self.m_node.argdef_graph_map[treedef].interpret(*inputs)
out_def = self.m_node.argdef_outdef_map[treedef]
outputs = self.argdef_graph_map[treedef].interpret(*inputs)
out_def = self.argdef_outdef_map[treedef]
outputs = out_def.unflatten(outputs)
return outputs

@property
def graph(self):
assert len(self.m_node.argdef_graph_map) == 1
return list(self.m_node.argdef_graph_map.values())[0]
self._update_modulenode_ref()
assert len(self.argdef_graph_map) == 1
return list(self.argdef_graph_map.values())[0]

def _update_modulenode_ref(self):
for _, graph in self.argdef_graph_map.items():
graph._inputs[0]._owner = weakref.ref(self)
node2obj = {}
node2obj[graph._inputs[0]] = self
for expr in graph._exprs:
if isinstance(expr, GetAttr) and isinstance(
expr.outputs[0], ModuleNode
):
obj = getattr(node2obj[expr.inputs[0]], expr.name)
expr.outputs[0]._owner = weakref.ref(obj)
node2obj[expr.outputs[0]] = obj
if isinstance(obj, TracedModule):
obj._update_modulenode_ref()

@property
def exprs(self):
@@ -561,39 +637,49 @@ class TracedModule(Module):
const.outputs[0] = call.inputs[0]
const.outputs[0].expr = const
return [const, call]
if call is not None:
graph = copy.deepcopy(graph)
exprs = []
node2obj = {}
node2obj[graph._inputs[0]] = module
if call:
node2obj[call.inputs[0]] = module
for expr in graph._exprs:
# replace inputs for submodule's expr
for idx, inp in enumerate(expr.inputs):
if call and inp in graph._inputs:
inp_idx = graph._inputs.index(inp)
expr.inputs[idx] = call.inputs[inp_idx]
call.inputs[inp_idx].users.append(expr)
# replace outputs for submodule's expr
for idx, outp in enumerate(expr.outputs):
if call and outp in graph._outputs:
oup_idx = graph._outputs.index(outp)
expr.outputs[idx] = call.outputs[oup_idx]
call.outputs[oup_idx].expr = expr
# replace inputs for submodule's exprx
if call:
repl_dict = dict(
zip(graph._inputs + graph._outputs, call.inputs + call.outputs)
)
graph._replace_inputs_outputs(repl_dict)

if isinstance(expr, GetAttr):
# replace GetAttr with Constant
if isinstance(expr.outputs[0], TensorNode):
const = Constant(getattr(module, expr.name))
const = Constant(getattr(node2obj[expr.inputs[0]], expr.name))
const.outputs = expr.outputs
const.outputs[0].expr = const
exprs.append(const)
elif isinstance(expr.outputs[0], ModuleNode):
node2obj[expr.outputs[0]] = getattr(
node2obj[expr.inputs[0]], expr.name
)

elif isinstance(expr, CallMethod):
obj_node = expr.inputs[0]
if isinstance(obj_node, ModuleNode):
pre_expr = expr.inputs[0].expr
if isinstance(pre_expr, GetAttr):
(obj,) = expr.inputs[0].expr.interpret(module)
exprs.extend(_flatten_subgraph(expr.graph, obj, expr))
(obj,) = pre_expr.interpret(node2obj[pre_expr.inputs[0]])
expr_graph = (
obj.argdef_graph_map[expr.arg_def]
if hasattr(obj, "argdef_graph_map")
else None
)
exprs.extend(_flatten_subgraph(expr_graph, obj, expr))
else:
# module has been replaced.
assert isinstance(pre_expr, Constant)
exprs.append(expr)
else:
exprs.append(expr)
else:


+ 1
- 0
imperative/python/test/integration/test_converge.py View File

@@ -9,6 +9,7 @@
import itertools

import numpy as np
import pytest

import megengine as mge
import megengine.autodiff as ad


+ 1
- 0
imperative/python/test/integration/test_converge_with_gradient_clip.py View File

@@ -9,6 +9,7 @@
import itertools

import numpy as np
import pytest

import megengine as mge
import megengine.autodiff as ad


Loading…
Cancel
Save