Browse Source

refactor(traced_module): let TracedModule own argdef_graph_map

GitOrigin-RevId: 80d685b9a3
release-1.6
Megvii Engine Team 3 years ago
parent
commit
f88bd3ae33
6 changed files with 177 additions and 67 deletions
  1. +9
    -8
      imperative/python/megengine/experimental/traced_module/expr.py
  2. +3
    -0
      imperative/python/megengine/experimental/traced_module/module_tracer.py
  3. +25
    -7
      imperative/python/megengine/experimental/traced_module/node.py
  4. +138
    -52
      imperative/python/megengine/experimental/traced_module/traced_module.py
  5. +1
    -0
      imperative/python/test/integration/test_converge.py
  6. +1
    -0
      imperative/python/test/integration/test_converge_with_gradient_clip.py

+ 9
- 8
imperative/python/megengine/experimental/traced_module/expr.py View File

@@ -9,6 +9,7 @@


import builtins import builtins
import collections import collections
import copy
import inspect import inspect
from typing import Callable, List from typing import Callable, List


@@ -46,7 +47,7 @@ class Expr:
idx = len(self.inputs) + len(self.const_val) idx = len(self.inputs) + len(self.const_val)
self.const_val.append((idx, val)) self.const_val.append((idx, val))


def add_outputs(self, outputs, check_inplace=True):
def add_outputs(self, outputs):
self.outputs = [] self.outputs = []
if outputs is not None: if outputs is not None:
if not isinstance(outputs, collections.Sequence): if not isinstance(outputs, collections.Sequence):
@@ -54,10 +55,7 @@ class Expr:


for i in outputs: for i in outputs:
assert isinstance(i, RawTensor) assert isinstance(i, RawTensor)
node = NodeMixin.get(i, None) if check_inplace else None
self.outputs.append(
node if node else NodeMixin.get_wrapped_type(i)(self)
)
self.outputs.append(NodeMixin.get_wrapped_type(i)(self))


for i, node in zip(outputs, self.outputs,): for i, node in zip(outputs, self.outputs,):
NodeMixin.wrap_safe(i, node) NodeMixin.wrap_safe(i, node)
@@ -165,9 +163,12 @@ class CallMethod(Expr):
def graph(self): def graph(self):
if isinstance(self.inputs[0], ModuleNode): if isinstance(self.inputs[0], ModuleNode):
m_node = self.inputs[0] m_node = self.inputs[0]
if m_node.argdef_graph_map:
assert self.arg_def in m_node.argdef_graph_map
return m_node.argdef_graph_map[self.arg_def]
if (
hasattr(m_node.owner, "argdef_graph_map")
and m_node.owner.argdef_graph_map
):
assert self.arg_def in m_node.owner.argdef_graph_map
return m_node.owner.argdef_graph_map[self.arg_def]
return None return None


def interpret(self, *inputs): def interpret(self, *inputs):


+ 3
- 0
imperative/python/megengine/experimental/traced_module/module_tracer.py View File

@@ -184,6 +184,9 @@ class Patcher:
if id(i) not in self.visited_frames_ids: if id(i) not in self.visited_frames_ids:
self.patch_function(i, j, self.wrap_fn) self.patch_function(i, j, self.wrap_fn)


for m in module_tracer._opaque_types:
self.auto_patch(getattr(getattr(m, "forward", m), "__globals__", {}))

def patch_function(self, frame_dict, fn, wrap_fn): def patch_function(self, frame_dict, fn, wrap_fn):
patched_fn = PatchedFn(frame_dict, fn) patched_fn = PatchedFn(frame_dict, fn)
self.patched_fn_ids.add(id(patched_fn.origin_fn)) self.patched_fn_ids.add(id(patched_fn.origin_fn))


+ 25
- 7
imperative/python/megengine/experimental/traced_module/node.py View File

@@ -6,6 +6,8 @@
# Unless required by applicable law or agreed to in writing, # Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import abc
import weakref
from typing import Any, Dict, List, Tuple, Type from typing import Any, Dict, List, Tuple, Type


import numpy import numpy
@@ -58,15 +60,10 @@ class ModuleNode(Node):
""" """


module_type = Module # type: Type[Module] module_type = Module # type: Type[Module]
attr_type_map = None # type: Dict[str, Type[Any]]
argdef_graph_map = None # type: Dict[Treedef, "InternalGraph"]
argdef_outdef_map = None # type: Dict[Treedef, Treedef]
_owner = None # type: weakref.ReferenceType


