Browse Source

feat(traced_module): update graph transform and add _module_name

GitOrigin-RevId: ef63ae0fd0
release-1.6
Megvii Engine Team 3 years ago
parent
commit
a3f9073c2c
7 changed files with 582 additions and 200 deletions
  1. +0
    -1
      imperative/python/megengine/experimental/traced_module/__init__.py
  2. +66
    -40
      imperative/python/megengine/experimental/traced_module/expr.py
  3. +39
    -3
      imperative/python/megengine/experimental/traced_module/module_tracer.py
  4. +64
    -17
      imperative/python/megengine/experimental/traced_module/node.py
  5. +1
    -1
      imperative/python/megengine/experimental/traced_module/pytree.py
  6. +408
    -135
      imperative/python/megengine/experimental/traced_module/traced_module.py
  7. +4
    -3
      imperative/python/test/unit/traced_module/test_modification.py

+ 0
- 1
imperative/python/megengine/experimental/traced_module/__init__.py View File

@@ -14,7 +14,6 @@ from .traced_module import (
register_as_builtin,
trace_module,
wrap,
wrap_tensors,
)

_register_all_builtin_module()


+ 66
- 40
imperative/python/megengine/experimental/traced_module/expr.py View File

@@ -33,17 +33,6 @@ def rstrip(s: str, __chars: str):
return s


def lstrip(s: str, __chars: str):
__chars = re.escape(__chars)
s = re.sub(r"^(?:%s)+(?P<right>.*)$" % __chars, "\g<right>", s)
return s


def strip(s: str, __chars: str):
s = lstrip(rstrip(s, __chars), __chars)
return s


