Browse Source

feat(traced_module): add some functions of graph modification

GitOrigin-RevId: 09691ebd33
release-1.6
Megvii Engine Team 3 years ago
parent
commit
c7e730bc12
6 changed files with 612 additions and 124 deletions
  1. +2
    -0
      imperative/python/megengine/experimental/traced_module/__init__.py
  2. +68
    -13
      imperative/python/megengine/experimental/traced_module/expr.py
  3. +22
    -8
      imperative/python/megengine/experimental/traced_module/node.py
  4. +23
    -0
      imperative/python/megengine/experimental/traced_module/pytree.py
  5. +492
    -98
      imperative/python/megengine/experimental/traced_module/traced_module.py
  6. +5
    -5
      imperative/python/test/unit/traced_module/test_modification.py

+ 2
- 0
imperative/python/megengine/experimental/traced_module/__init__.py View File

@@ -13,6 +13,8 @@ from .traced_module import (
cpp_apply_module_trace,
register_as_builtin,
trace_module,
wrap,
wrap_tensors,
)

_register_all_builtin_module()


+ 68
- 13
imperative/python/megengine/experimental/traced_module/expr.py View File

@@ -11,7 +11,7 @@ import builtins
import collections
import copy
import inspect
from typing import Callable, List
from typing import Callable, Dict, List

from ...core._imperative_rt import OpDef
from ...core._imperative_rt.core2 import Tensor as RawTensor
@@ -29,10 +29,24 @@ class Expr:
``Expr`` represents the operations(i.e. CallMethod, CallFunction, Apply, GetAttr, Input, Constant) on ``Node``.
"""

__total_id = 0
inputs = None # type: List[Node]
outputs = None # type: List[Node]
const_val = None # type: List[Any]
arg_def = None # type: TreeDef
out_def = None # type: TreeDef
_top_graph = None # type: weakref.ReferenceType

def __init__(self) -> None:
self._id = Expr.__total_id
Expr.__total_id += 1
self._disable_remove = False

def enable_remove(self):
self._disable_remove = False

def disable_remove(self):
self._disable_remove = True

def add_inputs(self, vals):
if not isinstance(vals, collections.abc.Sequence):
@@ -70,6 +84,22 @@ class Expr:
else:
return inputs, {}

def _replace_nodes(self, repl_dict: Dict[Node, Node], nodes: List[Node]):
while repl_dict:
node, repl_node = repl_dict.popitem()
assert type(node) == type(repl_node)
assert node in nodes
index = nodes.index(node)
nodes[index] = repl_node
repl_node.users.append(self)
node.users.pop(self)

def replace_inputs(self, repl_dict: Dict[Node, Node]):
self._replace_nodes(repl_dict, self.inputs)

def replace_outputs(self, repl_dict: Dict[Node, Node]):
self._replace_nodes(repl_dict, self.outputs)

@property
def kwargs(self):
_, kwargs = self.unflatten_args(self.inputs)
@@ -80,12 +110,19 @@ class Expr:
args, _ = self.unflatten_args(self.inputs)
return args

@property
def top_graph(self):
if self._top_graph:
return self._top_graph()
return None


# expr: None (i.e. fake expression which is used to mark input)
class Input(Expr):
name = None

def __init__(self, name=None, type=None):
super().__init__()
self.inputs = []
node_cls = type if type else Node
self.outputs = [
@@ -100,7 +137,7 @@ class Input(Expr):
return expr.outputs[0]

def __repr__(self):
return "{} = Input({})".format(self.outputs[0], self.name)
return "%{}: {} = Input({})".format(self._id, self.outputs[0], self.name)


# expr: outputs = getattr(inputs[0], self.name)
@@ -108,6 +145,7 @@ class GetAttr(Expr):
name = None

def __init__(self, module, name, type=None):
super().__init__()
assert isinstance(module, ModuleNode)
self.inputs = [
module,
@@ -130,14 +168,15 @@ class GetAttr(Expr):
return (getattr(inputs[0], self.name),)

def __repr__(self):
return '{} = GetAttr({}, "{}")'.format(
self.outputs[0], self.inputs[0], self.name
return '%{}: {} = GetAttr({}, "{}")'.format(
self._id, self.outputs[0], self.inputs[0], self.name
)


# expr: outputs = inputs[0].__call__(*inputs[1:])
class CallMethod(Expr):
def __init__(self, node, method="__call__"):
super().__init__()
if isinstance(node, type):
assert issubclass(node, Tensor)
cls = Parameter if issubclass(node, Parameter) else Tensor
@@ -178,6 +217,8 @@ class CallMethod(Expr):
if inspect.ismethod(meth):
args = args[1:]
outputs = getattr(obj, self.method)(*args, **kwargs)
if self.method == "__setitem__":
outputs = obj
if outputs is None:
return outputs
outputs, _ = tree_flatten(outputs, is_leaf=lambda x: isinstance(x, RawTensor))
@@ -186,8 +227,12 @@ class CallMethod(Expr):
def __repr__(self):
args = ", ".join(str(i) for i in self.args[1:])
kwargs = ", ".join("{}={}".format(k, v) for k, v in self.kwargs.items())
return "{} = {}.{}({})".format(
", ".join(str(i) for i in self.outputs),
outputs = self.outputs
if self.out_def:
outputs = self.out_def.unflatten(outputs)
return "%{}: {}{}.{}({})".format(
self._id,
str(outputs) + " = " if outputs else "",
self.args[0],
self.method,
", ".join([args, kwargs]),
@@ -199,6 +244,7 @@ class Apply(Expr):
opdef = None

def __init__(self, opdef):
super().__init__()
assert isinstance(opdef, OpDef)
self.opdef = opdef
self.inputs = []
@@ -213,7 +259,8 @@ class Apply(Expr):
return apply(self.opdef, *inputs)

def __repr__(self):
return "{} = {}({})".format(
return "%{}: {} = {}({})".format(
self._id,
", ".join(str(i) for i in self.outputs),
self.opdef,
", ".join(str(i) for i in self.inputs),
@@ -241,6 +288,7 @@ class Apply(Expr):

class CallFunction(Expr):
def __init__(self, func):
super().__init__()
assert isinstance(func, Callable)
self.func = func
self.const_val = []
@@ -255,16 +303,20 @@ class CallFunction(Expr):
def interpret(self, *inputs):
args, kwargs = self.unflatten_args(inputs)
outputs = self.func(*args, **kwargs)
outputs = (
outputs if isinstance(outputs, collections.abc.Sequence) else (outputs,)
)
if outputs is None:
return outputs
outputs, _ = tree_flatten(outputs, is_leaf=lambda x: isinstance(x, RawTensor))
return outputs

def __repr__(self):
args = ", ".join(str(i) for i in self.args)
kwargs = ", ".join("{}={}".format(k, v) for k, v in self.kwargs.items())
return "{} = {}({})".format(
", ".join(str(i) for i in self.outputs),
outputs = self.outputs
if self.out_def:
outputs = self.out_def.unflatten(outputs)
return "%{}: {}{}({})".format(
self._id,
str(outputs) + " = " if outputs else "",
self.func.__module__ + "." + self.func.__name__,
", ".join([args, kwargs]),
)
@@ -277,6 +329,7 @@ class Constant(Expr):
_constant_cache = {}

def __init__(self, c):
super().__init__()
assert isinstance(c, (RawTensor, Module))
if isinstance(c, Module):
assert module_tracer.is_builtin(c)
@@ -299,7 +352,9 @@ class Constant(Expr):
return (self.value,)

def __repr__(self):
return "{} = Constant({})".format(self.outputs[0], type(self.value))
return "%{}: {} = Constant({})".format(
self._id, self.outputs[0], type(self.value)
)

def __getstate__(self):
state = self.__dict__.copy()


+ 22
- 8
imperative/python/megengine/experimental/traced_module/node.py View File

@@ -30,6 +30,7 @@ class Node:
__total_id = 0
_id = None
_name = None
_top_graph = None # type: weakref.ReferenceType

def __init__(self, expr: "Expr", name: str = None):
self.expr = expr
@@ -48,6 +49,12 @@ class Node:
else:
return "%{}".format(self._name)

@property
def top_graph(self):
if self._top_graph:
return self._top_graph()
return None


class ModuleNode(Node):
"""
@@ -64,21 +71,28 @@ class ModuleNode(Node):