def __init__(self, expr: "Expr", name: str = None): def __init__(self, expr: "Expr", name: str = None):
super().__init__(expr, name) super().__init__(expr, name)
self.attr_type_map = {}
self.argdef_graph_map = {}
self.argdef_outdef_map = {}


def __repr__(self): def __repr__(self):
if self._name is None: if self._name is None:
@@ -74,6 +71,15 @@ class ModuleNode(Node):
else: else:
return "%{}({})".format(self._name, self.module_type.__name__) return "%{}({})".format(self._name, self.module_type.__name__)


def __getstate__(self):
d = self.__dict__
d.pop("_owner", None)
return d

@property
def owner(self):
return self._owner()



class TensorNode(Node): class TensorNode(Node):
""" """
@@ -90,9 +96,14 @@ class TensorNode(Node):
return "%{}(Tensor)".format(self._name) return "%{}(Tensor)".format(self._name)




class NodeMixin:
class NodeMixin(abc.ABC):
__node = None __node = None


@abc.abstractmethod
def _record_wrapped_nodes(self, node):
# record the nodes which had been bound to this NodeMixin
pass

@classmethod @classmethod
def wrap(cls, value, node): def wrap(cls, value, node):
if isinstance(value, (NodeMixin, RawTensor)): if isinstance(value, (NodeMixin, RawTensor)):
@@ -102,15 +113,20 @@ class NodeMixin:
node.shape = ( node.shape = (
value._tuple_shape if isinstance(value, Tensor) else value.shape value._tuple_shape if isinstance(value, Tensor) else value.shape
) )
if isinstance(value, NodeMixin):
value._record_wrapped_nodes(node)
setattr(value, "_NodeMixin__node", node) setattr(value, "_NodeMixin__node", node)
else: else:
assert callable(node) assert callable(node)
n = node() n = node()
assert isinstance(n, Node)
if isinstance(value, RawTensor): if isinstance(value, RawTensor):
n.dtype = value.dtype n.dtype = value.dtype
n.shape = ( n.shape = (
value._tuple_shape if isinstance(value, Tensor) else value.shape value._tuple_shape if isinstance(value, Tensor) else value.shape
) )
if isinstance(value, NodeMixin):
value._record_wrapped_nodes(n)
setattr(value, "_NodeMixin__node", n) setattr(value, "_NodeMixin__node", n)


@classmethod @classmethod
@@ -122,6 +138,8 @@ class NodeMixin:
value._tuple_shape if isinstance(value, Tensor) else value.shape value._tuple_shape if isinstance(value, Tensor) else value.shape
) )
setattr(value, "_NodeMixin__node", node) setattr(value, "_NodeMixin__node", node)
if isinstance(value, NodeMixin):
value._record_wrapped_nodes(node)


@classmethod @classmethod
def get(cls, value, *default): def get(cls, value, *default):


+ 138
- 52
imperative/python/megengine/experimental/traced_module/traced_module.py View File

@@ -9,6 +9,7 @@
import collections import collections
import copy import copy
import functools import functools
import weakref
from inspect import getmembers, isclass, ismethod from inspect import getmembers, isclass, ismethod
from typing import Callable, Dict, Iterable, List, Sequence, Type from typing import Callable, Dict, Iterable, List, Sequence, Type


@@ -51,7 +52,9 @@ def _leaf_type(node):




def _is_leaf(node): def _is_leaf(node):
assert isinstance(node, RawTensor), type(node)
assert isinstance(node, RawTensor), "doesn't support {} in return values".format(
type(node)
)
return isinstance(node, RawTensor) return isinstance(node, RawTensor)




@@ -107,6 +110,32 @@ class InternalGraph:
def add_output(self, o): def add_output(self, o):
self._outputs.append(o) self._outputs.append(o)


def _replace_inputs_outputs(self, repl_dict):

for node, repl_node in repl_dict.items():
assert node in self._inputs or node in self._outputs
for i in node.users:
if i not in repl_node.users:
repl_node.users.append(i)

for idx, i in enumerate(self._inputs):
if i in repl_dict:
self._inputs[idx] = repl_dict[i]
for idx, o in enumerate(self._outputs):
if o in repl_dict:
self._outputs[idx] = repl_dict[o]
self._outputs[idx].expr = node.expr

for expr in self._exprs:

for idx, i in enumerate(expr.inputs):
if i in repl_dict:
expr.inputs[idx] = repl_dict[i]

for idx, o in enumerate(expr.outputs):
if o in repl_dict:
expr.outputs[idx] = repl_dict[o]