class Expr:
"""
``Expr`` represents the operations(i.e. CallMethod, CallFunction, Apply, GetAttr, Input, Constant) on ``Node``.
@@ -89,27 +78,40 @@ class Expr:
outputs = (outputs,)

name = None
orig_name = None
if isinstance(self, CallMethod):
name = self.inputs[0]._name
assert name is not None
orig_name = self.inputs[0]._orig_name
assert isinstance(name, str), "The name of ({}) must be a str".format(
self.inputs[0]
)
assert isinstance(
orig_name, str
), "The orig_name of ({}) must be a str".format(self.inputs[0])
name = rstrip(name, "_out")
if self.method == "__call__":
name += "_out"
orig_name += "_out"
else:
strip_method = strip(self.method, "_")
strip_method = self.method.strip("_")
name = "%s_out" % strip_method
orig_name = name
elif isinstance(self, CallFunction):
name = self.func.__name__ + "_out"
elif isinstance(self, Apply):
name = str(self.opdef).lower() + "_out"

for i in outputs:
assert isinstance(i, RawTensor)
assert isinstance(i, RawTensor), "The output must be a Tensor"
o_name = (
active_module_tracer().current_scope()._create_unique_name(name)
)
self.outputs.append(
NodeMixin.get_wrapped_type(i)(expr=self, name=o_name)
NodeMixin.get_wrapped_type(i)(
expr=self,
name=o_name,
orig_name=orig_name if orig_name else o_name,
)
)

for i, node in zip(outputs, self.outputs,):
@@ -125,21 +127,26 @@ class Expr:
else:
return inputs, {}

def _replace_nodes(self, repl_dict: Dict[Node, Node], nodes: List[Node]):
def replace_inputs(self, repl_dict: Dict[Node, Node]):
while repl_dict:
node, repl_node = repl_dict.popitem()
assert type(node) == type(repl_node)
assert node in nodes
index = nodes.index(node)
nodes[index] = repl_node
assert node in self.inputs, "({}) is not in the ({})".format(node, self)
assert (
repl_node.top_graph == node.top_graph
), "({}) and ({}) are not in the same graph".format(node, repl_node)
graph = self.top_graph
repl_expr_idx = graph._exprs.index(repl_node.expr)
self_idx = graph._exprs.index(self)
assert (
repl_expr_idx < self_idx
), "({}) must be generated before ({})".format(repl_node, self)
idx = self.inputs.index(node)
self.inputs[idx] = repl_node
user_idx = node.users.index(self)
assert user_idx >= 0
node.users.pop(user_idx)
repl_node.users.append(self)
node.users.pop(self)

def replace_inputs(self, repl_dict: Dict[Node, Node]):
self._replace_nodes(repl_dict, self.inputs)

def replace_outputs(self, repl_dict: Dict[Node, Node]):
self._replace_nodes(repl_dict, self.outputs)

@property
def kwargs(self):
@@ -159,7 +166,8 @@ class Expr:

def __getstate__(self):
state = self.__dict__.copy()
state.pop("_top_graph", None)
if "_top_graph" in state:
state.pop("_top_graph")
return state


@@ -167,12 +175,14 @@ class Expr:
class Input(Expr):
name = None

def __init__(self, name=None, type=None):
def __init__(self, name=None, type=None, orig_name=None):
super().__init__()
self.inputs = []
node_cls = type if type else Node
if orig_name is None:
orig_name = name
self.outputs = [
node_cls(self, name=name),
node_cls(self, name=name, orig_name=orig_name),
]
self.name = name

@@ -184,7 +194,7 @@ class Input(Expr):
active_module_tracer().current_scope()._create_unique_name(oup_node._name)
)
oup_node._name = name
active_module_tracer().current_scope().add_input(oup_node)
active_module_tracer().current_scope()._add_input(oup_node)
return expr.outputs[0]

def __repr__(self):
@@ -195,7 +205,7 @@ class Input(Expr):
class GetAttr(Expr):
name = None

def __init__(self, module, name, type=None):
def __init__(self, module, name, type=None, orig_name=None):
super().__init__()
assert isinstance(module, ModuleNode)
self.inputs = [
@@ -205,7 +215,7 @@ class GetAttr(Expr):
self.name = name
node_cls = type if type else Node
self.outputs = [
node_cls(self, name=name),
node_cls(self, name=name, orig_name=orig_name),
]

@classmethod
@@ -218,7 +228,7 @@ class GetAttr(Expr):
module = module.expr.inputs[0]
oup_name = active_module_tracer().current_scope()._create_unique_name(oup_name)
expr.outputs[0]._name = oup_name
active_module_tracer().current_scope().insert(expr)
active_module_tracer().current_scope()._insert(expr)
return expr.outputs[0]

def interpret(self, *inputs):
@@ -255,7 +265,7 @@ class CallMethod(Expr):
@classmethod
def make(cls, *args, **kwargs):
expr = cls(*args, **kwargs)
active_module_tracer().current_scope().insert(expr)
active_module_tracer().current_scope()._insert(expr)
return expr

@property
@@ -315,7 +325,7 @@ class Apply(Expr):
@classmethod
def make(cls, *args, **kwargs):
expr = cls(*args, **kwargs)
active_module_tracer().current_scope().insert(expr)
active_module_tracer().current_scope()._insert(expr)
return expr

def interpret(self, *inputs):
@@ -382,7 +392,7 @@ class CallFunction(Expr):
@classmethod
def make(cls, *args, **kwargs):
expr = cls(*args, **kwargs)
active_module_tracer().current_scope().insert(expr)
active_module_tracer().current_scope()._insert(expr)
return expr

def interpret(self, *inputs):
@@ -423,7 +433,7 @@ class Constant(Expr):
self.inputs = []
node_cls = NodeMixin.get_wrapped_type(c)
self.outputs = [
node_cls(self, name=name),
node_cls(self, name=name, orig_name=name),
]
self.outputs[0]._name = name if name else "const_" + str(self._id)

@@ -431,9 +441,23 @@ class Constant(Expr):
def make(cls, *args, **kwargs):
expr = cls(*args, **kwargs)
name = "const_module" if isinstance(expr.value, Module) else "const_tensor"
name = active_module_tracer().current_scope()._create_unique_name(name)
full_name = name
if (
isinstance(expr.value, RawTensor)
and id(expr.value) in active_module_tracer().id2name
):
full_name = active_module_tracer().id2name[id(expr.value)]
scope_name = active_module_tracer().current_scope()._module_name
if full_name and scope_name:
full_name = ("self." + full_name)[len(scope_name) + 1 :]
else:
full_name = name
else:
full_name = name
name = active_module_tracer().current_scope()._create_unique_name(full_name)
expr.outputs[0]._name = name
active_module_tracer().current_scope().insert(expr)
expr.outputs[0]._orig_name = full_name
active_module_tracer().current_scope()._insert(expr)
return expr.outputs[0]

def interpret(self, *inputs):
@@ -453,7 +477,9 @@ class Constant(Expr):
)

def __getstate__(self):
state = super().__getstate__()
state = self.__dict__.copy()
if "_top_graph" in state:
state.pop("_top_graph")
if isinstance(self.value, RawTensor):
state["value"] = Tensor(self.value)
return state

+ 39
- 3
imperative/python/megengine/experimental/traced_module/module_tracer.py View File

@@ -84,6 +84,34 @@ BUILTIN_ARRAY_METHOD = [
"__setitem__",
]

BUILTIN_TENSOR_WRAP_METHOD = [
"T",
"to",
"size",
"shape",
"detach",
"device",
"dtype",
"grad",
"item",
"name",
"ndim",
"numpy",
"qparams",
"set_value",
"reset_zero",
"requires_grad",
"_reset",
"_isscalar",
"_setscalar",
"_tuple_shape",
"_unsetscalar",
]


def get_tensor_wrapable_method():
return BUILTIN_TENSOR_WRAP_METHOD + BUILTIN_ARRAY_METHOD


def active_module_tracer():
return _active_module_tracer
@@ -101,9 +129,10 @@ class module_tracer:

_active_scopes = None

def __init__(self, wrap_fn):
def __init__(self, wrap_fn, id2name):
self._active_scopes = []
self.patcher = Patcher(wrap_fn)
self.id2name = id2name

@classmethod
def register_as_builtin(cls, mod):
@@ -127,6 +156,10 @@ class module_tracer:
return None


class NotExist:
pass


class PatchedFn:
frame_dict = None
name = None
@@ -138,14 +171,17 @@ class PatchedFn:
self.origin_fn = (
self.frame_dict[name]
if isinstance(frame_dict, collections.abc.Mapping)
else getattr(frame_dict, name)
else getattr(frame_dict, name, NotExist)
)

def set_func(self, func):
if isinstance(self.frame_dict, collections.abc.Mapping):
self.frame_dict[self.name] = func
else:
setattr(self.frame_dict, self.name, func)
if func is not NotExist:
setattr(self.frame_dict, self.name, func)
else:
delattr(self.frame_dict, self.name)


class Patcher:


+ 64
- 17
imperative/python/megengine/experimental/traced_module/node.py View File

@@ -30,14 +30,17 @@ class Node:
_id = None
_top_graph = None # type: weakref.ReferenceType
_name = None
_orig_name = None
_format_spec = ""

def __init__(self, expr: "Expr", name: str = None):
def __init__(self, expr: "Expr", name: str = None, orig_name: str = None):
self.expr = expr
self.users = [] # List[Expr]
self._id = Node.__total_id
Node.__total_id += 1
self._name = name
self._orig_name = orig_name
self.actual_node = [] # type: List[Node]

def __setstate__(self, d):
self.__dict__ = d
@@ -48,7 +51,7 @@ class Node:
return self.__format__(format_spec)

def __format__(self, format_spec: str) -> str:
if format_spec == "" or format_spec is None:
if not format_spec:
format_spec = Node._format_spec
name = self._name
if name is None:
@@ -100,9 +103,8 @@ class ModuleNode(Node):
module_type = Module # type: Type[Module]
_owner = None # type: weakref.ReferenceType

def __init__(self, expr: "Expr", name: str = None):
super().__init__(expr, name)
self.actual_mnode = []
def __init__(self, expr: "Expr", name: str = None, orig_name: str = None):
super().__init__(expr, name, orig_name)

def __getstate__(self):
return {
@@ -110,6 +112,7 @@ class ModuleNode(Node):
"users": self.users,
"_id": self._id,
"_name": self._name,
"_orig_name": self._orig_name,
"module_type": self.module_type,
}

@@ -125,23 +128,67 @@ class TensorNode(Node):
``TensorNode`` represents the Tensor objects.
"""

