GitOrigin-RevId: aaa9e51c74
release-1.7
@@ -7,6 +7,7 @@ | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
from ..core._imperative_rt.core2 import set_cpp_apply_module_trace | from ..core._imperative_rt.core2 import set_cpp_apply_module_trace | ||||
from . import compat | |||||
from .traced_module import ( | from .traced_module import ( | ||||
TracedModule, | TracedModule, | ||||
_register_all_builtin_module, | _register_all_builtin_module, | ||||
@@ -0,0 +1,136 @@ | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
import numpy as np | |||||
from .. import tensor | |||||
from ..core.ops.builtin import BatchNorm | |||||
from .expr import CallMethod, Constant | |||||
from .node import TensorNode | |||||
from .serialization import ( | |||||
register_functional_loader, | |||||
register_module_loader, | |||||
register_opdef_loader, | |||||
register_tensor_method_loader, | |||||
) | |||||
""" | |||||
# Expr loaders examples | |||||
from ..core.ops.builtin import Elemwise | |||||
@register_opdef_loader(Elemwise) | |||||
def add_opdef_loader(expr): | |||||
if expr.opdef_state["mode"] == "ADD": | |||||
expr.opdef_state["mode"] == "MUL" | |||||
node = expr.inputs[1] | |||||
astype_expr = CallMethod(node, "astype") | |||||
oup = TensorNode( | |||||
astype_expr, | |||||
shape=node.shape, | |||||
dtype=expr.inputs[0].dtype, | |||||
qparams=node.qparams, | |||||
) | |||||
astype_expr.set_args_kwargs(node, expr.inputs[0].dtype) | |||||
astype_expr.return_val = (oup,) | |||||
expr.inputs[1] = oup | |||||
@register_functional_loader(("megengine.functional.nn", "conv2d")) | |||||
def conv2df_loader(expr): | |||||
# expr.func = ("megengine.functional.nn","conv2d") | |||||
kwargs = expr.kwargs | |||||
orig_weight = expr.named_args["weight"] | |||||
astype_expr = CallMethod(orig_weight, "astype") | |||||
oup = TensorNode( | |||||
astype_expr, | |||||
shape=orig_weight.shape, | |||||
dtype=orig_weight.dtype, | |||||
qparams=orig_weight.qparams, | |||||
) | |||||
astype_expr.set_args_kwargs(orig_weight, expr.named_args["inp"].dtype) | |||||
astype_expr.return_val = (oup,) | |||||
expr.set_arg("weight", oup) | |||||
@register_module_loader(("megengine.module.conv", "Conv2d")) | |||||
def conv2dm_loader(expr): | |||||
module = expr.inputs[0].owner | |||||
args = list(expr.args) | |||||
orig_inp = args[1] | |||||
astype_expr = CallMethod(orig_inp, "astype") | |||||
oup = TensorNode( | |||||
astype_expr, | |||||
shape=orig_inp.shape, | |||||
dtype=orig_inp.dtype, | |||||
qparams=orig_inp.qparams, | |||||
) | |||||
astype_expr.set_args_kwargs(orig_inp, module.weight.dtype) | |||||
astype_expr.return_val = (oup,) | |||||
args[1] = oup | |||||
expr.set_args_kwargs(*args) | |||||
@register_tensor_method_loader("__add__") | |||||
def add_loader(expr): | |||||
args = list(expr.args) | |||||
if not isinstance(args[1], TensorNode): | |||||
args[1] = tensor(args[1]) | |||||
node = Constant(args[1], "const").outputs[0] | |||||
astype_expr = CallMethod(node, "astype") | |||||
oup = TensorNode( | |||||
astype_expr, shape=node.shape, dtype=node.dtype, qparams=node.qparams, | |||||
) | |||||
astype_expr.set_args_kwargs(node, expr.inputs[0].dtype) | |||||
astype_expr.return_val = (oup,) | |||||
args[1] = oup | |||||
expr.set_args_kwargs(*args) | |||||
""" | |||||
@register_module_loader( | |||||
("megengine.module.batchnorm", "BatchNorm1d"), | |||||
("megengine.module.batchnorm", "BatchNorm2d"), | |||||
("megengine.module.batchnorm", "SyncBatchNorm"), | |||||
) | |||||
def bn2d_module_loader(expr): | |||||
# mge 1.6 | |||||
if not hasattr(expr, "version"): | |||||
module = expr.inputs[0].owner | |||||
if not hasattr(module, "param_dim"): | |||||
module.param_dim = "dim_1c11" | |||||
@register_module_loader( | |||||
("megengine.module.conv_bn", "ConvBn2d"), | |||||
("megengine.module.conv_bn", "ConvBnRelu2d"), | |||||
("megengine.module.qat.conv_bn", "ConvBn2d"), | |||||
("megengine.module.qat.conv_bn", "ConvBnRelu2d"), | |||||
) | |||||
def convbn2d_module_loader(expr): | |||||
# mge 1.6 | |||||
if not hasattr(expr, "version"): | |||||
module = expr.inputs[0].owner | |||||
if not hasattr(module.bn, "param_dim"): | |||||
module.bn.param_dim = "dim_1c11" | |||||
@register_opdef_loader(BatchNorm) | |||||
def bn_opdef_loader(expr): | |||||
# mge 1.6 | |||||
if not hasattr(expr, "version"): | |||||
output = expr.outputs[-1] | |||||
oup = TensorNode(expr, shape=(0,), dtype=None, qparams=output._qparams,) | |||||
expr.outputs.insert(4, oup) |
@@ -11,19 +11,28 @@ import collections | |||||
import copy | import copy | ||||
import inspect | import inspect | ||||
import re | import re | ||||
from typing import Callable, Dict, List, Optional, Union | |||||
import weakref | |||||
from importlib import import_module | |||||
from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union | |||||
from ..core._imperative_rt import OpDef | from ..core._imperative_rt import OpDef | ||||
from ..core._imperative_rt.core2 import Tensor as RawTensor | from ..core._imperative_rt.core2 import Tensor as RawTensor | ||||
from ..core._imperative_rt.core2 import apply, set_module_tracing, unset_module_tracing | |||||
from ..core._imperative_rt.core2 import ( | |||||
apply, | |||||
is_tracing_module, | |||||
set_module_tracing, | |||||
unset_module_tracing, | |||||
) | |||||
from ..core.ops.builtin import FakeQuant | from ..core.ops.builtin import FakeQuant | ||||
from ..core.ops.special import Const | from ..core.ops.special import Const | ||||
from ..module import Module | from ..module import Module | ||||
from ..tensor import Parameter, Tensor | from ..tensor import Parameter, Tensor | ||||
from ..version import __version__ | |||||
from .module_tracer import active_module_tracer, module_tracer | from .module_tracer import active_module_tracer, module_tracer | ||||
from .node import ModuleNode, Node, NodeMixin, TensorNode | from .node import ModuleNode, Node, NodeMixin, TensorNode | ||||
from .pytree import ArgsIndex, TreeDef, _is_const_leaf, _is_leaf, tree_flatten | from .pytree import ArgsIndex, TreeDef, _is_const_leaf, _is_leaf, tree_flatten | ||||
from .serialization import get_opdef_state, load_opdef_from_state | |||||
from .serialization import _ModuleState | |||||
from .utils import _check_builtin_module_attr, _check_obj_attr, _convert_kwargs_to_args | |||||
def rstrip(s: str, __chars: str): | def rstrip(s: str, __chars: str): | ||||
@@ -112,6 +121,7 @@ class Expr: | |||||
node.users.append(self) | node.users.append(self) | ||||
else: | else: | ||||
assert node is None | assert node is None | ||||
assert not isinstance(val, (Module, RawTensor)) | |||||
assert _is_leaf(val) and _is_const_leaf(val) | assert _is_leaf(val) and _is_const_leaf(val) | ||||
idx = len(self.inputs) + len(self.const_val) | idx = len(self.inputs) + len(self.const_val) | ||||
self.const_val.append((idx, val)) | self.const_val.append((idx, val)) | ||||
@@ -132,14 +142,14 @@ class Expr: | |||||
current_graph._namespace.auto_naming_for_outputs(self) | current_graph._namespace.auto_naming_for_outputs(self) | ||||
def unflatten_args(self, inputs): | def unflatten_args(self, inputs): | ||||
if self.arg_def is not None: | |||||
inputs = list(inputs) | |||||
for idx, val in self.const_val: | |||||
inputs.insert(idx, val) | |||||
args, kwargs = self.arg_def.unflatten(inputs) | |||||
return args, kwargs | |||||
else: | |||||
return inputs, {} | |||||
assert self.arg_def is not None, "{} expr doesn't have args/kwargs".format( | |||||
type(self).__name__ | |||||
) | |||||
inputs = list(inputs) | |||||
for idx, val in self.const_val: | |||||
inputs.insert(idx, val) | |||||
args, kwargs = self.arg_def.unflatten(inputs) | |||||
return args, kwargs | |||||
def replace_inputs(self, repl_dict: Dict[Node, Node]): | def replace_inputs(self, repl_dict: Dict[Node, Node]): | ||||
r"""Replace the input Nodes of this Expr. | r"""Replace the input Nodes of this Expr. | ||||
@@ -166,6 +176,39 @@ class Expr: | |||||
repl_node.users.append(self) | repl_node.users.append(self) | ||||
@property | @property | ||||
def _support_set_args_kwargs(self): | |||||
return False | |||||
def set_args_kwargs(self, *args, **kwargs): | |||||
r""" Set args and kwargs for Expr. | |||||
""" | |||||
assert ( | |||||
self._support_set_args_kwargs | |||||
), "Doesn't support set args/kwargs for {} expr".format(type(self).__name__) | |||||
args, kwargs = _convert_kwargs_to_args(self._get_func(), args, kwargs) | |||||
inputs, arg_def = tree_flatten((args, kwargs)) | |||||
orig_inputs = self.inputs | |||||
self.inputs = [] | |||||
self.const_val = [] | |||||
for val in inputs: | |||||
if isinstance(val, (TensorNode, ModuleNode)): | |||||
self.inputs.append(val) | |||||
else: | |||||
assert _is_leaf(val) and _is_const_leaf(val) | |||||
idx = len(self.inputs) + len(self.const_val) | |||||
self.const_val.append((idx, val)) | |||||
for n in orig_inputs: | |||||
if n not in self.inputs: | |||||
n.users.remove(self) | |||||
for n in self.inputs: | |||||
if n not in orig_inputs: | |||||
n.users.append(self) | |||||
self.arg_def = arg_def | |||||
@property | |||||
def kwargs(self): | def kwargs(self): | ||||
r"""Get the keyword arguments of the operation corresponding to this Expr.""" | r"""Get the keyword arguments of the operation corresponding to this Expr.""" | ||||
_, kwargs = self.unflatten_args(self.inputs) | _, kwargs = self.unflatten_args(self.inputs) | ||||
@@ -177,6 +220,61 @@ class Expr: | |||||
args, _ = self.unflatten_args(self.inputs) | args, _ = self.unflatten_args(self.inputs) | ||||
return args | return args | ||||
def _get_func(self): | |||||
# get called function when the expr is interpreted | |||||
raise NotImplementedError | |||||
@property | |||||
def named_args(self): | |||||
func = self._get_func() | |||||
return inspect.getcallargs(func, *self.args, **self.kwargs) | |||||
def set_arg(self, name, val): | |||||
func = self._get_func() | |||||
if name in self.kwargs: | |||||
new_kwargs = self.kwargs | |||||
new_kwargs[name] = val | |||||
self.set_args_kwargs(*self.args, **new_kwargs) | |||||
else: | |||||
arg_spec = inspect.getfullargspec(func) | |||||
if name in arg_spec.args: | |||||
ind = arg_spec.args.index(name) | |||||
new_args = list(self.args) | |||||
new_args[ind] = val | |||||
self.set_args_kwargs(*new_args) | |||||
elif name == arg_spec.varargs: | |||||
assert arg_spec.varargs is not None | |||||
assert len(self.args) >= len(arg_spec.args) | |||||
val = (val,) if not isinstance(val, Sequence) else val | |||||
self.set_args_kwargs(*self.args[0 : len(arg_spec.args)], *val) | |||||
else: | |||||
assert ( | |||||
arg_spec.varkw is not None | |||||
), "func {} does't have argument named {}".format(func, name) | |||||
new_kwargs = self.kwargs | |||||
new_kwargs[name] = val | |||||
self.set_args_kwargs(*self.args, **new_kwargs) | |||||
@property | |||||
def return_val(self): | |||||
return self.out_def.unflatten(self.outputs) | |||||
@return_val.setter | |||||
def return_val(self, new_outputs): | |||||
outputs, out_def = tree_flatten( | |||||
new_outputs, is_leaf=lambda x: isinstance(x, Node) | |||||
) | |||||
assert all( | |||||
isinstance(o, Node) for o in outputs | |||||
), "Return values of expr must be ModuleNode or TensorNode or Container with them" | |||||
assert all( | |||||
o.expr in (None, self) for o in outputs | |||||
), "Some nodes are produced by other expr, can not be output of expr {}".format( | |||||
self | |||||
) | |||||
self.outputs = outputs | |||||
self.out_def = out_def | |||||
@property | @property | ||||
def top_graph(self): | def top_graph(self): | ||||
r"""Get the parent graph of this Expr.""" | r"""Get the parent graph of this Expr.""" | ||||
@@ -184,12 +282,6 @@ class Expr: | |||||
return self._top_graph() | return self._top_graph() | ||||
return None | return None | ||||
def __getstate__(self): | |||||
state = self.__dict__.copy() | |||||
if "_top_graph" in state: | |||||
state.pop("_top_graph") | |||||
return state | |||||
@classmethod | @classmethod | ||||
def _get_next_id(cls): | def _get_next_id(cls): | ||||
return cls.__total_id | return cls.__total_id | ||||
@@ -199,6 +291,23 @@ class Expr: | |||||
assert isinstance(id, int) | assert isinstance(id, int) | ||||
cls.__total_id = id | cls.__total_id = id | ||||
def __copy__(self): | |||||
cls = self.__class__ | |||||
result = cls.__new__(cls) | |||||
result.__dict__.update(self.__dict__) | |||||
return result | |||||
def __deepcopy__(self, memo): | |||||
cls = self.__class__ | |||||
result = cls.__new__(cls) | |||||
state = {} | |||||
memo[id(self)] = result | |||||
for k, v in self.__dict__.items(): | |||||
if not isinstance(v, weakref.ReferenceType): | |||||
state[k] = copy.deepcopy(v, memo) | |||||
result.__dict__.update(state) | |||||
return result | |||||
# expr: None (i.e. fake expression which is used to mark input) | # expr: None (i.e. fake expression which is used to mark input) | ||||
class Input(Expr): | class Input(Expr): | ||||
@@ -229,6 +338,17 @@ class Input(Expr): | |||||
def __repr__(self): | def __repr__(self): | ||||
return "%{}:\t{} = Input()".format(self._id, self.outputs[0]) | return "%{}:\t{} = Input()".format(self._id, self.outputs[0]) | ||||
def __getstate__(self): | |||||
state = { | |||||
"_id": self._id, | |||||
"_disable_remove": self._disable_remove, | |||||
"inputs": self.inputs, | |||||
"outputs": self.outputs, | |||||
"name": self.name, | |||||
} | |||||
_check_obj_attr(state) | |||||
return state | |||||
# expr: outputs = getattr(inputs[0], self.name) | # expr: outputs = getattr(inputs[0], self.name) | ||||
class GetAttr(Expr): | class GetAttr(Expr): | ||||
@@ -276,11 +396,23 @@ class GetAttr(Expr): | |||||
def __repr__(self): | def __repr__(self): | ||||
out_type = "Tensor" | out_type = "Tensor" | ||||
if isinstance(self.outputs[0], ModuleNode): | if isinstance(self.outputs[0], ModuleNode): | ||||
out_type = self.outputs[0].module_type.__name__ | |||||
m_type = self.outputs[0].module_type | |||||
out_type = m_type.__name__ if isinstance(m_type, type) else m_type[1] | |||||
return '%{}:\t{} = getattr({}, "{}") -> ({})'.format( | return '%{}:\t{} = getattr({}, "{}") -> ({})'.format( | ||||
self._id, self.outputs[0], self.inputs[0], self.name, out_type | self._id, self.outputs[0], self.inputs[0], self.name, out_type | ||||
) | ) | ||||
def __getstate__(self): | |||||
state = { | |||||
"_id": self._id, | |||||
"_disable_remove": self._disable_remove, | |||||
"inputs": self.inputs, | |||||
"outputs": self.outputs, | |||||
"name": self.name, | |||||
} | |||||
_check_obj_attr(state) | |||||
return state | |||||
# expr: outputs = inputs[0].__call__(*inputs[1:]) | # expr: outputs = inputs[0].__call__(*inputs[1:]) | ||||
class CallMethod(Expr): | class CallMethod(Expr): | ||||
@@ -307,6 +439,7 @@ class CallMethod(Expr): | |||||
node, | node, | ||||
] | ] | ||||
self.const_val = [] | self.const_val = [] | ||||
self.arg_def = tree_flatten(((node,), {}))[1] | |||||
self.method = method | self.method = method | ||||
@classmethod | @classmethod | ||||
@@ -342,6 +475,27 @@ class CallMethod(Expr): | |||||
outputs, _ = tree_flatten(outputs, is_leaf=lambda x: isinstance(x, RawTensor)) | outputs, _ = tree_flatten(outputs, is_leaf=lambda x: isinstance(x, RawTensor)) | ||||
return outputs | return outputs | ||||
def _get_func(self): | |||||
if isinstance(self.args[0], type): | |||||
obj_type = self.args[0] | |||||
elif isinstance(self.args[0], ModuleNode): | |||||
obj_type = self.args[0].module_type | |||||
else: | |||||
assert isinstance(self.args[0], TensorNode) | |||||
obj_type = Tensor | |||||
meth = getattr( | |||||
obj_type, "forward" if issubclass(obj_type, Module) else self.method | |||||
) | |||||
return meth | |||||
@property | |||||
def _support_set_args_kwargs(self): | |||||
# only expr call tensor method or builtin module support modify args/kwargs | |||||
return ( | |||||
isinstance(self.args[0], (TensorNode, type)) | |||||
or self.args[0].module_type is not Module | |||||
) | |||||
def __repr__(self): | def __repr__(self): | ||||
args = ", ".join(str(i) for i in self.args[1:]) | args = ", ".join(str(i) for i in self.args[1:]) | ||||
kwargs = ", ".join("{}={}".format(k, v) for k, v in self.kwargs.items()) | kwargs = ", ".join("{}={}".format(k, v) for k, v in self.kwargs.items()) | ||||
@@ -359,6 +513,21 @@ class CallMethod(Expr): | |||||
", ".join([args, kwargs]), | ", ".join([args, kwargs]), | ||||
) | ) | ||||
def __getstate__(self): | |||||
state = { | |||||
"_id": self._id, | |||||
"_disable_remove": self._disable_remove, | |||||
"inputs": self.inputs, | |||||
"const_val": self.const_val, | |||||
"method": self.method, | |||||
"arg_def": self.arg_def, | |||||
"out_def": self.out_def, | |||||
"outputs": self.outputs, | |||||
"version": __version__, | |||||
} | |||||
_check_obj_attr(state) | |||||
return state | |||||
# expr: outputs = apply(self.opdef, *inputs) | # expr: outputs = apply(self.opdef, *inputs) | ||||
class Apply(Expr): | class Apply(Expr): | ||||
@@ -394,14 +563,32 @@ class Apply(Expr): | |||||
) | ) | ||||
def __getstate__(self): | def __getstate__(self): | ||||
state = super().__getstate__() | |||||
state["opdef"] = get_opdef_state(state["opdef"]) | |||||
opdef_state = self.opdef.__getstate__() | |||||
opdef_state["opdef_type"] = type(self.opdef) | |||||
state = { | |||||
"_id": self._id, | |||||
"_disable_remove": self._disable_remove, | |||||
"opdef_state": opdef_state, | |||||
"inputs": self.inputs, | |||||
"outputs": self.outputs, | |||||
"version": __version__, | |||||
} | |||||
_check_obj_attr(state) | |||||
return state | return state | ||||
def __setstate__(self, state): | def __setstate__(self, state): | ||||
state["opdef"] = load_opdef_from_state(state["opdef"]) | |||||
for k, v in state.items(): | |||||
setattr(self, k, v) | |||||
# compat with mge 1.6 | |||||
if "opdef" in state and "opdef_state" not in state: | |||||
opdef_state = state.pop("opdef") | |||||
opdef_state["opdef_type"] = opdef_state.pop("type") | |||||
state["opdef_state"] = opdef_state | |||||
self.__dict__.update(state) | |||||
assert isinstance(state["opdef_state"], dict) | |||||
opdef_state = state["opdef_state"].copy() | |||||
opdef_type = opdef_state.pop("opdef_type") | |||||
opdef_obj = opdef_type() | |||||
opdef_obj.__setstate__(opdef_state) | |||||
setattr(self, "opdef", opdef_obj) | |||||
@classmethod | @classmethod | ||||
def apply_module_trace_hook(cls, opdef, *inputs): | def apply_module_trace_hook(cls, opdef, *inputs): | ||||
@@ -458,12 +645,24 @@ class CallFunction(Expr): | |||||
def interpret(self, *inputs): | def interpret(self, *inputs): | ||||
args, kwargs = self.unflatten_args(inputs) | args, kwargs = self.unflatten_args(inputs) | ||||
outputs = self.func(*args, **kwargs) | |||||
func = ( | |||||
self.func | |||||
if not is_tracing_module() | |||||
else active_module_tracer().patcher.wrap_fn(self.func) | |||||
) | |||||
outputs = func(*args, **kwargs) | |||||
if outputs is None: | if outputs is None: | ||||
return outputs | return outputs | ||||
outputs, _ = tree_flatten(outputs, is_leaf=lambda x: isinstance(x, RawTensor)) | outputs, _ = tree_flatten(outputs, is_leaf=lambda x: isinstance(x, RawTensor)) | ||||
return outputs | return outputs | ||||
def _get_func(self): | |||||
return self.func | |||||
@property | |||||
def _support_set_args_kwargs(self): | |||||
return True | |||||
def __repr__(self): | def __repr__(self): | ||||
args = ", ".join(str(i) for i in self.args) | args = ", ".join(str(i) for i in self.args) | ||||
kwargs = ", ".join("{}={}".format(k, v) for k, v in self.kwargs.items()) | kwargs = ", ".join("{}={}".format(k, v) for k, v in self.kwargs.items()) | ||||
@@ -477,6 +676,33 @@ class CallFunction(Expr): | |||||
", ".join([args, kwargs]), | ", ".join([args, kwargs]), | ||||
) | ) | ||||
def __getstate__(self): | |||||
state = { | |||||
"_id": self._id, | |||||
"_disable_remove": self._disable_remove, | |||||
"func": (self.func.__module__, self.func.__qualname__), | |||||
"const_val": self.const_val, | |||||
"inputs": self.inputs, | |||||
"arg_def": self.arg_def, | |||||
"out_def": self.out_def, | |||||
"outputs": self.outputs, | |||||
"version": __version__, | |||||
} | |||||
_check_obj_attr(state) | |||||
return state | |||||
def __setstate__(self, state): | |||||
self.__dict__.update(state) | |||||
try: | |||||
if isinstance(self.func, tuple): | |||||
mname, fname = self.func | |||||
f = import_module(mname) | |||||
for i in fname.split("."): | |||||
f = getattr(f, i) | |||||
self.func = f | |||||
except Exception: | |||||
pass | |||||
# expr outputs = self.value | # expr outputs = self.value | ||||
class Constant(Expr): | class Constant(Expr): | ||||
@@ -496,6 +722,13 @@ class Constant(Expr): | |||||
assert isinstance(c, (RawTensor, Module)) | assert isinstance(c, (RawTensor, Module)) | ||||
if isinstance(c, Module): | if isinstance(c, Module): | ||||
assert module_tracer.is_builtin(c) or c.is_qat | assert module_tracer.is_builtin(c) or c.is_qat | ||||
if isinstance(c, RawTensor): | |||||
if is_tracing_module(): | |||||
unset_module_tracing() | |||||
c = Tensor(c) | |||||
set_module_tracing() | |||||
else: | |||||
c = Tensor(c) | |||||
self.value = c | self.value = c | ||||
self.name = name | self.name = name | ||||
self.inputs = [] | self.inputs = [] | ||||
@@ -530,9 +763,25 @@ class Constant(Expr): | |||||
) | ) | ||||
def __getstate__(self): | def __getstate__(self): | ||||
state = self.__dict__.copy() | |||||
if "_top_graph" in state: | |||||
state.pop("_top_graph") | |||||
state = { | |||||
"_id": self._id, | |||||
"_disable_remove": self._disable_remove, | |||||
"value": self.value, | |||||
"name": self.name, | |||||
"inputs": self.inputs, | |||||
"outputs": self.outputs, | |||||
} | |||||
_check_obj_attr(state) | |||||
if isinstance(self.value, RawTensor): | if isinstance(self.value, RawTensor): | ||||
state["value"] = Tensor(self.value) | state["value"] = Tensor(self.value) | ||||
if isinstance(self.value, Module) and module_tracer.is_builtin(self.value): | |||||
_check_builtin_module_attr(self.value) | |||||
state["value"] = _ModuleState.get_module_state(self.value) | |||||
return state | return state | ||||
def __setstate__(self, state): | |||||
for k, v in state.items(): | |||||
if isinstance(v, _ModuleState): | |||||
state[k] = v.to_module() | |||||
self.__dict__.update(state) |
@@ -72,7 +72,6 @@ BUILTIN_ARRAY_METHOD = [ | |||||
"astype", | "astype", | ||||
"reshape", | "reshape", | ||||
"_broadcast", | "_broadcast", | ||||
"transpose", | |||||
"flatten", | "flatten", | ||||
"sum", | "sum", | ||||
"prod", | "prod", | ||||
@@ -6,7 +6,9 @@ | |||||
# software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
import abc | import abc | ||||
import copy | |||||
import weakref | import weakref | ||||
from importlib import import_module | |||||
from typing import Any, Dict, List, Tuple, Type | from typing import Any, Dict, List, Tuple, Type | ||||
import numpy | import numpy | ||||
@@ -14,7 +16,9 @@ import numpy | |||||
from .. import get_logger | from .. import get_logger | ||||
from ..core._imperative_rt.core2 import Tensor as RawTensor | from ..core._imperative_rt.core2 import Tensor as RawTensor | ||||
from ..module import Module | from ..module import Module | ||||
from ..quantization.utils import QParams | |||||
from ..tensor import Tensor | from ..tensor import Tensor | ||||
from .utils import _check_obj_attr | |||||
logger = get_logger(__name__) | logger = get_logger(__name__) | ||||
@@ -145,6 +149,23 @@ class Node: | |||||
assert isinstance(id, int) | assert isinstance(id, int) | ||||
cls.__total_id = id | cls.__total_id = id | ||||
def __copy__(self): | |||||
cls = self.__class__ | |||||
result = cls.__new__(cls) | |||||
result.__dict__.update(self.__dict__) | |||||
return result | |||||
def __deepcopy__(self, memo): | |||||
cls = self.__class__ | |||||
result = cls.__new__(cls) | |||||
state = {} | |||||
memo[id(self)] = result | |||||
for k, v in self.__dict__.items(): | |||||
if not isinstance(v, weakref.ReferenceType) and k != "actual_node": | |||||
state[k] = copy.deepcopy(v, memo) | |||||
result.__dict__.update(state) | |||||
return result | |||||
class ModuleNode(Node): | class ModuleNode(Node): | ||||
r"""``ModuleNode`` represents the Module objects.""" | r"""``ModuleNode`` represents the Module objects.""" | ||||
@@ -157,19 +178,28 @@ class ModuleNode(Node): | |||||
super().__init__(expr, name, qualname) | super().__init__(expr, name, qualname) | ||||
def __getstate__(self): | def __getstate__(self): | ||||
return { | |||||
state = { | |||||
"expr": self.expr, | "expr": self.expr, | ||||
"users": self.users, | "users": self.users, | ||||
"_id": self._id, | "_id": self._id, | ||||
"_name": self._name, | "_name": self._name, | ||||
"_qualname": self._qualname, | "_qualname": self._qualname, | ||||
"module_type": self.module_type, | |||||
"module_type": (self.module_type.__module__, self.module_type.__qualname__), | |||||
} | } | ||||
_check_obj_attr(state) | |||||
return state | |||||
def __setstate__(self, state): | def __setstate__(self, state): | ||||
if "_orig_name" in state: | if "_orig_name" in state: | ||||
state["_qualname"] = state.pop("_orig_name") | state["_qualname"] = state.pop("_orig_name") | ||||
self.__dict__.update(state) | self.__dict__.update(state) | ||||
try: | |||||
if isinstance(self.module_type, tuple): | |||||
mname, classname = self.module_type | |||||
mtype = getattr(import_module(mname), classname) | |||||
self.module_type = mtype | |||||
except Exception: | |||||
pass | |||||
@property | @property | ||||
def owner(self): | def owner(self): | ||||
@@ -185,12 +215,26 @@ class TensorNode(Node): | |||||
_shape = None # type: Tuple[int] | _shape = None # type: Tuple[int] | ||||
_dtype = None # type: numpy.dtype | _dtype = None # type: numpy.dtype | ||||
_qparams = None | |||||
_qparams = None # type: QParams | |||||
_device = None | _device = None | ||||
_value = None # type: Tensor | _value = None # type: Tensor | ||||
def __init__( | |||||
self, | |||||
expr: "Expr", | |||||
name: str = None, | |||||
qualname: str = None, | |||||
shape: Tuple[int] = None, | |||||
dtype: numpy.dtype = None, | |||||
qparams: QParams = None, | |||||
): | |||||
super().__init__(expr, name, qualname) | |||||
self._shape = shape | |||||
self._dtype = shape | |||||
self._qparams = qparams | |||||
def __getstate__(self): | def __getstate__(self): | ||||
return { | |||||
state = { | |||||
"expr": self.expr, | "expr": self.expr, | ||||
"users": self.users, | "users": self.users, | ||||
"_id": self._id, | "_id": self._id, | ||||
@@ -201,6 +245,8 @@ class TensorNode(Node): | |||||
"_name": self._name, | "_name": self._name, | ||||
"_qualname": self._qualname, | "_qualname": self._qualname, | ||||
} | } | ||||
_check_obj_attr(state) | |||||
return state | |||||
def __setstate__(self, state): | def __setstate__(self, state): | ||||
if "_orig_name" in state: | if "_orig_name" in state: | ||||
@@ -276,7 +322,10 @@ class NodeMixin(abc.ABC): | |||||
assert isinstance(node, TensorNode) | assert isinstance(node, TensorNode) | ||||
assert isinstance(value, RawTensor) | assert isinstance(value, RawTensor) | ||||
if isinstance(value, RawTensor): | if isinstance(value, RawTensor): | ||||
node._dtype = value.dtype | |||||
try: | |||||
node._dtype = value.dtype | |||||
except RuntimeError: | |||||
node._dtype = None | |||||
node._shape = ( | node._shape = ( | ||||
value._tuple_shape if isinstance(value, Tensor) else value.shape | value._tuple_shape if isinstance(value, Tensor) else value.shape | ||||
) | ) | ||||
@@ -7,15 +7,18 @@ | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
import collections | import collections | ||||
from collections import OrderedDict | |||||
from collections import OrderedDict, defaultdict | |||||
from functools import partial | |||||
from typing import Callable, NamedTuple | from typing import Callable, NamedTuple | ||||
import numpy as np | import numpy as np | ||||
from ..core._imperative_rt import OpDef | |||||
from ..core._imperative_rt.common import CompNode | from ..core._imperative_rt.common import CompNode | ||||
from ..core._imperative_rt.core2 import Tensor as RawTensor | from ..core._imperative_rt.core2 import Tensor as RawTensor | ||||
from ..core._wrap import Device | from ..core._wrap import Device | ||||
from ..core.tensor.dtype import QuantDtypeMeta | from ..core.tensor.dtype import QuantDtypeMeta | ||||
from ..distributed import Group | |||||
from ..module import Module | from ..module import Module | ||||
from ..quantization.utils import LSQParams, QParams, QuantMode | from ..quantization.utils import LSQParams, QParams, QuantMode | ||||
from ..tensor import Parameter, Tensor | from ..tensor import Parameter, Tensor | ||||
@@ -49,45 +52,54 @@ SUPPORTED_LEAF_TYPE = { | |||||
type(Ellipsis), | type(Ellipsis), | ||||
QuantMode, | QuantMode, | ||||
ArgsIndex, | ArgsIndex, | ||||
Group, | |||||
} | } | ||||
USER_REGISTERED_LEAF_TYPE = [] | |||||
USER_REGISTERED_CONTAINER_TYPE = [] | |||||
# if isinstance(object, SUPPORTED_LEAF_CLS) or issubclass(obj, SUPPORTED_LEAF_CLS) is True, the object could be threated as leaf node of pytree | # if isinstance(object, SUPPORTED_LEAF_CLS) or issubclass(obj, SUPPORTED_LEAF_CLS) is True, the object could be threated as leaf node of pytree | ||||
SUPPORTED_LEAF_CLS = [Module, Node, NodeMixin, np.dtype, np.ndarray, np.number] | |||||
SUPPORTED_LEAF_CLS = [ | |||||
Module, | |||||
Node, | |||||
NodeMixin, | |||||
np.dtype, | |||||
np.ndarray, | |||||
np.number, | |||||
np.bool_, | |||||
OpDef, | |||||
] | |||||
NodeType = NamedTuple("NodeType", [("flatten", Callable), ("unflatten", Callable)]) | NodeType = NamedTuple("NodeType", [("flatten", Callable), ("unflatten", Callable)]) | ||||
def register_supported_type(type, flatten=None, unflatten=None): | def register_supported_type(type, flatten=None, unflatten=None): | ||||
tp_info = (type.__module__, type.__qualname__) | |||||
if flatten and unflatten: | if flatten and unflatten: | ||||
SUPPORTED_TYPE[type] = NodeType(flatten, unflatten) | |||||
USER_REGISTERED_CONTAINER_TYPE.append(tp_info) | |||||
else: | else: | ||||
SUPPORTED_LEAF_CLS.append(type) | |||||
def _dict_flatten(inp): | |||||
aux_data = [] | |||||
results = [] | |||||
for key, value in sorted(inp.items()): | |||||
results.append(value) | |||||
aux_data.append(key) | |||||
return results, tuple(aux_data) | |||||
USER_REGISTERED_LEAF_TYPE.append(tp_info) | |||||
_register_supported_type(type, flatten, unflatten) | |||||
def _dict_unflatten(inps, aux_data): | |||||
return dict(zip(aux_data, inps)) | |||||
def _register_supported_type(type, flatten=None, unflatten=None): | |||||
if flatten and unflatten: | |||||
SUPPORTED_TYPE[type] = NodeType(flatten, unflatten) | |||||
else: | |||||
SUPPORTED_LEAF_CLS.append(type) | |||||
def _ordereddict_flatten(inp): | |||||
def _dict_flatten(ordered, inp): | |||||
aux_data = [] | aux_data = [] | ||||
results = [] | results = [] | ||||
for key, value in inp.items(): | |||||
dict_items = inp.items() if ordered else sorted(inp.items()) | |||||
for key, value in dict_items: | |||||
results.append(value) | results.append(value) | ||||
aux_data.append(key) | aux_data.append(key) | ||||
return results, tuple(aux_data) | return results, tuple(aux_data) | ||||
def _ordereddict_unflatten(inps, aux_data): | |||||
return OrderedDict(zip(aux_data, inps)) | |||||
def _dict_unflatten(dict_type, inps, aux_data): | |||||
return dict_type(zip(aux_data, inps)) | |||||
def qparams_flatten(inp): | def qparams_flatten(inp): | ||||
@@ -99,33 +111,41 @@ def qparams_flatten(inp): | |||||
return results, tuple(aux_data) | return results, tuple(aux_data) | ||||
def qparams_unflatten(inp, aux_data): | |||||
obj = QParams.__new__(QParams) | |||||
def qparams_unflatten(qparam_type, inp, aux_data): | |||||
obj = qparam_type.__new__(qparam_type) | |||||
for k, v in zip(aux_data, inp): | for k, v in zip(aux_data, inp): | ||||
setattr(obj, k, v) | setattr(obj, k, v) | ||||
return obj | return obj | ||||
register_supported_type(list, lambda x: (x, None), lambda x, aux_data: list(x)) | |||||
register_supported_type(tuple, lambda x: (x, None), lambda x, aux_data: tuple(x)) | |||||
register_supported_type(dict, _dict_flatten, _dict_unflatten) | |||||
register_supported_type( | |||||
collections.OrderedDict, _ordereddict_flatten, _ordereddict_unflatten | |||||
_register_supported_type(list, lambda x: (x, None), lambda x, aux_data: list(x)) | |||||
_register_supported_type(tuple, lambda x: (x, None), lambda x, aux_data: tuple(x)) | |||||
_register_supported_type( | |||||
dict, partial(_dict_flatten, False), partial(_dict_unflatten, dict) | |||||
) | |||||
_register_supported_type( | |||||
defaultdict, partial(_dict_flatten, False), partial(_dict_unflatten, defaultdict) | |||||
) | ) | ||||
register_supported_type( | |||||
_register_supported_type( | |||||
OrderedDict, partial(_dict_flatten, True), partial(_dict_unflatten, OrderedDict) | |||||
) | |||||
_register_supported_type( | |||||
slice, | slice, | ||||
lambda x: ([x.start, x.stop, x.step], None), | lambda x: ([x.start, x.stop, x.step], None), | ||||
lambda x, aux_data: slice(x[0], x[1], x[2]), | lambda x, aux_data: slice(x[0], x[1], x[2]), | ||||
) | ) | ||||
register_supported_type(QParams, qparams_flatten, qparams_unflatten) | |||||
_register_supported_type(QParams, qparams_flatten, partial(qparams_unflatten, QParams)) | |||||
_register_supported_type( | |||||
LSQParams, qparams_flatten, partial(qparams_unflatten, LSQParams) | |||||
) | |||||
def _is_leaf(obj): | def _is_leaf(obj): | ||||
if isinstance(obj, type): | |||||
return issubclass(obj, tuple(SUPPORTED_LEAF_CLS)) or obj in SUPPORTED_LEAF_TYPE | |||||
obj_type = obj if isinstance(obj, type) else type(obj) | |||||
return ( | return ( | ||||
isinstance(obj, tuple(SUPPORTED_LEAF_CLS)) or type(obj) in SUPPORTED_LEAF_TYPE | |||||
issubclass(obj_type, tuple(SUPPORTED_LEAF_CLS)) | |||||
or obj_type in SUPPORTED_LEAF_TYPE | |||||
) | ) | ||||
@@ -5,30 +5,158 @@ | |||||
# Unless required by applicable law or agreed to in writing, | # Unless required by applicable law or agreed to in writing, | ||||
# software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
from typing import Dict | |||||
from importlib import import_module | |||||
from typing import Dict, Tuple | |||||
from ..core._imperative_rt import OpDef | from ..core._imperative_rt import OpDef | ||||
from ..core.ops import builtin | from ..core.ops import builtin | ||||
from ..tensor import Tensor | |||||
from ..version import __version__ | from ..version import __version__ | ||||
from .utils import _convert_kwargs_to_args | |||||
OPDEF_PARAM_LOADER = {} | |||||
OPDEF_LOADER = {} | |||||
FUNCTIONAL_LOADER = {} | |||||
TENSORMETHOD_LOADER = {} | |||||
MODULE_LOADER = {} | |||||
def get_opdef_state(obj: OpDef) -> Dict: | |||||
state = obj.__getstate__() | |||||
state["type"] = type(obj) | |||||
state["version"] = __version__ | |||||
return state | |||||
class _ModuleState: | |||||
obj = None | |||||
def __init__(self, module: Tuple, state: Dict, version: str): | |||||
self.module = module | |||||
self.state = state | |||||
self.version = version | |||||
def load_opdef_from_state(state: Dict) -> OpDef: | |||||
assert "type" in state and issubclass(state["type"], OpDef) | |||||
assert "version" in state | |||||
opdef_type = state.pop("type") | |||||
if opdef_type in OPDEF_PARAM_LOADER: | |||||
loader = OPDEF_PARAM_LOADER[opdef_type] | |||||
state = loader(state) | |||||
state.pop("version") | |||||
opdef_obj = opdef_type() | |||||
opdef_obj.__setstate__(state) | |||||
return opdef_obj | |||||
@classmethod | |||||
def get_module_state(cls, module): | |||||
typem = (type(module).__module__, type(module).__qualname__) | |||||
state = module.__dict__.copy() | |||||
state.pop("_m_dump_modulestate", None) | |||||
if hasattr(module, "_m_dump_modulestate"): | |||||
assert isinstance(module._m_dump_modulestate, cls) | |||||
module._m_dump_modulestate.__init__(typem, state, __version__) | |||||
else: | |||||
module.__dict__["_m_dump_modulestate"] = _ModuleState( | |||||
typem, state, __version__ | |||||
) | |||||
return module._m_dump_modulestate | |||||
def __getstate__(self): | |||||
return {"module": self.module, "state": self.state, "version": self.version} | |||||
def to_module(self): | |||||
if self.obj is None: | |||||
typem = getattr(import_module(self.module[0]), self.module[1]) | |||||
m_obj = typem.__new__(typem) | |||||
m_obj.__dict__.update(self.state) | |||||
self.obj = m_obj | |||||
return self.obj | |||||
def register_opdef_loader(*opdefs): | |||||
def callback(loader): | |||||
for opdef in opdefs: | |||||
assert opdef not in OPDEF_LOADER | |||||
OPDEF_LOADER[opdef] = loader | |||||
return loader | |||||
return callback | |||||
def register_functional_loader(*funcs): | |||||
def callback(loader): | |||||
for func in funcs: | |||||
assert func not in FUNCTIONAL_LOADER | |||||
FUNCTIONAL_LOADER[func] = loader | |||||
return loader | |||||
return callback | |||||
def register_module_loader(*module_types): | |||||
def callback(loader): | |||||
for module_type in module_types: | |||||
assert module_type not in MODULE_LOADER | |||||
MODULE_LOADER[module_type] = loader | |||||
return loader | |||||
return callback | |||||
def register_tensor_method_loader(*methods): | |||||
def callback(loader): | |||||
for method in methods: | |||||
assert method not in TENSORMETHOD_LOADER | |||||
TENSORMETHOD_LOADER[method] = loader | |||||
return loader | |||||
return callback | |||||
def _replace_args_kwargs(expr, new_args, new_kwargs): | |||||
if len(new_args) != len(expr.args) or set(new_kwargs.keys()) != set( | |||||
expr.kwargs.keys() | |||||
): | |||||
expr.set_args_kwargs(*new_args, **new_kwargs) | |||||
def load_functional(expr): | |||||
func = ( | |||||
(expr.func.__module__, expr.func.__qualname__) | |||||
if callable(expr.func) | |||||
else expr.func | |||||
) | |||||
assert isinstance(func, tuple) | |||||
if func in FUNCTIONAL_LOADER: | |||||
loader = FUNCTIONAL_LOADER[func] | |||||
loader(expr) | |||||
mname, fname = func | |||||
f = import_module(mname) | |||||
for i in fname.split("."): | |||||
f = getattr(f, i) | |||||
expr.func = f | |||||
assert callable(expr.func) | |||||
if not hasattr(expr, "version") or expr.version != __version__: | |||||
args, kwargs = _convert_kwargs_to_args(expr.func, expr.args, expr.kwargs) | |||||
_replace_args_kwargs(expr, args, kwargs) | |||||
def load_call_module_expr(expr): | |||||
m_type = expr.inputs[0].module_type | |||||
if isinstance(m_type, type): | |||||
m_type = (m_type.__module__, m_type.__qualname__) | |||||
if m_type in MODULE_LOADER: | |||||
MODULE_LOADER[m_type](expr) | |||||
if isinstance(expr.inputs[0].module_type, tuple): | |||||
mname, classname = expr.inputs[0].module_type | |||||
expr.inputs[0].module_type = getattr(import_module(mname), classname) | |||||
if not hasattr(expr, "version") or expr.version != __version__: | |||||
fwd_func = getattr(expr.inputs[0].module_type, "forward") | |||||
args, kwargs = _convert_kwargs_to_args(fwd_func, expr.args, expr.kwargs) | |||||
_replace_args_kwargs(expr, args, kwargs) | |||||
def load_call_tensor_method_expr(expr): | |||||
if expr.method in TENSORMETHOD_LOADER: | |||||
loader = TENSORMETHOD_LOADER[expr.method] | |||||
loader(expr) | |||||
if not hasattr(expr, "version") or expr.version != __version__: | |||||
tmethod = ( | |||||
getattr(expr.args[0], expr.method) | |||||
if isinstance(expr.args[0], type) | |||||
else getattr(Tensor, expr.method) | |||||
) | |||||
args, kwargs = _convert_kwargs_to_args(tmethod, expr.args, expr.kwargs) | |||||
_replace_args_kwargs(expr, args, kwargs) | |||||
def load_apply_expr(expr): | |||||
opdef_type = type(expr.opdef) | |||||
if opdef_type in OPDEF_LOADER: | |||||
OPDEF_LOADER[opdef_type](expr) | |||||
opdef_state = expr.opdef_state | |||||
opdef_obj = opdef_state.pop("opdef_type")() | |||||
opdef_obj.__setstate__(opdef_state) | |||||
expr.opdef = opdef_obj |
@@ -14,6 +14,7 @@ import inspect | |||||
import keyword | import keyword | ||||
import re | import re | ||||
import weakref | import weakref | ||||
from importlib import import_module | |||||
from inspect import getcallargs, getmembers, isclass, ismethod | from inspect import getcallargs, getmembers, isclass, ismethod | ||||
from itertools import chain | from itertools import chain | ||||
from types import FunctionType | from types import FunctionType | ||||
@@ -53,6 +54,7 @@ from ..quantization.observer import ( | |||||
SyncMinMaxObserver, | SyncMinMaxObserver, | ||||
) | ) | ||||
from ..tensor import Tensor | from ..tensor import Tensor | ||||
from ..version import __version__ | |||||
from .expr import ( | from .expr import ( | ||||
Apply, | Apply, | ||||
CallFunction, | CallFunction, | ||||
@@ -80,8 +82,27 @@ from .module_tracer import ( | |||||
set_active_module_tracer, | set_active_module_tracer, | ||||
) | ) | ||||
from .node import ModuleNode, Node, NodeMixin, TensorNode | from .node import ModuleNode, Node, NodeMixin, TensorNode | ||||
from .pytree import ArgsIndex, tree_flatten | |||||
from .utils import replace_container_with_module_container | |||||
from .pytree import ( | |||||
USER_REGISTERED_CONTAINER_TYPE, | |||||
USER_REGISTERED_LEAF_TYPE, | |||||
ArgsIndex, | |||||
TreeDef, | |||||
_register_supported_type, | |||||
tree_flatten, | |||||
) | |||||
from .serialization import ( | |||||
_ModuleState, | |||||
load_apply_expr, | |||||
load_call_module_expr, | |||||
load_call_tensor_method_expr, | |||||
load_functional, | |||||
) | |||||
from .utils import ( | |||||
_check_builtin_module_attr, | |||||
_check_obj_attr, | |||||
_convert_kwargs_to_args, | |||||
replace_container_with_module_container, | |||||
) | |||||
logger = get_logger(__name__) | logger = get_logger(__name__) | ||||
@@ -341,7 +362,7 @@ class NameSpace: | |||||
def create_unique_name(self, name: str, node: Any = None) -> str: | def create_unique_name(self, name: str, node: Any = None) -> str: | ||||
assert isinstance(name, str), "The name must be a string" | assert isinstance(name, str), "The name must be a string" | ||||
if name in self._used_names and self._used_names[name] is node: | |||||
if name in self._used_names and (self._used_names[name] is node): | |||||
return name | return name | ||||
name = re.sub("[^0-9a-zA-Z_]+", "_", name) | name = re.sub("[^0-9a-zA-Z_]+", "_", name) | ||||
@@ -1067,6 +1088,7 @@ class InternalGraph: | |||||
if node2value[n][1] == 0: | if node2value[n][1] == 0: | ||||
node2value.pop(n) | node2value.pop(n) | ||||
if values is not None: | if values is not None: | ||||
assert len(values) == len(expr.outputs) | |||||
for n, v in zip(expr.outputs, values): | for n, v in zip(expr.outputs, values): | ||||
if ref_count(n) > 0: | if ref_count(n) > 0: | ||||
node2value[n] = [v, ref_count(n)] | node2value[n] = [v, ref_count(n)] | ||||
@@ -1105,13 +1127,27 @@ class InternalGraph: | |||||
return res | return res | ||||
def __getstate__(self): | def __getstate__(self): | ||||
state = self.__dict__.copy() | |||||
if "_top_graph" in state: | |||||
state.pop("_top_graph") | |||||
state = { | |||||
"_exprs": self._exprs, | |||||
"_inputs": self._inputs, | |||||
"_outputs": self._outputs, | |||||
"_watch_point": [], | |||||
"_end_point": [], | |||||
"_namespace": self._namespace, | |||||
"_rst": collections.defaultdict(list), | |||||
"_name": self._name, | |||||
"_qualname": self._qualname, | |||||
} | |||||
if self._total_ids: | |||||
state["_total_ids"] = self._total_ids | |||||
_check_obj_attr(state) | |||||
return state | return state | ||||
def __setstate__(self, state): | def __setstate__(self, state): | ||||
old_version = False | old_version = False | ||||
if "_module_name" in state: | if "_module_name" in state: | ||||
old_version = True | old_version = True | ||||
state["_qualname"] = state.pop("_module_name") | state["_qualname"] = state.pop("_module_name") | ||||
@@ -1144,6 +1180,25 @@ class InternalGraph: | |||||
self._namespace = NameSpace(self._name, self._qualname) | self._namespace = NameSpace(self._name, self._qualname) | ||||
self._re_associate_name() | self._re_associate_name() | ||||
def __copy__(self): | |||||
cls = self.__class__ | |||||
result = cls.__new__(cls) | |||||
result.__dict__.update(self.__dict__) | |||||
return result | |||||
def __deepcopy__(self, memo): | |||||
if id(self) in memo: | |||||
return memo[id(self)] | |||||
cls = self.__class__ | |||||
result = cls.__new__(cls) | |||||
state = {} | |||||
memo[id(self)] = result | |||||
for k, v in self.__dict__.items(): | |||||
if not isinstance(v, weakref.ReferenceType): | |||||
state[k] = copy.deepcopy(v, memo) | |||||
result.__dict__.update(state) | |||||
return result | |||||
def _get_meth_name(obj, func): | def _get_meth_name(obj, func): | ||||
tp = obj if isinstance(obj, type) else type(obj) | tp = obj if isinstance(obj, type) else type(obj) | ||||
@@ -1157,9 +1212,7 @@ def _get_meth_name(obj, func): | |||||
def _wrapped_function(orig_func): | def _wrapped_function(orig_func): | ||||
@functools.wraps(orig_func) | @functools.wraps(orig_func) | ||||
def wrapped_fn(*args, **kwargs): | def wrapped_fn(*args, **kwargs): | ||||
method_func = wrapped_fn | |||||
if "method_func" in kwargs: | |||||
method_func = kwargs.pop("method_func") | |||||
method_func = kwargs.pop("method_func", wrapped_fn) | |||||
if is_tracing_module(): | if is_tracing_module(): | ||||
unset_module_tracing() | unset_module_tracing() | ||||
inputs, tree_def = tree_flatten((args, kwargs)) | inputs, tree_def = tree_flatten((args, kwargs)) | ||||
@@ -1167,11 +1220,11 @@ def _wrapped_function(orig_func): | |||||
if not NodeMixin.get(i, None): | if not NodeMixin.get(i, None): | ||||
if isinstance(i, (RawTensor, NodeMixin)): | if isinstance(i, (RawTensor, NodeMixin)): | ||||
NodeMixin.wrap_safe(i, Constant.make(i)) | NodeMixin.wrap_safe(i, Constant.make(i)) | ||||
meth_name, arg_type = None, None | |||||
if args: | |||||
meth_name = _get_meth_name(args[0], method_func) | |||||
arg_type = args[0] if isinstance(args[0], type) else type(args[0]) | |||||
args, kwargs = _convert_kwargs_to_args(orig_func, args, kwargs) | |||||
meth_name = _get_meth_name(args[0], method_func) | |||||
arg_type = args[0] if isinstance(args[0], type) else type(args[0]) | |||||
if meth_name and arg_type and issubclass(arg_type, RawTensor): | if meth_name and arg_type and issubclass(arg_type, RawTensor): | ||||
inputs, tree_def = tree_flatten((args, kwargs)) | |||||
self = inputs[0] | self = inputs[0] | ||||
if meth_name == "__new__": | if meth_name == "__new__": | ||||
if all([not isinstance(i, RawTensor) for i in inputs]): | if all([not isinstance(i, RawTensor) for i in inputs]): | ||||
@@ -1190,6 +1243,7 @@ def _wrapped_function(orig_func): | |||||
call_node = CallMethod.make(NodeMixin.get(self), meth_name) | call_node = CallMethod.make(NodeMixin.get(self), meth_name) | ||||
call_node.add_inputs(inputs[1:]) | call_node.add_inputs(inputs[1:]) | ||||
else: | else: | ||||
inputs, tree_def = tree_flatten((args, kwargs)) | |||||
call_node = CallFunction.make(orig_func) | call_node = CallFunction.make(orig_func) | ||||
call_node.add_inputs(inputs) | call_node.add_inputs(inputs) | ||||
@@ -1228,9 +1282,11 @@ class TracedModuleBuilder(NodeMixin): | |||||
"_record_wrapped_nodes", | "_record_wrapped_nodes", | ||||
"_argdef_graph_map", | "_argdef_graph_map", | ||||
"_argdef_outdef_map", | "_argdef_outdef_map", | ||||
"_check_qat_module", | |||||
"nodes", | "nodes", | ||||
"__class__", | "__class__", | ||||
"__dict__", | "__dict__", | ||||
"_is_top", | |||||
] | ] | ||||
def __init__(self, mod, is_top_module=False): | def __init__(self, mod, is_top_module=False): | ||||
@@ -1301,22 +1357,18 @@ class TracedModuleBuilder(NodeMixin): | |||||
qat_module.weight_fake_quant.set_qparams(qparams) | qat_module.weight_fake_quant.set_qparams(qparams) | ||||
def build(self): | def build(self): | ||||
if self._is_builtin or isinstance(self._mod, TracedModule): | |||||
if module_tracer.is_builtin(self._mod) or isinstance( | |||||
self._mod, TracedModule | |||||
): | |||||
mod_type = type(self._mod) | |||||
else: | |||||
assert isinstance(self._mod, (Observer, _FakeQuantize)) | |||||
mod_type = ( | |||||
Observer if isinstance(self._mod, Observer) else _FakeQuantize | |||||
) | |||||
if self._is_builtin: | |||||
assert module_tracer.is_builtin(self._mod) | |||||
mod_type = type(self._mod) | |||||
for node in self.nodes: | for node in self.nodes: | ||||
node.module_type = mod_type | node.module_type = mod_type | ||||
return self._mod | return self._mod | ||||
else: | else: | ||||
is_qat = isinstance(self._mod, QATModule) | |||||
is_qat = isinstance(self._mod, QATModule) or ( | |||||
isinstance(self._mod, TracedModule) and self._mod.is_qat | |||||
) | |||||
traced_module = TracedModule( | traced_module = TracedModule( | ||||
self._is_top, self._argdef_graph_map, self._argdef_outdef_map, is_qat | self._is_top, self._argdef_graph_map, self._argdef_outdef_map, is_qat | ||||
) | ) | ||||
@@ -1338,15 +1390,18 @@ class TracedModuleBuilder(NodeMixin): | |||||
traced_module.with_act = self._mod.with_act | traced_module.with_act = self._mod.with_act | ||||
traced_module.with_weight = self._mod.with_weight | traced_module.with_weight = self._mod.with_weight | ||||
if not hasattr(traced_module, "act_fake_quant"): | if not hasattr(traced_module, "act_fake_quant"): | ||||
traced_module.act_fakequant = None | |||||
traced_module.act_fake_quant = None | |||||
if not hasattr(traced_module, "act_observer"): | if not hasattr(traced_module, "act_observer"): | ||||
traced_module.act_observer = None | traced_module.act_observer = None | ||||
if not hasattr(traced_module, "weight_fake_quant"): | if not hasattr(traced_module, "weight_fake_quant"): | ||||
traced_module.weight_fakequant = None | |||||
traced_module.weight_fake_quant = None | |||||
if not hasattr(traced_module, "weight_observer"): | if not hasattr(traced_module, "weight_observer"): | ||||
traced_module.weight_observer = None | traced_module.weight_observer = None | ||||
set_module_tracing() | set_module_tracing() | ||||
if self._is_top: | |||||
traced_module._update_ref() | |||||
return traced_module | return traced_module | ||||
def _record_wrapped_nodes(self, node): | def _record_wrapped_nodes(self, node): | ||||
@@ -1357,6 +1412,7 @@ class TracedModuleBuilder(NodeMixin): | |||||
# prepare args and kwargs for inner graph | # prepare args and kwargs for inner graph | ||||
if "method_func" in kwargs: | if "method_func" in kwargs: | ||||
kwargs.pop("method_func") | kwargs.pop("method_func") | ||||
args, kwargs = _convert_kwargs_to_args(self._mod.forward, args, kwargs, True) | |||||
def mark_constant(x): | def mark_constant(x): | ||||
node = NodeMixin.get(x, None) | node = NodeMixin.get(x, None) | ||||
@@ -1372,11 +1428,7 @@ class TracedModuleBuilder(NodeMixin): | |||||
callnode.arg_def = tree_def | callnode.arg_def = tree_def | ||||
if ( | |||||
self._is_builtin | |||||
or tree_def in self._argdef_graph_map | |||||
or isinstance(self._mod, TracedModule) | |||||
): | |||||
if self._is_builtin or tree_def in self._argdef_graph_map: | |||||
unset_module_tracing() | unset_module_tracing() | ||||
rst = self._mod(*args, **kwargs) | rst = self._mod(*args, **kwargs) | ||||
outputs, out_def = tree_flatten(rst, is_leaf=_is_leaf) | outputs, out_def = tree_flatten(rst, is_leaf=_is_leaf) | ||||
@@ -1385,33 +1437,7 @@ class TracedModuleBuilder(NodeMixin): | |||||
self._body = None | self._body = None | ||||
elif tree_def in self._argdef_graph_map: | elif tree_def in self._argdef_graph_map: | ||||
self._body = self._argdef_graph_map[tree_def] | self._body = self._argdef_graph_map[tree_def] | ||||
else: | |||||
self._mod._is_top = False | |||||
self._body = self._mod.argdef_graph_map[tree_def] | |||||
module_qualname = NodeMixin.get(self).qualname | |||||
if module_qualname != self._body.qualname: | |||||
src_name, dst_name = self._body.qualname, module_qualname | |||||
def replace_qualname(g): | |||||
attr_name = get_suffix_name(src_name, g.qualname) | |||||
if attr_name is not None: | |||||
g._qualname = ( | |||||
("%s.%s" % (dst_name, attr_name)) | |||||
if attr_name | |||||
else dst_name | |||||
) | |||||
assert get_suffix_name(dst_name, g.qualname) is not None | |||||
for mod in self._mod.modules(): | |||||
if not hasattr(mod, "argdef_graph_map"): | |||||
continue | |||||
for g in mod.argdef_graph_map.values(): | |||||
replace_qualname(g) | |||||
g._namespace.qualname = g.qualname | |||||
for n in g.nodes(False): | |||||
replace_qualname(n) | |||||
else: | else: | ||||
self_node = None | |||||
orig_self = NodeMixin.get(self) | orig_self = NodeMixin.get(self) | ||||
parent_graph = active_module_tracer().current_scope() | parent_graph = active_module_tracer().current_scope() | ||||
module_qualname = orig_self._qualname | module_qualname = orig_self._qualname | ||||
@@ -1423,20 +1449,14 @@ class TracedModuleBuilder(NodeMixin): | |||||
active_module_tracer().push_scope(self._body) | active_module_tracer().push_scope(self._body) | ||||
# rebind self to new input node | # rebind self to new input node | ||||
if self_node: | |||||
NodeMixin.wrap_safe(self, self_node) | |||||
active_module_tracer().current_scope()._add_input(self_node) | |||||
else: | |||||
NodeMixin.wrap_safe( | |||||
self, | |||||
self_node | |||||
if self_node | |||||
else Input.make( | |||||
name="self", | |||||
qualname=module_qualname, | |||||
type=NodeMixin.get_wrapped_type(self), | |||||
), | |||||
) | |||||
NodeMixin.wrap_safe( | |||||
self, | |||||
Input.make( | |||||
name="self", | |||||
qualname=module_qualname, | |||||
type=NodeMixin.get_wrapped_type(self), | |||||
), | |||||
) | |||||
origin_inp_node = [NodeMixin.get(i, None) for i in inputs[1:]] | origin_inp_node = [NodeMixin.get(i, None) for i in inputs[1:]] | ||||
# prepare args and kwargs for inner graph | # prepare args and kwargs for inner graph | ||||
@@ -1470,8 +1490,23 @@ class TracedModuleBuilder(NodeMixin): | |||||
return x | return x | ||||
args = [self] | args = [self] | ||||
for i, v in enumerate(inputs[1:]): | |||||
args.append(wrap(v, idx2key[i + 1])) | |||||
orig_traced_inputs = ( | |||||
None | |||||
if not isinstance(self._mod, TracedModule) | |||||
else self._mod.argdef_graph_map[tree_def].inputs | |||||
) | |||||
ind = 1 | |||||
for v in inputs[1:]: | |||||
if isinstance(v, (RawTensor, NodeMixin)): | |||||
args_name = ( | |||||
orig_traced_inputs[ind]._name | |||||
if orig_traced_inputs | |||||
else idx2key[ind] | |||||
) | |||||
ind += 1 | |||||
args.append(wrap(v, args_name)) | |||||
else: | |||||
args.append(v) | |||||
args, kwargs = tree_def.unflatten(args) | args, kwargs = tree_def.unflatten(args) | ||||
active_module_tracer().patcher.auto_patch( | active_module_tracer().patcher.auto_patch( | ||||
@@ -1514,7 +1549,6 @@ class TracedModuleBuilder(NodeMixin): | |||||
attr = getattr(type(self._mod), name).__get__(self, type(self)) | attr = getattr(type(self._mod), name).__get__(self, type(self)) | ||||
else: | else: | ||||
attr = getattr(self._mod, name) | attr = getattr(self._mod, name) | ||||
if ( | if ( | ||||
isinstance(attr, FunctionType) | isinstance(attr, FunctionType) | ||||
and id(attr) in active_module_tracer().patcher.patched_fn_ids | and id(attr) in active_module_tracer().patcher.patched_fn_ids | ||||
@@ -1568,7 +1602,7 @@ class TracedModuleBuilder(NodeMixin): | |||||
wrapped = self.__getattr__(name) | wrapped = self.__getattr__(name) | ||||
if isinstance(wrapped, TracedModuleBuilder): | if isinstance(wrapped, TracedModuleBuilder): | ||||
if not isinstance(mod_attr, (List, Dict)): | |||||
if not isinstance(mod_attr, (List, Dict, QATModule)): | |||||
assert mod_attr is wrapped._mod | assert mod_attr is wrapped._mod | ||||
else: | else: | ||||
assert mod_attr is wrapped | assert mod_attr is wrapped | ||||
@@ -1977,8 +2011,6 @@ class TracedModule(Module): | |||||
def graph(self) -> InternalGraph: | def graph(self) -> InternalGraph: | ||||
"""Return the ``InternalGraph`` of this ``TracedModule``. | """Return the ``InternalGraph`` of this ``TracedModule``. | ||||
""" | """ | ||||
if self._is_top: | |||||
self._update_ref() | |||||
assert len(self.argdef_graph_map) == 1 | assert len(self.argdef_graph_map) == 1 | ||||
return list(self.argdef_graph_map.values())[0] | return list(self.argdef_graph_map.values())[0] | ||||
@@ -2112,7 +2144,7 @@ class TracedModule(Module): | |||||
if hasattr(obj, "argdef_graph_map") | if hasattr(obj, "argdef_graph_map") | ||||
else None | else None | ||||
) | ) | ||||
if expr_graph is not None: | |||||
if expr_graph is not None and not obj.is_qat: | |||||
exprs = _flatten_subgraph(graph, expr_graph, expr, obj) | exprs = _flatten_subgraph(graph, expr_graph, expr, obj) | ||||
if parent_graph is not None: | if parent_graph is not None: | ||||
@@ -2137,26 +2169,119 @@ class TracedModule(Module): | |||||
) | ) | ||||
new_module.graph._re_associate_name() | new_module.graph._re_associate_name() | ||||
new_module.graph.compile() | new_module.graph.compile() | ||||
new_module._update_ref() | |||||
new_module.graph._reset_ids() | new_module.graph._reset_ids() | ||||
return new_module | return new_module | ||||
def __getstate__(self): | def __getstate__(self): | ||||
d = self.__dict__ | |||||
d = self.__dict__.copy() | |||||
for k in Module.__dict__: | for k in Module.__dict__: | ||||
d.pop(k, None) | d.pop(k, None) | ||||
_check_obj_attr(d) | |||||
for k in d: | |||||
if module_tracer.is_builtin(d[k]): | |||||
assert _check_builtin_module_attr( | |||||
d[k] | |||||
), "Module {} can not be serialized. ".format(type(d[k])) | |||||
d[k] = _ModuleState.get_module_state(d[k]) | |||||
dump_info = { | |||||
"version": __version__, | |||||
"register_type": USER_REGISTERED_LEAF_TYPE, | |||||
"register_container_type": USER_REGISTERED_CONTAINER_TYPE, | |||||
"register_mdule": USER_REGISTERED_MODULE, | |||||
"register_function": USER_REGISTERED_FUNCTION, | |||||
} | |||||
d["dump_info"] = dump_info | |||||
return d | return d | ||||
def __setstate__(self, state): | |||||
for k, v in state.items(): | |||||
if isinstance(v, _ModuleState): | |||||
state[k] = v.to_module() | |||||
self.__dict__.update(state) | |||||
self._update_ref() | |||||
for _, graph in self.argdef_graph_map.items(): | |||||
for expr in graph._exprs: | |||||
if isinstance(expr, CallFunction): | |||||
load_functional(expr) | |||||
if isinstance(expr, CallMethod): | |||||
if expr.method == "__call__": | |||||
load_call_module_expr(expr) | |||||
else: | |||||
load_call_tensor_method_expr(expr) | |||||
if isinstance(expr, Apply): | |||||
load_apply_expr(expr) | |||||
for _, graph in self.argdef_graph_map.items(): | |||||
ind = 0 | |||||
while ind < len(graph._exprs): | |||||
cur_expr = graph._exprs[ind] | |||||
has_new_expr = False | |||||
for i in cur_expr.inputs: | |||||
if i.expr not in graph._exprs and not isinstance(i.expr, Input): | |||||
graph._exprs.insert(ind, i.expr) | |||||
has_new_expr = True | |||||
if not has_new_expr: | |||||
ind += 1 | |||||
for expr in graph._exprs: | |||||
for i in expr.inputs: | |||||
if expr.inputs.count(i) != i.users.count(expr): | |||||
add_or_del_count = expr.inputs.count(i) - i.users.count(expr) | |||||
if add_or_del_count > 0: | |||||
i.users.extend([expr] * add_or_del_count) | |||||
else: | |||||
[i.users.remove(expr) for i in range(-add_or_del_count)] | |||||
for o in expr.outputs: | |||||
if o.expr is not expr: | |||||
assert o not in o.expr.outputs | |||||
o.expr = expr | |||||
for node in graph.nodes(False): | |||||
# remove users of node which doesn't use node as input | |||||
node.users = [e for e in node.users if node in e.inputs] | |||||
for expr in graph._exprs: | |||||
graph._namespace.auto_naming_for_outputs(expr) | |||||
self._update_ref() | |||||
for _, graph in self.argdef_graph_map.items(): | |||||
graph._reset_ids() | |||||
def __copy__(self): | |||||
cls = self.__class__ | |||||
result = cls.__new__(cls) | |||||
result.__dict__.update(self.__dict__) | |||||
return result | |||||
def __deepcopy__(self, memo): | |||||
cls = self.__class__ | |||||
result = cls.__new__(cls) | |||||
state = {} | |||||
memo[id(self)] = result | |||||
for k, v in self.__dict__.items(): | |||||
if not isinstance(v, weakref.ReferenceType): | |||||
state[k] = copy.deepcopy(v, memo) | |||||
result.__dict__.update(state) | |||||
result._update_ref() | |||||
return result | |||||
def cpp_apply_module_trace(opdef, *args): | def cpp_apply_module_trace(opdef, *args): | ||||
return Apply.apply_module_trace_hook(opdef, *args) | return Apply.apply_module_trace_hook(opdef, *args) | ||||
USER_REGISTERED_MODULE = [] | |||||
USER_REGISTERED_FUNCTION = [] | |||||
def register_as_builtin(mod_cls: Type[Module]) -> None: | def register_as_builtin(mod_cls: Type[Module]) -> None: | ||||
r"""Registers class ``mod_cls`` (subclass of :class:`~.Module`) as builtin module. | r"""Registers class ``mod_cls`` (subclass of :class:`~.Module`) as builtin module. | ||||
Args: | Args: | ||||
mod_cls: the module class which will be treated as builtin module in tracing. | mod_cls: the module class which will be treated as builtin module in tracing. | ||||
""" | """ | ||||
USER_REGISTERED_MODULE.append((mod_cls.__module__, mod_cls.__qualname__)) | |||||
module_tracer.register_as_builtin(mod_cls) | module_tracer.register_as_builtin(mod_cls) | ||||
@@ -2181,6 +2306,7 @@ def wrap(func: Callable): | |||||
Args: | Args: | ||||
func: the function of the global function to insert into the graph when it's called. | func: the function of the global function to insert into the graph when it's called. | ||||
""" | """ | ||||
USER_REGISTERED_FUNCTION.append((func.__module__, func.__qualname__)) | |||||
assert callable(func), "func must be a callable" | assert callable(func), "func must be a callable" | ||||
assert hasattr(func, "__code__") | assert hasattr(func, "__code__") | ||||
fn_name = func.__code__.co_name | fn_name = func.__code__.co_name | ||||
@@ -2247,6 +2373,8 @@ def trace_module( | |||||
NodeMixin.wrap_safe( | NodeMixin.wrap_safe( | ||||
builder, Input.make(name="top", type=ModuleNode, qualname=net_name) | builder, Input.make(name="top", type=ModuleNode, qualname=net_name) | ||||
) | ) | ||||
args, kwargs = _convert_kwargs_to_args(mod.forward, args, kwargs, True) | |||||
inputs, _ = tree_flatten((args, kwargs)) | inputs, _ = tree_flatten((args, kwargs)) | ||||
for _, i in enumerate(inputs): | for _, i in enumerate(inputs): | ||||
# assert isinstance(i, Tensor), "not support " | # assert isinstance(i, Tensor), "not support " | ||||
@@ -5,12 +5,17 @@ | |||||
# Unless required by applicable law or agreed to in writing, | # Unless required by applicable law or agreed to in writing, | ||||
# software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
import collections | |||||
import copy | import copy | ||||
import inspect | |||||
from collections.abc import MutableMapping, MutableSequence | from collections.abc import MutableMapping, MutableSequence | ||||
from typing import Dict, Iterable, List, Optional, Sequence | |||||
from typing import Dict, Iterable, List, Optional, Sequence, Type | |||||
from .. import get_logger | |||||
from ..module import Module | from ..module import Module | ||||
logger = get_logger(__name__) | |||||
def replace_container_with_module_container(container): | def replace_container_with_module_container(container): | ||||
has_module = False | has_module = False | ||||
@@ -52,6 +57,101 @@ def replace_container_with_module_container(container): | |||||
return has_module, module_container | return has_module, module_container | ||||
def _convert_kwargs_to_args(func, args, kwargs, is_bounded=False): | |||||
# is_bounded = True when func is a method and provided args don't include 'self' | |||||
arg_specs = inspect.getfullargspec(func) | |||||
arg_specs_args = arg_specs.args | |||||
if is_bounded: | |||||
arg_specs_args = arg_specs.args[1:] | |||||
new_args = [] | |||||
new_kwargs = {} | |||||
new_args.extend(args) | |||||
if set(arg_specs_args[0 : len(new_args)]) & set(kwargs.keys()): | |||||
repeated_arg_name = set(arg_specs_args[0 : len(new_args)]) & set(kwargs.keys()) | |||||
raise TypeError( | |||||
"{} got multiple values for argument {}".format( | |||||
func.__qualname__, ", ".join(repeated_arg_name) | |||||
) | |||||
) | |||||
if len(new_args) < len(arg_specs.args): | |||||
for ind in range(len(new_args), len(arg_specs_args)): | |||||
arg_name = arg_specs_args[ind] | |||||
if arg_name in kwargs: | |||||
new_args.append(kwargs[arg_name]) | |||||
else: | |||||
index = ind - len(arg_specs_args) + len(arg_specs.defaults) | |||||
assert index < len(arg_specs.defaults) and index >= 0 | |||||
new_args.append(arg_specs.defaults[index]) | |||||
for kwarg_name in arg_specs.kwonlyargs: | |||||
if kwarg_name in kwargs: | |||||
new_kwargs[kwarg_name] = kwargs[kwarg_name] | |||||
else: | |||||
assert kwarg_name in arg_specs.kwonlydefaults | |||||
new_kwargs[kwarg_name] = arg_specs.kwonlydefaults[kwarg_name] | |||||
for k, v in kwargs.items(): | |||||
if k not in arg_specs.args and k not in arg_specs.kwonlyargs: | |||||
if arg_specs.varkw is None: | |||||
raise TypeError( | |||||
"{} got an unexpected keyword argument {}".format( | |||||
func.__qualname__, k | |||||
) | |||||
) | |||||
new_kwargs[k] = v | |||||
return tuple(new_args), new_kwargs | |||||
def _check_obj_attr(obj): | |||||
# check if all the attributes of a obj is serializable | |||||
from .pytree import tree_flatten | |||||
from .pytree import SUPPORTED_LEAF_CLS, SUPPORTED_LEAF_TYPE, TreeDef | |||||
from .expr import Expr | |||||
from .traced_module import TracedModule, InternalGraph, NameSpace | |||||
def _check_leaf_type(leaf): | |||||
leaf_type = leaf if isinstance(leaf, type) else type(leaf) | |||||
traced_module_types = [Expr, TreeDef, TracedModule, InternalGraph, NameSpace] | |||||
return ( | |||||
issubclass(leaf_type, tuple(SUPPORTED_LEAF_CLS + traced_module_types)) | |||||
or leaf_type in SUPPORTED_LEAF_TYPE | |||||
) | |||||
for _, v in obj.items(): | |||||
leafs, _ = tree_flatten(v, is_leaf=lambda _: True) | |||||
for leaf in leafs: | |||||
assert _check_leaf_type( | |||||
leaf | |||||
), "Type {} is not supported by traced module".format( | |||||
leaf if isinstance(leaf, type) else type(leaf) | |||||
) | |||||
def _check_builtin_module_attr(mod): | |||||
from .pytree import _is_leaf as _check_leaf_type | |||||
from .pytree import tree_flatten | |||||
# check if all the attributes of a builtin module is serializable | |||||
is_non_serializable_module = lambda m: isinstance( | |||||
m, Module | |||||
) and not _check_builtin_module_attr(m) | |||||
for k, v in mod.__dict__.items(): | |||||
if k == "_m_dump_modulestate": | |||||
continue | |||||
if is_non_serializable_module(v): | |||||
return False | |||||
elif not isinstance(v, Module): | |||||
leafs, _ = tree_flatten(v, is_leaf=lambda _: True) | |||||
for leaf in leafs: | |||||
if not _check_leaf_type(leaf) or is_non_serializable_module(leaf): | |||||
logger.warn( | |||||
"Type {} is not supported by traced module".format( | |||||
leaf if isinstance(leaf, type) else type(leaf) | |||||
) | |||||
) | |||||
return False | |||||
return True | |||||
class _ModuleList(Module, MutableSequence): | class _ModuleList(Module, MutableSequence): | ||||
r"""A List-like container. | r"""A List-like container. | ||||
@@ -15,7 +15,6 @@ import numpy as np | |||||
import megengine as mge | import megengine as mge | ||||
from megengine import Parameter, Tensor | from megengine import Parameter, Tensor | ||||
from megengine.core.ops import builtin | from megengine.core.ops import builtin | ||||
from megengine.traced_module.serialization import get_opdef_state, load_opdef_from_state | |||||
def test_tensor_serialization(): | def test_tensor_serialization(): | ||||
@@ -88,25 +87,3 @@ def test_compatibility(): | |||||
test_old_tensor("tensor_v1_1.mge") | test_old_tensor("tensor_v1_1.mge") | ||||
test_old_tensor("tensor_v1_2.mge") | test_old_tensor("tensor_v1_2.mge") | ||||
def test_opdef_serialization(): | |||||
with TemporaryFile() as f: | |||||
x = builtin.Elemwise(mode="Add") | |||||
pickle.dump(get_opdef_state(x), f) | |||||
f.seek(0) | |||||
load_x = load_opdef_from_state(pickle.load(f)) | |||||
assert x == load_x | |||||
with TemporaryFile() as f: | |||||
x = builtin.Convolution(stride_h=9, compute_mode="float32") | |||||
x.strategy = ( | |||||
builtin.Convolution.Strategy.PROFILE | |||||
| builtin.Convolution.Strategy.HEURISTIC | |||||
| builtin.Convolution.Strategy.REPRODUCIBLE | |||||
) | |||||
pickle.dump(get_opdef_state(x), f) | |||||
f.seek(0) | |||||
load_x = load_opdef_from_state(pickle.load(f)) | |||||
assert x.strategy == load_x.strategy | |||||
assert x == load_x |
@@ -85,12 +85,12 @@ class NewModule(M.Module): | |||||
return x | return x | ||||
def _check_expr_users(traced_module): | |||||
def _check_expr_users(flattened_module): | |||||
node_user = defaultdict(list) | node_user = defaultdict(list) | ||||
for expr in traced_module.graph._exprs: | |||||
for expr in flattened_module.graph._exprs: | |||||
for node in expr.inputs: | for node in expr.inputs: | ||||
node_user[node].append(expr) | node_user[node].append(expr) | ||||
for node in traced_module.graph.nodes(): | |||||
for node in flattened_module.graph.nodes(): | |||||
node.users.sort(key=lambda m: m._id) | node.users.sort(key=lambda m: m._id) | ||||
node_user[node].sort(key=lambda m: m._id) | node_user[node].sort(key=lambda m: m._id) | ||||
assert node.users == node_user[node] | assert node.users == node_user[node] | ||||
@@ -8,6 +8,7 @@ import numpy as np | |||||
import megengine as mge | import megengine as mge | ||||
import megengine.functional as F | import megengine.functional as F | ||||
import megengine.module as M | import megengine.module as M | ||||
import megengine.module.qat as QM | |||||
import megengine.quantization as Q | import megengine.quantization as Q | ||||
from megengine import Tensor | from megengine import Tensor | ||||
from megengine.module.qat.module import QATModule | from megengine.module.qat.module import QATModule | ||||
@@ -28,10 +29,18 @@ def get_subattr(self: M.Module, name: str): | |||||
return getattr(self, name) | return getattr(self, name) | ||||
class MyConvBnRelu2d(M.ConvBnRelu2d): | |||||
pass | |||||
class MyQATConvBnRelu2d(QM.ConvBnRelu2d): | |||||
pass | |||||
class Myblcok(M.Module): | class Myblcok(M.Module): | ||||
def __init__(self,): | def __init__(self,): | ||||
super().__init__() | super().__init__() | ||||
self.conv0 = M.ConvBnRelu2d(3, 3, 3, 1, 1) | |||||
self.conv0 = MyConvBnRelu2d(3, 3, 3, 1, 1) | |||||
self.conv1 = M.ConvBn2d(3, 3, 1, 1, 0) | self.conv1 = M.ConvBn2d(3, 3, 1, 1, 0) | ||||
self.conv2 = M.ConvBn2d(3, 3, 1, 1, 0) | self.conv2 = M.ConvBn2d(3, 3, 1, 1, 0) | ||||
self.add = M.Elemwise("FUSE_ADD_RELU") | self.add = M.Elemwise("FUSE_ADD_RELU") | ||||
@@ -106,7 +115,11 @@ def check_qparams(qparmsa: Q.QParams, qparmsb: Q.QParams): | |||||
def build_observered_net(net: M.Module, observer_cls): | def build_observered_net(net: M.Module, observer_cls): | ||||
qat_net = Q.quantize_qat(net, qconfig=get_observer_config(observer_cls)) | |||||
qat_net = Q.quantize_qat( | |||||
net, | |||||
qconfig=get_observer_config(observer_cls), | |||||
mapping={MyConvBnRelu2d: MyQATConvBnRelu2d}, | |||||
) | |||||
Q.enable_observer(qat_net) | Q.enable_observer(qat_net) | ||||
inp = Tensor(np.random.random(size=(5, 3, 32, 32))) | inp = Tensor(np.random.random(size=(5, 3, 32, 32))) | ||||
qat_net(inp) | qat_net(inp) | ||||
@@ -134,6 +147,15 @@ def test_trace_qat(): | |||||
check_qparams(weight_qparams, traced_weight_qparams) | check_qparams(weight_qparams, traced_weight_qparams) | ||||
if act_qparams: | if act_qparams: | ||||
check_qparams(act_qparams, traced_act_qparams) | check_qparams(act_qparams, traced_act_qparams) | ||||
flatten_traced_net = traced_net.flatten() | |||||
conv0_node = flatten_traced_net.graph.get_node_by_name( | |||||
"MyModule_block0_conv0" | |||||
).as_unique() | |||||
conv0_out_node = flatten_traced_net.graph.get_node_by_name( | |||||
"MyModule_block0_conv0_out" | |||||
).as_unique() | |||||
assert isinstance(conv0_node.owner, TracedModule) | |||||
assert conv0_out_node.expr.inputs[0] is conv0_node | |||||
_check_qat_module(build_observered_net(MyModule(), Q.MinMaxObserver)) | _check_qat_module(build_observered_net(MyModule(), Q.MinMaxObserver)) | ||||
_check_qat_module(build_observered_net(MyModule(), MyMinMaxObserver)) | _check_qat_module(build_observered_net(MyModule(), MyMinMaxObserver)) | ||||
@@ -6,14 +6,59 @@ | |||||
# software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
import pickle | import pickle | ||||
from collections import defaultdict | |||||
from tempfile import TemporaryFile | |||||
import numpy as np | import numpy as np | ||||
import megengine.functional as F | import megengine.functional as F | ||||
import megengine.module as M | import megengine.module as M | ||||
import megengine.traced_module.serialization as S | |||||
from megengine import Tensor | from megengine import Tensor | ||||
from megengine.core._imperative_rt.core2 import apply | |||||
from megengine.core.ops import builtin | |||||
from megengine.core.ops.builtin import Elemwise | |||||
from megengine.module import Module | from megengine.module import Module | ||||
from megengine.traced_module import trace_module | from megengine.traced_module import trace_module | ||||
from megengine.traced_module.expr import CallMethod, Constant | |||||
from megengine.traced_module.node import TensorNode | |||||
from megengine.traced_module.serialization import ( | |||||
register_functional_loader, | |||||
register_module_loader, | |||||
register_opdef_loader, | |||||
register_tensor_method_loader, | |||||
) | |||||
from megengine.traced_module.utils import _convert_kwargs_to_args | |||||
def _check_id(traced_module): | |||||
_total_ids = traced_module.graph._total_ids | |||||
node_ids = [n._id for n in traced_module.graph.nodes().as_list()] | |||||
assert len(set(node_ids)) == len(node_ids) | |||||
assert max(node_ids) + 1 == _total_ids[0] | |||||
expr_ids = [n._id for n in traced_module.graph.exprs().as_list()] | |||||
assert len(set(expr_ids)) == len(expr_ids) | |||||
assert max(expr_ids) + 1 == _total_ids[1] | |||||
def _check_name(flatened_module): | |||||
node_names = [n._name for n in flatened_module.graph.nodes().as_list()] | |||||
assert len(set(node_names)) == len(node_names) | |||||
def _check_expr_users(traced_module): | |||||
node_user = defaultdict(list) | |||||
for expr in traced_module.graph._exprs: | |||||
for node in expr.inputs: | |||||
node_user[node].append(expr) | |||||
if isinstance(expr, CallMethod) and expr.graph: | |||||
_check_expr_users(expr.inputs[0].owner) | |||||
for node in traced_module.graph.nodes(False): | |||||
node.users.sort(key=lambda m: m._id) | |||||
node_user[node].sort(key=lambda m: m._id) | |||||
assert node.users == node_user[node] | |||||
class MyBlock(Module): | class MyBlock(Module): | ||||
@@ -48,5 +93,274 @@ def test_dump_and_load(): | |||||
traced_module = trace_module(module, x) | traced_module = trace_module(module, x) | ||||
np.testing.assert_array_equal(expect, traced_module(x)) | np.testing.assert_array_equal(expect, traced_module(x)) | ||||
obj = pickle.dumps(traced_module) | obj = pickle.dumps(traced_module) | ||||
pickle.loads(obj) | |||||
new_tm = pickle.loads(obj) | |||||
_check_id(new_tm) | |||||
_check_expr_users(new_tm) | |||||
traced_module.graph._reset_ids() | |||||
old_nodes = traced_module.graph.nodes().as_list() | |||||
new_nodes = new_tm.graph.nodes().as_list() | |||||
old_exprs = traced_module.graph.exprs().as_list() | |||||
new_exprs = new_tm.graph.exprs().as_list() | |||||
assert len(old_nodes) == len(new_nodes) | |||||
for i, j in zip(old_nodes, new_nodes): | |||||
assert i._name == j._name | |||||
assert i._qualname == j._qualname | |||||
assert i._id == j._id | |||||
assert len(old_exprs) == len(new_exprs) | |||||
for i, j in zip(old_exprs, new_exprs): | |||||
assert i._id == j._id | |||||
np.testing.assert_array_equal(expect, traced_module(x)) | np.testing.assert_array_equal(expect, traced_module(x)) | ||||
def test_opdef_loader(): | |||||
class MyModule1(Module): | |||||
def forward(self, x, y): | |||||
op = Elemwise("ADD") | |||||
return apply(op, x, y)[0] | |||||
m = MyModule1() | |||||
x = Tensor(np.ones((20))) | |||||
y = Tensor(np.ones((20))) | |||||
traced_module = trace_module(m, x, y) | |||||
orig_loader_dict = S.OPDEF_LOADER | |||||
S.OPDEF_LOADER = {} | |||||
@register_opdef_loader(Elemwise) | |||||
def add_opdef_loader(expr): | |||||
if expr.opdef_state["mode"] == "ADD": | |||||
expr.opdef_state["mode"] = "MUL" | |||||
node = expr.inputs[1] | |||||
astype_expr = CallMethod(node, "astype") | |||||
oup = TensorNode( | |||||
astype_expr, | |||||
shape=node.shape, | |||||
dtype=expr.inputs[0].dtype, | |||||
qparams=node.qparams, | |||||
) | |||||
astype_expr.set_args_kwargs(node, expr.inputs[0].dtype) | |||||
astype_expr.return_val = (oup,) | |||||
expr.inputs[1] = oup | |||||
obj = pickle.dumps(traced_module) | |||||
new_module = pickle.loads(obj) | |||||
_check_id(new_module) | |||||
_check_expr_users(new_module) | |||||
_check_name(new_module.flatten()) | |||||
assert ( | |||||
isinstance(new_module.graph._exprs[0], CallMethod) | |||||
and new_module.graph._exprs[1].opdef.mode == "MUL" | |||||
and len(new_module.graph._exprs) == 2 | |||||
) | |||||
result = new_module(x, y) | |||||
np.testing.assert_equal(result.numpy(), x.numpy()) | |||||
S.OPDEF_LOADER = orig_loader_dict | |||||
def test_functional_loader(): | |||||
class MyModule2(Module): | |||||
def forward(self, x, y): | |||||
return F.conv2d(x, y) | |||||
m = MyModule2() | |||||
x = Tensor(np.random.random((1, 3, 32, 32))) | |||||
y = Tensor(np.random.random((3, 3, 3, 3))) | |||||
traced_module = trace_module(m, x, y) | |||||
orig_loader_dict = S.FUNCTIONAL_LOADER | |||||
S.FUNCTIONAL_LOADER = {} | |||||
@register_functional_loader(("megengine.functional.nn", "conv2d")) | |||||
def conv2df_loader(expr): | |||||
# expr.func = ("megengine.functional.nn","conv2d") | |||||
kwargs = expr.kwargs | |||||
orig_weight = expr.named_args["weight"] | |||||
astype_expr = CallMethod(orig_weight, "astype") | |||||
oup = TensorNode( | |||||
astype_expr, | |||||
shape=orig_weight.shape, | |||||
dtype=orig_weight.dtype, | |||||
qparams=orig_weight.qparams, | |||||
) | |||||
astype_expr.set_args_kwargs(orig_weight, expr.named_args["inp"].dtype) | |||||
astype_expr.return_val = (oup,) | |||||
expr.set_arg("weight", oup) | |||||
obj = pickle.dumps(traced_module) | |||||
new_module = pickle.loads(obj) | |||||
_check_expr_users(new_module) | |||||
_check_id(new_module) | |||||
result = new_module(x, y) | |||||
gt = m(x, y) | |||||
assert ( | |||||
isinstance(new_module.graph._exprs[0], CallMethod) | |||||
and len(new_module.graph._exprs) == 2 | |||||
) | |||||
np.testing.assert_equal(result.numpy(), gt.numpy()) | |||||
S.FUNCTIONAL_LOADER = orig_loader_dict | |||||
def test_tensor_method_loader(): | |||||
class MyModule3(Module): | |||||
def forward(self, x): | |||||
return x + 1 | |||||
m = MyModule3() | |||||
x = Tensor(np.ones((20))) | |||||
traced_module = trace_module(m, x) | |||||
orig_loader_dict = S.TENSORMETHOD_LOADER | |||||
S.TENSORMETHOD_LOADER = {} | |||||
@register_tensor_method_loader("__add__") | |||||
def add_loader(expr): | |||||
args = list(expr.args) | |||||
if not isinstance(args[1], TensorNode): | |||||
args[1] = Tensor(args[1]) | |||||
node = Constant(args[1], "const").outputs[0] | |||||
astype_expr = CallMethod(node, "astype") | |||||
oup = TensorNode( | |||||
astype_expr, shape=node.shape, dtype=node.dtype, qparams=node.qparams, | |||||
) | |||||
astype_expr.set_args_kwargs(node, expr.inputs[0].dtype) | |||||
astype_expr.return_val = (oup,) | |||||
add_expr = CallMethod(oup, "__add__") | |||||
add_expr.set_args_kwargs(oup, oup) | |||||
oup1 = TensorNode( | |||||
add_expr, shape=oup.shape, dtype=oup.dtype, qparams=node.qparams, | |||||
) | |||||
add_expr.return_val = oup1 | |||||
args[1] = oup1 | |||||
expr.set_args_kwargs(*args) | |||||
obj = pickle.dumps(traced_module) | |||||
new_module = pickle.loads(obj) | |||||
_check_expr_users(new_module) | |||||
_check_id(new_module) | |||||
result = new_module(x) | |||||
gt = m(x) | |||||
assert ( | |||||
isinstance(new_module.graph._exprs[0], Constant) | |||||
and len(new_module.graph._exprs) == 4 | |||||
) | |||||
np.testing.assert_equal(result.numpy(), (x + 2).numpy()) | |||||
S.TENSORMETHOD_LOADER = orig_loader_dict | |||||
def test_module_loader(): | |||||
class MyModule4(Module): | |||||
def __init__(self): | |||||
super().__init__() | |||||
self.conv = M.Conv2d(3, 3, 3) | |||||
def forward(self, x): | |||||
return self.conv(x) | |||||
m = MyModule4() | |||||
x = Tensor(np.random.random((1, 3, 32, 32))) | |||||
traced_module = trace_module(m, x) | |||||
orig_loader_dict = S.MODULE_LOADER | |||||
S.MODULE_LOADER = {} | |||||
@register_module_loader(("megengine.module.conv", "Conv2d")) | |||||
def conv2dm_loader(expr): | |||||
module = expr.inputs[0].owner | |||||
args = list(expr.args) | |||||
orig_inp = args[1] | |||||
astype_expr = CallMethod(orig_inp, "astype") | |||||
oup = TensorNode( | |||||
astype_expr, | |||||
shape=orig_inp.shape, | |||||
dtype=orig_inp.dtype, | |||||
qparams=orig_inp.qparams, | |||||
) | |||||
astype_expr.set_args_kwargs(orig_inp, module.weight.dtype) | |||||
astype_expr.return_val = (oup,) | |||||
args[1] = oup | |||||
expr.set_args_kwargs(*args) | |||||
obj = pickle.dumps(traced_module) | |||||
new_module = pickle.loads(obj) | |||||
result = new_module(x) | |||||
gt = m(x) | |||||
assert ( | |||||
isinstance(new_module.graph._exprs[1], CallMethod) | |||||
and len(new_module.graph._exprs) == 3 | |||||
) | |||||
np.testing.assert_equal(result.numpy(), gt.numpy()) | |||||
S.MODULE_LOADER = orig_loader_dict | |||||
def test_shared_module(): | |||||
class MyModule(M.Module): | |||||
def __init__(self): | |||||
super().__init__() | |||||
self.a = M.Elemwise("ADD") | |||||
self.b = self.a | |||||
def forward(self, x, y): | |||||
z = self.a(x, y) | |||||
z = self.b(z, y) | |||||
return z | |||||
x = Tensor(1) | |||||
y = Tensor(2) | |||||
m = MyModule() | |||||
tm = trace_module(m, x, y) | |||||
obj = pickle.dumps(tm) | |||||
load_tm = pickle.loads(obj) | |||||
_check_expr_users(load_tm) | |||||
_check_name(load_tm.flatten()) | |||||
_check_id(load_tm) | |||||
assert load_tm.a is load_tm.b | |||||
def test_convert_kwargs_to_args(): | |||||
def func(a, b, c=4, *, d, e=3, f=4): | |||||
pass | |||||
args = (1,) | |||||
kwargs = {"b": 1, "d": 6} | |||||
new_args, new_kwargs = _convert_kwargs_to_args(func, args, kwargs) | |||||
assert new_args == (1, 1, 4) | |||||
assert new_kwargs == {"d": 6, "e": 3, "f": 4} | |||||
args = (1,) | |||||
kwargs = {"d": 6} | |||||
new_args, new_kwargs = _convert_kwargs_to_args(func, args, kwargs, is_bounded=True) | |||||
assert new_args == (1, 4) | |||||
assert new_kwargs == {"d": 6, "e": 3, "f": 4} | |||||
def func1(a, b, c, d, e, *, f): | |||||
pass | |||||
args = () | |||||
kwargs = {"a": 1, "b": 2, "c": 3, "d": 4, "e": 5, "f": 6} | |||||
new_args, new_kwargs = _convert_kwargs_to_args(func1, args, kwargs) | |||||
assert new_args == (1, 2, 3, 4, 5) | |||||
assert new_kwargs == {"f": 6} | |||||
def test_opdef_serialization(): | |||||
with TemporaryFile() as f: | |||||
x = builtin.Elemwise(mode="Add") | |||||
pickle.dump(x, f) | |||||
f.seek(0) | |||||
load_x = pickle.load(f) | |||||
assert x == load_x | |||||
with TemporaryFile() as f: | |||||
x = builtin.Convolution(stride_h=9, compute_mode="float32") | |||||
x.strategy = ( | |||||
builtin.Convolution.Strategy.PROFILE | |||||
| builtin.Convolution.Strategy.HEURISTIC | |||||
| builtin.Convolution.Strategy.REPRODUCIBLE | |||||
) | |||||
pickle.dump(x, f) | |||||
f.seek(0) | |||||
load_x = pickle.load(f) | |||||
assert x.strategy == load_x.strategy | |||||
assert x == load_x |