def get_dep_exprs(self, nodes: Sequence[Node]) -> List[Expr]: def get_dep_exprs(self, nodes: Sequence[Node]) -> List[Expr]:
if not isinstance(nodes, Sequence): if not isinstance(nodes, Sequence):
nodes = (nodes,) nodes = (nodes,)
@@ -117,6 +146,7 @@ class InternalGraph:
expr = node.expr expr = node.expr
if expr not in ret: if expr not in ret:
ret.append(expr) ret.append(expr)

for i in expr.inputs: for i in expr.inputs:
if i not in queue: if i not in queue:
queue.append(i) queue.append(i)
@@ -287,10 +317,7 @@ def _wrapped_function(orig_func):


call_node.arg_def = tree_def call_node.arg_def = tree_def
outputs = orig_func(*args, **kwargs) outputs = orig_func(*args, **kwargs)
if meth_name == "__new__":
call_node.add_outputs(outputs, False)
else:
call_node.add_outputs(outputs)
call_node.add_outputs(outputs)
set_module_tracing() set_module_tracing()
return outputs return outputs
return orig_func(*args, **kwargs) return orig_func(*args, **kwargs)
@@ -303,12 +330,19 @@ class TracedModuleBuilder(NodeMixin):
_mod = None # type: Module _mod = None # type: Module
_body = None # type: InternalGraph _body = None # type: InternalGraph
_is_builtin = None # type: bool _is_builtin = None # type: bool
_argdef_graph_map = None # type: Dict[Treedef, "InternalGraph"]
_argdef_outdef_map = None # type: Dict[Treedef, Treedef]
nodes = None

__builder_attributes__ = [ __builder_attributes__ = [
"_mod", "_mod",
"_body", "_body",
"_NodeMixin__node", "_NodeMixin__node",
"_is_builtin", "_is_builtin",
"build", "build",
"_argdef_graph_map",
"_argdef_outdef_map",
"nodes",
] ]


def __init__(self, mod, is_top_module=False): def __init__(self, mod, is_top_module=False):
@@ -316,23 +350,36 @@ class TracedModuleBuilder(NodeMixin):
self._mod = mod self._mod = mod
self._body = None self._body = None
self._is_builtin = module_tracer.is_builtin(mod) self._is_builtin = module_tracer.is_builtin(mod)
self._argdef_graph_map = {}
self._argdef_outdef_map = {}
self.nodes = set()


def build(self): def build(self):
if self._is_builtin: if self._is_builtin:
node = NodeMixin.get(self)
node.module_type = type(self._mod)
for node in self.nodes:
node.module_type = type(self._mod)
# node._owner = weakref.ref(self._mod)
return self._mod return self._mod
else: else:
node = NodeMixin.get(self)
traced_module = TracedModule(node)
traced_module = TracedModule(
self._argdef_graph_map, self._argdef_outdef_map
)
for _, g in self._argdef_graph_map.items():
g.compile()
# for node in self.nodes:
# node._owner = weakref.ref(traced_module)

for k, v in self.__dict__.items(): for k, v in self.__dict__.items():
if k not in TracedModuleBuilder.__builder_attributes__: if k not in TracedModuleBuilder.__builder_attributes__:
if isinstance(v, TracedModuleBuilder): if isinstance(v, TracedModuleBuilder):
v = v.build() v = v.build()
setattr(traced_module, k, v) setattr(traced_module, k, v)
traced_module.m_node.attr_type_map[k] = type(v)
return traced_module return traced_module


def _record_wrapped_nodes(self, node):
self.nodes.add(node)

def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
assert isinstance(self._mod, Module) assert isinstance(self._mod, Module)
# prepare args and kwargs for inner graph # prepare args and kwargs for inner graph
@@ -360,19 +407,30 @@ class TracedModuleBuilder(NodeMixin):
if self._is_builtin: if self._is_builtin:
self._body = None self._body = None
else: else:
self_node = None
if self._body:
self_node = self._body.inputs[0]
self._body = InternalGraph() self._body = InternalGraph()
active_module_tracer().push_scope(self._body) active_module_tracer().push_scope(self._body)
# rebind self to new input node # rebind self to new input node
orig_self = NodeMixin.get(self) orig_self = NodeMixin.get(self)
NodeMixin.wrap_safe(
self, Input.make("self", NodeMixin.get_wrapped_type(self))
)
if self_node:
NodeMixin.wrap_safe(self, self_node)
active_module_tracer().current_scope().add_input(self_node)
else:
NodeMixin.wrap_safe(
self,
self_node
if self_node
else Input.make("self", NodeMixin.get_wrapped_type(self)),
)
origin_inp_node = [NodeMixin.get(i, None) for i in inputs[1:]] origin_inp_node = [NodeMixin.get(i, None) for i in inputs[1:]]
# prepare args and kwargs for inner graph # prepare args and kwargs for inner graph
def wrap(x): def wrap(x):
NodeMixin.wrap(
x, lambda: Input.make(type=NodeMixin.get_wrapped_type(x)),
)
if isinstance(x, (RawTensor, NodeMixin)):
NodeMixin.wrap(
x, lambda: Input.make(type=NodeMixin.get_wrapped_type(x)),
)
return x return x