shape = None # type: Tuple[int]
dtype = None # type: numpy.dtype
qparams = None
device = None
_shape = None # type: Tuple[int]
_dtype = None # type: numpy.dtype
_qparams = None
_device = None
_value = None # type: Tensor

def __getstate__(self):
return {
"expr": self.expr,
"users": self.users,
"_id": self._id,
"qparams": self.qparams,
"shape": self.shape,
"dtype": self.dtype,
"device": self.device,
"_qparams": self._qparams,
"_shape": self._shape,
"_dtype": self._dtype,
"_device": self._device,
"_name": self._name,
"_orig_name": self._orig_name,
}

@property
def shape(self):
return self._shape

@shape.setter
def shape(self, shape):
self._shape = shape

@property
def dtype(self):
return self._dtype

@dtype.setter
def dtype(self, dtype):
self._dtype = dtype

@property
def device(self):
return self._device

@device.setter
def device(self, device):
self._device = device

@property
def qparams(self):
return self._qparams

@qparams.setter
def qparams(self, qparams):
self._qparams = qparams

@property
def value(self):
return self._value

@value.setter
def value(self, value):
if isinstance(value, RawTensor) and NodeMixin.get(value, None) is not None:
setattr(value, "_NodeMixin__node", None)
self._value = value


class NodeMixin(abc.ABC):
__node = None
@@ -156,13 +203,13 @@ class NodeMixin(abc.ABC):
assert isinstance(node, TensorNode)
assert isinstance(value, RawTensor)
if isinstance(value, RawTensor):
node.dtype = value.dtype
node.shape = (
node._dtype = value.dtype
node._shape = (
value._tuple_shape if isinstance(value, Tensor) else value.shape
)
node.device = value.device
node._device = value.device
if hasattr(value, "_qparams") and value._qparams is not None:
node.qparams = value.qparams
node._qparams = value.qparams

@classmethod
def wrap(cls, value, node):


+ 1
- 1
imperative/python/megengine/experimental/traced_module/pytree.py View File

@@ -133,7 +133,7 @@ def _is_leaf(obj):
def _leaf_type(node):
if isinstance(node, (RawTensor, TensorNode)):
return (Tensor, TensorNode, ArgsIndex)
elif isinstance(node, (NodeMixin, Module)):
elif isinstance(node, (NodeMixin, Module, ModuleNode)):
return (Module, ModuleNode, NodeMixin, ArgsIndex)
else:
return (type(node), ArgsIndex)


+ 408
- 135
imperative/python/megengine/experimental/traced_module/traced_module.py View File

@@ -9,14 +9,19 @@
import builtins
import collections
import copy
import ctypes
import fnmatch
import functools
import inspect
import keyword
import re
import weakref
from inspect import getcallargs, getmembers, isclass, ismethod
from itertools import chain
from typing import Callable, Dict, Iterable, List, Optional, Sequence, Type, Union

from megengine import tensor

from ... import functional as F
from ... import get_logger
from ... import module as M
@@ -44,8 +49,10 @@ from ...tensor import Tensor
from .expr import Apply, CallFunction, CallMethod, Constant, Expr, GetAttr, Input
from .fake_quant import FakeQuantize as TM_FakeQuant
from .module_tracer import (
PatchedFn,
Patcher,
active_module_tracer,
get_tensor_wrapable_method,
module_tracer,
set_active_module_tracer,
)
@@ -70,46 +77,267 @@ def _is_leaf(node):
return isinstance(node, RawTensor)


def wrap_tensors(tensors: Tensor, nodes: TensorNode):
inp_tensors = copy.deepcopy(tensors)
inp_tensors, inp_def_v = tree_flatten(inp_tensors)
inp_nodes, inp_def_n = tree_flatten(nodes)
for v, n in zip(inp_tensors, inp_nodes):
if isinstance(n, TensorNode) and isinstance(v, Tensor):
NodeMixin.wrap_safe(v, n)
return inp_def_v.unflatten(inp_tensors)
_enable_node_to_tensor = False


def _convert_node_flag():
return _enable_node_to_tensor


def _set_convert_node_flag(flag: bool = False):
global _enable_node_to_tensor
pre_flag = _enable_node_to_tensor
_enable_node_to_tensor = flag
return pre_flag


def _node_to_tensor(*args, **kwargs):
tensors = []
nodes, tree_def = tree_flatten((args, kwargs))
for n in nodes:
if isinstance(n, TensorNode):
if n.top_graph is not None:
active_module_tracer().current_scope()._add_input(n)
value = n.value
if value is None:
flag = _set_convert_node_flag(False)
unset_module_tracing()
value = F.zeros(shape=n._shape, dtype=n._dtype)
set_module_tracing()
_set_convert_node_flag(flag)
orig_n = NodeMixin.get(value, None)
if orig_n is None or "setitem" not in orig_n._name:
NodeMixin.wrap_safe(value, n)
tensors.append(value)
else:
tensors.append(n)
tensors = tree_def.unflatten(tensors)
return tensors


def _tensor_to_node(tensors):
if tensors is None:
return None
nodes = []
tensors, out_def = tree_flatten(tensors)
for t in tensors:
if isinstance(t, Tensor):
n = NodeMixin.get(t, None)
if isinstance(n, TensorNode):
n.value = t
nodes.append(n)
else:
nodes.append(t)
else:
nodes.append(t)
nodes = out_def.unflatten(nodes)
return nodes


def _wrap_method_to_tensor_node():
def _any_method(name):
def _any(*args, **kwargs):
args, kwargs = _node_to_tensor(*args, **kwargs)
attr = getattr(args[0], name)
outs = attr
if callable(attr):
outs = attr(*(args[1:]), **kwargs)
if name == "__setitem__":
_node_to_tensor(outs)
return None
outs = _tensor_to_node(outs)
return outs

return _any

tensor_method_patch = []
for method in get_tensor_wrapable_method():
patch = PatchedFn(TensorNode, method)
if type(getattr(Tensor, method)) == property:
patch.set_func(property(_any_method(method)))
else:
patch.set_func(_any_method(method))
tensor_method_patch.append(patch)
return tensor_method_patch


def _convert_node_and_tensor(orig_func):
@functools.wraps(orig_func)
def _convert(*args, **kwargs):
if _convert_node_flag() and is_tracing_module():
args, kwargs = _node_to_tensor(*args, **kwargs)
rst = orig_func(*args, **kwargs, method_func=_convert)
rst = _tensor_to_node(rst)
return rst
else:
rst = orig_func(*args, **kwargs)
return rst

return _convert


def _wrap_mnode_getattr(orig_getattr):
@functools.wraps(orig_getattr)
def wraped_fn(self, name):
obj = self.owner
if self.top_graph is not None:
active_module_tracer().current_scope()._add_input(self)
attr = getattr(obj, name)
node = attr
full_name = None
if id(attr) in active_module_tracer().id2name:
full_name = active_module_tracer().id2name[id(attr)]

if not isinstance(attr, TracedModuleBuilder):
if isinstance(attr, Module):
attr = TracedModuleBuilder(attr)
setattr(obj, name, attr)
active_module_tracer().id2name[id(attr)] = full_name

if isinstance(attr, (NodeMixin, RawTensor)):
if full_name:
scope_name = active_module_tracer().current_scope()._module_name
if scope_name:
full_name = full_name[len(scope_name) + 1 :]
else:
full_name = name
else:
full_name = name
NodeMixin.wrap(
attr,
lambda: GetAttr.make(
self,
name,
type=NodeMixin.get_wrapped_type(attr),
orig_name=full_name,
),
)
if isinstance(attr, (NodeMixin, RawTensor)):
node = NodeMixin.get(attr)
if isinstance(node, ModuleNode):
node._owner = weakref.ref(attr)
return node

return wraped_fn


def _wrap_mnode_call(orig_call):
@functools.wraps(orig_call)
def wraped_fn(self, *args, **kwargs):
obj = self.owner
if self.top_graph is not None:
active_module_tracer().current_scope()._add_input(self)
rst = obj(*args, **kwargs)
return rst

return wraped_fn


def _init_id2name(mod: Module, prefix: str = ""):
id2name = {
id(m): "%s.%s" % (prefix, key)
for key, m in chain(
mod.named_modules(), mod.named_parameters(), mod.named_buffers()
)
}
return id2name


class _InsertExprs:
def __init__(self, graph, expr: Optional[Expr] = None, after: bool = True):
def __init__(self, graph, expr: Optional[Expr] = None):
self.graph = graph
self.global_scope = InternalGraph()
self.global_scope = InternalGraph(
graph._name, graph._prefix_name, graph._module_name
)
self.global_scope._used_names.update(graph._used_names)
self.expr = expr
self.after = after
self._tensor_method_patch = None

def __enter__(self):
self.use_sym_shape = set_symbolic_shape(True)
set_module_tracing()
_set_convert_node_flag(True)
assert active_module_tracer() is None
set_active_module_tracer(module_tracer(_wrapped_function))
module = self.graph.inputs[0].owner
_wrap_func = lambda x: _convert_node_and_tensor(_wrapped_function(x))
set_active_module_tracer(
module_tracer(_wrap_func, _init_id2name(module, self.graph._module_name))
)
active_module_tracer().patcher.__enter__()
for cls, name, func in [
[ModuleNode, "__getattr__", _wrap_mnode_getattr],
[ModuleNode, "__call__", _wrap_mnode_call],
[TracedModuleBuilder, "__call__", _convert_node_and_tensor],
]:
active_module_tracer().patcher.patch_function(cls, name, func)
self._tensor_method_patch = _wrap_method_to_tensor_node()
active_module_tracer().push_scope(self.global_scope)

def __exit__(self, ty, va, tr):
if va is not None:
return False
set_symbolic_shape(self.use_sym_shape)
unset_module_tracing()
active_module_tracer().patcher.__exit__(ty, va, tr)
_set_convert_node_flag(False)

while self._tensor_method_patch:
pf = self._tensor_method_patch.pop()
pf.set_func(pf.origin_fn)

module = self.graph.inputs[0].owner

for mod, parent in module.modules(with_parent=True):
name = mod._name
if isinstance(mod, TracedModuleBuilder):
mod = mod.build()
if hasattr(mod, "graph"):
for node in mod.graph.nodes():
node.value = None
setattr(parent, name, mod)
set_active_module_tracer(None)
index = len(self.graph._exprs) if self.after else 0

for node in self.global_scope.nodes():
node.value = None

extra_inp_nodes = set(self.global_scope.inputs)
max_inp_expr_idx = -1
for node in extra_inp_nodes:
assert (
node.top_graph == self.graph
), "The input node ({}) is not in the graph ({})".format(node, self.graph)
if isinstance(node, TensorNode) and node.expr in self.graph._exprs:
max_inp_expr_idx = max(
max_inp_expr_idx, self.graph._exprs.index(node.expr)
)
max_inp_expr_idx += 1

insert_index = -1
if self.expr is not None:
index = self.graph._exprs.index(self.expr)
if self.after:
index += 1
insert_index = self.graph._exprs.index(self.expr)
insert_index += 1

if insert_index < max_inp_expr_idx:
insert_index = max_inp_expr_idx

anchor_index = insert_index - 1
if anchor_index >= 0:
logger.info(
"The new expr will be inserted after ( {} )".format(
self.graph._exprs[anchor_index]
)
)

for expr in self.global_scope._exprs:
self.graph._exprs.insert(index, expr)
index += 1
self.graph._exprs.insert(insert_index, expr)
insert_index += 1

self.graph._used_names.update(self.global_scope._used_names)
graph = self.graph
while graph.top_graph is not None:
graph = graph.top_graph
graph.inputs[0].owner._update_ref()
return True


class InternalGraph:
@@ -125,8 +353,9 @@ class InternalGraph:
_exprs = None # type: List[Expr]
_inputs = None # type: List[Node]
_outputs = None # type: List[Node]
_top_graph = None

def __init__(self, name: str = None, prefix_name: str = ""):
def __init__(self, name: str = None, prefix_name: str = "", module_name: str = ""):
self._exprs = []
self._inputs = []
self._outputs = []
@@ -136,12 +365,13 @@ class InternalGraph:
self._rst = collections.defaultdict(list)
self._name = name
self._prefix_name = prefix_name
self._module_name = module_name

def insert(self, expr):
def _insert(self, expr):
self._exprs.append(expr)

def _create_unique_name(self, name: str) -> str:
assert isinstance(name, str)
assert isinstance(name, str), "The name must be a str"
name = re.sub("[^0-9a-zA-Z_]+", "_", name)
if name[0].isdigit():
name = "_{}".format(name)
@@ -166,40 +396,45 @@ class InternalGraph:
return self._outputs

@property
def expr_filter(self):
return ExprFilter(_expr_iter(self))
def top_graph(self):
if self._top_graph:
return self._top_graph()
return None

@property
def node_filter(self):
return NodeFilter(_node_iter(self))
def exprs(self, recursive=True):
return ExprFilter(_expr_iter(self, recursive))

def nodes(self, recursive=True):
return NodeFilter(_node_iter(self, recursive))

def get_function_by_type(self, func: Callable = None):
return self.expr_filter.call_function(func)
def get_function_by_type(self, func: Callable = None, recursive=True):
return self.exprs(recursive).call_function(func)

def get_method_by_type(self, method: str = None):
return self.expr_filter.call_method(method)
def get_method_by_type(self, method: str = None, recursive=True):
return self.exprs(recursive).call_method(method)

def get_expr_by_id(self, expr_id: List[int] = None):
return self.expr_filter.expr_id(expr_id)
def get_expr_by_id(self, expr_id: List[int] = None, recursive=True):
return self.exprs(recursive).expr_id(expr_id)

def get_module_by_type(self, module_cls: Module):
def get_module_by_type(self, module_cls: Module, recursive=True):
assert issubclass(module_cls, Module)
return self.node_filter.type(module_cls, ModuleNode)
return self.nodes(recursive).type(module_cls, ModuleNode)

def get_node_by_id(self, node_id: List[int] = None):
return self.node_filter.node_id(node_id)
def get_node_by_id(self, node_id: List[int] = None, recursive=True):
return self.nodes(recursive).node_id(node_id)

def get_node_by_name(self, name: str = None, ignorecase: bool = True):
return self.node_filter.name(name, ignorecase)
def get_node_by_name(
self, name: str = None, ignorecase: bool = True, recursive=True
):
return self.nodes(recursive).name(name, ignorecase)

def add_input(self, i):
def _add_input(self, i):
self._inputs.append(i)

def add_output(self, o):
def _add_output(self, o):
self._outputs.append(o)

def _replace_inputs_outputs_and_add_prefixname(self, repl_dict, prefix_name=""):

def _replace_inputs_outputs(self, repl_dict, prefix_name="", module_name=""):
for node, repl_node in repl_dict.items():
assert node in self._inputs or node in self._outputs
for i in node.users:
@@ -212,12 +447,15 @@ class InternalGraph:

for idx, o in enumerate(self._outputs):
if o in repl_dict:
repl_dict[o]._orig_name = "{}{}".format(module_name, o._orig_name)
self._outputs[idx] = repl_dict[o]

for expr in self._exprs:

for idx, i in enumerate(expr.inputs):
assert i._name is not None
assert isinstance(
i._name, str
), "The node ({}) name must be a str".format(i)
if i in repl_dict:
expr.inputs[idx] = repl_dict[i]
elif isinstance(i, TensorNode) and prefix_name not in i._name:
@@ -227,9 +465,12 @@ class InternalGraph:
.current_scope()
._create_unique_name(prefix_name + i._name.lstrip("_"))
)
i._orig_name = "{}{}".format(module_name, i._orig_name)

