Browse Source

feat(traced_module): add name to Node

GitOrigin-RevId: 39c2809067
release-1.6
Megvii Engine Team 3 years ago
parent
commit
15712807b9
4 changed files with 331 additions and 80 deletions
  1. +82
    -17
      imperative/python/megengine/experimental/traced_module/expr.py
  2. +37
    -16
      imperative/python/megengine/experimental/traced_module/node.py
  3. +14
    -3
      imperative/python/megengine/experimental/traced_module/pytree.py
  4. +198
    -44
      imperative/python/megengine/experimental/traced_module/traced_module.py

+ 82
- 17
imperative/python/megengine/experimental/traced_module/expr.py View File

@@ -11,6 +11,7 @@ import builtins
import collections import collections
import copy import copy
import inspect import inspect
import re
from typing import Callable, Dict, List from typing import Callable, Dict, List


from ...core._imperative_rt import OpDef from ...core._imperative_rt import OpDef
@@ -21,7 +22,24 @@ from ...module import Module
from ...tensor import Parameter, Tensor from ...tensor import Parameter, Tensor
from .module_tracer import active_module_tracer, module_tracer from .module_tracer import active_module_tracer, module_tracer
from .node import ModuleNode, Node, NodeMixin, TensorNode from .node import ModuleNode, Node, NodeMixin, TensorNode
from .pytree import TreeDef, tree_flatten
from .pytree import ArgsIndex, TreeDef, tree_flatten


def rstrip(s: str, __chars: str):
__chars = re.escape(__chars)
s = re.sub(r"^(?P<left>.*?)(?:%s)+$" % __chars, "\g<left>", s)
return s


def lstrip(s: str, __chars: str):
__chars = re.escape(__chars)
s = re.sub(r"^(?:%s)+(?P<right>.*)$" % __chars, "\g<right>", s)
return s


def strip(s: str, __chars: str):
s = lstrip(rstrip(s, __chars), __chars)
return s




class Expr: class Expr:
@@ -67,9 +85,29 @@ class Expr:
if not isinstance(outputs, collections.Sequence): if not isinstance(outputs, collections.Sequence):
outputs = (outputs,) outputs = (outputs,)


name = None
if isinstance(self, CallMethod):
name = self.inputs[0]._name
assert name is not None
name = rstrip(name, "_out")
if self.method == "__call__":
name += "_out"
else:
strip_method = strip(self.method, "_")
name = "%s_out" % strip_method
elif isinstance(self, CallFunction):
name = self.func.__name__ + "_out"
elif isinstance(self, Apply):
name = str(self.opdef).lower() + "_out"

for i in outputs: for i in outputs:
assert isinstance(i, RawTensor) assert isinstance(i, RawTensor)
self.outputs.append(NodeMixin.get_wrapped_type(i)(self))
o_name = (
active_module_tracer().current_scope()._create_unique_name(name)
)
self.outputs.append(
NodeMixin.get_wrapped_type(i)(expr=self, name=o_name)
)


for i, node in zip(outputs, self.outputs,): for i, node in zip(outputs, self.outputs,):
NodeMixin.wrap_safe(i, node) NodeMixin.wrap_safe(i, node)
@@ -133,11 +171,16 @@ class Input(Expr):
@classmethod @classmethod
def make(cls, *args, **kwargs): def make(cls, *args, **kwargs):
expr = cls(*args, **kwargs) expr = cls(*args, **kwargs)
active_module_tracer().current_scope().add_input(expr.outputs[0])
oup_node = expr.outputs[0]
name = (
active_module_tracer().current_scope()._create_unique_name(oup_node._name)
)
oup_node._name = name
active_module_tracer().current_scope().add_input(oup_node)
return expr.outputs[0] return expr.outputs[0]


def __repr__(self): def __repr__(self):
return "%{}: {} = Input({})".format(self._id, self.outputs[0], self.name)
return "%{}:\t{} = Input({})".format(self._id, self.outputs[0], self.name)




# expr: outputs = getattr(inputs[0], self.name) # expr: outputs = getattr(inputs[0], self.name)
@@ -154,22 +197,31 @@ class GetAttr(Expr):
self.name = name self.name = name
node_cls = type if type else Node node_cls = type if type else Node
self.outputs = [ self.outputs = [
node_cls(self),
node_cls(self, name=name),
] ]