args = [self] args = [self]
@@ -397,9 +455,8 @@ class TracedModuleBuilder(NodeMixin):


# rebind output to outer graph # rebind output to outer graph
callnode.add_outputs(outputs) callnode.add_outputs(outputs)
self_node = NodeMixin.get(self)
self_node.argdef_graph_map[callnode.arg_def] = self._body
self_node.argdef_outdef_map[callnode.arg_def] = out_def
self._argdef_graph_map[callnode.arg_def] = self._body
self._argdef_outdef_map[callnode.arg_def] = out_def
return rst return rst


def __getattr__(self, name): def __getattr__(self, name):
@@ -424,8 +481,8 @@ class TracedModuleBuilder(NodeMixin):
else: else:
wrapped = super().__getattribute__(name) wrapped = super().__getattribute__(name)
if name in self._mod.__dict__: if name in self._mod.__dict__:
if not NodeMixin.get(wrapped, None):
assert not self._is_builtin
assert not self._is_builtin
if isinstance(wrapped, (NodeMixin, RawTensor)):
NodeMixin.wrap( NodeMixin.wrap(
wrapped, wrapped,
lambda: GetAttr.make( lambda: GetAttr.make(
@@ -434,14 +491,15 @@ class TracedModuleBuilder(NodeMixin):
type=NodeMixin.get_wrapped_type(wrapped), type=NodeMixin.get_wrapped_type(wrapped),
), ),
) )
"""
else: else:
node = NodeMixin.get(wrapped) node = NodeMixin.get(wrapped)
expr = GetAttr.make(
NodeMixin.get(self),
name,
type=NodeMixin.get_wrapped_type(wrapped),
).expr
expr.outputs[0] = node
expr = node.expr
assert isinstance(expr, GetAttr)
if expr not in active_module_tracer().current_scope()._exprs:
active_module_tracer().current_scope().insert(expr)
"""
return wrapped return wrapped




@@ -514,33 +572,51 @@ class ExprFilterCallMethod(ExprFilter):


class TracedModule(Module): class TracedModule(Module):
""" """
`TracedModule` is the Module created by tracing normal module. It owns a ModuleNode(m_node). `TracedModule` can not be called directly. It can be
interpreted by CallMethod Expr.
`TracedModule` is the Module created by tracing normal module. It owns an argdef to graph(InternalGraph) map. The forward method of `TracedModule` will get a graph from `argdef_graph_map` according to the argdef of input args/kwargs and interpret it.
""" """


m_node = None # type: ModuleNode
# m_node = None # type: ModuleNode
argdef_graph_map = None
argdef_outdef_map = None


def __init__(self, node):
def __init__(self, argdef_graph_map, argdef_outdef_map):
super(TracedModule, self).__init__() super(TracedModule, self).__init__()
self.m_node = node
self.argdef_graph_map = argdef_graph_map
self.argdef_outdef_map = argdef_outdef_map