for idx, o in enumerate(expr.outputs):
assert o._name is not None
assert isinstance(
o._name, str
), "The node ({}) name must be a str".format(i)
if o in repl_dict:
expr.outputs[idx] = repl_dict[o]
expr.outputs[idx].expr = expr
@@ -240,6 +481,7 @@ class InternalGraph:
.current_scope()
._create_unique_name(prefix_name + o._name.lstrip("_"))
)
o._orig_name = "{}{}".format(module_name, o._orig_name)

def get_dep_exprs(self, nodes: Sequence[Node]) -> List[Expr]:
if not isinstance(nodes, Sequence):
@@ -263,7 +505,7 @@ class InternalGraph:

def reset_inputs(self, *args, **kwargs):
forma_mnode = self.inputs[0]
actual_mnodes = forma_mnode.actual_mnode
actual_mnodes = forma_mnode.actual_node
call_nodes = []
for n in actual_mnodes:
for c_expr in n.users:
@@ -318,7 +560,7 @@ class InternalGraph:

def add_input_node(self, shape, dtype="float32", name="args"):
forma_mnode = self.inputs[0]
actual_mnodes = forma_mnode.actual_mnode
actual_mnodes = forma_mnode.actual_node

moudle = forma_mnode.owner
assert moudle._is_top, "add_input_node only support the top-level graph"
@@ -378,7 +620,7 @@ class InternalGraph:
moudle = forma_mnode.owner
assert moudle._is_top, "reset_outputs only support the top-level graph"