@classmethod @classmethod
def make(cls, *args, **kwargs): def make(cls, *args, **kwargs):
expr = cls(*args, **kwargs) expr = cls(*args, **kwargs)
module = expr.inputs[0]
oup_name = expr.name
while module._name != "self":
oup_name = module._name + "_" + oup_name
module = module.expr.inputs[0]
oup_name = active_module_tracer().current_scope()._create_unique_name(oup_name)
expr.outputs[0]._name = oup_name
active_module_tracer().current_scope().insert(expr) active_module_tracer().current_scope().insert(expr)
expr.outputs[0]._name = expr.name
return expr.outputs[0] return expr.outputs[0]


def interpret(self, *inputs): def interpret(self, *inputs):
return (getattr(inputs[0], self.name),) return (getattr(inputs[0], self.name),)


def __repr__(self): def __repr__(self):
return '%{}: {} = GetAttr({}, "{}")'.format(
self._id, self.outputs[0], self.inputs[0], self.name
out_type = "Tensor"
if isinstance(self.outputs[0], ModuleNode):
out_type = self.outputs[0].module_type.__name__
return '%{}:\t{} = getattr({}, "{}") -> ({})'.format(
self._id, self.outputs[0], self.inputs[0], self.name, out_type
) )




@@ -230,11 +282,14 @@ class CallMethod(Expr):
outputs = self.outputs outputs = self.outputs
if self.out_def: if self.out_def:
outputs = self.out_def.unflatten(outputs) outputs = self.out_def.unflatten(outputs)
return "%{}: {}{}.{}({})".format(
method = ".%s" % self.method
if method == ".__call__":
method = ""
return "%{}:\t{}{}{}({})".format(
self._id, self._id,
str(outputs) + " = " if outputs else "", str(outputs) + " = " if outputs else "",
self.args[0], self.args[0],
self.method,
method,
", ".join([args, kwargs]), ", ".join([args, kwargs]),
) )


@@ -259,7 +314,7 @@ class Apply(Expr):
return apply(self.opdef, *inputs) return apply(self.opdef, *inputs)


