From 4bb253695b34cb7b8895ac95b67029603c43dfca Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 7 Jul 2021 14:37:33 +0800 Subject: [PATCH] feat(traced_module): let CallFunction own graph GitOrigin-RevId: 66cdbca7e54df07576a984c3fd48d3bcafb678f1 --- .../megengine/experimental/traced_module/expr.py | 15 ++- .../megengine/experimental/traced_module/node.py | 8 +- .../megengine/experimental/traced_module/pytree.py | 38 ++++++- .../experimental/traced_module/traced_module.py | 119 +++++++++++---------- 4 files changed, 115 insertions(+), 65 deletions(-) diff --git a/imperative/python/megengine/experimental/traced_module/expr.py b/imperative/python/megengine/experimental/traced_module/expr.py index c4dce926..361e19ec 100644 --- a/imperative/python/megengine/experimental/traced_module/expr.py +++ b/imperative/python/megengine/experimental/traced_module/expr.py @@ -17,7 +17,7 @@ from ...core._imperative_rt.core2 import apply, set_module_tracing, unset_module from ...core.ops.special import Const from ...module import Module from ...tensor import Tensor -from .module_tracer import active_module_tracer +from .module_tracer import active_module_tracer, module_tracer from .node import ModuleNode, Node, NodeMixin, TensorNode from .pytree import TreeDef @@ -148,6 +148,15 @@ class CallMethod(Expr): active_module_tracer().current_scope().insert(expr) return expr + @property + def graph(self): + if isinstance(self.inputs[0], ModuleNode): + m_node = self.inputs[0] + if m_node.argdef_graph_map: + assert self.arg_def in m_node.argdef_graph_map + return m_node.argdef_graph_map[self.arg_def] + return None + def interpret(self, *inputs): args, kwargs = self.unflatten_args(inputs) obj = args[0] @@ -252,7 +261,9 @@ class Constant(Expr): _constant_cache = {} def __init__(self, c): - # TODO: type check, since not all types should be captured as constant + assert isinstance(c, (RawTensor, Module)) + if isinstance(c, Module): + assert module_tracer.is_builtin(c) self.value = c self.inputs = [] node_cls = NodeMixin.get_wrapped_type(c) diff --git a/imperative/python/megengine/experimental/traced_module/node.py b/imperative/python/megengine/experimental/traced_module/node.py index 066fefe6..9a7436e9 100644 --- a/imperative/python/megengine/experimental/traced_module/node.py +++ b/imperative/python/megengine/experimental/traced_module/node.py @@ -57,9 +57,13 @@ class ModuleNode(Node): """ module_type = Module # type: Type[Module] - graph = None attr_type_map = None # type: Dict[str, Type[Any]] - arg_def = None # type: TreeDef + argdef_graph_map = None # type: Dict[Treedef, "InternalGraph"] + + def __init__(self, expr: "Expr", name: str = None): + super().__init__(expr, name) + self.attr_type_map = {} + self.argdef_graph_map = {} def __repr__(self): if self._name is None: diff --git a/imperative/python/megengine/experimental/traced_module/pytree.py b/imperative/python/megengine/experimental/traced_module/pytree.py index d3cb9fed..f6c5d7ea 100644 --- a/imperative/python/megengine/experimental/traced_module/pytree.py +++ b/imperative/python/megengine/experimental/traced_module/pytree.py @@ -25,7 +25,7 @@ def _dict_flatten(inp): for key, value in sorted(inp.items()): results.append(value) aux_data.append(key) - return results, aux_data + return results, tuple(aux_data) def _dict_unflatten(inps, aux_data): @@ -43,16 +43,23 @@ register_supported_type( def tree_flatten( - values, leaf_type: Callable = lambda x: type(x), is_leaf: Callable = lambda x: True + values, + leaf_type: Callable = lambda x: type(x), + is_leaf: Callable = lambda _: True, + is_const_leaf: Callable = lambda _: False, ): if type(values) not in SUPPORTED_TYPE: assert is_leaf(values) - return [values,], LeafDef(leaf_type(values)) + node = LeafDef(leaf_type(values)) + if is_const_leaf(values): + node.const_val = values + return [values,], node + rst = [] children_defs = [] children_values, aux_data = SUPPORTED_TYPE[type(values)].flatten(values) for v in children_values: - v_list, treedef = tree_flatten(v, leaf_type) + v_list, treedef = tree_flatten(v, leaf_type, is_leaf, is_const_leaf) rst.extend(v_list) children_defs.append(treedef) @@ -75,6 +82,18 @@ class TreeDef: start += ch.num_leaves return SUPPORTED_TYPE[self.type].unflatten(children, self.aux_data) + def __hash__(self): + return hash( + tuple( + [ + self.type, + self.aux_data, + self.num_leaves, + tuple([hash(x) for x in self.children_defs]), + ] + ) + ) + def __eq__(self, other): return ( self.type == other.type @@ -93,11 +112,20 @@ class LeafDef(TreeDef): type = (type,) super().__init__(type, None, []) self.num_leaves = 1 + self.const_val = None def unflatten(self, leaves): assert len(leaves) == 1 assert isinstance(leaves[0], self.type), self.type return leaves[0] + def __eq__(self, other): + return self.type == other.type and self.const_val == other.const_val + + def __hash__(self): + return hash(tuple([self.type, self.const_val])) + def __repr__(self): - return "Leaf({})".format(", ".join(t.__name__ for t in self.type)) + return "Leaf({}[{}])".format( + ", ".join(t.__name__ for t in self.type), self.const_val + ) diff --git a/imperative/python/megengine/experimental/traced_module/traced_module.py b/imperative/python/megengine/experimental/traced_module/traced_module.py index 9b8b8e34..9c80eb85 100644 --- a/imperative/python/megengine/experimental/traced_module/traced_module.py +++ b/imperative/python/megengine/experimental/traced_module/traced_module.py @@ -42,6 +42,12 @@ def _leaf_type(node): return type(node) +def _is_const_leaf(node): + if isinstance(node, (RawTensor, NodeMixin, Module)): + return False + return True + + class InternalGraph: """ ``InternalGraph`` is a graph consist of ``Node`` and ``Expr``, it is used to represent the execution procedure of Module's forward method. @@ -72,6 +78,10 @@ class InternalGraph: def outputs(self): return self._outputs + @property + def exprs(self): + return _expr_list(self) + def add_input(self, i): self._inputs.append(i) @@ -111,7 +121,9 @@ def _wrapped_function(orig_func): def wrapped_fn(*args, **kwargs): if is_tracing_module(): unset_module_tracing() - inputs, tree_def = tree_flatten((args, kwargs), leaf_type=_leaf_type) + inputs, tree_def = tree_flatten( + (args, kwargs), leaf_type=_leaf_type, is_const_leaf=_is_const_leaf + ) for i in inputs: if not NodeMixin.get(i, None): if isinstance(i, (RawTensor, NodeMixin)): @@ -140,21 +152,18 @@ class TracedModuleBuilder(NodeMixin): _mod = None # type: Module _body = None # type: InternalGraph _is_builtin = None # type: bool - _arg_def = None # type: TreeDef __builder_attributes__ = [ "_mod", "_body", "_NodeMixin__node", "_is_builtin", - "_is_traced", - "_arg_def" "build", + "build", ] - def __init__(self, mod): + def __init__(self, mod, is_top_module=False): super(TracedModuleBuilder, self).__init__() self._mod = mod - self._body = InternalGraph() - self._is_traced = False + self._body = None self._is_builtin = module_tracer.is_builtin(mod) def build(self): @@ -164,9 +173,6 @@ class TracedModuleBuilder(NodeMixin): return self._mod else: node = NodeMixin.get(self) - node.graph = self._body - node.attr_type_map = {} - node.arg_def = self._arg_def traced_module = TracedModule(node) for k, v in self.__dict__.items(): if k not in TracedModuleBuilder.__builder_attributes__: @@ -178,21 +184,15 @@ class TracedModuleBuilder(NodeMixin): def __call__(self, *args, **kwargs): assert isinstance(self._mod, Module) - for arg in args: - assert isinstance(arg, RawTensor) - - for k, v in kwargs.items(): - assert isinstance(v, RawTensor) # prepare args and kwargs for inner graph def mark_constant(x): node = NodeMixin.get(x, None) if node is None: # capture as constant NodeMixin.wrap(x, lambda: Constant.make(x)) - inputs, tree_def = tree_flatten(((self, *args), kwargs), leaf_type=_leaf_type) - if self._arg_def is None: - self._arg_def = tree_def - assert self._arg_def == tree_def + inputs, tree_def = tree_flatten( + ((self, *args), kwargs), leaf_type=_leaf_type, is_const_leaf=_is_const_leaf + ) for i in inputs: mark_constant(i) callnode = CallMethod.make(NodeMixin.get(self)) @@ -201,13 +201,14 @@ class TracedModuleBuilder(NodeMixin): callnode.arg_def = tree_def - if self._is_builtin or self._is_traced: + if self._is_builtin: unset_module_tracing() outputs = self._mod(*args, **kwargs) set_module_tracing() if self._is_builtin: self._body = None else: + self._body = InternalGraph() active_module_tracer().push_scope(self._body) # rebind self to new input node orig_self = NodeMixin.get(self) @@ -238,11 +239,12 @@ class TracedModuleBuilder(NodeMixin): active_module_tracer().current_scope().add_output(NodeMixin.get(i)) NodeMixin.wrap_safe(self, orig_self) - self._is_traced = True active_module_tracer().pop_scope() # rebind output to outer graph callnode.add_outputs(outputs) + self_node = NodeMixin.get(self) + self_node.argdef_graph_map[callnode.arg_def] = self._body return outputs def __getattr__(self, name): @@ -280,24 +282,23 @@ class TracedModuleBuilder(NodeMixin): class _expr_list: - def __init__(self, module: "TracedModule"): - self.module = module + def __init__(self, graph: InternalGraph): + self.graph = graph def __iter__(self): - graph = self.module.m_node.graph - for expr in graph._exprs: + for expr in self.graph._exprs: if isinstance(expr, CallMethod) and isinstance(expr.inputs[0], ModuleNode): yield expr - assert isinstance(expr.inputs[0].expr, GetAttr) - (obj,) = expr.inputs[0].expr.interpret(self.module) - if isinstance(obj, TracedModule): - yield from obj.exprs - yield expr + if expr.graph is not None: + yield from expr.graph.exprs + else: + yield expr class TracedModule(Module): """ - `TracedModule` is the Module created by tracing normal module. It owns a ModuleNode(m_node), and will interpret the m_node.graph when it is called. + `TracedModule` is the Module created by tracing normal module. It owns a ModuleNode(m_node). `TracedModule` can not be called directly. It can be + interpreted by CallMethod Expr. """ m_node = None # type: ModuleNode @@ -307,21 +308,24 @@ class TracedModule(Module): self.m_node = node def forward(self, *args, **kwargs): - inputs, treedef = tree_flatten(((self, *args), kwargs), leaf_type=_leaf_type) - assert treedef == self.m_node.arg_def - rst = self.m_node.graph.interpret(*inputs) - if len(rst) == 1: - rst = rst[0] - return rst + inputs, treedef = tree_flatten( + ((self, *args), kwargs), _leaf_type, is_const_leaf=_is_const_leaf + ) + assert treedef in self.m_node.argdef_graph_map + inputs = [i for i in inputs if isinstance(i, (Module, RawTensor))] + outputs = self.m_node.argdef_graph_map[treedef].interpret(*inputs) + if len(outputs) == 1: + return outputs[0] + return outputs @property - def exprs(self): - """ - Get all ``Expr`` s recursively. + def graph(self): + assert len(self.m_node.argdef_graph_map) == 1 + return list(self.m_node.argdef_graph_map.values())[0] - :return: Iterator[Expr] - """ - return _expr_list(self) + @property + def exprs(self): + return self.graph.exprs def flatten(self): """ @@ -331,24 +335,26 @@ class TracedModule(Module): """ new_module = copy.deepcopy(self) - def _flatten_submodule(module, call=None): - if not isinstance(module, TracedModule): - call.inputs[0] = module - return (call,) - + def _flatten_subgraph(graph, module, call=None): + if graph is None: + assert not isinstance(module, TracedModule) + const = Constant(module) + modulenode = const.outputs[0] + modulenode.module_type = type(module) + call.inputs[0] = modulenode + return [const, call] exprs = [] - - graph = module.m_node.graph for expr in graph._exprs: - # replace inputs for submodule's expr for idx, inp in enumerate(expr.inputs): if call and inp in graph._inputs: - expr.inputs[idx] = call.inputs[idx] + inp_idx = graph._inputs.index(inp) + expr.inputs[idx] = call.inputs[inp_idx] # replace outputs for submodule's expr for idx, outp in enumerate(expr.outputs): if call and outp in graph._outputs: - expr.outputs[idx] = call.outputs[idx] + oup_idx = graph._outputs.index(outp) + expr.outputs[idx] = call.outputs[oup_idx] if isinstance(expr, GetAttr): # replace GetAttr with Constant @@ -356,12 +362,13 @@ class TracedModule(Module): const = Constant(getattr(module, expr.name)) const.outputs = expr.outputs exprs.append(const) + elif isinstance(expr, CallMethod): obj_node = expr.inputs[0] if isinstance(obj_node, ModuleNode): assert isinstance(expr.inputs[0].expr, GetAttr) (obj,) = expr.inputs[0].expr.interpret(module) - exprs.extend(_flatten_submodule(obj, expr)) + exprs.extend(_flatten_subgraph(expr.graph, obj, expr)) else: exprs.append(expr) else: @@ -369,7 +376,7 @@ class TracedModule(Module): return exprs - new_module.m_node.graph._exprs = _flatten_submodule(new_module) + new_module.graph._exprs = _flatten_subgraph(new_module.graph, new_module) return new_module @@ -421,7 +428,7 @@ def trace_module(mod: Module, *args: Tensor, **kwargs: Tensor) -> TracedModule: global_scope = InternalGraph() active_module_tracer().push_scope(global_scope) - builder = TracedModuleBuilder(mod) + builder = TracedModuleBuilder(mod, True) NodeMixin.wrap_safe(builder, Input.make("TopModule", ModuleNode)) inputs, _ = tree_flatten((args, kwargs)) for _, i in enumerate(inputs):