actual_mnodes = forma_mnode.actual_mnode
actual_mnodes = forma_mnode.actual_node
call_nodes = []
for n in actual_mnodes:
for c_expr in n.users:
@@ -406,7 +648,6 @@ class InternalGraph:

self._outputs[:] = outputs
moudle.argdef_outdef_map[tree_def] = out_def

return actual_nodes

def add_output_node(self, node: TensorNode):
@@ -415,7 +656,7 @@ class InternalGraph:
moudle = forma_mnode.owner
assert moudle._is_top, "add_output_node only support the top-level graph"

actual_mnodes = forma_mnode.actual_mnode
actual_mnodes = forma_mnode.actual_node
call_nodes = []

for n in actual_mnodes:
@@ -455,74 +696,35 @@ class InternalGraph:

return actual_out_nodes

def insert_function(self, func: Callable, *args, **kwargs):
assert isinstance(func, Callable)

inp_nodes, inp_def = tree_flatten((args, kwargs))

insert_idx = -1
for i in inp_nodes:
if isinstance(i, TensorNode) and i.expr in self._exprs:
insert_idx = max(insert_idx, self._exprs.index(i.expr))

fake_inp_val = list(
F.zeros(shape=i.shape, dtype=i.dtype) if isinstance(i, TensorNode) else i
for i in inp_nodes
)

