From c7e730bc12e0255aa82ed5ea37a6f1f06fdd7dc2 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 26 Jul 2021 00:02:04 +0800 Subject: [PATCH] feat(traced_module): add some functions of graph modification GitOrigin-RevId: 09691ebd334072f822226125acb11cebdc218618 --- .../experimental/traced_module/__init__.py | 2 + .../megengine/experimental/traced_module/expr.py | 81 ++- .../megengine/experimental/traced_module/node.py | 30 +- .../megengine/experimental/traced_module/pytree.py | 23 + .../experimental/traced_module/traced_module.py | 590 +++++++++++++++++---- .../test/unit/traced_module/test_modification.py | 10 +- 6 files changed, 612 insertions(+), 124 deletions(-) diff --git a/imperative/python/megengine/experimental/traced_module/__init__.py b/imperative/python/megengine/experimental/traced_module/__init__.py index cad44a0c..bda9fe92 100644 --- a/imperative/python/megengine/experimental/traced_module/__init__.py +++ b/imperative/python/megengine/experimental/traced_module/__init__.py @@ -13,6 +13,8 @@ from .traced_module import ( cpp_apply_module_trace, register_as_builtin, trace_module, + wrap, + wrap_tensors, ) _register_all_builtin_module() diff --git a/imperative/python/megengine/experimental/traced_module/expr.py b/imperative/python/megengine/experimental/traced_module/expr.py index 27cd2cf2..1f8ff685 100644 --- a/imperative/python/megengine/experimental/traced_module/expr.py +++ b/imperative/python/megengine/experimental/traced_module/expr.py @@ -11,7 +11,7 @@ import builtins import collections import copy import inspect -from typing import Callable, List +from typing import Callable, Dict, List from ...core._imperative_rt import OpDef from ...core._imperative_rt.core2 import Tensor as RawTensor @@ -29,10 +29,24 @@ class Expr: ``Expr`` represents the operations(i.e. CallMethod, CallFunction, Apply, GetAttr, Input, Constant) on ``Node``. """ + __total_id = 0 inputs = None # type: List[Node] outputs = None # type: List[Node] const_val = None # type: List[Any] arg_def = None # type: TreeDef + out_def = None # type: TreeDef + _top_graph = None # type: weakref.ReferenceType + + def __init__(self) -> None: + self._id = Expr.__total_id + Expr.__total_id += 1 + self._disable_remove = False + + def enable_remove(self): + self._disable_remove = False + + def disable_remove(self): + self._disable_remove = True def add_inputs(self, vals): if not isinstance(vals, collections.abc.Sequence): @@ -70,6 +84,22 @@ class Expr: else: return inputs, {} + def _replace_nodes(self, repl_dict: Dict[Node, Node], nodes: List[Node]): + while repl_dict: + node, repl_node = repl_dict.popitem() + assert type(node) == type(repl_node) + assert node in nodes + index = nodes.index(node) + nodes[index] = repl_node + repl_node.users.append(self) + node.users.pop(self) + + def replace_inputs(self, repl_dict: Dict[Node, Node]): + self._replace_nodes(repl_dict, self.inputs) + + def replace_outputs(self, repl_dict: Dict[Node, Node]): + self._replace_nodes(repl_dict, self.outputs) + @property def kwargs(self): _, kwargs = self.unflatten_args(self.inputs) @@ -80,12 +110,19 @@ class Expr: args, _ = self.unflatten_args(self.inputs) return args + @property + def top_graph(self): + if self._top_graph: + return self._top_graph() + return None + # expr: None (i.e. fake expression which is used to mark input) class Input(Expr): name = None def __init__(self, name=None, type=None): + super().__init__() self.inputs = [] node_cls = type if type else Node self.outputs = [ @@ -100,7 +137,7 @@ class Input(Expr): return expr.outputs[0] def __repr__(self): - return "{} = Input({})".format(self.outputs[0], self.name) + return "%{}: {} = Input({})".format(self._id, self.outputs[0], self.name) # expr: outputs = getattr(inputs[0], self.name) @@ -108,6 +145,7 @@ class GetAttr(Expr): name = None def __init__(self, module, name, type=None): + super().__init__() assert isinstance(module, ModuleNode) self.inputs = [ module, @@ -130,14 +168,15 @@ class GetAttr(Expr): return (getattr(inputs[0], self.name),) def __repr__(self): - return '{} = GetAttr({}, "{}")'.format( - self.outputs[0], self.inputs[0], self.name + return '%{}: {} = GetAttr({}, "{}")'.format( + self._id, self.outputs[0], self.inputs[0], self.name ) # expr: outputs = inputs[0].__call__(*inputs[1:]) class CallMethod(Expr): def __init__(self, node, method="__call__"): + super().__init__() if isinstance(node, type): assert issubclass(node, Tensor) cls = Parameter if issubclass(node, Parameter) else Tensor @@ -178,6 +217,8 @@ class CallMethod(Expr): if inspect.ismethod(meth): args = args[1:] outputs = getattr(obj, self.method)(*args, **kwargs) + if self.method == "__setitem__": + outputs = obj if outputs is None: return outputs outputs, _ = tree_flatten(outputs, is_leaf=lambda x: isinstance(x, RawTensor)) @@ -186,8 +227,12 @@ class CallMethod(Expr): def __repr__(self): args = ", ".join(str(i) for i in self.args[1:]) kwargs = ", ".join("{}={}".format(k, v) for k, v in self.kwargs.items()) - return "{} = {}.{}({})".format( - ", ".join(str(i) for i in self.outputs), + outputs = self.outputs + if self.out_def: + outputs = self.out_def.unflatten(outputs) + return "%{}: {}{}.{}({})".format( + self._id, + str(outputs) + " = " if outputs else "", self.args[0], self.method, ", ".join([args, kwargs]), @@ -199,6 +244,7 @@ class Apply(Expr): opdef = None def __init__(self, opdef): + super().__init__() assert isinstance(opdef, OpDef) self.opdef = opdef self.inputs = [] @@ -213,7 +259,8 @@ class Apply(Expr): return apply(self.opdef, *inputs) def __repr__(self): - return "{} = {}({})".format( + return "%{}: {} = {}({})".format( + self._id, ", ".join(str(i) for i in self.outputs), self.opdef, ", ".join(str(i) for i in self.inputs), @@ -241,6 +288,7 @@ class Apply(Expr): class CallFunction(Expr): def __init__(self, func): + super().__init__() assert isinstance(func, Callable) self.func = func self.const_val = [] @@ -255,16 +303,20 @@ class CallFunction(Expr): def interpret(self, *inputs): args, kwargs = self.unflatten_args(inputs) outputs = self.func(*args, **kwargs) - outputs = ( - outputs if isinstance(outputs, collections.abc.Sequence) else (outputs,) - ) + if outputs is None: + return outputs + outputs, _ = tree_flatten(outputs, is_leaf=lambda x: isinstance(x, RawTensor)) return outputs def __repr__(self): args = ", ".join(str(i) for i in self.args) kwargs = ", ".join("{}={}".format(k, v) for k, v in self.kwargs.items()) - return "{} = {}({})".format( - ", ".join(str(i) for i in self.outputs), + outputs = self.outputs + if self.out_def: + outputs = self.out_def.unflatten(outputs) + return "%{}: {}{}({})".format( + self._id, + str(outputs) + " = " if outputs else "", self.func.__module__ + "." + self.func.__name__, ", ".join([args, kwargs]), ) @@ -277,6 +329,7 @@ class Constant(Expr): _constant_cache = {} def __init__(self, c): + super().__init__() assert isinstance(c, (RawTensor, Module)) if isinstance(c, Module): assert module_tracer.is_builtin(c) @@ -299,7 +352,9 @@ class Constant(Expr): return (self.value,) def __repr__(self): - return "{} = Constant({})".format(self.outputs[0], type(self.value)) + return "%{}: {} = Constant({})".format( + self._id, self.outputs[0], type(self.value) + ) def __getstate__(self): state = self.__dict__.copy() diff --git a/imperative/python/megengine/experimental/traced_module/node.py b/imperative/python/megengine/experimental/traced_module/node.py index c6e605df..44506ead 100644 --- a/imperative/python/megengine/experimental/traced_module/node.py +++ b/imperative/python/megengine/experimental/traced_module/node.py @@ -30,6 +30,7 @@ class Node: __total_id = 0 _id = None _name = None + _top_graph = None # type: weakref.ReferenceType def __init__(self, expr: "Expr", name: str = None): self.expr = expr @@ -48,6 +49,12 @@ class Node: else: return "%{}".format(self._name) + @property + def top_graph(self): + if self._top_graph: + return self._top_graph() + return None + class ModuleNode(Node): """ @@ -64,21 +71,28 @@ class ModuleNode(Node): def __init__(self, expr: "Expr", name: str = None): super().__init__(expr, name) + self.actual_mnode = [] def __repr__(self): if self._name is None: - return "%{}({})".format(self._id, self.module_type.__name__) + return "%{}_({})".format(self._id, self.module_type.__name__) else: - return "%{}({})".format(self._name, self.module_type.__name__) + return "%{}_{}({})".format(self._id, self._name, self.module_type.__name__) def __getstate__(self): - d = self.__dict__ - d.pop("_owner", None) - return d + return { + "expr": self.expr, + "users": self.users, + "_id": self._id, + "_name": self._name, + "module_type": self.module_type, + } @property def owner(self): - return self._owner() + if self._owner: + return self._owner() + return None class TensorNode(Node): @@ -91,9 +105,9 @@ class TensorNode(Node): def __repr__(self): if self._name is None: - return "%{}(Tensor)".format(self._id) + return "%{}_(Tensor)".format(self._id) else: - return "%{}(Tensor)".format(self._name) + return "%{}_{}(Tensor)".format(self._id, self._name) class NodeMixin(abc.ABC): diff --git a/imperative/python/megengine/experimental/traced_module/pytree.py b/imperative/python/megengine/experimental/traced_module/pytree.py index 74ca1933..9ca05347 100644 --- a/imperative/python/megengine/experimental/traced_module/pytree.py +++ b/imperative/python/megengine/experimental/traced_module/pytree.py @@ -8,6 +8,7 @@ # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import collections +from collections import OrderedDict from typing import Callable, NamedTuple import numpy as np @@ -34,10 +35,26 @@ def _dict_unflatten(inps, aux_data): return dict(zip(aux_data, inps)) +def _ordereddict_flatten(inp): + aux_data = [] + results = [] + for key, value in inp.items(): + results.append(value) + aux_data.append(key) + return results, tuple(aux_data) + + +def _ordereddict_unflatten(inps, aux_data): + return OrderedDict(zip(aux_data, inps)) + + register_supported_type(list, lambda x: (x, None), lambda x, aux_data: list(x)) register_supported_type(tuple, lambda x: (x, None), lambda x, aux_data: tuple(x)) register_supported_type(dict, _dict_flatten, _dict_unflatten) register_supported_type( + collections.OrderedDict, _ordereddict_flatten, _ordereddict_unflatten +) +register_supported_type( slice, lambda x: ([x.start, x.stop, x.step], None), lambda x, aux_data: slice(x[0], x[1], x[2]), @@ -99,6 +116,12 @@ class TreeDef: ) ) + def __lt__(self, other): + return self.__hash__() < other.__hash__() + + def __gt__(self, other): + return self.__hash__() > other.__hash__() + def __eq__(self, other): return ( self.type == other.type diff --git a/imperative/python/megengine/experimental/traced_module/traced_module.py b/imperative/python/megengine/experimental/traced_module/traced_module.py index be609395..58d9d8e0 100644 --- a/imperative/python/megengine/experimental/traced_module/traced_module.py +++ b/imperative/python/megengine/experimental/traced_module/traced_module.py @@ -9,12 +9,10 @@ import collections import copy import functools +import inspect import weakref from inspect import getmembers, isclass, ismethod -from typing import Callable, Dict, Iterable, List, Sequence, Type - -import numpy as np -from numpy.lib.arraysetops import isin +from typing import Callable, Dict, Iterable, List, Optional, Sequence, Type, Union from ... import functional as F from ... import get_logger @@ -43,9 +41,9 @@ logger = get_logger(__name__) def _leaf_type(node): - if isinstance(node, RawTensor): + if isinstance(node, (RawTensor, TensorNode)): return (Tensor, TensorNode) - elif isinstance(node, (NodeMixin, Module)): + elif isinstance(node, (NodeMixin, Module, ModuleNode)): return (Module, ModuleNode, NodeMixin) else: return type(node) @@ -64,6 +62,50 @@ def _is_const_leaf(node): return True +def wrap_tensors(tensors: Tensor, nodes: TensorNode): + inp_tensors = copy.deepcopy(tensors) + inp_tensors, inp_def_v = tree_flatten( + inp_tensors, leaf_type=_leaf_type, is_const_leaf=_is_const_leaf + ) + inp_nodes, inp_def_n = tree_flatten( + nodes, leaf_type=_leaf_type, is_const_leaf=_is_const_leaf + ) + for v, n in zip(inp_tensors, inp_nodes): + if isinstance(n, TensorNode) and isinstance(v, Tensor): + NodeMixin.wrap_safe(v, n) + return inp_def_v.unflatten(inp_tensors) + + +class _InsertExprs: + def __init__(self, graph, expr: Optional[Expr] = None, after: bool = True): + self.graph = graph + self.global_scope = InternalGraph() + self.expr = expr + self.after = after + + def __enter__(self): + self.use_sym_shape = set_symbolic_shape(True) + set_module_tracing() + assert active_module_tracer() is None + set_active_module_tracer(module_tracer(_wrapped_function)) + active_module_tracer().patcher.__enter__() + active_module_tracer().push_scope(self.global_scope) + + def __exit__(self, ty, va, tr): + set_symbolic_shape(self.use_sym_shape) + unset_module_tracing() + active_module_tracer().patcher.__exit__(ty, va, tr) + set_active_module_tracer(None) + index = len(self.graph._exprs) if self.after else 0 + if self.expr is not None: + index = self.graph._exprs.index(self.expr) + if self.after: + index += 1 + for expr in self.global_scope._exprs: + self.graph._exprs.insert(index, expr) + index += 1 + + class InternalGraph: """ ``InternalGraph`` is a graph consist of ``Node`` and ``Expr``, it is used to represent the execution procedure of Module's forward method. @@ -95,14 +137,28 @@ class InternalGraph: return self._outputs @property - def exprs(self): + def expr_filter(self): return ExprFilter(_expr_iter(self)) - def get_call_function(self, func: Callable = None): - return self.exprs.call_function(func) + @property + def node_filter(self): + return NodeFilter(_node_iter(self)) + + def get_function_by_type(self, func: Callable = None): + return self.expr_filter.call_function(func) + + def get_method_by_type(self, method: str = None): + return self.expr_filter.call_method(method) - def get_call_method(self, method: str = None): - return self.exprs.call_method(method) + def get_expr_by_id(self, expr_id: List[int] = None): + return self.expr_filter.expr_id(expr_id) + + def get_module_by_type(self, module_cls: Module): + assert issubclass(module_cls, Module) + return self.node_filter.type(module_cls, ModuleNode) + + def get_node_by_id(self, node_id: List[int] = None): + return self.node_filter.node_id(node_id) def add_input(self, i): self._inputs.append(i) @@ -124,7 +180,6 @@ class InternalGraph: for idx, o in enumerate(self._outputs): if o in repl_dict: self._outputs[idx] = repl_dict[o] - self._outputs[idx].expr = node.expr for expr in self._exprs: @@ -135,83 +190,283 @@ class InternalGraph: for idx, o in enumerate(expr.outputs): if o in repl_dict: expr.outputs[idx] = repl_dict[o] + expr.outputs[idx].expr = expr def get_dep_exprs(self, nodes: Sequence[Node]) -> List[Expr]: if not isinstance(nodes, Sequence): nodes = (nodes,) ret = list() queue = list(nodes) + visited_queue = list() while queue: node = queue.pop() + visited_queue.append(node) + expr = node.expr + if expr not in ret: ret.append(expr) for i in expr.inputs: - if i not in queue: + if i not in queue and i not in visited_queue: queue.append(i) return ret - def insert_call_function(self, func: Callable, nodes: Sequence[Node]): - if not isinstance(nodes, Sequence): - nodes = [nodes] - assert isinstance(func, Callable) - for i in nodes: - assert isinstance( - i, TensorNode - ), "CallFunction only accept TensorNode as inputs" + def reset_inputs(self, *args, **kwargs): + forma_mnode = self.inputs[0] + actual_mnodes = forma_mnode.actual_mnode + call_nodes = [] + for n in actual_mnodes: + for c_expr in n.users: + if isinstance(c_expr, CallMethod) and c_expr.method == "__call__": + call_nodes.append((c_expr, n)) - expr = CallFunction(func) - expr.inputs = nodes + moudle = forma_mnode.owner + assert moudle._is_top, "reset_inputs only support the top-level graph" + + inputs, tree_def = tree_flatten( + ((moudle, *args), kwargs), + leaf_type=_leaf_type, + is_const_leaf=_is_const_leaf, + ) - for i in nodes: - i.users.append(expr) + def create_node(val: Tensor): + node = Input(type=TensorNode).outputs[0] + node.shape = val.shape + node.dtype = val.dtype + return node - idx = max(self._exprs.index(i.expr) for i in nodes) + 1 - self._exprs.insert(idx, expr) + formal_node_inputs = [ + forma_mnode, + ] + + org_argdef = list(moudle.argdef_graph_map.keys())[0] + if call_nodes: + org_argdef = call_nodes[0][0].arg_def + + for v in inputs[1:]: + assert isinstance(v, RawTensor) + formal_node_inputs.append(create_node(v)) + + actual_nodes = [] + for e, n in call_nodes: + e.arg_def = tree_def + actual_node_inputs = [ + n, + ] + for v in inputs[1:]: + actual_node_inputs.append(create_node(v)) + + for org_n in e.inputs: + org_n.users.pop(e) + + e.inputs[:] = actual_node_inputs + e.const_val = [] + actual_nodes.append(actual_node_inputs[1:]) + + self._inputs[:] = formal_node_inputs + moudle.argdef_graph_map[tree_def] = moudle.argdef_graph_map.pop(org_argdef) + moudle.argdef_outdef_map[tree_def] = moudle.argdef_outdef_map.pop(org_argdef) + + # return formal_node_inputs[1:], actual_nodes + return formal_node_inputs[1:] + + def add_input_node(self, shape, dtype="float32"): + forma_mnode = self.inputs[0] + actual_mnodes = forma_mnode.actual_mnode + + moudle = forma_mnode.owner + assert moudle._is_top, "add_input_node only support the top-level graph" + + call_nodes = [] + for n in actual_mnodes: + for c_expr in n.users: + if isinstance(c_expr, CallMethod) and c_expr.method == "__call__": + call_nodes.append(c_expr) + + def create_node(is_input: bool = True): + if is_input: + node = Input(type=TensorNode).outputs[0] + else: + node = TensorNode(expr=None) + node.shape = shape + node.dtype = dtype + return node - fake_inp_val = tuple(F.zeros(shape=i.shape, dtype=i.dtype) for i in nodes) - fake_out_val = func(*fake_inp_val) + org_argdef = list(moudle.argdef_graph_map.keys())[0] - def create_node(val: Tensor): + if call_nodes: + org_argdef = call_nodes[0].arg_def + + args, kwargs = org_argdef.unflatten(self._inputs) + formal_inp_node = create_node(True) + inputs, tree_def = tree_flatten( + ((*args, formal_inp_node), kwargs), + leaf_type=_leaf_type, + is_const_leaf=lambda x: not isinstance(x, (TensorNode, ModuleNode)), + ) + self._inputs[:] = inputs[:] + + actual_inp_nodes = [] + for e in call_nodes: + args, kwargs = e.unflatten_args(e.inputs) + args = args + (create_node(False),) + inputs, tree_def = tree_flatten( + (args, kwargs), + leaf_type=_leaf_type, + is_const_leaf=lambda x: not isinstance(x, (TensorNode, ModuleNode)), + ) + e.inputs[:] = inputs[:] + e.arg_def = tree_def + actual_inp_nodes.append(args[-1]) + + moudle.argdef_graph_map[tree_def] = moudle.argdef_graph_map.pop(org_argdef) + moudle.argdef_outdef_map[tree_def] = moudle.argdef_outdef_map.pop(org_argdef) + + # return formal_inp_node, actual_inp_nodes + return formal_inp_node + + def reset_outputs(self, outputs): + outputs, out_def = tree_flatten( + outputs, leaf_type=_leaf_type, is_leaf=lambda x: isinstance(x, TensorNode), + ) + forma_mnode = self.inputs[0] + + moudle = forma_mnode.owner + assert moudle._is_top, "reset_outputs only support the top-level graph" + + actual_mnodes = forma_mnode.actual_mnode + call_nodes = [] + for n in actual_mnodes: + for c_expr in n.users: + if isinstance(c_expr, CallMethod) and c_expr.method == "__call__": + call_nodes.append((c_expr)) + + def create_node(val: TensorNode, expr: Expr): node = TensorNode(expr) node.shape = val.shape node.dtype = val.dtype return node - out_nodes = list(create_node(i) for i in fake_out_val) - expr.outputs = out_nodes + tree_def = list(moudle.argdef_graph_map.keys())[0] + if call_nodes: + tree_def = call_nodes[0].arg_def - return out_nodes + actual_nodes = [] + for e in call_nodes: + actual_node_outputs = [] + for v in outputs: + actual_node_outputs.append(create_node(v, e)) + e.outputs[:] = actual_node_outputs + e.out_def = out_def + actual_nodes.append(actual_node_outputs) - def insert_call_method(self, target, method, args): - if not isinstance(args, Sequence): - args = [args] - assert isinstance(target, (TensorNode, ModuleNode)) - assert isinstance(method, str) - for i in args: - assert isinstance(i, TensorNode) + self._outputs[:] = outputs + moudle.argdef_outdef_map[tree_def] = out_def - expr = CallMethod(method) - expr.inputs = [target, *args] + return actual_nodes - if isinstance(target, TensorNode): - fake_target_val = F.zeros(shape=target.shape, dtype=target.dtype) - fake_inp_val = tuple(F.zeros(shape=i.shape, dtype=i.dtype) for i in args) - fake_out_val = getattr(fake_target_val, method)(fake_inp_val) + def add_output_node(self, node: TensorNode): + forma_mnode = self.inputs[0] - def create_node(val: Tensor): - node = TensorNode(expr) - node.shape = val.shape - node.dtype = val.dtype - return node + moudle = forma_mnode.owner + assert moudle._is_top, "add_output_node only support the top-level graph" - out_nodes = list(create_node(i) for i in fake_out_val) - expr.outputs = out_nodes - else: - raise NotImplementedError() + actual_mnodes = forma_mnode.actual_mnode + call_nodes = [] + + for n in actual_mnodes: + for c_expr in n.users: + if isinstance(c_expr, CallMethod) and c_expr.method == "__call__": + call_nodes.append((c_expr)) + + def create_node(val: TensorNode, expr: Expr): + node = TensorNode(expr) + node.shape = val.shape + node.dtype = val.dtype + return node + + tree_def = list(moudle.argdef_graph_map.keys())[0] + if call_nodes: + tree_def = call_nodes[0].arg_def + + org_out_def = moudle.argdef_outdef_map[tree_def] + org_outs = org_out_def.unflatten(self._outputs) + outputs, out_def = tree_flatten( + (org_outs, node), + leaf_type=_leaf_type, + is_leaf=lambda x: isinstance(x, TensorNode), + ) + self._outputs[:] = outputs + + actual_out_nodes = [] + for e in call_nodes: + actual_node = create_node(node, e) + org_outs = org_out_def.unflatten(e.outputs) + outputs, out_def = tree_flatten( + (org_outs, actual_node), + leaf_type=_leaf_type, + is_leaf=lambda x: isinstance(x, TensorNode), + ) + e.outputs[:] = outputs + e.out_def = out_def + actual_out_nodes.append(actual_node) + + moudle.argdef_outdef_map[tree_def] = out_def + + return actual_out_nodes + + def insert_function(self, func: Callable, *args, **kwargs): + assert isinstance(func, Callable) + + inp_nodes, inp_def = tree_flatten( + (args, kwargs), leaf_type=_leaf_type, is_const_leaf=_is_const_leaf + ) + + insert_idx = -1 + for i in inp_nodes: + if isinstance(i, TensorNode) and i.expr in self._exprs: + insert_idx = max(insert_idx, self._exprs.index(i.expr)) + + fake_inp_val = list( + F.zeros(shape=i.shape, dtype=i.dtype) if isinstance(i, TensorNode) else i + for i in inp_nodes + ) + + for v, n in zip(fake_inp_val, inp_nodes): + if isinstance(n, TensorNode): + NodeMixin.wrap_safe(v, n) + + fake_args, fake_kwargs = inp_def.unflatten(fake_inp_val) + + insert_point = self.insert_exprs_before() + if insert_idx != -1: + insert_point = self.insert_exprs_after(self._exprs[insert_idx]) + + with insert_point: + rst = func(*fake_args, **fake_kwargs) + + if rst is None: + return None + + outputs, out_def = tree_flatten(rst, leaf_type=_leaf_type, is_leaf=_is_leaf) + node_outputs = [] + for out in outputs: + assert isinstance(out, RawTensor) + node_outputs.append(NodeMixin.get(out, None)) + + node_outputs = out_def.unflatten(node_outputs) + return node_outputs + + def insert_exprs_after(self, expr: Optional[Expr] = None): + if expr is not None: + assert expr.top_graph == self, "Expr to insert after is not in graph." + return _InsertExprs(self, expr, after=True) - return out_nodes + def insert_exprs_before(self, expr: Optional[Expr] = None): + if expr is not None: + assert expr.top_graph == self, "Expr to insert before is not in graph." + return _InsertExprs(self, expr, after=False) def replace_node(self, repl_dict: Dict[Node, Node]): while repl_dict: @@ -246,7 +501,7 @@ class InternalGraph: i = 0 while i < len(self._exprs): expr = self._exprs[i] - if expr in dep_exprs: + if expr in dep_exprs or expr._disable_remove: i += 1 continue for n in expr.inputs: @@ -267,7 +522,7 @@ class InternalGraph: def __repr__(self): return "InternalGraph ({}) {{\n\t{}\n\treturn {}\n}}".format( ", ".join(str(i) for i in self._inputs), - "\n\t".join(str(i) for i in self._exprs), + "\n\t".join("{}".format(str(i)) for i in self._exprs), ", ".join(str(i) for i in self._outputs), ) @@ -293,7 +548,7 @@ def _wrapped_function(orig_func): if not NodeMixin.get(i, None): if isinstance(i, (RawTensor, NodeMixin)): NodeMixin.wrap_safe(i, Constant.make(i)) - meth_name = _get_meth_name(args[0], wrapped_fn) + meth_name = _get_meth_name(args[0], wrapped_fn) if args else None if meth_name: self = inputs[0] if meth_name == "__new__": @@ -316,10 +571,19 @@ def _wrapped_function(orig_func): call_node.add_inputs(inputs) call_node.arg_def = tree_def - outputs = orig_func(*args, **kwargs) + rst = orig_func(*args, **kwargs) + if meth_name == "__setitem__": + rst = self + if rst is not None: + outputs, out_def = tree_flatten( + rst, leaf_type=_leaf_type, is_leaf=_is_leaf + ) + call_node.out_def = out_def + else: + outputs = None call_node.add_outputs(outputs) set_module_tracing() - return outputs + return rst return orig_func(*args, **kwargs) return wrapped_fn @@ -349,6 +613,7 @@ class TracedModuleBuilder(NodeMixin): super(TracedModuleBuilder, self).__init__() self._mod = mod self._body = None + self._is_top = is_top_module self._is_builtin = module_tracer.is_builtin(mod) self._argdef_graph_map = {} self._argdef_outdef_map = {} @@ -362,7 +627,7 @@ class TracedModuleBuilder(NodeMixin): return self._mod else: traced_module = TracedModule( - self._argdef_graph_map, self._argdef_outdef_map + self._is_top, self._argdef_graph_map, self._argdef_outdef_map ) for _, g in self._argdef_graph_map.items(): g.compile() @@ -408,8 +673,8 @@ class TracedModuleBuilder(NodeMixin): self._body = None else: self_node = None - if self._body: - self_node = self._body.inputs[0] + if tree_def in self._argdef_graph_map: + self_node = self._argdef_graph_map[tree_def].inputs[0] self._body = InternalGraph() active_module_tracer().push_scope(self._body) # rebind self to new input node @@ -446,7 +711,7 @@ class TracedModuleBuilder(NodeMixin): outputs if isinstance(outputs, collections.abc.Sequence) else (outputs,) ): active_module_tracer().current_scope().add_output(NodeMixin.get(i)) - + NodeMixin.get(self, None).actual_mnode.append(orig_self) NodeMixin.wrap_safe(self, orig_self) for arg, node in zip(inputs[1:], origin_inp_node): if node: @@ -454,6 +719,7 @@ class TracedModuleBuilder(NodeMixin): active_module_tracer().pop_scope() # rebind output to outer graph + callnode.out_def = out_def callnode.add_outputs(outputs) self._argdef_graph_map[callnode.arg_def] = self._body self._argdef_outdef_map[callnode.arg_def] = out_def @@ -512,31 +778,44 @@ class _expr_iter: if isinstance(expr, CallMethod) and isinstance(expr.inputs[0], ModuleNode): yield expr if expr.graph is not None: - yield from expr.graph.exprs + yield from expr.graph.expr_filter else: yield expr -class ExprFilter: +class _node_iter: + def __init__(self, graph: InternalGraph) -> None: + nodes = [] + node_ids = set() + for expr in graph.expr_filter: + for n in expr.inputs + expr.outputs: + if n._id in node_ids: + continue + nodes.append(n) + node_ids.add(n._id) + self.nodes = list(sorted(nodes, key=lambda x: x._id)) + + def __iter__(self): + for node in self.nodes: + yield node + + +class BaseFilter: def __init__(self, expr_iter: Iterable): self._iter = expr_iter def __iter__(self): return iter(self._iter) - def call_function(self, func): - return ExprFilterCallFunction(self, func) - - def call_method(self, method): - return ExprFilterCallMethod(self, method) - def as_list(self): return list(self) def as_dict(self): - raise NotImplementedError("need key") + return collections.OrderedDict((i._id, i) for i in self) def as_unique(self): + rst = self.as_list() + assert len(rst) == 1, "{} elements found".format(len(rst)) (expr,) = self return expr @@ -544,17 +823,65 @@ class ExprFilter: return sum(1 for _ in self) +class ExprFilter(BaseFilter): + def call_function(self, func): + return ExprFilterCallFunction(self, func) + + def call_method(self, method): + return ExprFilterCallMethod(self, method) + + def expr_id(self, expr_id: List[int]): + return ExprFilterExprId(self, expr_id) + + +class NodeFilter(BaseFilter): + def type(self, owner_type, node_type): + return NodeFilterType(self, owner_type, node_type) + + def node_id(self, node_id: List[int]): + return NodeFilterNodeId(self, node_id) + + +class NodeFilterType(NodeFilter): + def __init__(self, expr_iter, owner_type, node_type): + super().__init__(expr_iter) + self.owner_type = owner_type + self.node_type = node_type + + def __iter__(self): + for node in self._iter: + if not isinstance(node, self.node_type): + continue + if not hasattr(node, "owner"): + continue + if isinstance(node.owner, self.owner_type): + yield node + + +class NodeFilterNodeId(NodeFilter): + def __init__(self, expr_iter, node_id: List[int]): + super().__init__(expr_iter) + if not isinstance(node_id, Sequence): + node_id = [node_id] + self.node_id = node_id + + def __iter__(self): + for node in self._iter: + if node._id in self.node_id: + yield node + + class ExprFilterCallFunction(ExprFilter): def __init__(self, expr_iter, func: Callable = None): super().__init__(expr_iter) self.func = func def __iter__(self): - for i in self._iter: - if not isinstance(i, CallFunction): + for expr in self._iter: + if not isinstance(expr, CallFunction): continue - if self.func is None or i.func == self.func: - yield i + if self.func is None or expr.func == self.func: + yield expr class ExprFilterCallMethod(ExprFilter): @@ -563,11 +890,24 @@ class ExprFilterCallMethod(ExprFilter): self.method = method def __iter__(self): - for i in self._iter: - if not isinstance(i, CallMethod): + for expr in self._iter: + if not isinstance(expr, CallMethod): continue - if self.method is None or i.method == self.method: - yield i + if self.method is None or expr.method == self.method: + yield expr + + +class ExprFilterExprId(ExprFilter): + def __init__(self, expr_iter, expr_id: List[int]): + super().__init__(expr_iter) + if not isinstance(expr_id, Sequence): + expr_id = [expr_id] + self.expr_id = expr_id + + def __iter__(self): + for expr in self._iter: + if expr._id in self.expr_id: + yield expr class TracedModule(Module): @@ -579,10 +919,11 @@ class TracedModule(Module): argdef_graph_map = None argdef_outdef_map = None - def __init__(self, argdef_graph_map, argdef_outdef_map): + def __init__(self, is_top, argdef_graph_map, argdef_outdef_map): super(TracedModule, self).__init__() self.argdef_graph_map = argdef_graph_map self.argdef_outdef_map = argdef_outdef_map + self._is_top = is_top def forward(self, *args, **kwargs): inputs, treedef = tree_flatten( @@ -598,29 +939,58 @@ class TracedModule(Module): return outputs @property - def graph(self): - self._update_modulenode_ref() + def graph(self) -> InternalGraph: + if self._is_top: + self._update_ref() assert len(self.argdef_graph_map) == 1 return list(self.argdef_graph_map.values())[0] - def _update_modulenode_ref(self): - for _, graph in self.argdef_graph_map.items(): + def _update_ref(self, actual_node_map: Union[Dict] = None): + for inp_def, graph in self.argdef_graph_map.items(): + for n in graph._inputs + graph.outputs: + n._top_graph = weakref.ref(graph) graph._inputs[0]._owner = weakref.ref(self) + graph._inputs[0].actual_mnode = [] + if actual_node_map is not None and inp_def in actual_node_map.keys(): + graph._inputs[0].actual_mnode = actual_node_map[inp_def] node2obj = {} + next_actual_node_map = collections.defaultdict( + lambda: collections.defaultdict(list) + ) node2obj[graph._inputs[0]] = self for expr in graph._exprs: + for n in expr.inputs + expr.outputs: + n._top_graph = weakref.ref(graph) + expr._top_graph = weakref.ref(graph) if isinstance(expr, GetAttr) and isinstance( expr.outputs[0], ModuleNode ): obj = getattr(node2obj[expr.inputs[0]], expr.name) expr.outputs[0]._owner = weakref.ref(obj) node2obj[expr.outputs[0]] = obj - if isinstance(obj, TracedModule): - obj._update_modulenode_ref() + if isinstance(expr, Constant) and isinstance( + expr.outputs[0], ModuleNode + ): + obj = expr.value + expr.outputs[0]._owner = weakref.ref(obj) + node2obj[expr.outputs[0]] = obj + if ( + isinstance(expr, CallMethod) + and expr.method == "__call__" + and isinstance(expr.inputs[0], ModuleNode) + ): + obj = node2obj[expr.inputs[0]] + if expr.arg_def is not None: + next_actual_node_map[obj][expr.arg_def].append(expr.inputs[0]) - @property - def exprs(self): - return self.graph.exprs + for obj in node2obj.values(): + if obj is self: + continue + mnode_map = None + if obj in next_actual_node_map.keys(): + mnode_map = next_actual_node_map[obj] + if isinstance(obj, TracedModule): + obj._update_ref(mnode_map) def flatten(self): """ @@ -644,13 +1014,21 @@ class TracedModule(Module): node2obj[graph._inputs[0]] = module if call: node2obj[call.inputs[0]] = module + repl_dict = dict(zip(graph._inputs, call.inputs)) + for ind, out in enumerate(graph.outputs): + if isinstance(out.expr, Input): + assert out in repl_dict + call_out = call.outputs[ind] + for expr in call.outputs[ind].users: + for index, inp in enumerate(expr.inputs): + if inp is call_out: + expr.inputs[index] = repl_dict[out] + + continue + repl_dict[out] = call.outputs[ind] + + graph._replace_inputs_outputs(repl_dict) for expr in graph._exprs: - # replace inputs for submodule's exprx - if call: - repl_dict = dict( - zip(graph._inputs + graph._outputs, call.inputs + call.outputs) - ) - graph._replace_inputs_outputs(repl_dict) if isinstance(expr, GetAttr): # replace GetAttr with Constant @@ -715,6 +1093,21 @@ def register_as_builtin(mod_cls: Type[Module]) -> None: module_tracer.register_as_builtin(mod_cls) +def wrap(func: Union[Callable]): + assert callable(func) + if hasattr(func, "__code__"): + assert not isinstance(func, str) + fn_name = func.__code__.co_name + currentframe = inspect.currentframe() + assert currentframe is not None + f = currentframe.f_back + assert f is not None + if f.f_code.co_name != "": + raise NotImplementedError("wrap must be called at the top level of a module") + Patcher._builtin_functions.append((f.f_globals, fn_name)) + return func + + def _register_all_builtin_module(): for sub_mod in [M, M.qat, M.quantized]: @@ -749,6 +1142,7 @@ def trace_module(mod: Module, *args: Tensor, **kwargs: Tensor) -> TracedModule: NodeMixin.wrap_safe(builder, Input.make("TopModule", ModuleNode)) inputs, _ = tree_flatten((args, kwargs), is_const_leaf=_is_const_leaf) for _, i in enumerate(inputs): + assert isinstance(i, Tensor), "not support " if isinstance(i, RawTensor): NodeMixin.wrap_safe( i, Input.make("arg_{}".format(_), NodeMixin.get_wrapped_type(i)) diff --git a/imperative/python/test/unit/traced_module/test_modification.py b/imperative/python/test/unit/traced_module/test_modification.py index 692fbb0b..5ccdcfa9 100644 --- a/imperative/python/test/unit/traced_module/test_modification.py +++ b/imperative/python/test/unit/traced_module/test_modification.py @@ -57,16 +57,16 @@ def _init_module(): def test_search(): traced_module, *_ = _init_block() graph = traced_module.graph - relu_expr = graph.get_call_function(F.relu).as_unique() + relu_expr = graph.get_function_by_type(F.relu).as_unique() assert isinstance(relu_expr, CallFunction) and relu_expr.func == F.relu def test_insert(): traced_module, x, expect = _init_block() graph = traced_module.graph - relu_node = graph.get_call_function(F.relu).as_unique().outputs - neg_node = graph.insert_call_function(F.neg, relu_node) - graph.replace_node({relu_node[0]: neg_node[0]}) + relu_node = graph.get_function_by_type(F.relu).as_unique().outputs + neg_node = graph.insert_function(lambda x: F.neg(x), *relu_node) + graph.replace_node({relu_node[0]: neg_node}) graph.compile() np.testing.assert_allclose(expect - 1, 1 - traced_module(x), atol=1e-6) @@ -74,7 +74,7 @@ def test_insert(): def test_delete(): traced_module, x, expect = _init_block() graph = traced_module.graph - relu_expr = graph.get_call_function(F.relu).as_unique() + relu_expr = graph.get_function_by_type(F.relu).as_unique() node = relu_expr.outputs repl_node = relu_expr.inputs graph.replace_node({node[0]: repl_node[0]})