Browse Source

feat(traced_module): add some functions of graph modification

GitOrigin-RevId: 09691ebd33
release-1.6
Megvii Engine Team 3 years ago
parent
commit
c7e730bc12
6 changed files with 612 additions and 124 deletions
  1. +2
    -0
      imperative/python/megengine/experimental/traced_module/__init__.py
  2. +68
    -13
      imperative/python/megengine/experimental/traced_module/expr.py
  3. +22
    -8
      imperative/python/megengine/experimental/traced_module/node.py
  4. +23
    -0
      imperative/python/megengine/experimental/traced_module/pytree.py
  5. +492
    -98
      imperative/python/megengine/experimental/traced_module/traced_module.py
  6. +5
    -5
      imperative/python/test/unit/traced_module/test_modification.py

+ 2
- 0
imperative/python/megengine/experimental/traced_module/__init__.py View File

@@ -13,6 +13,8 @@ from .traced_module import (
cpp_apply_module_trace, cpp_apply_module_trace,
register_as_builtin, register_as_builtin,
trace_module, trace_module,
wrap,
wrap_tensors,
) )


_register_all_builtin_module() _register_all_builtin_module()


+ 68
- 13
imperative/python/megengine/experimental/traced_module/expr.py View File

@@ -11,7 +11,7 @@ import builtins
import collections import collections
import copy import copy
import inspect import inspect
from typing import Callable, List
from typing import Callable, Dict, List


from ...core._imperative_rt import OpDef from ...core._imperative_rt import OpDef
from ...core._imperative_rt.core2 import Tensor as RawTensor from ...core._imperative_rt.core2 import Tensor as RawTensor
@@ -29,10 +29,24 @@ class Expr:
``Expr`` represents the operations(i.e. CallMethod, CallFunction, Apply, GetAttr, Input, Constant) on ``Node``. ``Expr`` represents the operations(i.e. CallMethod, CallFunction, Apply, GetAttr, Input, Constant) on ``Node``.
""" """


__total_id = 0
inputs = None # type: List[Node] inputs = None # type: List[Node]
outputs = None # type: List[Node] outputs = None # type: List[Node]
const_val = None # type: List[Any] const_val = None # type: List[Any]
arg_def = None # type: TreeDef arg_def = None # type: TreeDef
out_def = None # type: TreeDef
_top_graph = None # type: weakref.ReferenceType

def __init__(self) -> None:
self._id = Expr.__total_id
Expr.__total_id += 1
self._disable_remove = False

def enable_remove(self):
self._disable_remove = False

def disable_remove(self):
self._disable_remove = True


def add_inputs(self, vals): def add_inputs(self, vals):
if not isinstance(vals, collections.abc.Sequence): if not isinstance(vals, collections.abc.Sequence):
@@ -70,6 +84,22 @@ class Expr:
else: else:
return inputs, {} return inputs, {}


def _replace_nodes(self, repl_dict: Dict[Node, Node], nodes: List[Node]):
while repl_dict:
node, repl_node = repl_dict.popitem()
assert type(node) == type(repl_node)
assert node in nodes
index = nodes.index(node)
nodes[index] = repl_node
repl_node.users.append(self)
node.users.pop(self)

def replace_inputs(self, repl_dict: Dict[Node, Node]):
self._replace_nodes(repl_dict, self.inputs)

def replace_outputs(self, repl_dict: Dict[Node, Node]):
self._replace_nodes(repl_dict, self.outputs)

@property @property
def kwargs(self): def kwargs(self):
_, kwargs = self.unflatten_args(self.inputs) _, kwargs = self.unflatten_args(self.inputs)
@@ -80,12 +110,19 @@ class Expr:
args, _ = self.unflatten_args(self.inputs) args, _ = self.unflatten_args(self.inputs)
return args return args


@property
def top_graph(self):
if self._top_graph:
return self._top_graph()
return None



# expr: None (i.e. fake expression which is used to mark input) # expr: None (i.e. fake expression which is used to mark input)
class Input(Expr): class Input(Expr):
name = None name = None