for v, n in zip(fake_inp_val, inp_nodes):
if isinstance(n, TensorNode):
NodeMixin.wrap_safe(v, n)

fake_args, fake_kwargs = inp_def.unflatten(fake_inp_val)

insert_point = self.insert_exprs_before()
if insert_idx != -1:
insert_point = self.insert_exprs_after(self._exprs[insert_idx])

with insert_point:
rst = func(*fake_args, **fake_kwargs)

if rst is None:
return None

outputs, out_def = tree_flatten(rst, is_leaf=_is_leaf)
node_outputs = []
for out in outputs:
assert isinstance(out, RawTensor)
node_outputs.append(NodeMixin.get(out, None))

node_outputs = out_def.unflatten(node_outputs)
return node_outputs

def insert_exprs_after(self, expr: Optional[Expr] = None):
def insert_exprs(self, expr: Optional[Expr] = None):
if expr is not None:
assert expr.top_graph == self, "Expr to insert after is not in graph."
return _InsertExprs(self, expr, after=True)

def insert_exprs_before(self, expr: Optional[Expr] = None):
if expr is not None:
assert expr.top_graph == self, "Expr to insert before is not in graph."
return _InsertExprs(self, expr, after=False)
return _InsertExprs(self, expr)

def replace_node(self, repl_dict: Dict[Node, Node]):
while repl_dict:
node, repl_node = repl_dict.popitem()
# check graph inputs and outputs
assert node not in self.inputs, "Cannot replace inputs"
# assert node not in self.inputs, "Cannot replace inputs"
for i, n in enumerate(self.outputs):
if n is node:
self.outputs[i] = repl_node
# update users of node and repl_node
# update inputs of expr in node.users
graph = repl_node.top_graph
assert graph is not None
index = graph._exprs.index(repl_node.expr)
dep_exprs = self.get_dep_exprs(repl_node)
i = 0
while i < len(node.users):
n = node.users[i]
if n in graph._exprs and index >= graph._exprs.index(n):
i += 1
continue
if n in dep_exprs:
logger.info("Find a loop: ignore this replacement once")
logger.info("node: %s" % node.__repr__())
logger.info("repl_node: %s" % repl_node.__repr__())
logger.info("expr: %s" % n.__repr__())
i += 1
continue
repl_node.users.append(n)
@@ -598,6 +800,12 @@ class InternalGraph:
Node.set_format_spec(saved_format_spec)
return res

def __getstate__(self):
state = self.__dict__.copy()
if "_top_graph" in state:
state.pop("_top_graph")
return state


def _get_meth_name(obj, func):
tp = obj if isinstance(obj, type) else type(obj)
@@ -611,6 +819,9 @@ def _get_meth_name(obj, func):
def _wrapped_function(orig_func):
@functools.wraps(orig_func)
def wrapped_fn(*args, **kwargs):
method_func = wrapped_fn
if "method_func" in kwargs:
method_func = kwargs.pop("method_func")
if is_tracing_module():
unset_module_tracing()
inputs, tree_def = tree_flatten((args, kwargs))
@@ -618,9 +829,11 @@ def _wrapped_function(orig_func):
if not NodeMixin.get(i, None):
if isinstance(i, (RawTensor, NodeMixin)):
NodeMixin.wrap_safe(i, Constant.make(i))
meth_name = _get_meth_name(args[0], wrapped_fn) if args else None
arg_type = args[0] if isinstance(args[0], type) else type(args[0])
if meth_name and issubclass(arg_type, RawTensor):
meth_name, arg_type = None, None
if args:
meth_name = _get_meth_name(args[0], method_func)
arg_type = args[0] if isinstance(args[0], type) else type(args[0])
if meth_name and arg_type and issubclass(arg_type, RawTensor):
self = inputs[0]
if meth_name == "__new__":
if all([not isinstance(i, RawTensor) for i in inputs]):
@@ -799,6 +1012,9 @@ class TracedModuleBuilder(NodeMixin):
def __call__(self, *args, **kwargs):
assert isinstance(self._mod, Module)
# prepare args and kwargs for inner graph
if "method_func" in kwargs:
kwargs.pop("method_func")