def __repr__(self): def __repr__(self):
return "%{}: {} = {}({})".format(
return "%{}:\t{} = {}({})".format(
self._id, self._id,
", ".join(str(i) for i in self.outputs), ", ".join(str(i) for i in self.outputs),
self.opdef, self.opdef,
@@ -314,10 +369,10 @@ class CallFunction(Expr):
outputs = self.outputs outputs = self.outputs
if self.out_def: if self.out_def:
outputs = self.out_def.unflatten(outputs) outputs = self.out_def.unflatten(outputs)
return "%{}: {}{}({})".format(
return "%{}:\t{}{}({})".format(
self._id, self._id,
str(outputs) + " = " if outputs else "", str(outputs) + " = " if outputs else "",
self.func.__module__ + "." + self.func.__name__,
self.func.__module__.rsplit(".")[-1] + "." + self.func.__name__,
", ".join([args, kwargs]), ", ".join([args, kwargs]),
) )


@@ -328,21 +383,25 @@ class Constant(Expr):
# TODO: constant cache to reduce the size of dumped model # TODO: constant cache to reduce the size of dumped model
_constant_cache = {} _constant_cache = {}


def __init__(self, c):
def __init__(self, c, name=None):
super().__init__() super().__init__()
assert isinstance(c, (RawTensor, Module)) assert isinstance(c, (RawTensor, Module))
if isinstance(c, Module): if isinstance(c, Module):
assert module_tracer.is_builtin(c) assert module_tracer.is_builtin(c)
self.value = c self.value = c
self.name = name
self.inputs = [] self.inputs = []
node_cls = NodeMixin.get_wrapped_type(c) node_cls = NodeMixin.get_wrapped_type(c)
self.outputs = [ self.outputs = [
node_cls(self),
node_cls(self, name=name),
] ]


@classmethod @classmethod
def make(cls, *args, **kwargs): def make(cls, *args, **kwargs):
expr = cls(*args, **kwargs) expr = cls(*args, **kwargs)
name = "const_module" if isinstance(expr.value, Module) else "const_tensor"
name = active_module_tracer().current_scope()._create_unique_name(name)
expr.outputs[0]._name = name
active_module_tracer().current_scope().insert(expr) active_module_tracer().current_scope().insert(expr)
return expr.outputs[0] return expr.outputs[0]


@@ -352,8 +411,14 @@ class Constant(Expr):
return (self.value,) return (self.value,)


def __repr__(self): def __repr__(self):
return "%{}: {} = Constant({})".format(
self._id, self.outputs[0], type(self.value)
name = self.name
if name is None:
name = type(self.value)
node_type = "Module"
if isinstance(self.outputs[0], TensorNode):
node_type = "Tensor"
return "%{}:\t{} = Constant({}) -> ({})".format(
self._id, self.outputs[0], name, node_type
) )


def __getstate__(self): def __getstate__(self):


+ 37
- 16
imperative/python/megengine/experimental/traced_module/node.py View File

@@ -28,8 +28,9 @@ class Node:
expr = None expr = None
__total_id = 0 __total_id = 0
_id = None _id = None
_name = None
_top_graph = None # type: weakref.ReferenceType _top_graph = None # type: weakref.ReferenceType
_name = None
_format_spec = ""


def __init__(self, expr: "Expr", name: str = None): def __init__(self, expr: "Expr", name: str = None):
self.expr = expr self.expr = expr
@@ -43,10 +44,35 @@ class Node:
Node.__total_id = max(Node.__total_id, self._id) + 1 Node.__total_id = max(Node.__total_id, self._id) + 1


def __repr__(self): def __repr__(self):
if self._name is None:
return "%{}".format(self._id)
format_spec = Node._format_spec
return self.__format__(format_spec)

def __format__(self, format_spec: str) -> str:
if format_spec == "" or format_spec is None:
format_spec = Node._format_spec
name = self._name
if name is None:
name = ""
if format_spec in ["i", "p", "ip", "pi"]:
if "p" in format_spec:
graph = self.top_graph
prefix_name = ""
if graph is not None:
prefix_name = graph._name
if graph._prefix_name:
prefix_name = "{}_{}".format(
graph._prefix_name, prefix_name.lstrip("_")
)
if name:
name = "_" + name.lstrip("_")
name = "{}{}".format(prefix_name, name)
if "i" in format_spec:
if name:
name = "_" + name.lstrip("_")
name = "%{}{}".format(self._id, name)
return name
else: else:
return "%{}".format(self._name)
return name if name else ("%d" % self._id)


@property @property
def top_graph(self): def top_graph(self):
@@ -54,6 +80,12 @@ class Node:
return self._top_graph() return self._top_graph()
return None return None


@classmethod
def set_format_spec(cls, str):
old_format_spec = cls._format_spec
cls._format_spec = str
return old_format_spec



class ModuleNode(Node): class ModuleNode(Node):
""" """
@@ -72,12 +104,6 @@ class ModuleNode(Node):
super().__init__(expr, name) super().__init__(expr, name)
self.actual_mnode = [] self.actual_mnode = []


def __repr__(self):
if self._name is None:
return "%{}_({})".format(self._id, self.module_type.__name__)
else:
return "%{}_{}({})".format(self._id, self._name, self.module_type.__name__)

def __getstate__(self): def __getstate__(self):
return { return {
"expr": self.expr, "expr": self.expr,
@@ -104,12 +130,6 @@ class TensorNode(Node):
qparam = None qparam = None
device = None device = None


def __repr__(self):
if self._name is None:
return "%{}_(Tensor)".format(self._id)
else:
return "%{}_{}(Tensor)".format(self._id, self._name)

def __getstate__(self): def __getstate__(self):
return { return {
"expr": self.expr, "expr": self.expr,
@@ -119,6 +139,7 @@ class TensorNode(Node):
"shape": self.shape, "shape": self.shape,
"dtype": self.dtype, "dtype": self.dtype,
"device": self.device, "device": self.device,
"_name": self._name,
} }






+ 14
- 3
imperative/python/megengine/experimental/traced_module/pytree.py View File

@@ -22,6 +22,16 @@ from ...quantization.utils import LSQParams, QParams, QuantMode
from ...tensor import Parameter, Tensor from ...tensor import Parameter, Tensor
from .node import ModuleNode, Node, NodeMixin, TensorNode from .node import ModuleNode, Node, NodeMixin, TensorNode



class ArgsIndex:
def __init__(self, index=0, name="") -> None:
self.index = index
self.name = name

def __repr__(self) -> str:
return self.name


SUPPORTED_TYPE = {} SUPPORTED_TYPE = {}


# if type(object) or obj in SUPPORTED_LEAF_TYPE, the object could be treated as leaf node of pytree # if type(object) or obj in SUPPORTED_LEAF_TYPE, the object could be treated as leaf node of pytree
@@ -39,6 +49,7 @@ SUPPORTED_LEAF_TYPE = {
type(None), type(None),
type(Ellipsis), type(Ellipsis),
QuantMode, QuantMode,
ArgsIndex,
} }


# if isinstance(object, SUPPORTED_LEAF_CLS) or issubclass(obj, SUPPORTED_LEAF_CLS) is True, the object could be threated as leaf node of pytree # if isinstance(object, SUPPORTED_LEAF_CLS) or issubclass(obj, SUPPORTED_LEAF_CLS) is True, the object could be threated as leaf node of pytree
@@ -121,11 +132,11 @@ def _is_leaf(obj):


def _leaf_type(node): def _leaf_type(node):
if isinstance(node, (RawTensor, TensorNode)): if isinstance(node, (RawTensor, TensorNode)):
return (Tensor, TensorNode)
return (Tensor, TensorNode, ArgsIndex)
elif isinstance(node, (NodeMixin, Module)): elif isinstance(node, (NodeMixin, Module)):
return (Module, ModuleNode, NodeMixin)
return (Module, ModuleNode, NodeMixin, ArgsIndex)
else: else:
return type(node)
return (type(node), ArgsIndex)




def _is_const_leaf(node): def _is_const_leaf(node):


+ 198
- 44
imperative/python/megengine/experimental/traced_module/traced_module.py View File

@@ -6,12 +6,15 @@
# Unless required by applicable law or agreed to in writing, # Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import builtins
import collections import collections
import copy import copy
import fnmatch
import functools import functools
import inspect
import keyword
import re
import weakref import weakref
from inspect import getmembers, isclass, ismethod
from inspect import getcallargs, getmembers, isclass, ismethod
from typing import Callable, Dict, Iterable, List, Optional, Sequence, Type, Union from typing import Callable, Dict, Iterable, List, Optional, Sequence, Type, Union


from ... import functional as F from ... import functional as F
@@ -41,11 +44,19 @@ from .module_tracer import (
set_active_module_tracer, set_active_module_tracer,
) )
from .node import ModuleNode, Node, NodeMixin, TensorNode from .node import ModuleNode, Node, NodeMixin, TensorNode
from .pytree import tree_flatten
from .pytree import ArgsIndex, tree_flatten


logger = get_logger(__name__) logger = get_logger(__name__)




def _is_builtin_name(name: str) -> bool:
return (
name in builtins.__dict__
or name in keyword.kwlist
or name in {"inf", "nan", "NoneType"}
)


def _is_leaf(node): def _is_leaf(node):
assert isinstance(node, RawTensor), "doesn't support {} in return values".format( assert isinstance(node, RawTensor), "doesn't support {} in return values".format(
type(node) type(node)
@@ -67,6 +78,7 @@ class _InsertExprs:
def __init__(self, graph, expr: Optional[Expr] = None, after: bool = True): def __init__(self, graph, expr: Optional[Expr] = None, after: bool = True):
self.graph = graph self.graph = graph
self.global_scope = InternalGraph() self.global_scope = InternalGraph()
self.global_scope._used_names.update(graph._used_names)
self.expr = expr self.expr = expr
self.after = after self.after = after


@@ -91,6 +103,7 @@ class _InsertExprs:
for expr in self.global_scope._exprs: for expr in self.global_scope._exprs:
self.graph._exprs.insert(index, expr) self.graph._exprs.insert(index, expr)
index += 1 index += 1
self.graph._used_names.update(self.global_scope._used_names)




class InternalGraph: class InternalGraph:
@@ -107,17 +120,37 @@ class InternalGraph:
_inputs = None # type: List[Node] _inputs = None # type: List[Node]
_outputs = None # type: List[Node] _outputs = None # type: List[Node]


def __init__(self):
def __init__(self, name: str = None, prefix_name: str = ""):
self._exprs = [] self._exprs = []
self._inputs = [] self._inputs = []
self._outputs = [] self._outputs = []
self._watch_point = [] self._watch_point = []
self._end_point = [] self._end_point = []
self._used_names = {}
self._rst = collections.defaultdict(list) self._rst = collections.defaultdict(list)
self._name = name
self._prefix_name = prefix_name


def insert(self, expr): def insert(self, expr):
self._exprs.append(expr) self._exprs.append(expr)


def _create_unique_name(self, name: str) -> str:
assert isinstance(name, str)
name = re.sub("[^0-9a-zA-Z_]+", "_", name)
if name[0].isdigit():
name = "_{}".format(name)

while name in self._used_names or _is_builtin_name(name):
match = re.match(r"(.*)_(\d+)$", name)
if match is None:
name = name + "_1"
else:
base, num = match.group(1, 2)
name = "{}_{}".format(base, int(num) + 1)

self._used_names.setdefault(name)
return name

@property @property
def inputs(self): def inputs(self):
return self._inputs return self._inputs
@@ -150,13 +183,16 @@ class InternalGraph:
def get_node_by_id(self, node_id: List[int] = None): def get_node_by_id(self, node_id: List[int] = None):
return self.node_filter.node_id(node_id) return self.node_filter.node_id(node_id)


def get_node_by_name(self, name: str = None, ignorecase: bool = True):
return self.node_filter.name(name, ignorecase)

def add_input(self, i): def add_input(self, i):
self._inputs.append(i) self._inputs.append(i)


def add_output(self, o): def add_output(self, o):
self._outputs.append(o) self._outputs.append(o)


def _replace_inputs_outputs(self, repl_dict):
def _replace_inputs_outputs_and_add_prefixname(self, repl_dict, prefix_name=""):


for node, repl_node in repl_dict.items(): for node, repl_node in repl_dict.items():
assert node in self._inputs or node in self._outputs assert node in self._inputs or node in self._outputs
@@ -175,13 +211,29 @@ class InternalGraph:
for expr in self._exprs: for expr in self._exprs:


for idx, i in enumerate(expr.inputs): for idx, i in enumerate(expr.inputs):
assert i._name is not None
if i in repl_dict: if i in repl_dict:
expr.inputs[idx] = repl_dict[i] expr.inputs[idx] = repl_dict[i]
elif isinstance(i, TensorNode) and prefix_name not in i._name:
if i.top_graph != active_module_tracer().current_scope():
i._name = (
active_module_tracer()
.current_scope()
._create_unique_name(prefix_name + i._name.lstrip("_"))
)


for idx, o in enumerate(expr.outputs): for idx, o in enumerate(expr.outputs):
assert o._name is not None
if o in repl_dict: if o in repl_dict:
expr.outputs[idx] = repl_dict[o] expr.outputs[idx] = repl_dict[o]
expr.outputs[idx].expr = expr expr.outputs[idx].expr = expr
elif isinstance(o, TensorNode) and prefix_name not in i._name:
if o.top_graph != active_module_tracer().current_scope():
o._name = (
active_module_tracer()
.current_scope()
._create_unique_name(prefix_name + o._name.lstrip("_"))
)


def get_dep_exprs(self, nodes: Sequence[Node]) -> List[Expr]: def get_dep_exprs(self, nodes: Sequence[Node]) -> List[Expr]:
if not isinstance(nodes, Sequence): if not isinstance(nodes, Sequence):
@@ -258,7 +310,7 @@ class InternalGraph:
# return formal_node_inputs[1:], actual_nodes # return formal_node_inputs[1:], actual_nodes
return formal_node_inputs[1:] return formal_node_inputs[1:]


def add_input_node(self, shape, dtype="float32"):
def add_input_node(self, shape, dtype="float32", name="args"):
forma_mnode = self.inputs[0] forma_mnode = self.inputs[0]
actual_mnodes = forma_mnode.actual_mnode actual_mnodes = forma_mnode.actual_mnode


@@ -271,11 +323,11 @@ class InternalGraph:
if isinstance(c_expr, CallMethod) and c_expr.method == "__call__": if isinstance(c_expr, CallMethod) and c_expr.method == "__call__":
call_nodes.append(c_expr) call_nodes.append(c_expr)


def create_node(is_input: bool = True):
def create_node(name=None, is_input: bool = True):
if is_input: if is_input:
node = Input(type=TensorNode).outputs[0]
node = Input(type=TensorNode, name=name).outputs[0]
else: else:
node = TensorNode(expr=None)
node = TensorNode(expr=None, name=None)
node.shape = shape node.shape = shape
node.dtype = dtype node.dtype = dtype
return node return node
@@ -286,7 +338,7 @@ class InternalGraph:
org_argdef = call_nodes[0].arg_def org_argdef = call_nodes[0].arg_def


args, kwargs = org_argdef.unflatten(self._inputs) args, kwargs = org_argdef.unflatten(self._inputs)
formal_inp_node = create_node(True)
formal_inp_node = create_node(self._create_unique_name(name), True)
inputs, tree_def = tree_flatten( inputs, tree_def = tree_flatten(
((*args, formal_inp_node), kwargs), ((*args, formal_inp_node), kwargs),
is_const_leaf=lambda x: not isinstance(x, (TensorNode, ModuleNode)), is_const_leaf=lambda x: not isinstance(x, (TensorNode, ModuleNode)),
@@ -524,11 +576,21 @@ class InternalGraph:
return self.interpret(*inp) return self.interpret(*inp)


def __repr__(self): def __repr__(self):
return "InternalGraph ({}) {{\n\t{}\n\treturn {}\n}}".format(
return self.__format__()

def __format__(self, format_spec: str = "") -> str:
saved_format_spec = Node.set_format_spec(format_spec)
name = ""
if self._name:
name = "%s.Graph" % self._name
res = "{} ({}) {{\n\t{}\n\treturn {}\n}}".format(
name,
", ".join(str(i) for i in self._inputs), ", ".join(str(i) for i in self._inputs),
"\n\t".join("{}".format(str(i)) for i in self._exprs), "\n\t".join("{}".format(str(i)) for i in self._exprs),
", ".join(str(i) for i in self._outputs), ", ".join(str(i) for i in self._outputs),
) )
Node.set_format_spec(saved_format_spec)
return res




def _get_meth_name(obj, func): def _get_meth_name(obj, func):
@@ -621,6 +683,7 @@ class TracedModuleBuilder(NodeMixin):
self._is_builtin = module_tracer.is_builtin(mod) self._is_builtin = module_tracer.is_builtin(mod)
self._argdef_graph_map = {} self._argdef_graph_map = {}
self._argdef_outdef_map = {} self._argdef_outdef_map = {}

self.nodes = set() self.nodes = set()
# The builder will be passed to self._mod.forward as 'self' argument. If the 'forward' uses super().xxx to call method of its base classes, the trace procedure will throw exceprion, because the builder doesn't inherit from self._mod.__bases__. # The builder will be passed to self._mod.forward as 'self' argument. If the 'forward' uses super().xxx to call method of its base classes, the trace procedure will throw exceprion, because the builder doesn't inherit from self._mod.__bases__.
# modify self.__class__ and let the builder inherit from TracedModuleBuilder and mod.__class__. # modify self.__class__ and let the builder inherit from TracedModuleBuilder and mod.__class__.
@@ -631,7 +694,7 @@ class TracedModuleBuilder(NodeMixin):
) )


def build(self): def build(self):
if self._is_builtin:
if self._is_builtin or isinstance(self._mod, TracedModule):
for node in self.nodes: for node in self.nodes:
node.module_type = type(self._mod) node.module_type = type(self._mod)
# node._owner = weakref.ref(self._mod) # node._owner = weakref.ref(self._mod)
@@ -671,21 +734,38 @@ class TracedModuleBuilder(NodeMixin):


callnode.arg_def = tree_def callnode.arg_def = tree_def


if self._is_builtin:
if (
self._is_builtin
or tree_def in self._argdef_graph_map
or isinstance(self._mod, TracedModule)
):
unset_module_tracing() unset_module_tracing()
rst = self._mod(*args, **kwargs) rst = self._mod(*args, **kwargs)
outputs, out_def = tree_flatten(rst, is_leaf=_is_leaf) outputs, out_def = tree_flatten(rst, is_leaf=_is_leaf)
set_module_tracing() set_module_tracing()
if self._is_builtin: if self._is_builtin:
self._body = None self._body = None
elif tree_def in self._argdef_graph_map:
self._body = self._argdef_graph_map[tree_def]
else:
self._mod._is_top = False
self._body = self._mod.graph
name = NodeMixin.get(self)._name
if name:
self._body._name = name
else: else:
self_node = None self_node = None
if tree_def in self._argdef_graph_map:
self_node = self._argdef_graph_map[tree_def].inputs[0]
self._body = InternalGraph()
orig_self = NodeMixin.get(self)
top_graph = active_module_tracer().current_scope()
graph_prefix_name = top_graph._name
if top_graph._prefix_name:
graph_prefix_name = "{}_{}".format(
top_graph._prefix_name, graph_prefix_name.lstrip("_")
)
self._body = InternalGraph(orig_self._name, prefix_name=graph_prefix_name)
active_module_tracer().push_scope(self._body) active_module_tracer().push_scope(self._body)
# rebind self to new input node # rebind self to new input node
orig_self = NodeMixin.get(self)
if self_node: if self_node:
NodeMixin.wrap_safe(self, self_node) NodeMixin.wrap_safe(self, self_node)
active_module_tracer().current_scope().add_input(self_node) active_module_tracer().current_scope().add_input(self_node)
@@ -698,16 +778,37 @@ class TracedModuleBuilder(NodeMixin):
) )
origin_inp_node = [NodeMixin.get(i, None) for i in inputs[1:]] origin_inp_node = [NodeMixin.get(i, None) for i in inputs[1:]]
# prepare args and kwargs for inner graph # prepare args and kwargs for inner graph
def wrap(x):
index_args, index_kwargs = tree_def.unflatten(
[
ArgsIndex(0),
*list(ArgsIndex(i + 1) for i in range(len(origin_inp_node))),
]
)
key2idx = getcallargs(type(self._mod).forward, *index_args, **index_kwargs)
idx2key = {}
for k, v in key2idx.items():
if isinstance(v, ArgsIndex):
idx2key[v.index] = k
else:
flatten_argidx, _ = tree_flatten(v)
for _i, v in enumerate(flatten_argidx):
if isinstance(v, ArgsIndex):
idx2key[v.index] = k + "_%d" % _i

def wrap(x, name):
if isinstance(x, (RawTensor, NodeMixin)): if isinstance(x, (RawTensor, NodeMixin)):
NodeMixin.wrap( NodeMixin.wrap(
x, lambda: Input.make(type=NodeMixin.get_wrapped_type(x)),
x,
lambda: Input.make(
type=NodeMixin.get_wrapped_type(x), name=name
),
) )
return x return x


args = [self] args = [self]
for i in inputs[1:]:
args.append(wrap(i))
for i, v in enumerate(inputs[1:]):
args.append(wrap(v, idx2key[i + 1]))

args, kwargs = tree_def.unflatten(args) args, kwargs = tree_def.unflatten(args)
active_module_tracer().patcher.auto_patch( active_module_tracer().patcher.auto_patch(
getattr(getattr(self._mod, "forward", self._mod), "__globals__", {}) getattr(getattr(self._mod, "forward", self._mod), "__globals__", {})
@@ -857,6 +958,9 @@ class NodeFilter(BaseFilter):
def node_id(self, node_id: List[int]): def node_id(self, node_id: List[int]):
return NodeFilterNodeId(self, node_id) return NodeFilterNodeId(self, node_id)


def name(self, name: str, ignorecase: bool = True):
return NodeFilterName(self, name, ignorecase)



class NodeFilterType(NodeFilter): class NodeFilterType(NodeFilter):
def __init__(self, expr_iter, owner_type, node_type): def __init__(self, expr_iter, owner_type, node_type):
@@ -887,6 +991,33 @@ class NodeFilterNodeId(NodeFilter):
yield node yield node




class NodeFilterName(NodeFilter):
_re = None

def __init__(self, node_iter, pattern, ignorecase):
super().__init__(node_iter)
self.pattern = pattern
self._re = self.make_re(pattern, ignorecase)

@classmethod
def make_re(cls, pattern, ignorecase=True):
assert isinstance(pattern, str), "bad pattern: {!r}".format(pattern)
assert isinstance(ignorecase, bool)
flags = 0
if ignorecase:
flags |= re.IGNORECASE
return re.compile(fnmatch.translate(pattern), flags=flags)

def __iter__(self):
for i in self._iter:
graph = i.top_graph
name = "{}_{}".format(graph._name, i._name.lstrip("_"))
if graph._prefix_name:
name = "{}_{}".format(graph._prefix_name, name.lstrip("_"))
if self.pattern == name or self._re.match(name):
yield i


class ExprFilterCallFunction(ExprFilter): class ExprFilterCallFunction(ExprFilter):
def __init__(self, expr_iter, func: Callable = None): def __init__(self, expr_iter, func: Callable = None):
super().__init__(expr_iter) super().__init__(expr_iter)
@@ -1052,12 +1183,29 @@ class TracedModule(Module):
:return: :class:`TracedModule` :return: :class:`TracedModule`
""" """
new_module = copy.deepcopy(self) new_module = copy.deepcopy(self)

def _flatten_subgraph(graph, module, call=None):
module2name = {}
assert active_module_tracer() is None
set_active_module_tracer(module_tracer(lambda x: x))
active_module_tracer().push_scope(new_module.graph)
for n, m in new_module.named_modules():
module2name[id(m)] = n

def _flatten_subgraph(
graph: InternalGraph, module: Module, call=None, prefix_name=""
):
if graph is not None and prefix_name and prefix_name[-1] != "_":
prefix_name += "_"
if graph is None: if graph is None:
assert not isinstance(module, TracedModule) assert not isinstance(module, TracedModule)
const = Constant(module)
const.outputs[0] = call.inputs[0]
const = Constant(module, "self.%s" % module2name[id(module)])
m_node = call.inputs[0]
if m_node.top_graph != active_module_tracer().current_scope():
m_node._name = (
active_module_tracer()
.current_scope()
._create_unique_name(prefix_name)
)
const.outputs[0] = m_node
const.outputs[0].expr = const const.outputs[0].expr = const
return [const, call] return [const, call]
if call is not None: if call is not None:
@@ -1083,7 +1231,7 @@ class TracedModule(Module):
continue continue
repl_dict[out] = call.outputs[ind] repl_dict[out] = call.outputs[ind]


graph._replace_inputs_outputs(repl_dict)
graph._replace_inputs_outputs_and_add_prefixname(repl_dict, prefix_name)


for expr in graph._exprs: for expr in graph._exprs:
if isinstance(expr, GetAttr): if isinstance(expr, GetAttr):
@@ -1109,7 +1257,14 @@ class TracedModule(Module):
if hasattr(obj, "argdef_graph_map") if hasattr(obj, "argdef_graph_map")
else None else None
) )
exprs.extend(_flatten_subgraph(expr_graph, obj, expr))
exprs.extend(
_flatten_subgraph(
expr_graph,
obj,
expr,
prefix_name + obj_node._name.lstrip("_"),
)
)
else: else:
# module has been replaced. # module has been replaced.
assert isinstance(pre_expr, Constant) assert isinstance(pre_expr, Constant)
@@ -1126,7 +1281,18 @@ class TracedModule(Module):
return exprs return exprs


new_module.graph._exprs = _flatten_subgraph(new_module.graph, new_module) new_module.graph._exprs = _flatten_subgraph(new_module.graph, new_module)

new_module.graph.compile()
set_active_module_tracer(None)
for _id, expr in enumerate(new_module.graph._exprs):
expr._id = _id
total_node_id = 0
for i in new_module.graph._inputs:
i._id = total_node_id
total_node_id += 1
for expr in new_module.graph._exprs:
for o in expr.outputs:
o._id = total_node_id
total_node_id += 1
return new_module return new_module


def __getstate__(self): def __getstate__(self):
@@ -1149,19 +1315,7 @@ def register_as_builtin(mod_cls: Type[Module]) -> None:
module_tracer.register_as_builtin(mod_cls) module_tracer.register_as_builtin(mod_cls)




def wrap(func: Union[Callable]):
assert callable(func)
if hasattr(func, "__code__"):
assert not isinstance(func, str)
fn_name = func.__code__.co_name
currentframe = inspect.currentframe()
assert currentframe is not None
f = currentframe.f_back
assert f is not None
if f.f_code.co_name != "<module>":
raise NotImplementedError("wrap must be called at the top level of a module")
Patcher._builtin_functions.append((f.f_globals, fn_name))
return func
wrap = _wrapped_function




def _register_all_builtin_module(): def _register_all_builtin_module():
@@ -1192,11 +1346,11 @@ def trace_module(mod: Module, *args: Tensor, **kwargs: Tensor) -> TracedModule:
set_active_module_tracer(module_tracer(_wrapped_function)) set_active_module_tracer(module_tracer(_wrapped_function))


with active_module_tracer().patcher: with active_module_tracer().patcher:
global_scope = InternalGraph()
global_scope = InternalGraph(name="")
active_module_tracer().push_scope(global_scope) active_module_tracer().push_scope(global_scope)

builder = TracedModuleBuilder(mod, True) builder = TracedModuleBuilder(mod, True)
NodeMixin.wrap_safe(builder, Input.make("TopModule", ModuleNode))
name = mod._name if mod._name else mod.__class__.__name__
NodeMixin.wrap_safe(builder, Input.make(name, ModuleNode))
inputs, _ = tree_flatten((args, kwargs)) inputs, _ = tree_flatten((args, kwargs))
for _, i in enumerate(inputs): for _, i in enumerate(inputs):
# assert isinstance(i, Tensor), "not support " # assert isinstance(i, Tensor), "not support "


Loading…
Cancel
Save