From f88bd3ae33b21bcbd72aae7b5f15fd432f4c70ad Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 26 Jul 2021 14:07:06 +0800 Subject: [PATCH] refactor(traced_module): let TracedModule own argdef_graph_map GitOrigin-RevId: 80d685b9a395c7ee2d84742aad0b0f507efec7dd --- .../megengine/experimental/traced_module/expr.py | 17 +- .../experimental/traced_module/module_tracer.py | 3 + .../megengine/experimental/traced_module/node.py | 32 +++- .../experimental/traced_module/traced_module.py | 190 +++++++++++++++------ .../python/test/integration/test_converge.py | 1 + .../test_converge_with_gradient_clip.py | 1 + 6 files changed, 177 insertions(+), 67 deletions(-) diff --git a/imperative/python/megengine/experimental/traced_module/expr.py b/imperative/python/megengine/experimental/traced_module/expr.py index 1cb592d4..27cd2cf2 100644 --- a/imperative/python/megengine/experimental/traced_module/expr.py +++ b/imperative/python/megengine/experimental/traced_module/expr.py @@ -9,6 +9,7 @@ import builtins import collections +import copy import inspect from typing import Callable, List @@ -46,7 +47,7 @@ class Expr: idx = len(self.inputs) + len(self.const_val) self.const_val.append((idx, val)) - def add_outputs(self, outputs, check_inplace=True): + def add_outputs(self, outputs): self.outputs = [] if outputs is not None: if not isinstance(outputs, collections.Sequence): @@ -54,10 +55,7 @@ class Expr: for i in outputs: assert isinstance(i, RawTensor) - node = NodeMixin.get(i, None) if check_inplace else None - self.outputs.append( - node if node else NodeMixin.get_wrapped_type(i)(self) - ) + self.outputs.append(NodeMixin.get_wrapped_type(i)(self)) for i, node in zip(outputs, self.outputs,): NodeMixin.wrap_safe(i, node) @@ -165,9 +163,12 @@ class CallMethod(Expr): def graph(self): if isinstance(self.inputs[0], ModuleNode): m_node = self.inputs[0] - if m_node.argdef_graph_map: - assert self.arg_def in m_node.argdef_graph_map - return m_node.argdef_graph_map[self.arg_def] + if ( + hasattr(m_node.owner, "argdef_graph_map") + and m_node.owner.argdef_graph_map + ): + assert self.arg_def in m_node.owner.argdef_graph_map + return m_node.owner.argdef_graph_map[self.arg_def] return None def interpret(self, *inputs): diff --git a/imperative/python/megengine/experimental/traced_module/module_tracer.py b/imperative/python/megengine/experimental/traced_module/module_tracer.py index 6bdd65f9..221310ae 100644 --- a/imperative/python/megengine/experimental/traced_module/module_tracer.py +++ b/imperative/python/megengine/experimental/traced_module/module_tracer.py @@ -184,6 +184,9 @@ class Patcher: if id(i) not in self.visited_frames_ids: self.patch_function(i, j, self.wrap_fn) + for m in module_tracer._opaque_types: + self.auto_patch(getattr(getattr(m, "forward", m), "__globals__", {})) + def patch_function(self, frame_dict, fn, wrap_fn): patched_fn = PatchedFn(frame_dict, fn) self.patched_fn_ids.add(id(patched_fn.origin_fn)) diff --git a/imperative/python/megengine/experimental/traced_module/node.py b/imperative/python/megengine/experimental/traced_module/node.py index bd1fc4c9..c6e605df 100644 --- a/imperative/python/megengine/experimental/traced_module/node.py +++ b/imperative/python/megengine/experimental/traced_module/node.py @@ -6,6 +6,8 @@ # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +import abc +import weakref from typing import Any, Dict, List, Tuple, Type import numpy @@ -58,15 +60,10 @@ class ModuleNode(Node): """ module_type = Module # type: Type[Module] - attr_type_map = None # type: Dict[str, Type[Any]] - argdef_graph_map = None # type: Dict[Treedef, "InternalGraph"] - argdef_outdef_map = None # type: Dict[Treedef, Treedef] + _owner = None # type: weakref.ReferenceType def __init__(self, expr: "Expr", name: str = None): super().__init__(expr, name) - self.attr_type_map = {} - self.argdef_graph_map = {} - self.argdef_outdef_map = {} def __repr__(self): if self._name is None: @@ -74,6 +71,15 @@ class ModuleNode(Node): else: return "%{}({})".format(self._name, self.module_type.__name__) + def __getstate__(self): + d = self.__dict__ + d.pop("_owner", None) + return d + + @property + def owner(self): + return self._owner() + class TensorNode(Node): """ @@ -90,9 +96,14 @@ class TensorNode(Node): return "%{}(Tensor)".format(self._name) -class NodeMixin: +class NodeMixin(abc.ABC): __node = None + @abc.abstractmethod + def _record_wrapped_nodes(self, node): + # record the nodes which had been bound to this NodeMixin + pass + @classmethod def wrap(cls, value, node): if isinstance(value, (NodeMixin, RawTensor)): @@ -102,15 +113,20 @@ class NodeMixin: node.shape = ( value._tuple_shape if isinstance(value, Tensor) else value.shape ) + if isinstance(value, NodeMixin): + value._record_wrapped_nodes(node) setattr(value, "_NodeMixin__node", node) else: assert callable(node) n = node() + assert isinstance(n, Node) if isinstance(value, RawTensor): n.dtype = value.dtype n.shape = ( value._tuple_shape if isinstance(value, Tensor) else value.shape ) + if isinstance(value, NodeMixin): + value._record_wrapped_nodes(n) setattr(value, "_NodeMixin__node", n) @classmethod @@ -122,6 +138,8 @@ class NodeMixin: value._tuple_shape if isinstance(value, Tensor) else value.shape ) setattr(value, "_NodeMixin__node", node) + if isinstance(value, NodeMixin): + value._record_wrapped_nodes(node) @classmethod def get(cls, value, *default): diff --git a/imperative/python/megengine/experimental/traced_module/traced_module.py b/imperative/python/megengine/experimental/traced_module/traced_module.py index 38869012..be609395 100644 --- a/imperative/python/megengine/experimental/traced_module/traced_module.py +++ b/imperative/python/megengine/experimental/traced_module/traced_module.py @@ -9,6 +9,7 @@ import collections import copy import functools +import weakref from inspect import getmembers, isclass, ismethod from typing import Callable, Dict, Iterable, List, Sequence, Type @@ -51,7 +52,9 @@ def _leaf_type(node): def _is_leaf(node): - assert isinstance(node, RawTensor), type(node) + assert isinstance(node, RawTensor), "doesn't support {} in return values".format( + type(node) + ) return isinstance(node, RawTensor) @@ -107,6 +110,32 @@ class InternalGraph: def add_output(self, o): self._outputs.append(o) + def _replace_inputs_outputs(self, repl_dict): + + for node, repl_node in repl_dict.items(): + assert node in self._inputs or node in self._outputs + for i in node.users: + if i not in repl_node.users: + repl_node.users.append(i) + + for idx, i in enumerate(self._inputs): + if i in repl_dict: + self._inputs[idx] = repl_dict[i] + for idx, o in enumerate(self._outputs): + if o in repl_dict: + self._outputs[idx] = repl_dict[o] + self._outputs[idx].expr = node.expr + + for expr in self._exprs: + + for idx, i in enumerate(expr.inputs): + if i in repl_dict: + expr.inputs[idx] = repl_dict[i] + + for idx, o in enumerate(expr.outputs): + if o in repl_dict: + expr.outputs[idx] = repl_dict[o] + def get_dep_exprs(self, nodes: Sequence[Node]) -> List[Expr]: if not isinstance(nodes, Sequence): nodes = (nodes,) @@ -117,6 +146,7 @@ class InternalGraph: expr = node.expr if expr not in ret: ret.append(expr) + for i in expr.inputs: if i not in queue: queue.append(i) @@ -287,10 +317,7 @@ def _wrapped_function(orig_func): call_node.arg_def = tree_def outputs = orig_func(*args, **kwargs) - if meth_name == "__new__": - call_node.add_outputs(outputs, False) - else: - call_node.add_outputs(outputs) + call_node.add_outputs(outputs) set_module_tracing() return outputs return orig_func(*args, **kwargs) @@ -303,12 +330,19 @@ class TracedModuleBuilder(NodeMixin): _mod = None # type: Module _body = None # type: InternalGraph _is_builtin = None # type: bool + _argdef_graph_map = None # type: Dict[Treedef, "InternalGraph"] + _argdef_outdef_map = None # type: Dict[Treedef, Treedef] + nodes = None + __builder_attributes__ = [ "_mod", "_body", "_NodeMixin__node", "_is_builtin", "build", + "_argdef_graph_map", + "_argdef_outdef_map", + "nodes", ] def __init__(self, mod, is_top_module=False): @@ -316,23 +350,36 @@ class TracedModuleBuilder(NodeMixin): self._mod = mod self._body = None self._is_builtin = module_tracer.is_builtin(mod) + self._argdef_graph_map = {} + self._argdef_outdef_map = {} + self.nodes = set() def build(self): if self._is_builtin: - node = NodeMixin.get(self) - node.module_type = type(self._mod) + for node in self.nodes: + node.module_type = type(self._mod) + # node._owner = weakref.ref(self._mod) return self._mod else: - node = NodeMixin.get(self) - traced_module = TracedModule(node) + traced_module = TracedModule( + self._argdef_graph_map, self._argdef_outdef_map + ) + for _, g in self._argdef_graph_map.items(): + g.compile() + # for node in self.nodes: + # node._owner = weakref.ref(traced_module) + for k, v in self.__dict__.items(): if k not in TracedModuleBuilder.__builder_attributes__: if isinstance(v, TracedModuleBuilder): v = v.build() setattr(traced_module, k, v) - traced_module.m_node.attr_type_map[k] = type(v) + return traced_module + def _record_wrapped_nodes(self, node): + self.nodes.add(node) + def __call__(self, *args, **kwargs): assert isinstance(self._mod, Module) # prepare args and kwargs for inner graph @@ -360,19 +407,30 @@ class TracedModuleBuilder(NodeMixin): if self._is_builtin: self._body = None else: + self_node = None + if self._body: + self_node = self._body.inputs[0] self._body = InternalGraph() active_module_tracer().push_scope(self._body) # rebind self to new input node orig_self = NodeMixin.get(self) - NodeMixin.wrap_safe( - self, Input.make("self", NodeMixin.get_wrapped_type(self)) - ) + if self_node: + NodeMixin.wrap_safe(self, self_node) + active_module_tracer().current_scope().add_input(self_node) + else: + NodeMixin.wrap_safe( + self, + self_node + if self_node + else Input.make("self", NodeMixin.get_wrapped_type(self)), + ) origin_inp_node = [NodeMixin.get(i, None) for i in inputs[1:]] # prepare args and kwargs for inner graph def wrap(x): - NodeMixin.wrap( - x, lambda: Input.make(type=NodeMixin.get_wrapped_type(x)), - ) + if isinstance(x, (RawTensor, NodeMixin)): + NodeMixin.wrap( + x, lambda: Input.make(type=NodeMixin.get_wrapped_type(x)), + ) return x args = [self] @@ -397,9 +455,8 @@ class TracedModuleBuilder(NodeMixin): # rebind output to outer graph callnode.add_outputs(outputs) - self_node = NodeMixin.get(self) - self_node.argdef_graph_map[callnode.arg_def] = self._body - self_node.argdef_outdef_map[callnode.arg_def] = out_def + self._argdef_graph_map[callnode.arg_def] = self._body + self._argdef_outdef_map[callnode.arg_def] = out_def return rst def __getattr__(self, name): @@ -424,8 +481,8 @@ class TracedModuleBuilder(NodeMixin): else: wrapped = super().__getattribute__(name) if name in self._mod.__dict__: - if not NodeMixin.get(wrapped, None): - assert not self._is_builtin + assert not self._is_builtin + if isinstance(wrapped, (NodeMixin, RawTensor)): NodeMixin.wrap( wrapped, lambda: GetAttr.make( @@ -434,14 +491,15 @@ class TracedModuleBuilder(NodeMixin): type=NodeMixin.get_wrapped_type(wrapped), ), ) + """ else: node = NodeMixin.get(wrapped) - expr = GetAttr.make( - NodeMixin.get(self), - name, - type=NodeMixin.get_wrapped_type(wrapped), - ).expr - expr.outputs[0] = node + expr = node.expr + assert isinstance(expr, GetAttr) + if expr not in active_module_tracer().current_scope()._exprs: + active_module_tracer().current_scope().insert(expr) + """ + return wrapped @@ -514,33 +572,51 @@ class ExprFilterCallMethod(ExprFilter): class TracedModule(Module): """ - `TracedModule` is the Module created by tracing normal module. It owns a ModuleNode(m_node). `TracedModule` can not be called directly. It can be - interpreted by CallMethod Expr. + `TracedModule` is the Module created by tracing normal module. It owns an argdef to graph(InternalGraph) map. The forward method of `TracedModule` will get a graph from `argdef_graph_map` according to the argdef of input args/kwargs and interpret it. """ - m_node = None # type: ModuleNode + # m_node = None # type: ModuleNode + argdef_graph_map = None + argdef_outdef_map = None - def __init__(self, node): + def __init__(self, argdef_graph_map, argdef_outdef_map): super(TracedModule, self).__init__() - self.m_node = node + self.argdef_graph_map = argdef_graph_map + self.argdef_outdef_map = argdef_outdef_map def forward(self, *args, **kwargs): inputs, treedef = tree_flatten( ((self, *args), kwargs), _leaf_type, is_const_leaf=_is_const_leaf ) - assert treedef in self.m_node.argdef_graph_map + assert treedef in self.argdef_graph_map inputs = filter( lambda i: isinstance(i, (Module, TracedModuleBuilder, RawTensor)), inputs ) # allow TracedModuleBuilder for retrace. - outputs = self.m_node.argdef_graph_map[treedef].interpret(*inputs) - out_def = self.m_node.argdef_outdef_map[treedef] + outputs = self.argdef_graph_map[treedef].interpret(*inputs) + out_def = self.argdef_outdef_map[treedef] outputs = out_def.unflatten(outputs) return outputs @property def graph(self): - assert len(self.m_node.argdef_graph_map) == 1 - return list(self.m_node.argdef_graph_map.values())[0] + self._update_modulenode_ref() + assert len(self.argdef_graph_map) == 1 + return list(self.argdef_graph_map.values())[0] + + def _update_modulenode_ref(self): + for _, graph in self.argdef_graph_map.items(): + graph._inputs[0]._owner = weakref.ref(self) + node2obj = {} + node2obj[graph._inputs[0]] = self + for expr in graph._exprs: + if isinstance(expr, GetAttr) and isinstance( + expr.outputs[0], ModuleNode + ): + obj = getattr(node2obj[expr.inputs[0]], expr.name) + expr.outputs[0]._owner = weakref.ref(obj) + node2obj[expr.outputs[0]] = obj + if isinstance(obj, TracedModule): + obj._update_modulenode_ref() @property def exprs(self): @@ -561,39 +637,49 @@ class TracedModule(Module): const.outputs[0] = call.inputs[0] const.outputs[0].expr = const return [const, call] + if call is not None: + graph = copy.deepcopy(graph) exprs = [] + node2obj = {} + node2obj[graph._inputs[0]] = module + if call: + node2obj[call.inputs[0]] = module for expr in graph._exprs: - # replace inputs for submodule's expr - for idx, inp in enumerate(expr.inputs): - if call and inp in graph._inputs: - inp_idx = graph._inputs.index(inp) - expr.inputs[idx] = call.inputs[inp_idx] - call.inputs[inp_idx].users.append(expr) - # replace outputs for submodule's expr - for idx, outp in enumerate(expr.outputs): - if call and outp in graph._outputs: - oup_idx = graph._outputs.index(outp) - expr.outputs[idx] = call.outputs[oup_idx] - call.outputs[oup_idx].expr = expr + # replace inputs for submodule's exprx + if call: + repl_dict = dict( + zip(graph._inputs + graph._outputs, call.inputs + call.outputs) + ) + graph._replace_inputs_outputs(repl_dict) if isinstance(expr, GetAttr): # replace GetAttr with Constant if isinstance(expr.outputs[0], TensorNode): - const = Constant(getattr(module, expr.name)) + const = Constant(getattr(node2obj[expr.inputs[0]], expr.name)) const.outputs = expr.outputs const.outputs[0].expr = const exprs.append(const) + elif isinstance(expr.outputs[0], ModuleNode): + node2obj[expr.outputs[0]] = getattr( + node2obj[expr.inputs[0]], expr.name + ) elif isinstance(expr, CallMethod): obj_node = expr.inputs[0] if isinstance(obj_node, ModuleNode): pre_expr = expr.inputs[0].expr if isinstance(pre_expr, GetAttr): - (obj,) = expr.inputs[0].expr.interpret(module) - exprs.extend(_flatten_subgraph(expr.graph, obj, expr)) + (obj,) = pre_expr.interpret(node2obj[pre_expr.inputs[0]]) + expr_graph = ( + obj.argdef_graph_map[expr.arg_def] + if hasattr(obj, "argdef_graph_map") + else None + ) + exprs.extend(_flatten_subgraph(expr_graph, obj, expr)) else: # module has been replaced. assert isinstance(pre_expr, Constant) + exprs.append(expr) else: exprs.append(expr) else: diff --git a/imperative/python/test/integration/test_converge.py b/imperative/python/test/integration/test_converge.py index ab32aaa9..fb080b41 100644 --- a/imperative/python/test/integration/test_converge.py +++ b/imperative/python/test/integration/test_converge.py @@ -9,6 +9,7 @@ import itertools import numpy as np +import pytest import megengine as mge import megengine.autodiff as ad diff --git a/imperative/python/test/integration/test_converge_with_gradient_clip.py b/imperative/python/test/integration/test_converge_with_gradient_clip.py index 6ec9fafe..9a0b8393 100644 --- a/imperative/python/test/integration/test_converge_with_gradient_clip.py +++ b/imperative/python/test/integration/test_converge_with_gradient_clip.py @@ -9,6 +9,7 @@ import itertools import numpy as np +import pytest import megengine as mge import megengine.autodiff as ad