def mark_constant(x):
node = NodeMixin.get(x, None)
if node is None: # capture as constant
@@ -829,9 +1045,6 @@ class TracedModuleBuilder(NodeMixin):
else:
self._mod._is_top = False
self._body = self._mod.graph
name = NodeMixin.get(self)._name
if name:
self._body._name = name
else:
self_node = None
orig_self = NodeMixin.get(self)
@@ -841,19 +1054,24 @@ class TracedModuleBuilder(NodeMixin):
graph_prefix_name = "{}_{}".format(
top_graph._prefix_name, graph_prefix_name.lstrip("_")
)
self._body = InternalGraph(orig_self._name, prefix_name=graph_prefix_name)
module_name = orig_self._orig_name
if top_graph._module_name:
module_name = "{}.{}".format(top_graph._module_name, module_name)
self._body = InternalGraph(
orig_self._name, prefix_name=graph_prefix_name, module_name=module_name
)
active_module_tracer().push_scope(self._body)
# rebind self to new input node

if self_node:
NodeMixin.wrap_safe(self, self_node)
active_module_tracer().current_scope().add_input(self_node)
active_module_tracer().current_scope()._add_input(self_node)
else:
NodeMixin.wrap_safe(
self,
self_node
if self_node
else Input.make("self", NodeMixin.get_wrapped_type(self)),
else Input.make("self", NodeMixin.get_wrapped_type(self), ""),
)
origin_inp_node = [NodeMixin.get(i, None) for i in inputs[1:]]
# prepare args and kwargs for inner graph
@@ -893,12 +1111,13 @@ class TracedModuleBuilder(NodeMixin):
getattr(getattr(self._mod, "forward", self._mod), "__globals__", {})
)
rst = type(self._mod).forward(*args, **kwargs)
if _convert_node_flag():
rst = _node_to_tensor(rst)[0][0]
outputs, out_def = tree_flatten(rst, is_leaf=_is_leaf)
for i in (
outputs if isinstance(outputs, collections.abc.Sequence) else (outputs,)
):
active_module_tracer().current_scope().add_output(NodeMixin.get(i))
NodeMixin.get(self, None).actual_mnode.append(orig_self)
active_module_tracer().current_scope()._add_output(NodeMixin.get(i))
NodeMixin.wrap_safe(self, orig_self)
for arg, node in zip(inputs[1:], origin_inp_node):
if node:
@@ -923,14 +1142,33 @@ class TracedModuleBuilder(NodeMixin):
attr = getattr(type(self._mod), name).__get__(self, type(self))
else:
attr = getattr(self._mod, name)
full_name = None

if id(attr) in active_module_tracer().id2name:
full_name = active_module_tracer().id2name[id(attr)]

if isinstance(attr, Module):
attr = TracedModuleBuilder(attr)

if isinstance(attr, (Module, RawTensor)):
setattr(self, name, attr)
active_module_tracer().id2name[id(attr)] = full_name

if full_name:
scope_name = active_module_tracer().current_scope()._module_name
if scope_name:
full_name = full_name[len(scope_name) + 1 :]
else:
full_name = name
else:
full_name = name
NodeMixin.wrap(
attr,
lambda: GetAttr.make(
NodeMixin.get(self), name, type=NodeMixin.get_wrapped_type(attr)
NodeMixin.get(self),
name,
type=NodeMixin.get_wrapped_type(attr),
orig_name=full_name,
),
)
return attr
@@ -951,7 +1189,16 @@ class TracedModuleBuilder(NodeMixin):
assert mod_attr is wrapped._mod
else:
assert mod_attr is wrapped

full_name = None
if id(mod_attr) in active_module_tracer().id2name:
full_name = active_module_tracer().id2name[id(mod_attr)]
scope_name = active_module_tracer().current_scope()._module_name
if full_name and scope_name:
full_name = full_name[len(scope_name) + 1 :]
else:
full_name = name
else:
full_name = name
# assert not self._is_builtin
if isinstance(wrapped, (NodeMixin, RawTensor)):
NodeMixin.wrap(
@@ -960,6 +1207,7 @@ class TracedModuleBuilder(NodeMixin):
NodeMixin.get(self),
name,
type=NodeMixin.get_wrapped_type(wrapped),
orig_name=full_name,
),
)

@@ -967,24 +1215,25 @@ class TracedModuleBuilder(NodeMixin):


class _expr_iter:
def __init__(self, graph: InternalGraph):
def __init__(self, graph: InternalGraph, recursive: bool = True):
self.graph = graph
self.recursive = recursive

def __iter__(self):
for expr in self.graph._exprs:
if isinstance(expr, CallMethod) and isinstance(expr.inputs[0], ModuleNode):
yield expr
if expr.graph is not None:
yield from expr.graph.expr_filter
if self.recursive and expr.graph is not None:
yield from expr.graph.exprs(self.recursive)
else:
yield expr


class _node_iter:
def __init__(self, graph: InternalGraph) -> None:
def __init__(self, graph: InternalGraph, recursive: bool = True) -> None:
nodes = []
node_ids = set()
for expr in graph.expr_filter:
for expr in graph.exprs(recursive):
for n in expr.inputs + expr.outputs:
if n._id in node_ids:
continue
@@ -1210,14 +1459,17 @@ class TracedModule(Module):
assert len(self.argdef_graph_map) == 1
return list(self.argdef_graph_map.values())[0]