def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
inputs, treedef = tree_flatten( inputs, treedef = tree_flatten(
((self, *args), kwargs), _leaf_type, is_const_leaf=_is_const_leaf ((self, *args), kwargs), _leaf_type, is_const_leaf=_is_const_leaf
) )
assert treedef in self.m_node.argdef_graph_map
assert treedef in self.argdef_graph_map
inputs = filter( inputs = filter(
lambda i: isinstance(i, (Module, TracedModuleBuilder, RawTensor)), inputs lambda i: isinstance(i, (Module, TracedModuleBuilder, RawTensor)), inputs
) # allow TracedModuleBuilder for retrace. ) # allow TracedModuleBuilder for retrace.
outputs = self.m_node.argdef_graph_map[treedef].interpret(*inputs)
out_def = self.m_node.argdef_outdef_map[treedef]
outputs = self.argdef_graph_map[treedef].interpret(*inputs)
out_def = self.argdef_outdef_map[treedef]
outputs = out_def.unflatten(outputs) outputs = out_def.unflatten(outputs)
return outputs return outputs


@property @property
def graph(self): def graph(self):
assert len(self.m_node.argdef_graph_map) == 1
return list(self.m_node.argdef_graph_map.values())[0]
self._update_modulenode_ref()
assert len(self.argdef_graph_map) == 1
return list(self.argdef_graph_map.values())[0]

def _update_modulenode_ref(self):
for _, graph in self.argdef_graph_map.items():
graph._inputs[0]._owner = weakref.ref(self)
node2obj = {}
node2obj[graph._inputs[0]] = self
for expr in graph._exprs:
if isinstance(expr, GetAttr) and isinstance(
expr.outputs[0], ModuleNode
):
obj = getattr(node2obj[expr.inputs[0]], expr.name)
expr.outputs[0]._owner = weakref.ref(obj)
node2obj[expr.outputs[0]] = obj
if isinstance(obj, TracedModule):
obj._update_modulenode_ref()


@property @property
def exprs(self): def exprs(self):
@@ -561,39 +637,49 @@ class TracedModule(Module):
const.outputs[0] = call.inputs[0] const.outputs[0] = call.inputs[0]
const.outputs[0].expr = const const.outputs[0].expr = const
return [const, call] return [const, call]
if call is not None:
graph = copy.deepcopy(graph)
exprs = [] exprs = []
node2obj = {}
node2obj[graph._inputs[0]] = module
if call:
node2obj[call.inputs[0]] = module
for expr in graph._exprs: for expr in graph._exprs:
# replace inputs for submodule's expr
for idx, inp in enumerate(expr.inputs):
if call and inp in graph._inputs:
inp_idx = graph._inputs.index(inp)
expr.inputs[idx] = call.inputs[inp_idx]
call.inputs[inp_idx].users.append(expr)
# replace outputs for submodule's expr
for idx, outp in enumerate(expr.outputs):
if call and outp in graph._outputs:
oup_idx = graph._outputs.index(outp)
expr.outputs[idx] = call.outputs[oup_idx]
call.outputs[oup_idx].expr = expr
# replace inputs for submodule's exprx
if call:
repl_dict = dict(
zip(graph._inputs + graph._outputs, call.inputs + call.outputs)
)
graph._replace_inputs_outputs(repl_dict)


if isinstance(expr, GetAttr): if isinstance(expr, GetAttr):
# replace GetAttr with Constant # replace GetAttr with Constant
if isinstance(expr.outputs[0], TensorNode): if isinstance(expr.outputs[0], TensorNode):
const = Constant(getattr(module, expr.name))
const = Constant(getattr(node2obj[expr.inputs[0]], expr.name))
const.outputs = expr.outputs const.outputs = expr.outputs
const.outputs[0].expr = const const.outputs[0].expr = const
exprs.append(const) exprs.append(const)
elif isinstance(expr.outputs[0], ModuleNode):
node2obj[expr.outputs[0]] = getattr(
node2obj[expr.inputs[0]], expr.name
)


elif isinstance(expr, CallMethod): elif isinstance(expr, CallMethod):
obj_node = expr.inputs[0] obj_node = expr.inputs[0]
if isinstance(obj_node, ModuleNode): if isinstance(obj_node, ModuleNode):
pre_expr = expr.inputs[0].expr pre_expr = expr.inputs[0].expr
if isinstance(pre_expr, GetAttr): if isinstance(pre_expr, GetAttr):
(obj,) = expr.inputs[0].expr.interpret(module)
exprs.extend(_flatten_subgraph(expr.graph, obj, expr))
(obj,) = pre_expr.interpret(node2obj[pre_expr.inputs[0]])
expr_graph = (
obj.argdef_graph_map[expr.arg_def]
if hasattr(obj, "argdef_graph_map")
else None
)
exprs.extend(_flatten_subgraph(expr_graph, obj, expr))
else: else:
# module has been replaced. # module has been replaced.
assert isinstance(pre_expr, Constant) assert isinstance(pre_expr, Constant)
exprs.append(expr)
else: else:
exprs.append(expr) exprs.append(expr)
else: else:


+ 1
- 0
imperative/python/test/integration/test_converge.py View File

@@ -9,6 +9,7 @@
import itertools import itertools


import numpy as np import numpy as np
import pytest


import megengine as mge import megengine as mge
import megengine.autodiff as ad import megengine.autodiff as ad


+ 1
- 0
imperative/python/test/integration/test_converge_with_gradient_clip.py View File

@@ -9,6 +9,7 @@
import itertools import itertools


import numpy as np import numpy as np
import pytest


import megengine as mge import megengine as mge
import megengine.autodiff as ad import megengine.autodiff as ad


Loading…
Cancel
Save