def __init__(self, name=None, type=None): def __init__(self, name=None, type=None):
super().__init__()
self.inputs = [] self.inputs = []
node_cls = type if type else Node node_cls = type if type else Node
self.outputs = [ self.outputs = [
@@ -100,7 +137,7 @@ class Input(Expr):
return expr.outputs[0] return expr.outputs[0]


def __repr__(self): def __repr__(self):
return "{} = Input({})".format(self.outputs[0], self.name)
return "%{}: {} = Input({})".format(self._id, self.outputs[0], self.name)




# expr: outputs = getattr(inputs[0], self.name) # expr: outputs = getattr(inputs[0], self.name)
@@ -108,6 +145,7 @@ class GetAttr(Expr):
name = None name = None


def __init__(self, module, name, type=None): def __init__(self, module, name, type=None):
super().__init__()
assert isinstance(module, ModuleNode) assert isinstance(module, ModuleNode)
self.inputs = [ self.inputs = [
module, module,
@@ -130,14 +168,15 @@ class GetAttr(Expr):
return (getattr(inputs[0], self.name),) return (getattr(inputs[0], self.name),)


def __repr__(self): def __repr__(self):
return '{} = GetAttr({}, "{}")'.format(
self.outputs[0], self.inputs[0], self.name
return '%{}: {} = GetAttr({}, "{}")'.format(
self._id, self.outputs[0], self.inputs[0], self.name
) )




# expr: outputs = inputs[0].__call__(*inputs[1:]) # expr: outputs = inputs[0].__call__(*inputs[1:])
class CallMethod(Expr): class CallMethod(Expr):
def __init__(self, node, method="__call__"): def __init__(self, node, method="__call__"):
super().__init__()
if isinstance(node, type): if isinstance(node, type):
assert issubclass(node, Tensor) assert issubclass(node, Tensor)
cls = Parameter if issubclass(node, Parameter) else Tensor cls = Parameter if issubclass(node, Parameter) else Tensor
@@ -178,6 +217,8 @@ class CallMethod(Expr):
if inspect.ismethod(meth): if inspect.ismethod(meth):
args = args[1:] args = args[1:]
outputs = getattr(obj, self.method)(*args, **kwargs) outputs = getattr(obj, self.method)(*args, **kwargs)
if self.method == "__setitem__":
outputs = obj
if outputs is None: if outputs is None:
return outputs return outputs
outputs, _ = tree_flatten(outputs, is_leaf=lambda x: isinstance(x, RawTensor)) outputs, _ = tree_flatten(outputs, is_leaf=lambda x: isinstance(x, RawTensor))
@@ -186,8 +227,12 @@ class CallMethod(Expr):
def __repr__(self): def __repr__(self):
args = ", ".join(str(i) for i in self.args[1:]) args = ", ".join(str(i) for i in self.args[1:])
kwargs = ", ".join("{}={}".format(k, v) for k, v in self.kwargs.items()) kwargs = ", ".join("{}={}".format(k, v) for k, v in self.kwargs.items())
return "{} = {}.{}({})".format(
", ".join(str(i) for i in self.outputs),
outputs = self.outputs
if self.out_def:
outputs = self.out_def.unflatten(outputs)
return "%{}: {}{}.{}({})".format(
self._id,
str(outputs) + " = " if outputs else "",
self.args[0], self.args[0],
self.method, self.method,
", ".join([args, kwargs]), ", ".join([args, kwargs]),
@@ -199,6 +244,7 @@ class Apply(Expr):
opdef = None opdef = None


def __init__(self, opdef): def __init__(self, opdef):
super().__init__()
assert isinstance(opdef, OpDef) assert isinstance(opdef, OpDef)
self.opdef = opdef self.opdef = opdef
self.inputs = [] self.inputs = []
@@ -213,7 +259,8 @@ class Apply(Expr):
return apply(self.opdef, *inputs) return apply(self.opdef, *inputs)


def __repr__(self): def __repr__(self):
return "{} = {}({})".format(
return "%{}: {} = {}({})".format(
self._id,
", ".join(str(i) for i in self.outputs), ", ".join(str(i) for i in self.outputs),
self.opdef, self.opdef,
", ".join(str(i) for i in self.inputs), ", ".join(str(i) for i in self.inputs),
@@ -241,6 +288,7 @@ class Apply(Expr):


class CallFunction(Expr): class CallFunction(Expr):
def __init__(self, func): def __init__(self, func):
super().__init__()
assert isinstance(func, Callable) assert isinstance(func, Callable)
self.func = func self.func = func
self.const_val = [] self.const_val = []
@@ -255,16 +303,20 @@ class CallFunction(Expr):
def interpret(self, *inputs): def interpret(self, *inputs):
args, kwargs = self.unflatten_args(inputs) args, kwargs = self.unflatten_args(inputs)
outputs = self.func(*args, **kwargs) outputs = self.func(*args, **kwargs)
outputs = (
outputs if isinstance(outputs, collections.abc.Sequence) else (outputs,)
)
if outputs is None:
return outputs
outputs, _ = tree_flatten(outputs, is_leaf=lambda x: isinstance(x, RawTensor))
return outputs return outputs


def __repr__(self): def __repr__(self):
args = ", ".join(str(i) for i in self.args) args = ", ".join(str(i) for i in self.args)
kwargs = ", ".join("{}={}".format(k, v) for k, v in self.kwargs.items()) kwargs = ", ".join("{}={}".format(k, v) for k, v in self.kwargs.items())
return "{} = {}({})".format(
", ".join(str(i) for i in self.outputs),
outputs = self.outputs
if self.out_def:
outputs = self.out_def.unflatten(outputs)
return "%{}: {}{}({})".format(
self._id,
str(outputs) + " = " if outputs else "",
self.func.__module__ + "." + self.func.__name__, self.func.__module__ + "." + self.func.__name__,
", ".join([args, kwargs]), ", ".join([args, kwargs]),
) )
@@ -277,6 +329,7 @@ class Constant(Expr):
_constant_cache = {} _constant_cache = {}


def __init__(self, c): def __init__(self, c):
super().__init__()
assert isinstance(c, (RawTensor, Module)) assert isinstance(c, (RawTensor, Module))
if isinstance(c, Module): if isinstance(c, Module):
assert module_tracer.is_builtin(c) assert module_tracer.is_builtin(c)
@@ -299,7 +352,9 @@ class Constant(Expr):
return (self.value,) return (self.value,)


def __repr__(self): def __repr__(self):
return "{} = Constant({})".format(self.outputs[0], type(self.value))
return "%{}: {} = Constant({})".format(
self._id, self.outputs[0], type(self.value)
)


def __getstate__(self): def __getstate__(self):
state = self.__dict__.copy() state = self.__dict__.copy()


+ 22
- 8
imperative/python/megengine/experimental/traced_module/node.py View File

@@ -30,6 +30,7 @@ class Node:
__total_id = 0 __total_id = 0
_id = None _id = None
_name = None _name = None
_top_graph = None # type: weakref.ReferenceType


def __init__(self, expr: "Expr", name: str = None): def __init__(self, expr: "Expr", name: str = None):
self.expr = expr self.expr = expr
@@ -48,6 +49,12 @@ class Node:
else: else:
return "%{}".format(self._name) return "%{}".format(self._name)


@property
def top_graph(self):
if self._top_graph:
return self._top_graph()
return None



class ModuleNode(Node): class ModuleNode(Node):
""" """
@@ -64,21 +71,28 @@ class ModuleNode(Node):


def __init__(self, expr: "Expr", name: str = None): def __init__(self, expr: "Expr", name: str = None):
super().__init__(expr, name) super().__init__(expr, name)
self.actual_mnode = []


def __repr__(self): def __repr__(self):
if self._name is None: if self._name is None:
return "%{}({})".format(self._id, self.module_type.__name__)
return "%{}_({})".format(self._id, self.module_type.__name__)
else: else:
return "%{}({})".format(self._name, self.module_type.__name__)
return "%{}_{}({})".format(self._id, self._name, self.module_type.__name__)


def __getstate__(self): def __getstate__(self):
d = self.__dict__
d.pop("_owner", None)
return d
return {
"expr": self.expr,
"users": self.users,
"_id": self._id,
"_name": self._name,
"module_type": self.module_type,
}


@property @property
def owner(self): def owner(self):
return self._owner()
if self._owner:
return self._owner()
return None




class TensorNode(Node): class TensorNode(Node):
@@ -91,9 +105,9 @@ class TensorNode(Node):


def __repr__(self): def __repr__(self):
if self._name is None: if self._name is None:
return "%{}(Tensor)".format(self._id)
return "%{}_(Tensor)".format(self._id)
else: else:
return "%{}(Tensor)".format(self._name)
return "%{}_{}(Tensor)".format(self._id, self._name)




class NodeMixin(abc.ABC): class NodeMixin(abc.ABC):


+ 23
- 0
imperative/python/megengine/experimental/traced_module/pytree.py View File

@@ -8,6 +8,7 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.


import collections import collections
from collections import OrderedDict
from typing import Callable, NamedTuple from typing import Callable, NamedTuple


import numpy as np import numpy as np
@@ -34,10 +35,26 @@ def _dict_unflatten(inps, aux_data):
return dict(zip(aux_data, inps)) return dict(zip(aux_data, inps))




def _ordereddict_flatten(inp):
aux_data = []
results = []
for key, value in inp.items():
results.append(value)
aux_data.append(key)
return results, tuple(aux_data)


def _ordereddict_unflatten(inps, aux_data):
return OrderedDict(zip(aux_data, inps))


register_supported_type(list, lambda x: (x, None), lambda x, aux_data: list(x)) register_supported_type(list, lambda x: (x, None), lambda x, aux_data: list(x))
register_supported_type(tuple, lambda x: (x, None), lambda x, aux_data: tuple(x)) register_supported_type(tuple, lambda x: (x, None), lambda x, aux_data: tuple(x))
register_supported_type(dict, _dict_flatten, _dict_unflatten) register_supported_type(dict, _dict_flatten, _dict_unflatten)
register_supported_type( register_supported_type(
collections.OrderedDict, _ordereddict_flatten, _ordereddict_unflatten
)
register_supported_type(
slice, slice,
lambda x: ([x.start, x.stop, x.step], None), lambda x: ([x.start, x.stop, x.step], None),
lambda x, aux_data: slice(x[0], x[1], x[2]), lambda x, aux_data: slice(x[0], x[1], x[2]),
@@ -99,6 +116,12 @@ class TreeDef:
) )
) )


def __lt__(self, other):
return self.__hash__() < other.__hash__()

def __gt__(self, other):
return self.__hash__() > other.__hash__()

def __eq__(self, other): def __eq__(self, other):
return ( return (
self.type == other.type self.type == other.type


+ 492
- 98
imperative/python/megengine/experimental/traced_module/traced_module.py View File

@@ -9,12 +9,10 @@
import collections import collections
import copy import copy
import functools import functools
import inspect
import weakref import weakref
from inspect import getmembers, isclass, ismethod from inspect import getmembers, isclass, ismethod
from typing import Callable, Dict, Iterable, List, Sequence, Type

import numpy as np
from numpy.lib.arraysetops import isin
from typing import Callable, Dict, Iterable, List, Optional, Sequence, Type, Union


from ... import functional as F from ... import functional as F
from ... import get_logger from ... import get_logger
@@ -43,9 +41,9 @@ logger = get_logger(__name__)




def _leaf_type(node): def _leaf_type(node):
if isinstance(node, RawTensor):
if isinstance(node, (RawTensor, TensorNode)):
return (Tensor, TensorNode) return (Tensor, TensorNode)
elif isinstance(node, (NodeMixin, Module)):
elif isinstance(node, (NodeMixin, Module, ModuleNode)):
return (Module, ModuleNode, NodeMixin) return (Module, ModuleNode, NodeMixin)
else: else:
return type(node) return type(node)
@@ -64,6 +62,50 @@ def _is_const_leaf(node):
return True return True




def wrap_tensors(tensors: Tensor, nodes: TensorNode):
inp_tensors = copy.deepcopy(tensors)
inp_tensors, inp_def_v = tree_flatten(
inp_tensors, leaf_type=_leaf_type, is_const_leaf=_is_const_leaf
)
inp_nodes, inp_def_n = tree_flatten(
nodes, leaf_type=_leaf_type, is_const_leaf=_is_const_leaf
)
for v, n in zip(inp_tensors, inp_nodes):
if isinstance(n, TensorNode) and isinstance(v, Tensor):
NodeMixin.wrap_safe(v, n)
return inp_def_v.unflatten(inp_tensors)


class _InsertExprs:
def __init__(self, graph, expr: Optional[Expr] = None, after: bool = True):
self.graph = graph
self.global_scope = InternalGraph()
self.expr = expr
self.after = after

def __enter__(self):
self.use_sym_shape = set_symbolic_shape(True)
set_module_tracing()
assert active_module_tracer() is None
set_active_module_tracer(module_tracer(_wrapped_function))
active_module_tracer().patcher.__enter__()
active_module_tracer().push_scope(self.global_scope)

def __exit__(self, ty, va, tr):
set_symbolic_shape(self.use_sym_shape)
unset_module_tracing()
active_module_tracer().patcher.__exit__(ty, va, tr)
set_active_module_tracer(None)
index = len(self.graph._exprs) if self.after else 0
if self.expr is not None:
index = self.graph._exprs.index(self.expr)
if self.after:
index += 1
for expr in self.global_scope._exprs:
self.graph._exprs.insert(index, expr)
index += 1


class InternalGraph: class InternalGraph:
""" """
``InternalGraph`` is a graph consist of ``Node`` and ``Expr``, it is used to represent the execution procedure of Module's forward method. ``InternalGraph`` is a graph consist of ``Node`` and ``Expr``, it is used to represent the execution procedure of Module's forward method.
@@ -95,14 +137,28 @@ class InternalGraph:
return self._outputs return self._outputs


@property @property
def exprs(self):
def expr_filter(self):
return ExprFilter(_expr_iter(self)) return ExprFilter(_expr_iter(self))


def get_call_function(self, func: Callable = None):
return self.exprs.call_function(func)
@property
def node_filter(self):
return NodeFilter(_node_iter(self))

def get_function_by_type(self, func: Callable = None):
return self.expr_filter.call_function(func)

def get_method_by_type(self, method: str = None):
return self.expr_filter.call_method(method)


def get_call_method(self, method: str = None):
return self.exprs.call_method(method)
def get_expr_by_id(self, expr_id: List[int] = None):
return self.expr_filter.expr_id(expr_id)

def get_module_by_type(self, module_cls: Module):
assert issubclass(module_cls, Module)
return self.node_filter.type(module_cls, ModuleNode)

def get_node_by_id(self, node_id: List[int] = None):
return self.node_filter.node_id(node_id)


def add_input(self, i): def add_input(self, i):
self._inputs.append(i) self._inputs.append(i)
@@ -124,7 +180,6 @@ class InternalGraph:
for idx, o in enumerate(self._outputs): for idx, o in enumerate(self._outputs):
if o in repl_dict: if o in repl_dict:
self._outputs[idx] = repl_dict[o] self._outputs[idx] = repl_dict[o]
self._outputs[idx].expr = node.expr


for expr in self._exprs: for expr in self._exprs:


@@ -135,83 +190,283 @@ class InternalGraph:
for idx, o in enumerate(expr.outputs): for idx, o in enumerate(expr.outputs):
if o in repl_dict: if o in repl_dict:
expr.outputs[idx] = repl_dict[o] expr.outputs[idx] = repl_dict[o]
expr.outputs[idx].expr = expr


def get_dep_exprs(self, nodes: Sequence[Node]) -> List[Expr]: def get_dep_exprs(self, nodes: Sequence[Node]) -> List[Expr]:
if not isinstance(nodes, Sequence): if not isinstance(nodes, Sequence):
nodes = (nodes,) nodes = (nodes,)
ret = list() ret = list()
queue = list(nodes) queue = list(nodes)
visited_queue = list()
while queue: while queue:
node = queue.pop() node = queue.pop()
visited_queue.append(node)

expr = node.expr expr = node.expr

if expr not in ret: if expr not in ret:
ret.append(expr) ret.append(expr)


for i in expr.inputs: for i in expr.inputs:
if i not in queue:
if i not in queue and i not in visited_queue:
queue.append(i) queue.append(i)
return ret return ret


def insert_call_function(self, func: Callable, nodes: Sequence[Node]):
if not isinstance(nodes, Sequence):
nodes = [nodes]
assert isinstance(func, Callable)
for i in nodes:
assert isinstance(
i, TensorNode
), "CallFunction only accept TensorNode as inputs"
def reset_inputs(self, *args, **kwargs):
forma_mnode = self.inputs[0]
actual_mnodes = forma_mnode.actual_mnode
call_nodes = []
for n in actual_mnodes:
for c_expr in n.users:
if isinstance(c_expr, CallMethod) and c_expr.method == "__call__":
call_nodes.append((c_expr, n))


expr = CallFunction(func)
expr.inputs = nodes
moudle = forma_mnode.owner
assert moudle._is_top, "reset_inputs only support the top-level graph"

inputs, tree_def = tree_flatten(
((moudle, *args), kwargs),
leaf_type=_leaf_type,
is_const_leaf=_is_const_leaf,
)


for i in nodes:
i.users.append(expr)
def create_node(val: Tensor):
node = Input(type=TensorNode).outputs[0]
node.shape = val.shape
node.dtype = val.dtype
return node


idx = max(self._exprs.index(i.expr) for i in nodes) + 1
self._exprs.insert(idx, expr)
formal_node_inputs = [
forma_mnode,
]

org_argdef = list(moudle.argdef_graph_map.keys())[0]
if call_nodes:
org_argdef = call_nodes[0][0].arg_def

for v in inputs[1:]:
assert isinstance(v, RawTensor)
formal_node_inputs.append(create_node(v))

actual_nodes = []
for e, n in call_nodes:
e.arg_def = tree_def
actual_node_inputs = [
n,
]
for v in inputs[1:]:
actual_node_inputs.append(create_node(v))

for org_n in e.inputs:
org_n.users.pop(e)

e.inputs[:] = actual_node_inputs
e.const_val = []
actual_nodes.append(actual_node_inputs[1:])

self._inputs[:] = formal_node_inputs
moudle.argdef_graph_map[tree_def] = moudle.argdef_graph_map.pop(org_argdef)
moudle.argdef_outdef_map[tree_def] = moudle.argdef_outdef_map.pop(org_argdef)

# return formal_node_inputs[1:], actual_nodes
return formal_node_inputs[1:]

def add_input_node(self, shape, dtype="float32"):
forma_mnode = self.inputs[0]
actual_mnodes = forma_mnode.actual_mnode

moudle = forma_mnode.owner
assert moudle._is_top, "add_input_node only support the top-level graph"

call_nodes = []
for n in actual_mnodes:
for c_expr in n.users:
if isinstance(c_expr, CallMethod) and c_expr.method == "__call__":
call_nodes.append(c_expr)

def create_node(is_input: bool = True):
if is_input:
node = Input(type=TensorNode).outputs[0]
else:
node = TensorNode(expr=None)
node.shape = shape
node.dtype = dtype
return node


fake_inp_val = tuple(F.zeros(shape=i.shape, dtype=i.dtype) for i in nodes)
fake_out_val = func(*fake_inp_val)
org_argdef = list(moudle.argdef_graph_map.keys())[0]


def create_node(val: Tensor):
if call_nodes:
org_argdef = call_nodes[0].arg_def

args, kwargs = org_argdef.unflatten(self._inputs)
formal_inp_node = create_node(True)
inputs, tree_def = tree_flatten(
((*args, formal_inp_node), kwargs),
leaf_type=_leaf_type,
is_const_leaf=lambda x: not isinstance(x, (TensorNode, ModuleNode)),
)
self._inputs[:] = inputs[:]

actual_inp_nodes = []
for e in call_nodes:
args, kwargs = e.unflatten_args(e.inputs)
args = args + (create_node(False),)
inputs, tree_def = tree_flatten(
(args, kwargs),
leaf_type=_leaf_type,
is_const_leaf=lambda x: not isinstance(x, (TensorNode, ModuleNode)),
)
e.inputs[:] = inputs[:]
e.arg_def = tree_def
actual_inp_nodes.append(args[-1])

moudle.argdef_graph_map[tree_def] = moudle.argdef_graph_map.pop(org_argdef)
moudle.argdef_outdef_map[tree_def] = moudle.argdef_outdef_map.pop(org_argdef)

# return formal_inp_node, actual_inp_nodes
return formal_inp_node

def reset_outputs(self, outputs):
outputs, out_def = tree_flatten(
outputs, leaf_type=_leaf_type, is_leaf=lambda x: isinstance(x, TensorNode),
)
forma_mnode = self.inputs[0]

moudle = forma_mnode.owner
assert moudle._is_top, "reset_outputs only support the top-level graph"

actual_mnodes = forma_mnode.actual_mnode
call_nodes = []
for n in actual_mnodes:
for c_expr in n.users:
if isinstance(c_expr, CallMethod) and c_expr.method == "__call__":
call_nodes.append((c_expr))

def create_node(val: TensorNode, expr: Expr):
node = TensorNode(expr) node = TensorNode(expr)
node.shape = val.shape node.shape = val.shape
node.dtype = val.dtype node.dtype = val.dtype
return node return node


out_nodes = list(create_node(i) for i in fake_out_val)
expr.outputs = out_nodes
tree_def = list(moudle.argdef_graph_map.keys())[0]
if call_nodes:
tree_def = call_nodes[0].arg_def


return out_nodes
actual_nodes = []
for e in call_nodes:
actual_node_outputs = []
for v in outputs:
actual_node_outputs.append(create_node(v, e))
e.outputs[:] = actual_node_outputs
e.out_def = out_def
actual_nodes.append(actual_node_outputs)


def insert_call_method(self, target, method, args):
if not isinstance(args, Sequence):
args = [args]
assert isinstance(target, (TensorNode, ModuleNode))
assert isinstance(method, str)
for i in args:
assert isinstance(i, TensorNode)
self._outputs[:] = outputs
moudle.argdef_outdef_map[tree_def] = out_def


expr = CallMethod(method)
expr.inputs = [target, *args]
return actual_nodes


if isinstance(target, TensorNode):
fake_target_val = F.zeros(shape=target.shape, dtype=target.dtype)
fake_inp_val = tuple(F.zeros(shape=i.shape, dtype=i.dtype) for i in args)
fake_out_val = getattr(fake_target_val, method)(fake_inp_val)
def add_output_node(self, node: TensorNode):
forma_mnode = self.inputs[0]


def create_node(val: Tensor):
node = TensorNode(expr)
node.shape = val.shape
node.dtype = val.dtype
return node
moudle = forma_mnode.owner
assert moudle._is_top, "add_output_node only support the top-level graph"


out_nodes = list(create_node(i) for i in fake_out_val)
expr.outputs = out_nodes
else:
raise NotImplementedError()
actual_mnodes = forma_mnode.actual_mnode
call_nodes = []

for n in actual_mnodes:
for c_expr in n.users:
if isinstance(c_expr, CallMethod) and c_expr.method == "__call__":
call_nodes.append((c_expr))

def create_node(val: TensorNode, expr: Expr):
node = TensorNode(expr)
node.shape = val.shape
node.dtype = val.dtype
return node

tree_def = list(moudle.argdef_graph_map.keys())[0]
if call_nodes:
tree_def = call_nodes[0].arg_def

org_out_def = moudle.argdef_outdef_map[tree_def]
org_outs = org_out_def.unflatten(self._outputs)
outputs, out_def = tree_flatten(
(org_outs, node),
leaf_type=_leaf_type,
is_leaf=lambda x: isinstance(x, TensorNode),
)
self._outputs[:] = outputs

actual_out_nodes = []
for e in call_nodes:
actual_node = create_node(node, e)
org_outs = org_out_def.unflatten(e.outputs)
outputs, out_def = tree_flatten(
(org_outs, actual_node),
leaf_type=_leaf_type,
is_leaf=lambda x: isinstance(x, TensorNode),
)
e.outputs[:] = outputs
e.out_def = out_def
actual_out_nodes.append(actual_node)

moudle.argdef_outdef_map[tree_def] = out_def

return actual_out_nodes

def insert_function(self, func: Callable, *args, **kwargs):
assert isinstance(func, Callable)

inp_nodes, inp_def = tree_flatten(
(args, kwargs), leaf_type=_leaf_type, is_const_leaf=_is_const_leaf
)

insert_idx = -1
for i in inp_nodes:
if isinstance(i, TensorNode) and i.expr in self._exprs:
insert_idx = max(insert_idx, self._exprs.index(i.expr))

fake_inp_val = list(
F.zeros(shape=i.shape, dtype=i.dtype) if isinstance(i, TensorNode) else i
for i in inp_nodes
)

for v, n in zip(fake_inp_val, inp_nodes):
if isinstance(n, TensorNode):
NodeMixin.wrap_safe(v, n)

fake_args, fake_kwargs = inp_def.unflatten(fake_inp_val)

insert_point = self.insert_exprs_before()
if insert_idx != -1:
insert_point = self.insert_exprs_after(self._exprs[insert_idx])

with insert_point:
rst = func(*fake_args, **fake_kwargs)

if rst is None:
return None

outputs, out_def = tree_flatten(rst, leaf_type=_leaf_type, is_leaf=_is_leaf)
node_outputs = []
for out in outputs:
assert isinstance(out, RawTensor)
node_outputs.append(NodeMixin.get(out, None))

node_outputs = out_def.unflatten(node_outputs)
return node_outputs

def insert_exprs_after(self, expr: Optional[Expr] = None):
if expr is not None:
assert expr.top_graph == self, "Expr to insert after is not in graph."
return _InsertExprs(self, expr, after=True)


return out_nodes
def insert_exprs_before(self, expr: Optional[Expr] = None):
if expr is not None:
assert expr.top_graph == self, "Expr to insert before is not in graph."
return _InsertExprs(self, expr, after=False)


def replace_node(self, repl_dict: Dict[Node, Node]): def replace_node(self, repl_dict: Dict[Node, Node]):
while repl_dict: while repl_dict:
@@ -246,7 +501,7 @@ class InternalGraph:
i = 0 i = 0
while i < len(self._exprs): while i < len(self._exprs):
expr = self._exprs[i] expr = self._exprs[i]
if expr in dep_exprs:
if expr in dep_exprs or expr._disable_remove:
i += 1 i += 1
continue continue
for n in expr.inputs: for n in expr.inputs:
@@ -267,7 +522,7 @@ class InternalGraph:
def __repr__(self): def __repr__(self):
return "InternalGraph ({}) {{\n\t{}\n\treturn {}\n}}".format( return "InternalGraph ({}) {{\n\t{}\n\treturn {}\n}}".format(
", ".join(str(i) for i in self._inputs), ", ".join(str(i) for i in self._inputs),
"\n\t".join(str(i) for i in self._exprs),
"\n\t".join("{}".format(str(i)) for i in self._exprs),
", ".join(str(i) for i in self._outputs), ", ".join(str(i) for i in self._outputs),
) )


@@ -293,7 +548,7 @@ def _wrapped_function(orig_func):
if not NodeMixin.get(i, None): if not NodeMixin.get(i, None):
if isinstance(i, (RawTensor, NodeMixin)): if isinstance(i, (RawTensor, NodeMixin)):
NodeMixin.wrap_safe(i, Constant.make(i)) NodeMixin.wrap_safe(i, Constant.make(i))
meth_name = _get_meth_name(args[0], wrapped_fn)
meth_name = _get_meth_name(args[0], wrapped_fn) if args else None
if meth_name: if meth_name:
self = inputs[0] self = inputs[0]
if meth_name == "__new__": if meth_name == "__new__":
@@ -316,10 +571,19 @@ def _wrapped_function(orig_func):
call_node.add_inputs(inputs) call_node.add_inputs(inputs)


call_node.arg_def = tree_def call_node.arg_def = tree_def
outputs = orig_func(*args, **kwargs)
rst = orig_func(*args, **kwargs)
if meth_name == "__setitem__":
rst = self
if rst is not None:
outputs, out_def = tree_flatten(
rst, leaf_type=_leaf_type, is_leaf=_is_leaf
)
call_node.out_def = out_def
else:
outputs = None
call_node.add_outputs(outputs) call_node.add_outputs(outputs)
set_module_tracing() set_module_tracing()
return outputs
return rst
return orig_func(*args, **kwargs) return orig_func(*args, **kwargs)


return wrapped_fn return wrapped_fn
@@ -349,6 +613,7 @@ class TracedModuleBuilder(NodeMixin):
super(TracedModuleBuilder, self).__init__() super(TracedModuleBuilder, self).__init__()
self._mod = mod self._mod = mod
self._body = None self._body = None
self._is_top = is_top_module
self._is_builtin = module_tracer.is_builtin(mod) self._is_builtin = module_tracer.is_builtin(mod)
self._argdef_graph_map = {} self._argdef_graph_map = {}
self._argdef_outdef_map = {} self._argdef_outdef_map = {}
@@ -362,7 +627,7 @@ class TracedModuleBuilder(NodeMixin):
return self._mod return self._mod
else: else:
traced_module = TracedModule( traced_module = TracedModule(
self._argdef_graph_map, self._argdef_outdef_map
self._is_top, self._argdef_graph_map, self._argdef_outdef_map
) )
for _, g in self._argdef_graph_map.items(): for _, g in self._argdef_graph_map.items():
g.compile() g.compile()
@@ -408,8 +673,8 @@ class TracedModuleBuilder(NodeMixin):
self._body = None self._body = None
else: else:
self_node = None self_node = None
if self._body:
self_node = self._body.inputs[0]
if tree_def in self._argdef_graph_map:
self_node = self._argdef_graph_map[tree_def].inputs[0]
self._body = InternalGraph() self._body = InternalGraph()
active_module_tracer().push_scope(self._body) active_module_tracer().push_scope(self._body)
# rebind self to new input node # rebind self to new input node
@@ -446,7 +711,7 @@ class TracedModuleBuilder(NodeMixin):
outputs if isinstance(outputs, collections.abc.Sequence) else (outputs,) outputs if isinstance(outputs, collections.abc.Sequence) else (outputs,)
): ):
active_module_tracer().current_scope().add_output(NodeMixin.get(i)) active_module_tracer().current_scope().add_output(NodeMixin.get(i))
NodeMixin.get(self, None).actual_mnode.append(orig_self)
NodeMixin.wrap_safe(self, orig_self) NodeMixin.wrap_safe(self, orig_self)
for arg, node in zip(inputs[1:], origin_inp_node): for arg, node in zip(inputs[1:], origin_inp_node):
if node: if node:
@@ -454,6 +719,7 @@ class TracedModuleBuilder(NodeMixin):
active_module_tracer().pop_scope() active_module_tracer().pop_scope()


# rebind output to outer graph # rebind output to outer graph
callnode.out_def = out_def
callnode.add_outputs(outputs) callnode.add_outputs(outputs)
self._argdef_graph_map[callnode.arg_def] = self._body self._argdef_graph_map[callnode.arg_def] = self._body
self._argdef_outdef_map[callnode.arg_def] = out_def self._argdef_outdef_map[callnode.arg_def] = out_def
@@ -512,31 +778,44 @@ class _expr_iter:
if isinstance(expr, CallMethod) and isinstance(expr.inputs[0], ModuleNode): if isinstance(expr, CallMethod) and isinstance(expr.inputs[0], ModuleNode):
yield expr yield expr
if expr.graph is not None: if expr.graph is not None:
yield from expr.graph.exprs
yield from expr.graph.expr_filter
else: else:
yield expr yield expr




class ExprFilter:
class _node_iter:
def __init__(self, graph: InternalGraph) -> None:
nodes = []
node_ids = set()
for expr in graph.expr_filter:
for n in expr.inputs + expr.outputs:
if n._id in node_ids:
continue
nodes.append(n)
node_ids.add(n._id)
self.nodes = list(sorted(nodes, key=lambda x: x._id))

def __iter__(self):
for node in self.nodes:
yield node


class BaseFilter:
def __init__(self, expr_iter: Iterable): def __init__(self, expr_iter: Iterable):
self._iter = expr_iter self._iter = expr_iter


def __iter__(self): def __iter__(self):
return iter(self._iter) return iter(self._iter)


def call_function(self, func):
return ExprFilterCallFunction(self, func)

def call_method(self, method):
return ExprFilterCallMethod(self, method)

def as_list(self): def as_list(self):
return list(self) return list(self)


def as_dict(self): def as_dict(self):
raise NotImplementedError("need key")
return collections.OrderedDict((i._id, i) for i in self)


def as_unique(self): def as_unique(self):
rst = self.as_list()
assert len(rst) == 1, "{} elements found".format(len(rst))
(expr,) = self (expr,) = self
return expr return expr


@@ -544,17 +823,65 @@ class ExprFilter:
return sum(1 for _ in self) return sum(1 for _ in self)




class ExprFilter(BaseFilter):
def call_function(self, func):
return ExprFilterCallFunction(self, func)

def call_method(self, method):
return ExprFilterCallMethod(self, method)

def expr_id(self, expr_id: List[int]):
return ExprFilterExprId(self, expr_id)


class NodeFilter(BaseFilter):
def type(self, owner_type, node_type):
return NodeFilterType(self, owner_type, node_type)

def node_id(self, node_id: List[int]):
return NodeFilterNodeId(self, node_id)


class NodeFilterType(NodeFilter):
def __init__(self, expr_iter, owner_type, node_type):
super().__init__(expr_iter)
self.owner_type = owner_type
self.node_type = node_type

def __iter__(self):
for node in self._iter:
if not isinstance(node, self.node_type):
continue
if not hasattr(node, "owner"):
continue
if isinstance(node.owner, self.owner_type):
yield node


class NodeFilterNodeId(NodeFilter):
def __init__(self, expr_iter, node_id: List[int]):
super().__init__(expr_iter)
if not isinstance(node_id, Sequence):
node_id = [node_id]
self.node_id = node_id

def __iter__(self):
for node in self._iter:
if node._id in self.node_id:
yield node


class ExprFilterCallFunction(ExprFilter): class ExprFilterCallFunction(ExprFilter):
def __init__(self, expr_iter, func: Callable = None): def __init__(self, expr_iter, func: Callable = None):
super().__init__(expr_iter) super().__init__(expr_iter)
self.func = func self.func = func


def __iter__(self): def __iter__(self):
for i in self._iter:
if not isinstance(i, CallFunction):
for expr in self._iter:
if not isinstance(expr, CallFunction):
continue continue
if self.func is None or i.func == self.func:
yield i
if self.func is None or expr.func == self.func:
yield expr




class ExprFilterCallMethod(ExprFilter): class ExprFilterCallMethod(ExprFilter):
@@ -563,11 +890,24 @@ class ExprFilterCallMethod(ExprFilter):
self.method = method self.method = method


def __iter__(self): def __iter__(self):
for i in self._iter:
if not isinstance(i, CallMethod):
for expr in self._iter:
if not isinstance(expr, CallMethod):
continue continue
if self.method is None or i.method == self.method:
yield i
if self.method is None or expr.method == self.method:
yield expr


class ExprFilterExprId(ExprFilter):
def __init__(self, expr_iter, expr_id: List[int]):
super().__init__(expr_iter)
if not isinstance(expr_id, Sequence):
expr_id = [expr_id]
self.expr_id = expr_id

def __iter__(self):
for expr in self._iter:
if expr._id in self.expr_id:
yield expr




class TracedModule(Module): class TracedModule(Module):
@@ -579,10 +919,11 @@ class TracedModule(Module):
argdef_graph_map = None argdef_graph_map = None
argdef_outdef_map = None argdef_outdef_map = None


def __init__(self, argdef_graph_map, argdef_outdef_map):
def __init__(self, is_top, argdef_graph_map, argdef_outdef_map):
super(TracedModule, self).__init__() super(TracedModule, self).__init__()
self.argdef_graph_map = argdef_graph_map self.argdef_graph_map = argdef_graph_map
self.argdef_outdef_map = argdef_outdef_map self.argdef_outdef_map = argdef_outdef_map
self._is_top = is_top


def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
inputs, treedef = tree_flatten( inputs, treedef = tree_flatten(
@@ -598,29 +939,58 @@ class TracedModule(Module):
return outputs return outputs


@property @property
def graph(self):
self._update_modulenode_ref()
def graph(self) -> InternalGraph:
if self._is_top:
self._update_ref()
assert len(self.argdef_graph_map) == 1 assert len(self.argdef_graph_map) == 1
return list(self.argdef_graph_map.values())[0] return list(self.argdef_graph_map.values())[0]


def _update_modulenode_ref(self):
for _, graph in self.argdef_graph_map.items():
def _update_ref(self, actual_node_map: Union[Dict] = None):
for inp_def, graph in self.argdef_graph_map.items():
for n in graph._inputs + graph.outputs:
n._top_graph = weakref.ref(graph)
graph._inputs[0]._owner = weakref.ref(self) graph._inputs[0]._owner = weakref.ref(self)
graph._inputs[0].actual_mnode = []
if actual_node_map is not None and inp_def in actual_node_map.keys():
graph._inputs[0].actual_mnode = actual_node_map[inp_def]
node2obj = {} node2obj = {}
next_actual_node_map = collections.defaultdict(
lambda: collections.defaultdict(list)
)
node2obj[graph._inputs[0]] = self node2obj[graph._inputs[0]] = self
for expr in graph._exprs: for expr in graph._exprs:
for n in expr.inputs + expr.outputs:
n._top_graph = weakref.ref(graph)
expr._top_graph = weakref.ref(graph)
if isinstance(expr, GetAttr) and isinstance( if isinstance(expr, GetAttr) and isinstance(
expr.outputs[0], ModuleNode expr.outputs[0], ModuleNode
): ):
obj = getattr(node2obj[expr.inputs[0]], expr.name) obj = getattr(node2obj[expr.inputs[0]], expr.name)
expr.outputs[0]._owner = weakref.ref(obj) expr.outputs[0]._owner = weakref.ref(obj)
node2obj[expr.outputs[0]] = obj node2obj[expr.outputs[0]] = obj
if isinstance(obj, TracedModule):
obj._update_modulenode_ref()
if isinstance(expr, Constant) and isinstance(
expr.outputs[0], ModuleNode
):
obj = expr.value
expr.outputs[0]._owner = weakref.ref(obj)
node2obj[expr.outputs[0]] = obj
if (
isinstance(expr, CallMethod)
and expr.method == "__call__"
and isinstance(expr.inputs[0], ModuleNode)
):
obj = node2obj[expr.inputs[0]]
if expr.arg_def is not None:
next_actual_node_map[obj][expr.arg_def].append(expr.inputs[0])


@property
def exprs(self):
return self.graph.exprs
for obj in node2obj.values():
if obj is self:
continue
mnode_map = None
if obj in next_actual_node_map.keys():
mnode_map = next_actual_node_map[obj]
if isinstance(obj, TracedModule):
obj._update_ref(mnode_map)


def flatten(self): def flatten(self):
""" """
@@ -644,13 +1014,21 @@ class TracedModule(Module):
node2obj[graph._inputs[0]] = module node2obj[graph._inputs[0]] = module
if call: if call:
node2obj[call.inputs[0]] = module node2obj[call.inputs[0]] = module
repl_dict = dict(zip(graph._inputs, call.inputs))
for ind, out in enumerate(graph.outputs):
if isinstance(out.expr, Input):
assert out in repl_dict
call_out = call.outputs[ind]
for expr in call.outputs[ind].users:
for index, inp in enumerate(expr.inputs):
if inp is call_out:
expr.inputs[index] = repl_dict[out]

continue
repl_dict[out] = call.outputs[ind]

graph._replace_inputs_outputs(repl_dict)
for expr in graph._exprs: for expr in graph._exprs:
# replace inputs for submodule's exprx
if call:
repl_dict = dict(
zip(graph._inputs + graph._outputs, call.inputs + call.outputs)
)
graph._replace_inputs_outputs(repl_dict)


if isinstance(expr, GetAttr): if isinstance(expr, GetAttr):
# replace GetAttr with Constant # replace GetAttr with Constant
@@ -715,6 +1093,21 @@ def register_as_builtin(mod_cls: Type[Module]) -> None:
module_tracer.register_as_builtin(mod_cls) module_tracer.register_as_builtin(mod_cls)




def wrap(func: Union[Callable]):
assert callable(func)
if hasattr(func, "__code__"):
assert not isinstance(func, str)
fn_name = func.__code__.co_name
currentframe = inspect.currentframe()
assert currentframe is not None
f = currentframe.f_back
assert f is not None
if f.f_code.co_name != "<module>":
raise NotImplementedError("wrap must be called at the top level of a module")
Patcher._builtin_functions.append((f.f_globals, fn_name))
return func


def _register_all_builtin_module(): def _register_all_builtin_module():


for sub_mod in [M, M.qat, M.quantized]: for sub_mod in [M, M.qat, M.quantized]:
@@ -749,6 +1142,7 @@ def trace_module(mod: Module, *args: Tensor, **kwargs: Tensor) -> TracedModule:
NodeMixin.wrap_safe(builder, Input.make("TopModule", ModuleNode)) NodeMixin.wrap_safe(builder, Input.make("TopModule", ModuleNode))
inputs, _ = tree_flatten((args, kwargs), is_const_leaf=_is_const_leaf) inputs, _ = tree_flatten((args, kwargs), is_const_leaf=_is_const_leaf)
for _, i in enumerate(inputs): for _, i in enumerate(inputs):
assert isinstance(i, Tensor), "not support "
if isinstance(i, RawTensor): if isinstance(i, RawTensor):
NodeMixin.wrap_safe( NodeMixin.wrap_safe(
i, Input.make("arg_{}".format(_), NodeMixin.get_wrapped_type(i)) i, Input.make("arg_{}".format(_), NodeMixin.get_wrapped_type(i))


+ 5
- 5
imperative/python/test/unit/traced_module/test_modification.py View File

@@ -57,16 +57,16 @@ def _init_module():
def test_search(): def test_search():
traced_module, *_ = _init_block() traced_module, *_ = _init_block()
graph = traced_module.graph graph = traced_module.graph
relu_expr = graph.get_call_function(F.relu).as_unique()
relu_expr = graph.get_function_by_type(F.relu).as_unique()
assert isinstance(relu_expr, CallFunction) and relu_expr.func == F.relu assert isinstance(relu_expr, CallFunction) and relu_expr.func == F.relu




def test_insert(): def test_insert():
traced_module, x, expect = _init_block() traced_module, x, expect = _init_block()
graph = traced_module.graph graph = traced_module.graph
relu_node = graph.get_call_function(F.relu).as_unique().outputs
neg_node = graph.insert_call_function(F.neg, relu_node)
graph.replace_node({relu_node[0]: neg_node[0]})
relu_node = graph.get_function_by_type(F.relu).as_unique().outputs
neg_node = graph.insert_function(lambda x: F.neg(x), *relu_node)
graph.replace_node({relu_node[0]: neg_node})
graph.compile() graph.compile()
np.testing.assert_allclose(expect - 1, 1 - traced_module(x), atol=1e-6) np.testing.assert_allclose(expect - 1, 1 - traced_module(x), atol=1e-6)


@@ -74,7 +74,7 @@ def test_insert():
def test_delete(): def test_delete():
traced_module, x, expect = _init_block() traced_module, x, expect = _init_block()
graph = traced_module.graph graph = traced_module.graph
relu_expr = graph.get_call_function(F.relu).as_unique()
relu_expr = graph.get_function_by_type(F.relu).as_unique()
node = relu_expr.outputs node = relu_expr.outputs
repl_node = relu_expr.inputs repl_node = relu_expr.inputs
graph.replace_node({node[0]: repl_node[0]}) graph.replace_node({node[0]: repl_node[0]})


Loading…
Cancel
Save