def _update_ref(self, actual_node_map: Union[Dict] = None):
def _update_ref(self, actual_node_map: Union[Dict] = None, top_graph=None):
for inp_def, graph in self.argdef_graph_map.items():
if top_graph is not None:
graph._top_graph = weakref.ref(top_graph)
for n in graph._inputs + graph.outputs:
n._top_graph = weakref.ref(graph)
graph._inputs[0]._owner = weakref.ref(self)
graph._inputs[0].actual_mnode = []
if actual_node_map is not None and inp_def in actual_node_map.keys():
graph._inputs[0].actual_mnode = actual_node_map[inp_def]
for i, n in enumerate(graph._inputs):
n.actual_node = []
if actual_node_map is not None and inp_def in actual_node_map.keys():
n.actual_node = list(list(zip(*(actual_node_map[inp_def])))[i])
node2obj = {}
next_actual_node_map = collections.defaultdict(
lambda: collections.defaultdict(list)
@@ -1246,7 +1498,7 @@ class TracedModule(Module):
):
obj = node2obj[expr.inputs[0]]
if expr.arg_def is not None:
next_actual_node_map[obj][expr.arg_def].append(expr.inputs[0])
next_actual_node_map[obj][expr.arg_def].append(expr.inputs)

for obj in node2obj.values():
if obj is self:
@@ -1255,7 +1507,7 @@ class TracedModule(Module):
if obj in next_actual_node_map.keys():
mnode_map = next_actual_node_map[obj]
if isinstance(obj, TracedModule):
obj._update_ref(mnode_map)
obj._update_ref(mnode_map, graph)

def flatten(self):
"""
@@ -1264,21 +1516,25 @@ class TracedModule(Module):
:return: :class:`TracedModule`
"""
new_module = copy.deepcopy(self)
module2name = {}
assert active_module_tracer() is None
set_active_module_tracer(module_tracer(lambda x: x))
id2name = _init_id2name(new_module, "self")
set_active_module_tracer(module_tracer(lambda x: x, {}))
active_module_tracer().push_scope(new_module.graph)
for n, m in new_module.named_modules():
module2name[id(m)] = n

def _flatten_subgraph(
graph: InternalGraph, module: Module, call=None, prefix_name=""
graph: InternalGraph,
module: Module,
call=None,
prefix_name="",
module_name="",
):
if graph is not None and prefix_name and prefix_name[-1] != "_":
if isinstance(prefix_name, str) and prefix_name and prefix_name[-1] != "_":
prefix_name += "_"
if isinstance(module_name, str) and module_name:
module_name += "."
if graph is None or module.is_qat:
assert not isinstance(module, TracedModule) or module.is_qat
const = Constant(module, "self.%s" % module2name[id(module)])
const = Constant(module, id2name[id(module)])
m_node = call.inputs[0]
if m_node.top_graph != active_module_tracer().current_scope():
m_node._name = (
@@ -1286,6 +1542,7 @@ class TracedModule(Module):
.current_scope()
._create_unique_name(prefix_name)
)
m_node._orig_name = id2name[id(module)][5:]
const.outputs[0] = m_node
const.outputs[0].expr = const
return [const, call]
@@ -1312,7 +1569,7 @@ class TracedModule(Module):
continue
repl_dict[out] = call.outputs[ind]

graph._replace_inputs_outputs_and_add_prefixname(repl_dict, prefix_name)
graph._replace_inputs_outputs(repl_dict, prefix_name, module_name)

for expr in graph._exprs:
if isinstance(expr, GetAttr):
@@ -1344,6 +1601,7 @@ class TracedModule(Module):
obj,
expr,
prefix_name + obj_node._name.lstrip("_"),
module_name + obj_node._orig_name,
)
)
else:
@@ -1358,7 +1616,6 @@ class TracedModule(Module):
if call is not None:
for i in call.inputs:
i.users.remove(call)

return exprs

new_module.graph._exprs = _flatten_subgraph(new_module.graph, new_module)
@@ -1396,7 +1653,22 @@ def register_as_builtin(mod_cls: Type[Module]) -> None:
module_tracer.register_as_builtin(mod_cls)


wrap = _wrapped_function
def wrap(func: Callable):
"""
Call this function to register func as a builtin function.
"""
assert callable(func), "func must be a callable"
assert hasattr(func, "__code__")
fn_name = func.__code__.co_name
currentframe = inspect.currentframe()
assert currentframe is not None
f = currentframe.f_back
assert f is not None
assert (
f.f_code.co_name == "<module>"
), "wrap must be called at the top level of a module"
Patcher._builtin_functions.append((f.f_globals, fn_name))
return func


def _register_all_builtin_module():
@@ -1438,14 +1710,15 @@ def trace_module(mod: Module, *args: Tensor, **kwargs: Tensor) -> TracedModule:
try:
use_sym_shape = set_symbolic_shape(True)
set_module_tracing()
set_active_module_tracer(module_tracer(_wrapped_function))

set_active_module_tracer(
module_tracer(_wrapped_function, _init_id2name(mod, "self"))
)
with active_module_tracer().patcher:
global_scope = InternalGraph(name="")
active_module_tracer().push_scope(global_scope)
builder = TracedModuleBuilder(mod, True)
name = mod._name if mod._name else mod.__class__.__name__
NodeMixin.wrap_safe(builder, Input.make(name, ModuleNode))
NodeMixin.wrap_safe(builder, Input.make(name, ModuleNode, orig_name="self"))
inputs, _ = tree_flatten((args, kwargs))
for _, i in enumerate(inputs):
# assert isinstance(i, Tensor), "not support "


+ 4
- 3
imperative/python/test/unit/traced_module/test_modification.py View File

@@ -64,9 +64,10 @@ def test_search():
def test_insert():
traced_module, x, expect = _init_block()
graph = traced_module.graph
relu_node = graph.get_function_by_type(F.relu).as_unique().outputs
neg_node = graph.insert_function(lambda x: F.neg(x), *relu_node)
graph.replace_node({relu_node[0]: neg_node})
relu_out = graph.get_function_by_type(F.relu).as_unique().outputs[0]
with graph.insert_exprs():
neg_out = F.neg(relu_out)
graph.replace_node({relu_out: neg_out})
graph.compile()
np.testing.assert_allclose(expect - 1, 1 - traced_module(x), atol=1e-6)



Loading…
Cancel
Save