def __init__(self, expr: "Expr", name: str = None):
super().__init__(expr, name)
self.actual_mnode = []

def __repr__(self):
if self._name is None:
return "%{}({})".format(self._id, self.module_type.__name__)
return "%{}_({})".format(self._id, self.module_type.__name__)
else:
return "%{}({})".format(self._name, self.module_type.__name__)
return "%{}_{}({})".format(self._id, self._name, self.module_type.__name__)

def __getstate__(self):
d = self.__dict__
d.pop("_owner", None)
return d
return {
"expr": self.expr,
"users": self.users,
"_id": self._id,
"_name": self._name,
"module_type": self.module_type,
}

@property
def owner(self):
return self._owner()
if self._owner:
return self._owner()
return None


class TensorNode(Node):
@@ -91,9 +105,9 @@ class TensorNode(Node):

def __repr__(self):
if self._name is None:
return "%{}(Tensor)".format(self._id)
return "%{}_(Tensor)".format(self._id)
else:
return "%{}(Tensor)".format(self._name)
return "%{}_{}(Tensor)".format(self._id, self._name)


class NodeMixin(abc.ABC):


+ 23
- 0
imperative/python/megengine/experimental/traced_module/pytree.py View File

@@ -8,6 +8,7 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

import collections
from collections import OrderedDict
from typing import Callable, NamedTuple

import numpy as np
@@ -34,10 +35,26 @@ def _dict_unflatten(inps, aux_data):
return dict(zip(aux_data, inps))


def _ordereddict_flatten(inp):
aux_data = []
results = []
for key, value in inp.items():
results.append(value)
aux_data.append(key)
return results, tuple(aux_data)


def _ordereddict_unflatten(inps, aux_data):
return OrderedDict(zip(aux_data, inps))


register_supported_type(list, lambda x: (x, None), lambda x, aux_data: list(x))
register_supported_type(tuple, lambda x: (x, None), lambda x, aux_data: tuple(x))
register_supported_type(dict, _dict_flatten, _dict_unflatten)
register_supported_type(
collections.OrderedDict, _ordereddict_flatten, _ordereddict_unflatten
)
register_supported_type(
slice,
lambda x: ([x.start, x.stop, x.step], None),
lambda x, aux_data: slice(x[0], x[1], x[2]),
@@ -99,6 +116,12 @@ class TreeDef:
)
)

def __lt__(self, other):
return self.__hash__() < other.__hash__()

def __gt__(self, other):
return self.__hash__() > other.__hash__()

def __eq__(self, other):
return (
self.type == other.type


+ 492
- 98
imperative/python/megengine/experimental/traced_module/traced_module.py View File

@@ -9,12 +9,10 @@
import collections
import copy
import functools
import inspect
import weakref
from inspect import getmembers, isclass, ismethod
from typing import Callable, Dict, Iterable, List, Sequence, Type

import numpy as np
from numpy.lib.arraysetops import isin
from typing import Callable, Dict, Iterable, List, Optional, Sequence, Type, Union

from ... import functional as F
from ... import get_logger
@@ -43,9 +41,9 @@ logger = get_logger(__name__)


def _leaf_type(node):
if isinstance(node, RawTensor):
if isinstance(node, (RawTensor, TensorNode)):
return (Tensor, TensorNode)
elif isinstance(node, (NodeMixin, Module)):
elif isinstance(node, (NodeMixin, Module, ModuleNode)):
return (Module, ModuleNode, NodeMixin)
else:
return type(node)
@@ -64,6 +62,50 @@ def _is_const_leaf(node):
return True


def wrap_tensors(tensors: Tensor, nodes: TensorNode):
inp_tensors = copy.deepcopy(tensors)
inp_tensors, inp_def_v = tree_flatten(
inp_tensors, leaf_type=_leaf_type, is_const_leaf=_is_const_leaf
)
inp_nodes, inp_def_n = tree_flatten(
nodes, leaf_type=_leaf_type, is_const_leaf=_is_const_leaf
)
for v, n in zip(inp_tensors, inp_nodes):
if isinstance(n, TensorNode) and isinstance(v, Tensor):
NodeMixin.wrap_safe(v, n)
return inp_def_v.unflatten(inp_tensors)


class _InsertExprs:
def __init__(self, graph, expr: Optional[Expr] = None, after: bool = True):
self.graph = graph
self.global_scope = InternalGraph()
self.expr = expr
self.after = after

def __enter__(self):
self.use_sym_shape = set_symbolic_shape(True)
set_module_tracing()
assert active_module_tracer() is None
set_active_module_tracer(module_tracer(_wrapped_function))
active_module_tracer().patcher.__enter__()
active_module_tracer().push_scope(self.global_scope)

def __exit__(self, ty, va, tr):
set_symbolic_shape(self.use_sym_shape)
unset_module_tracing()
active_module_tracer().patcher.__exit__(ty, va, tr)
set_active_module_tracer(None)
index = len(self.graph._exprs) if self.after else 0
if self.expr is not None:
index = self.graph._exprs.index(self.expr)
if self.after:
index += 1
for expr in self.global_scope._exprs:
self.graph._exprs.insert(index, expr)
index += 1


class InternalGraph:
"""
``InternalGraph`` is a graph consist of ``Node`` and ``Expr``, it is used to represent the execution procedure of Module's forward method.
@@ -95,14 +137,28 @@ class InternalGraph:
return self._outputs

@property
def exprs(self):
def expr_filter(self):
return ExprFilter(_expr_iter(self))

def get_call_function(self, func: Callable = None):
return self.exprs.call_function(func)
@property
def node_filter(self):
return NodeFilter(_node_iter(self))

def get_function_by_type(self, func: Callable = None):
return self.expr_filter.call_function(func)

def get_method_by_type(self, method: str = None):
return self.expr_filter.call_method(method)

def get_call_method(self, method: str = None):
return self.exprs.call_method(method)
def get_expr_by_id(self, expr_id: List[int] = None):
return self.expr_filter.expr_id(expr_id)

def get_module_by_type(self, module_cls: Module):
assert issubclass(module_cls, Module)
return self.node_filter.type(module_cls, ModuleNode)

def get_node_by_id(self, node_id: List[int] = None):
return self.node_filter.node_id(node_id)

def add_input(self, i):
self._inputs.append(i)
@@ -124,7 +180,6 @@ class InternalGraph:
for idx, o in enumerate(self._outputs):
if o in repl_dict:
self._outputs[idx] = repl_dict[o]
self._outputs[idx].expr = node.expr

for expr in self._exprs:

@@ -135,83 +190,283 @@ class InternalGraph:
for idx, o in enumerate(expr.outputs):
if o in repl_dict:
expr.outputs[idx] = repl_dict[o]
expr.outputs[idx].expr = expr

def get_dep_exprs(self, nodes: Sequence[Node]) -> List[Expr]:
if not isinstance(nodes, Sequence):
nodes = (nodes,)
ret = list()
queue = list(nodes)
visited_queue = list()
while queue:
node = queue.pop()
visited_queue.append(node)

expr = node.expr

if expr not in ret:
ret.append(expr)

for i in expr.inputs:
if i not in queue:
if i not in queue and i not in visited_queue:
queue.append(i)
return ret

def insert_call_function(self, func: Callable, nodes: Sequence[Node]):
if not isinstance(nodes, Sequence):
nodes = [nodes]
assert isinstance(func, Callable)
for i in nodes:
assert isinstance(
i, TensorNode
), "CallFunction only accept TensorNode as inputs"
def reset_inputs(self, *args, **kwargs):
forma_mnode = self.inputs[0]
actual_mnodes = forma_mnode.actual_mnode
call_nodes = []
for n in actual_mnodes:
for c_expr in n.users:
if isinstance(c_expr, CallMethod) and c_expr.method == "__call__":
call_nodes.append((c_expr, n))

expr = CallFunction(func)
expr.inputs = nodes
moudle = forma_mnode.owner
assert moudle._is_top, "reset_inputs only support the top-level graph"

inputs, tree_def = tree_flatten(
((moudle, *args), kwargs),
leaf_type=_leaf_type,
is_const_leaf=_is_const_leaf,
)

for i in nodes:
i.users.append(expr)
def create_node(val: Tensor):
node = Input(type=TensorNode).outputs[0]
node.shape = val.shape
node.dtype = val.dtype
return node

idx = max(self._exprs.index(i.expr) for i in nodes) + 1
self._exprs.insert(idx, expr)
formal_node_inputs = [
forma_mnode,
]

org_argdef = list(moudle.argdef_graph_map.keys())[0]
if call_nodes:
org_argdef = call_nodes[0][0].arg_def

for v in inputs[1:]:
assert isinstance(v, RawTensor)
formal_node_inputs.append(create_node(v))

actual_nodes = []
for e, n in call_nodes:
e.arg_def = tree_def
actual_node_inputs = [
n,
]
for v in inputs[1:]:
actual_node_inputs.append(create_node(v))

for org_n in e.inputs:
org_n.users.pop(e)

e.inputs[:] = actual_node_inputs
e.const_val = []
actual_nodes.append(actual_node_inputs[1:])

self._inputs[:] = formal_node_inputs
moudle.argdef_graph_map[tree_def] = moudle.argdef_graph_map.pop(org_argdef)
moudle.argdef_outdef_map[tree_def] = moudle.argdef_outdef_map.pop(org_argdef)

# return formal_node_inputs[1:], actual_nodes
return formal_node_inputs[1:]

def add_input_node(self, shape, dtype="float32"):
forma_mnode = self.inputs[0]
actual_mnodes = forma_mnode.actual_mnode

moudle = forma_mnode.owner
assert moudle._is_top, "add_input_node only support the top-level graph"

call_nodes = []
for n in actual_mnodes:
for c_expr in n.users:
if isinstance(c_expr, CallMethod) and c_expr.method == "__call__":
call_nodes.append(c_expr)

def create_node(is_input: bool = True):
if is_input:
node = Input(type=TensorNode).outputs[0]
else:
node = TensorNode(expr=None)
node.shape = shape
node.dtype = dtype
return node

fake_inp_val = tuple(F.zeros(shape=i.shape, dtype=i.dtype) for i in nodes)
fake_out_val = func(*fake_inp_val)
org_argdef = list(moudle.argdef_graph_map.keys())[0]

def create_node(val: Tensor):
if call_nodes:
org_argdef = call_nodes[0].arg_def

args, kwargs = org_argdef.unflatten(self._inputs)
formal_inp_node = create_node(True)
inputs, tree_def = tree_flatten(
((*args, formal_inp_node), kwargs),
leaf_type=_leaf_type,
is_const_leaf=lambda x: not isinstance(x, (TensorNode, ModuleNode)),
)
self._inputs[:] = inputs[:]

actual_inp_nodes = []
for e in call_nodes:
args, kwargs = e.unflatten_args(e.inputs)
args = args + (create_node(False),)
inputs, tree_def = tree_flatten(
(args, kwargs),
leaf_type=_leaf_type,
is_const_leaf=lambda x: not isinstance(x, (TensorNode, ModuleNode)),
)
e.inputs[:] = inputs[:]
e.arg_def = tree_def
actual_inp_nodes.append(args[-1])

moudle.argdef_graph_map[tree_def] = moudle.argdef_graph_map.pop(org_argdef)
moudle.argdef_outdef_map[tree_def] = moudle.argdef_outdef_map.pop(org_argdef)

# return formal_inp_node, actual_inp_nodes
return formal_inp_node

def reset_outputs(self, outputs):
outputs, out_def = tree_flatten(
outputs, leaf_type=_leaf_type, is_leaf=lambda x: isinstance(x, TensorNode),
)
forma_mnode = self.inputs[0]

moudle = forma_mnode.owner
assert moudle._is_top, "reset_outputs only support the top-level graph"

actual_mnodes = forma_mnode.actual_mnode
call_nodes = []
for n in actual_mnodes:
for c_expr in n.users:
if isinstance(c_expr, CallMethod) and c_expr.method == "__call__":
call_nodes.append((c_expr))

def create_node(val: TensorNode, expr: Expr):
node = TensorNode(expr)
node.shape = val.shape
node.dtype = val.dtype
return node

out_nodes = list(create_node(i) for i in fake_out_val)
expr.outputs = out_nodes
tree_def = list(moudle.argdef_graph_map.keys())[0]
if call_nodes:
tree_def = call_nodes[0].arg_def

return out_nodes
actual_nodes = []
for e in call_nodes:
actual_node_outputs = []
for v in outputs:
actual_node_outputs.append(create_node(v, e))
e.outputs[:] = actual_node_outputs
e.out_def = out_def
actual_nodes.append(actual_node_outputs)

def insert_call_method(self, target, method, args):
if not isinstance(args, Sequence):
args = [args]
assert isinstance(target, (TensorNode, ModuleNode))
assert isinstance(method, str)
for i in args:
assert isinstance(i, TensorNode)
self._outputs[:] = outputs
moudle.argdef_outdef_map[tree_def] = out_def

expr = CallMethod(method)
expr.inputs = [target, *args]
return actual_nodes

if isinstance(target, TensorNode):
fake_target_val = F.zeros(shape=target.shape, dtype=target.dtype)
fake_inp_val = tuple(F.zeros(shape=i.shape, dtype=i.dtype) for i in args)
fake_out_val = getattr(fake_target_val, method)(fake_inp_val)
def add_output_node(self, node: TensorNode):
forma_mnode = self.inputs[0]

def create_node(val: Tensor):
node = TensorNode(expr)
node.shape = val.shape
node.dtype = val.dtype
return node
moudle = forma_mnode.owner
assert moudle._is_top, "add_output_node only support the top-level graph"

out_nodes = list(create_node(i) for i in fake_out_val)
expr.outputs = out_nodes
else:
raise NotImplementedError()
actual_mnodes = forma_mnode.actual_mnode
call_nodes = []

for n in actual_mnodes:
for c_expr in n.users:
if isinstance(c_expr, CallMethod) and c_expr.method == "__call__":
call_nodes.append((c_expr))

def create_node(val: TensorNode, expr: Expr):
node = TensorNode(expr)
node.shape = val.shape
node.dtype = val.dtype
return node

tree_def = list(moudle.argdef_graph_map.keys())[0]
if call_nodes:
tree_def = call_nodes[0].arg_def

org_out_def = moudle.argdef_outdef_map[tree_def]
org_outs = org_out_def.unflatten(self._outputs)
outputs, out_def = tree_flatten(
(org_outs, node),
leaf_type=_leaf_type,
is_leaf=lambda x: isinstance(x, TensorNode),
)
self._outputs[:] = outputs

actual_out_nodes = []
for e in call_nodes:
actual_node = create_node(node, e)
org_outs = org_out_def.unflatten(e.outputs)
outputs, out_def = tree_flatten(
(org_outs, actual_node),
leaf_type=_leaf_type,
is_leaf=lambda x: isinstance(x, TensorNode),
)
e.outputs[:] = outputs
e.out_def = out_def
actual_out_nodes.append(actual_node)

moudle.argdef_outdef_map[tree_def] = out_def

return actual_out_nodes

def insert_function(self, func: Callable, *args, **kwargs):
assert isinstance(func, Callable)

inp_nodes, inp_def = tree_flatten(
(args, kwargs), leaf_type=_leaf_type, is_const_leaf=_is_const_leaf
)

insert_idx = -1
for i in inp_nodes:
if isinstance(i, TensorNode) and i.expr in self._exprs:
insert_idx = max(insert_idx, self._exprs.index(i.expr))

fake_inp_val = list(
F.zeros(shape=i.shape, dtype=i.dtype) if isinstance(i, TensorNode) else i
for i in inp_nodes
)

for v, n in zip(fake_inp_val, inp_nodes):
if isinstance(n, TensorNode):
NodeMixin.wrap_safe(v, n)

fake_args, fake_kwargs = inp_def.unflatten(fake_inp_val)

insert_point = self.insert_exprs_before()
if insert_idx != -1:
insert_point = self.insert_exprs_after(self._exprs[insert_idx])

with insert_point:
rst = func(*fake_args, **fake_kwargs)

if rst is None:
return None

outputs, out_def = tree_flatten(rst, leaf_type=_leaf_type, is_leaf=_is_leaf)
node_outputs = []
for out in outputs:
assert isinstance(out, RawTensor)
node_outputs.append(NodeMixin.get(out, None))

node_outputs = out_def.unflatten(node_outputs)
return node_outputs

def insert_exprs_after(self, expr: Optional[Expr] = None):
if expr is not None:
assert expr.top_graph == self, "Expr to insert after is not in graph."
return _InsertExprs(self, expr, after=True)

return out_nodes
def insert_exprs_before(self, expr: Optional[Expr] = None):
if expr is not None:
assert expr.top_graph == self, "Expr to insert before is not in graph."
return _InsertExprs(self, expr, after=False)

def replace_node(self, repl_dict: Dict[Node, Node]):
while repl_dict:
@@ -246,7 +501,7 @@ class InternalGraph:
i = 0
while i < len(self._exprs):
expr = self._exprs[i]
if expr in dep_exprs:
if expr in dep_exprs or expr._disable_remove:
i += 1
continue
for n in expr.inputs:
@@ -267,7 +522,7 @@ class InternalGraph:
def __repr__(self):
return "InternalGraph ({}) {{\n\t{}\n\treturn {}\n}}".format(
", ".join(str(i) for i in self._inputs),
"\n\t".join(str(i) for i in self._exprs),
"\n\t".join("{}".format(str(i)) for i in self._exprs),
", ".join(str(i) for i in self._outputs),
)

@@ -293,7 +548,7 @@ def _wrapped_function(orig_func):
if not NodeMixin.get(i, None):
if isinstance(i, (RawTensor, NodeMixin)):
NodeMixin.wrap_safe(i, Constant.make(i))
meth_name = _get_meth_name(args[0], wrapped_fn)
meth_name = _get_meth_name(args[0], wrapped_fn) if args else None
if meth_name:
self = inputs[0]
if meth_name == "__new__":
@@ -316,10 +571,19 @@ def _wrapped_function(orig_func):
call_node.add_inputs(inputs)

call_node.arg_def = tree_def
outputs = orig_func(*args, **kwargs)
rst = orig_func(*args, **kwargs)
if meth_name == "__setitem__":
rst = self
if rst is not None:
outputs, out_def = tree_flatten(
rst, leaf_type=_leaf_type, is_leaf=_is_leaf
)
call_node.out_def = out_def
else:
outputs = None
call_node.add_outputs(outputs)
set_module_tracing()
return outputs
return rst
return orig_func(*args, **kwargs)

return wrapped_fn
@@ -349,6 +613,7 @@ class TracedModuleBuilder(NodeMixin):
super(TracedModuleBuilder, self).__init__()
self._mod = mod
self._body = None
self._is_top = is_top_module
self._is_builtin = module_tracer.is_builtin(mod)
self._argdef_graph_map = {}
self._argdef_outdef_map = {}
@@ -362,7 +627,7 @@ class TracedModuleBuilder(NodeMixin):
return self._mod
else:
traced_module = TracedModule(
self._argdef_graph_map, self._argdef_outdef_map
self._is_top, self._argdef_graph_map, self._argdef_outdef_map
)
for _, g in self._argdef_graph_map.items():
g.compile()
@@ -408,8 +673,8 @@ class TracedModuleBuilder(NodeMixin):
self._body = None
else:
self_node = None
if self._body:
self_node = self._body.inputs[0]
if tree_def in self._argdef_graph_map:
self_node = self._argdef_graph_map[tree_def].inputs[0]
self._body = InternalGraph()
active_module_tracer().push_scope(self._body)
# rebind self to new input node
@@ -446,7 +711,7 @@ class TracedModuleBuilder(NodeMixin):
outputs if isinstance(outputs, collections.abc.Sequence) else (outputs,)
):
active_module_tracer().current_scope().add_output(NodeMixin.get(i))
NodeMixin.get(self, None).actual_mnode.append(orig_self)
NodeMixin.wrap_safe(self, orig_self)
for arg, node in zip(inputs[1:], origin_inp_node):
if node:
@@ -454,6 +719,7 @@ class TracedModuleBuilder(NodeMixin):
active_module_tracer().pop_scope()

# rebind output to outer graph
callnode.out_def = out_def
callnode.add_outputs(outputs)
self._argdef_graph_map[callnode.arg_def] = self._body
self._argdef_outdef_map[callnode.arg_def] = out_def
@@ -512,31 +778,44 @@ class _expr_iter:
if isinstance(expr, CallMethod) and isinstance(expr.inputs[0], ModuleNode):
yield expr
if expr.graph is not None:
yield from expr.graph.exprs
yield from expr.graph.expr_filter
else:
yield expr


class ExprFilter:
class _node_iter:
def __init__(self, graph: InternalGraph) -> None:
nodes = []
node_ids = set()
for expr in graph.expr_filter:
for n in expr.inputs + expr.outputs:
if n._id in node_ids:
continue
nodes.append(n)
node_ids.add(n._id)
self.nodes = list(sorted(nodes, key=lambda x: x._id))

def __iter__(self):
for node in self.nodes:
yield node


class BaseFilter:
def __init__(self, expr_iter: Iterable):
self._iter = expr_iter

def __iter__(self):
return iter(self._iter)

def call_function(self, func):
return ExprFilterCallFunction(self, func)

def call_method(self, method):
return ExprFilterCallMethod(self, method)

def as_list(self):
return list(self)

def as_dict(self):
raise NotImplementedError("need key")
return collections.OrderedDict((i._id, i) for i in self)

def as_unique(self):
rst = self.as_list()
assert len(rst) == 1, "{} elements found".format(len(rst))
(expr,) = self
return expr

@@ -544,17 +823,65 @@ class ExprFilter:
return sum(1 for _ in self)


class ExprFilter(BaseFilter):
def call_function(self, func):
return ExprFilterCallFunction(self, func)

def call_method(self, method):
return ExprFilterCallMethod(self, method)

def expr_id(self, expr_id: List[int]):
return ExprFilterExprId(self, expr_id)


class NodeFilter(BaseFilter):
def type(self, owner_type, node_type):
return NodeFilterType(self, owner_type, node_type)

def node_id(self, node_id: List[int]):
return NodeFilterNodeId(self, node_id)


class NodeFilterType(NodeFilter):
def __init__(self, expr_iter, owner_type, node_type):
super().__init__(expr_iter)
self.owner_type = owner_type
self.node_type = node_type

def __iter__(self):
for node in self._iter:
if not isinstance(node, self.node_type):
continue
if not hasattr(node, "owner"):
continue
if isinstance(node.owner, self.owner_type):
yield node


class NodeFilterNodeId(NodeFilter):
def __init__(self, expr_iter, node_id: List[int]):
super().__init__(expr_iter)
if not isinstance(node_id, Sequence):
node_id = [node_id]
self.node_id = node_id

def __iter__(self):
for node in self._iter:
if node._id in self.node_id:
yield node


class ExprFilterCallFunction(ExprFilter):
def __init__(self, expr_iter, func: Callable = None):
super().__init__(expr_iter)
self.func = func

def __iter__(self):
for i in self._iter:
if not isinstance(i, CallFunction):
for expr in self._iter:
if not isinstance(expr, CallFunction):
continue
if self.func is None or i.func == self.func:
yield i
if self.func is None or expr.func == self.func:
yield expr


class ExprFilterCallMethod(ExprFilter):
@@ -563,11 +890,24 @@ class ExprFilterCallMethod(ExprFilter):
self.method = method

def __iter__(self):
for i in self._iter:
if not isinstance(i, CallMethod):
for expr in self._iter:
if not isinstance(expr, CallMethod):
continue
if self.method is None or i.method == self.method:
yield i
if self.method is None or expr.method == self.method:
yield expr


class ExprFilterExprId(ExprFilter):
def __init__(self, expr_iter, expr_id: List[int]):
super().__init__(expr_iter)
if not isinstance(expr_id, Sequence):
expr_id = [expr_id]
self.expr_id = expr_id

def __iter__(self):
for expr in self._iter:
if expr._id in self.expr_id:
yield expr


class TracedModule(Module):
@@ -579,10 +919,11 @@ class TracedModule(Module):
argdef_graph_map = None
argdef_outdef_map = None

def __init__(self, argdef_graph_map, argdef_outdef_map):
def __init__(self, is_top, argdef_graph_map, argdef_outdef_map):
super(TracedModule, self).__init__()
self.argdef_graph_map = argdef_graph_map
self.argdef_outdef_map = argdef_outdef_map
self._is_top = is_top

def forward(self, *args, **kwargs):
inputs, treedef = tree_flatten(
@@ -598,29 +939,58 @@ class TracedModule(Module):
return outputs

@property
def graph(self):
self._update_modulenode_ref()
def graph(self) -> InternalGraph:
if self._is_top:
self._update_ref()
assert len(self.argdef_graph_map) == 1
return list(self.argdef_graph_map.values())[0]

def _update_modulenode_ref(self):
for _, graph in self.argdef_graph_map.items():
def _update_ref(self, actual_node_map: Union[Dict] = None):
for inp_def, graph in self.argdef_graph_map.items():
for n in graph._inputs + graph.outputs:
n._top_graph = weakref.ref(graph)
graph._inputs[0]._owner = weakref.ref(self)
graph._inputs[0].actual_mnode = []
if actual_node_map is not None and inp_def in actual_node_map.keys():
graph._inputs[0].actual_mnode = actual_node_map[inp_def]
node2obj = {}
next_actual_node_map = collections.defaultdict(
lambda: collections.defaultdict(list)
)
node2obj[graph._inputs[0]] = self
for expr in graph._exprs:
for n in expr.inputs + expr.outputs:
n._top_graph = weakref.ref(graph)
expr._top_graph = weakref.ref(graph)
if isinstance(expr, GetAttr) and isinstance(
expr.outputs[0], ModuleNode
):
obj = getattr(node2obj[expr.inputs[0]], expr.name)
expr.outputs[0]._owner = weakref.ref(obj)
node2obj[expr.outputs[0]] = obj
if isinstance(obj, TracedModule):
obj._update_modulenode_ref()
if isinstance(expr, Constant) and isinstance(
expr.outputs[0], ModuleNode
):
obj = expr.value
expr.outputs[0]._owner = weakref.ref(obj)
node2obj[expr.outputs[0]] = obj
if (
isinstance(expr, CallMethod)
and expr.method == "__call__"
and isinstance(expr.inputs[0], ModuleNode)
):
obj = node2obj[expr.inputs[0]]
if expr.arg_def is not None:
next_actual_node_map[obj][expr.arg_def].append(expr.inputs[0])

@property
def exprs(self):
return self.graph.exprs
for obj in node2obj.values():
if obj is self:
continue
mnode_map = None
if obj in next_actual_node_map.keys():
mnode_map = next_actual_node_map[obj]
if isinstance(obj, TracedModule):
obj._update_ref(mnode_map)

def flatten(self):
"""
@@ -644,13 +1014,21 @@ class TracedModule(Module):
node2obj[graph._inputs[0]] = module
if call:
node2obj[call.inputs[0]] = module
repl_dict = dict(zip(graph._inputs, call.inputs))
for ind, out in enumerate(graph.outputs):
if isinstance(out.expr, Input):
assert out in repl_dict
call_out = call.outputs[ind]
for expr in call.outputs[ind].users:
for index, inp in enumerate(expr.inputs):
if inp is call_out:
expr.inputs[index] = repl_dict[out]

continue
repl_dict[out] = call.outputs[ind]

graph._replace_inputs_outputs(repl_dict)
for expr in graph._exprs:
# replace inputs for submodule's exprx
if call:
repl_dict = dict(
zip(graph._inputs + graph._outputs, call.inputs + call.outputs)
)
graph._replace_inputs_outputs(repl_dict)

if isinstance(expr, GetAttr):
# replace GetAttr with Constant
@@ -715,6 +1093,21 @@ def register_as_builtin(mod_cls: Type[Module]) -> None:
module_tracer.register_as_builtin(mod_cls)


def wrap(func: Union[Callable]):
assert callable(func)
if hasattr(func, "__code__"):
assert not isinstance(func, str)
fn_name = func.__code__.co_name
currentframe = inspect.currentframe()
assert currentframe is not None
f = currentframe.f_back
assert f is not None
if f.f_code.co_name != "<module>":
raise NotImplementedError("wrap must be called at the top level of a module")
Patcher._builtin_functions.append((f.f_globals, fn_name))
return func


def _register_all_builtin_module():

for sub_mod in [M, M.qat, M.quantized]:
@@ -749,6 +1142,7 @@ def trace_module(mod: Module, *args: Tensor, **kwargs: Tensor) -> TracedModule:
NodeMixin.wrap_safe(builder, Input.make("TopModule", ModuleNode))
inputs, _ = tree_flatten((args, kwargs), is_const_leaf=_is_const_leaf)
for _, i in enumerate(inputs):
assert isinstance(i, Tensor), "not support "
if isinstance(i, RawTensor):
NodeMixin.wrap_safe(
i, Input.make("arg_{}".format(_), NodeMixin.get_wrapped_type(i))


+ 5
- 5
imperative/python/test/unit/traced_module/test_modification.py View File

@@ -57,16 +57,16 @@ def _init_module():
def test_search():
traced_module, *_ = _init_block()
graph = traced_module.graph
relu_expr = graph.get_call_function(F.relu).as_unique()
relu_expr = graph.get_function_by_type(F.relu).as_unique()
assert isinstance(relu_expr, CallFunction) and relu_expr.func == F.relu


def test_insert():
traced_module, x, expect = _init_block()
graph = traced_module.graph
relu_node = graph.get_call_function(F.relu).as_unique().outputs
neg_node = graph.insert_call_function(F.neg, relu_node)
graph.replace_node({relu_node[0]: neg_node[0]})
relu_node = graph.get_function_by_type(F.relu).as_unique().outputs
neg_node = graph.insert_function(lambda x: F.neg(x), *relu_node)
graph.replace_node({relu_node[0]: neg_node})
graph.compile()
np.testing.assert_allclose(expect - 1, 1 - traced_module(x), atol=1e-6)

@@ -74,7 +74,7 @@ def test_insert():
def test_delete():
traced_module, x, expect = _init_block()
graph = traced_module.graph
relu_expr = graph.get_call_function(F.relu).as_unique()
relu_expr = graph.get_function_by_type(F.relu).as_unique()
node = relu_expr.outputs
repl_node = relu_expr.inputs
graph.replace_node({node[0]: repl_node[0]})


Loading…
Cancel
Save