GitOrigin-RevId: aaa9e51c74
release-1.7
@@ -7,6 +7,7 @@ | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from ..core._imperative_rt.core2 import set_cpp_apply_module_trace | |||
from . import compat | |||
from .traced_module import ( | |||
TracedModule, | |||
_register_all_builtin_module, | |||
@@ -0,0 +1,136 @@ | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import numpy as np | |||
from .. import tensor | |||
from ..core.ops.builtin import BatchNorm | |||
from .expr import CallMethod, Constant | |||
from .node import TensorNode | |||
from .serialization import ( | |||
register_functional_loader, | |||
register_module_loader, | |||
register_opdef_loader, | |||
register_tensor_method_loader, | |||
) | |||
""" | |||
# Expr loaders examples | |||
from ..core.ops.builtin import Elemwise | |||
@register_opdef_loader(Elemwise) | |||
def add_opdef_loader(expr): | |||
if expr.opdef_state["mode"] == "ADD": | |||
expr.opdef_state["mode"] == "MUL" | |||
node = expr.inputs[1] | |||
astype_expr = CallMethod(node, "astype") | |||
oup = TensorNode( | |||
astype_expr, | |||
shape=node.shape, | |||
dtype=expr.inputs[0].dtype, | |||
qparams=node.qparams, | |||
) | |||
astype_expr.set_args_kwargs(node, expr.inputs[0].dtype) | |||
astype_expr.return_val = (oup,) | |||
expr.inputs[1] = oup | |||
@register_functional_loader(("megengine.functional.nn", "conv2d")) | |||
def conv2df_loader(expr): | |||
# expr.func = ("megengine.functional.nn","conv2d") | |||
kwargs = expr.kwargs | |||
orig_weight = expr.named_args["weight"] | |||
astype_expr = CallMethod(orig_weight, "astype") | |||
oup = TensorNode( | |||
astype_expr, | |||
shape=orig_weight.shape, | |||
dtype=orig_weight.dtype, | |||
qparams=orig_weight.qparams, | |||
) | |||
astype_expr.set_args_kwargs(orig_weight, expr.named_args["inp"].dtype) | |||
astype_expr.return_val = (oup,) | |||
expr.set_arg("weight", oup) | |||
@register_module_loader(("megengine.module.conv", "Conv2d")) | |||
def conv2dm_loader(expr): | |||
module = expr.inputs[0].owner | |||
args = list(expr.args) | |||
orig_inp = args[1] | |||
astype_expr = CallMethod(orig_inp, "astype") | |||
oup = TensorNode( | |||
astype_expr, | |||
shape=orig_inp.shape, | |||
dtype=orig_inp.dtype, | |||
qparams=orig_inp.qparams, | |||
) | |||
astype_expr.set_args_kwargs(orig_inp, module.weight.dtype) | |||
astype_expr.return_val = (oup,) | |||
args[1] = oup | |||
expr.set_args_kwargs(*args) | |||
@register_tensor_method_loader("__add__") | |||
def add_loader(expr): | |||
args = list(expr.args) | |||
if not isinstance(args[1], TensorNode): | |||
args[1] = tensor(args[1]) | |||
node = Constant(args[1], "const").outputs[0] | |||
astype_expr = CallMethod(node, "astype") | |||
oup = TensorNode( | |||
astype_expr, shape=node.shape, dtype=node.dtype, qparams=node.qparams, | |||
) | |||
astype_expr.set_args_kwargs(node, expr.inputs[0].dtype) | |||
astype_expr.return_val = (oup,) | |||
args[1] = oup | |||
expr.set_args_kwargs(*args) | |||
""" | |||
@register_module_loader( | |||
("megengine.module.batchnorm", "BatchNorm1d"), | |||
("megengine.module.batchnorm", "BatchNorm2d"), | |||
("megengine.module.batchnorm", "SyncBatchNorm"), | |||
) | |||
def bn2d_module_loader(expr): | |||
# mge 1.6 | |||
if not hasattr(expr, "version"): | |||
module = expr.inputs[0].owner | |||
if not hasattr(module, "param_dim"): | |||
module.param_dim = "dim_1c11" | |||
@register_module_loader( | |||
("megengine.module.conv_bn", "ConvBn2d"), | |||
("megengine.module.conv_bn", "ConvBnRelu2d"), | |||
("megengine.module.qat.conv_bn", "ConvBn2d"), | |||
("megengine.module.qat.conv_bn", "ConvBnRelu2d"), | |||
) | |||
def convbn2d_module_loader(expr): | |||
# mge 1.6 | |||
if not hasattr(expr, "version"): | |||
module = expr.inputs[0].owner | |||
if not hasattr(module.bn, "param_dim"): | |||
module.bn.param_dim = "dim_1c11" | |||
@register_opdef_loader(BatchNorm) | |||
def bn_opdef_loader(expr): | |||
# mge 1.6 | |||
if not hasattr(expr, "version"): | |||
output = expr.outputs[-1] | |||
oup = TensorNode(expr, shape=(0,), dtype=None, qparams=output._qparams,) | |||
expr.outputs.insert(4, oup) |
@@ -11,19 +11,28 @@ import collections | |||
import copy | |||
import inspect | |||
import re | |||
from typing import Callable, Dict, List, Optional, Union | |||
import weakref | |||
from importlib import import_module | |||
from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union | |||
from ..core._imperative_rt import OpDef | |||
from ..core._imperative_rt.core2 import Tensor as RawTensor | |||
from ..core._imperative_rt.core2 import apply, set_module_tracing, unset_module_tracing | |||
from ..core._imperative_rt.core2 import ( | |||
apply, | |||
is_tracing_module, | |||
set_module_tracing, | |||
unset_module_tracing, | |||
) | |||
from ..core.ops.builtin import FakeQuant | |||
from ..core.ops.special import Const | |||
from ..module import Module | |||
from ..tensor import Parameter, Tensor | |||
from ..version import __version__ | |||
from .module_tracer import active_module_tracer, module_tracer | |||
from .node import ModuleNode, Node, NodeMixin, TensorNode | |||
from .pytree import ArgsIndex, TreeDef, _is_const_leaf, _is_leaf, tree_flatten | |||
from .serialization import get_opdef_state, load_opdef_from_state | |||
from .serialization import _ModuleState | |||
from .utils import _check_builtin_module_attr, _check_obj_attr, _convert_kwargs_to_args | |||
def rstrip(s: str, __chars: str): | |||
@@ -112,6 +121,7 @@ class Expr: | |||
node.users.append(self) | |||
else: | |||
assert node is None | |||
assert not isinstance(val, (Module, RawTensor)) | |||
assert _is_leaf(val) and _is_const_leaf(val) | |||
idx = len(self.inputs) + len(self.const_val) | |||
self.const_val.append((idx, val)) | |||
@@ -132,14 +142,14 @@ class Expr: | |||
current_graph._namespace.auto_naming_for_outputs(self) | |||
def unflatten_args(self, inputs): | |||
if self.arg_def is not None: | |||
inputs = list(inputs) | |||
for idx, val in self.const_val: | |||
inputs.insert(idx, val) | |||
args, kwargs = self.arg_def.unflatten(inputs) | |||
return args, kwargs | |||
else: | |||
return inputs, {} | |||
assert self.arg_def is not None, "{} expr doesn't have args/kwargs".format( | |||
type(self).__name__ | |||
) | |||
inputs = list(inputs) | |||
for idx, val in self.const_val: | |||
inputs.insert(idx, val) | |||
args, kwargs = self.arg_def.unflatten(inputs) | |||
return args, kwargs | |||
def replace_inputs(self, repl_dict: Dict[Node, Node]): | |||
r"""Replace the input Nodes of this Expr. | |||
@@ -166,6 +176,39 @@ class Expr: | |||
repl_node.users.append(self) | |||
@property | |||
def _support_set_args_kwargs(self): | |||
return False | |||
def set_args_kwargs(self, *args, **kwargs): | |||
r""" Set args and kwargs for Expr. | |||
""" | |||
assert ( | |||
self._support_set_args_kwargs | |||
), "Doesn't support set args/kwargs for {} expr".format(type(self).__name__) | |||
args, kwargs = _convert_kwargs_to_args(self._get_func(), args, kwargs) | |||
inputs, arg_def = tree_flatten((args, kwargs)) | |||
orig_inputs = self.inputs | |||
self.inputs = [] | |||
self.const_val = [] | |||
for val in inputs: | |||
if isinstance(val, (TensorNode, ModuleNode)): | |||
self.inputs.append(val) | |||
else: | |||
assert _is_leaf(val) and _is_const_leaf(val) | |||
idx = len(self.inputs) + len(self.const_val) | |||
self.const_val.append((idx, val)) | |||
for n in orig_inputs: | |||
if n not in self.inputs: | |||
n.users.remove(self) | |||
for n in self.inputs: | |||
if n not in orig_inputs: | |||
n.users.append(self) | |||
self.arg_def = arg_def | |||
@property | |||
def kwargs(self): | |||
r"""Get the keyword arguments of the operation corresponding to this Expr.""" | |||
_, kwargs = self.unflatten_args(self.inputs) | |||
@@ -177,6 +220,61 @@ class Expr: | |||
args, _ = self.unflatten_args(self.inputs) | |||
return args | |||
def _get_func(self): | |||
# get called function when the expr is interpreted | |||
raise NotImplementedError | |||
@property | |||
def named_args(self): | |||
func = self._get_func() | |||
return inspect.getcallargs(func, *self.args, **self.kwargs) | |||
def set_arg(self, name, val): | |||
func = self._get_func() | |||
if name in self.kwargs: | |||
new_kwargs = self.kwargs | |||
new_kwargs[name] = val | |||
self.set_args_kwargs(*self.args, **new_kwargs) | |||
else: | |||
arg_spec = inspect.getfullargspec(func) | |||
if name in arg_spec.args: | |||
ind = arg_spec.args.index(name) | |||
new_args = list(self.args) | |||
new_args[ind] = val | |||
self.set_args_kwargs(*new_args) | |||
elif name == arg_spec.varargs: | |||
assert arg_spec.varargs is not None | |||
assert len(self.args) >= len(arg_spec.args) | |||
val = (val,) if not isinstance(val, Sequence) else val | |||
self.set_args_kwargs(*self.args[0 : len(arg_spec.args)], *val) | |||
else: | |||
assert ( | |||
arg_spec.varkw is not None | |||
), "func {} does't have argument named {}".format(func, name) | |||
new_kwargs = self.kwargs | |||
new_kwargs[name] = val | |||
self.set_args_kwargs(*self.args, **new_kwargs) | |||
@property | |||
def return_val(self): | |||
return self.out_def.unflatten(self.outputs) | |||
@return_val.setter | |||
def return_val(self, new_outputs): | |||
outputs, out_def = tree_flatten( | |||
new_outputs, is_leaf=lambda x: isinstance(x, Node) | |||
) | |||
assert all( | |||
isinstance(o, Node) for o in outputs | |||
), "Return values of expr must be ModuleNode or TensorNode or Container with them" | |||
assert all( | |||
o.expr in (None, self) for o in outputs | |||
), "Some nodes are produced by other expr, can not be output of expr {}".format( | |||
self | |||
) | |||
self.outputs = outputs | |||
self.out_def = out_def | |||
@property | |||
def top_graph(self): | |||
r"""Get the parent graph of this Expr.""" | |||
@@ -184,12 +282,6 @@ class Expr: | |||
return self._top_graph() | |||
return None | |||
def __getstate__(self): | |||
state = self.__dict__.copy() | |||
if "_top_graph" in state: | |||
state.pop("_top_graph") | |||
return state | |||
@classmethod | |||
def _get_next_id(cls): | |||
return cls.__total_id | |||
@@ -199,6 +291,23 @@ class Expr: | |||
assert isinstance(id, int) | |||
cls.__total_id = id | |||
def __copy__(self): | |||
cls = self.__class__ | |||
result = cls.__new__(cls) | |||
result.__dict__.update(self.__dict__) | |||
return result | |||
def __deepcopy__(self, memo): | |||
cls = self.__class__ | |||
result = cls.__new__(cls) | |||
state = {} | |||
memo[id(self)] = result | |||
for k, v in self.__dict__.items(): | |||
if not isinstance(v, weakref.ReferenceType): | |||
state[k] = copy.deepcopy(v, memo) | |||
result.__dict__.update(state) | |||
return result | |||
# expr: None (i.e. fake expression which is used to mark input) | |||
class Input(Expr): | |||
@@ -229,6 +338,17 @@ class Input(Expr): | |||
def __repr__(self): | |||
return "%{}:\t{} = Input()".format(self._id, self.outputs[0]) | |||
def __getstate__(self): | |||
state = { | |||
"_id": self._id, | |||
"_disable_remove": self._disable_remove, | |||
"inputs": self.inputs, | |||
"outputs": self.outputs, | |||
"name": self.name, | |||
} | |||
_check_obj_attr(state) | |||
return state | |||
# expr: outputs = getattr(inputs[0], self.name) | |||
class GetAttr(Expr): | |||
@@ -276,11 +396,23 @@ class GetAttr(Expr): | |||
def __repr__(self): | |||
out_type = "Tensor" | |||
if isinstance(self.outputs[0], ModuleNode): | |||
out_type = self.outputs[0].module_type.__name__ | |||
m_type = self.outputs[0].module_type | |||
out_type = m_type.__name__ if isinstance(m_type, type) else m_type[1] | |||
return '%{}:\t{} = getattr({}, "{}") -> ({})'.format( | |||
self._id, self.outputs[0], self.inputs[0], self.name, out_type | |||
) | |||
def __getstate__(self): | |||
state = { | |||
"_id": self._id, | |||
"_disable_remove": self._disable_remove, | |||
"inputs": self.inputs, | |||
"outputs": self.outputs, | |||
"name": self.name, | |||
} | |||
_check_obj_attr(state) | |||
return state | |||
# expr: outputs = inputs[0].__call__(*inputs[1:]) | |||
class CallMethod(Expr): | |||
@@ -307,6 +439,7 @@ class CallMethod(Expr): | |||
node, | |||
] | |||
self.const_val = [] | |||
self.arg_def = tree_flatten(((node,), {}))[1] | |||
self.method = method | |||
@classmethod | |||
@@ -342,6 +475,27 @@ class CallMethod(Expr): | |||
outputs, _ = tree_flatten(outputs, is_leaf=lambda x: isinstance(x, RawTensor)) | |||
return outputs | |||
def _get_func(self): | |||
if isinstance(self.args[0], type): | |||
obj_type = self.args[0] | |||
elif isinstance(self.args[0], ModuleNode): | |||
obj_type = self.args[0].module_type | |||
else: | |||
assert isinstance(self.args[0], TensorNode) | |||
obj_type = Tensor | |||
meth = getattr( | |||
obj_type, "forward" if issubclass(obj_type, Module) else self.method | |||
) | |||
return meth | |||
@property | |||
def _support_set_args_kwargs(self): | |||
# only expr call tensor method or builtin module support modify args/kwargs | |||
return ( | |||
isinstance(self.args[0], (TensorNode, type)) | |||
or self.args[0].module_type is not Module | |||
) | |||
def __repr__(self): | |||
args = ", ".join(str(i) for i in self.args[1:]) | |||
kwargs = ", ".join("{}={}".format(k, v) for k, v in self.kwargs.items()) | |||
@@ -359,6 +513,21 @@ class CallMethod(Expr): | |||
", ".join([args, kwargs]), | |||
) | |||
def __getstate__(self): | |||
state = { | |||
"_id": self._id, | |||
"_disable_remove": self._disable_remove, | |||
"inputs": self.inputs, | |||
"const_val": self.const_val, | |||
"method": self.method, | |||
"arg_def": self.arg_def, | |||
"out_def": self.out_def, | |||
"outputs": self.outputs, | |||
"version": __version__, | |||
} | |||
_check_obj_attr(state) | |||
return state | |||
# expr: outputs = apply(self.opdef, *inputs) | |||
class Apply(Expr): | |||
@@ -394,14 +563,32 @@ class Apply(Expr): | |||
) | |||
def __getstate__(self): | |||
state = super().__getstate__() | |||
state["opdef"] = get_opdef_state(state["opdef"]) | |||
opdef_state = self.opdef.__getstate__() | |||
opdef_state["opdef_type"] = type(self.opdef) | |||
state = { | |||
"_id": self._id, | |||
"_disable_remove": self._disable_remove, | |||
"opdef_state": opdef_state, | |||
"inputs": self.inputs, | |||
"outputs": self.outputs, | |||
"version": __version__, | |||
} | |||
_check_obj_attr(state) | |||
return state | |||
def __setstate__(self, state): | |||
state["opdef"] = load_opdef_from_state(state["opdef"]) | |||
for k, v in state.items(): | |||
setattr(self, k, v) | |||
# compat with mge 1.6 | |||
if "opdef" in state and "opdef_state" not in state: | |||
opdef_state = state.pop("opdef") | |||
opdef_state["opdef_type"] = opdef_state.pop("type") | |||
state["opdef_state"] = opdef_state | |||
self.__dict__.update(state) | |||
assert isinstance(state["opdef_state"], dict) | |||
opdef_state = state["opdef_state"].copy() | |||
opdef_type = opdef_state.pop("opdef_type") | |||
opdef_obj = opdef_type() | |||
opdef_obj.__setstate__(opdef_state) | |||
setattr(self, "opdef", opdef_obj) | |||
@classmethod | |||
def apply_module_trace_hook(cls, opdef, *inputs): | |||
@@ -458,12 +645,24 @@ class CallFunction(Expr): | |||
def interpret(self, *inputs): | |||
args, kwargs = self.unflatten_args(inputs) | |||
outputs = self.func(*args, **kwargs) | |||
func = ( | |||
self.func | |||
if not is_tracing_module() | |||
else active_module_tracer().patcher.wrap_fn(self.func) | |||
) | |||
outputs = func(*args, **kwargs) | |||
if outputs is None: | |||
return outputs | |||
outputs, _ = tree_flatten(outputs, is_leaf=lambda x: isinstance(x, RawTensor)) | |||
return outputs | |||
def _get_func(self): | |||
return self.func | |||
@property | |||
def _support_set_args_kwargs(self): | |||
return True | |||
def __repr__(self): | |||
args = ", ".join(str(i) for i in self.args) | |||
kwargs = ", ".join("{}={}".format(k, v) for k, v in self.kwargs.items()) | |||
@@ -477,6 +676,33 @@ class CallFunction(Expr): | |||
", ".join([args, kwargs]), | |||
) | |||
def __getstate__(self): | |||
state = { | |||
"_id": self._id, | |||
"_disable_remove": self._disable_remove, | |||
"func": (self.func.__module__, self.func.__qualname__), | |||
"const_val": self.const_val, | |||
"inputs": self.inputs, | |||
"arg_def": self.arg_def, | |||
"out_def": self.out_def, | |||
"outputs": self.outputs, | |||
"version": __version__, | |||
} | |||
_check_obj_attr(state) | |||
return state | |||
def __setstate__(self, state): | |||
self.__dict__.update(state) | |||
try: | |||
if isinstance(self.func, tuple): | |||
mname, fname = self.func | |||
f = import_module(mname) | |||
for i in fname.split("."): | |||
f = getattr(f, i) | |||
self.func = f | |||
except Exception: | |||
pass | |||
# expr outputs = self.value | |||
class Constant(Expr): | |||
@@ -496,6 +722,13 @@ class Constant(Expr): | |||
assert isinstance(c, (RawTensor, Module)) | |||
if isinstance(c, Module): | |||
assert module_tracer.is_builtin(c) or c.is_qat | |||
if isinstance(c, RawTensor): | |||
if is_tracing_module(): | |||
unset_module_tracing() | |||
c = Tensor(c) | |||
set_module_tracing() | |||
else: | |||
c = Tensor(c) | |||
self.value = c | |||
self.name = name | |||
self.inputs = [] | |||
@@ -530,9 +763,25 @@ class Constant(Expr): | |||
) | |||
def __getstate__(self): | |||
state = self.__dict__.copy() | |||
if "_top_graph" in state: | |||
state.pop("_top_graph") | |||
state = { | |||
"_id": self._id, | |||
"_disable_remove": self._disable_remove, | |||
"value": self.value, | |||
"name": self.name, | |||
"inputs": self.inputs, | |||
"outputs": self.outputs, | |||
} | |||
_check_obj_attr(state) | |||
if isinstance(self.value, RawTensor): | |||
state["value"] = Tensor(self.value) | |||
if isinstance(self.value, Module) and module_tracer.is_builtin(self.value): | |||
_check_builtin_module_attr(self.value) | |||
state["value"] = _ModuleState.get_module_state(self.value) | |||
return state | |||
def __setstate__(self, state): | |||
for k, v in state.items(): | |||
if isinstance(v, _ModuleState): | |||
state[k] = v.to_module() | |||
self.__dict__.update(state) |
@@ -72,7 +72,6 @@ BUILTIN_ARRAY_METHOD = [ | |||
"astype", | |||
"reshape", | |||
"_broadcast", | |||
"transpose", | |||
"flatten", | |||
"sum", | |||
"prod", | |||
@@ -6,7 +6,9 @@ | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import abc | |||
import copy | |||
import weakref | |||
from importlib import import_module | |||
from typing import Any, Dict, List, Tuple, Type | |||
import numpy | |||
@@ -14,7 +16,9 @@ import numpy | |||
from .. import get_logger | |||
from ..core._imperative_rt.core2 import Tensor as RawTensor | |||
from ..module import Module | |||
from ..quantization.utils import QParams | |||
from ..tensor import Tensor | |||
from .utils import _check_obj_attr | |||
logger = get_logger(__name__) | |||
@@ -145,6 +149,23 @@ class Node: | |||
assert isinstance(id, int) | |||
cls.__total_id = id | |||
def __copy__(self): | |||
cls = self.__class__ | |||
result = cls.__new__(cls) | |||
result.__dict__.update(self.__dict__) | |||
return result | |||
def __deepcopy__(self, memo): | |||
cls = self.__class__ | |||
result = cls.__new__(cls) | |||
state = {} | |||
memo[id(self)] = result | |||
for k, v in self.__dict__.items(): | |||
if not isinstance(v, weakref.ReferenceType) and k != "actual_node": | |||
state[k] = copy.deepcopy(v, memo) | |||
result.__dict__.update(state) | |||
return result | |||
class ModuleNode(Node): | |||
r"""``ModuleNode`` represents the Module objects.""" | |||
@@ -157,19 +178,28 @@ class ModuleNode(Node): | |||
super().__init__(expr, name, qualname) | |||
def __getstate__(self): | |||
return { | |||
state = { | |||
"expr": self.expr, | |||
"users": self.users, | |||
"_id": self._id, | |||
"_name": self._name, | |||
"_qualname": self._qualname, | |||
"module_type": self.module_type, | |||
"module_type": (self.module_type.__module__, self.module_type.__qualname__), | |||
} | |||
_check_obj_attr(state) | |||
return state | |||
def __setstate__(self, state): | |||
if "_orig_name" in state: | |||
state["_qualname"] = state.pop("_orig_name") | |||
self.__dict__.update(state) | |||
try: | |||
if isinstance(self.module_type, tuple): | |||
mname, classname = self.module_type | |||
mtype = getattr(import_module(mname), classname) | |||
self.module_type = mtype | |||
except Exception: | |||
pass | |||
@property | |||
def owner(self): | |||
@@ -185,12 +215,26 @@ class TensorNode(Node): | |||
_shape = None # type: Tuple[int] | |||
_dtype = None # type: numpy.dtype | |||
_qparams = None | |||
_qparams = None # type: QParams | |||
_device = None | |||
_value = None # type: Tensor | |||
def __init__( | |||
self, | |||
expr: "Expr", | |||
name: str = None, | |||
qualname: str = None, | |||
shape: Tuple[int] = None, | |||
dtype: numpy.dtype = None, | |||
qparams: QParams = None, | |||
): | |||
super().__init__(expr, name, qualname) | |||
self._shape = shape | |||
self._dtype = shape | |||
self._qparams = qparams | |||
def __getstate__(self): | |||
return { | |||
state = { | |||
"expr": self.expr, | |||
"users": self.users, | |||
"_id": self._id, | |||
@@ -201,6 +245,8 @@ class TensorNode(Node): | |||
"_name": self._name, | |||
"_qualname": self._qualname, | |||
} | |||
_check_obj_attr(state) | |||
return state | |||
def __setstate__(self, state): | |||
if "_orig_name" in state: | |||
@@ -276,7 +322,10 @@ class NodeMixin(abc.ABC): | |||
assert isinstance(node, TensorNode) | |||
assert isinstance(value, RawTensor) | |||
if isinstance(value, RawTensor): | |||
node._dtype = value.dtype | |||
try: | |||
node._dtype = value.dtype | |||
except RuntimeError: | |||
node._dtype = None | |||
node._shape = ( | |||
value._tuple_shape if isinstance(value, Tensor) else value.shape | |||
) | |||
@@ -7,15 +7,18 @@ | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import collections | |||
from collections import OrderedDict | |||
from collections import OrderedDict, defaultdict | |||
from functools import partial | |||
from typing import Callable, NamedTuple | |||
import numpy as np | |||
from ..core._imperative_rt import OpDef | |||
from ..core._imperative_rt.common import CompNode | |||
from ..core._imperative_rt.core2 import Tensor as RawTensor | |||
from ..core._wrap import Device | |||
from ..core.tensor.dtype import QuantDtypeMeta | |||
from ..distributed import Group | |||
from ..module import Module | |||
from ..quantization.utils import LSQParams, QParams, QuantMode | |||
from ..tensor import Parameter, Tensor | |||
@@ -49,45 +52,54 @@ SUPPORTED_LEAF_TYPE = { | |||
type(Ellipsis), | |||
QuantMode, | |||
ArgsIndex, | |||
Group, | |||
} | |||
USER_REGISTERED_LEAF_TYPE = [] | |||
USER_REGISTERED_CONTAINER_TYPE = [] | |||
# if isinstance(object, SUPPORTED_LEAF_CLS) or issubclass(obj, SUPPORTED_LEAF_CLS) is True, the object could be threated as leaf node of pytree | |||
SUPPORTED_LEAF_CLS = [Module, Node, NodeMixin, np.dtype, np.ndarray, np.number] | |||
SUPPORTED_LEAF_CLS = [ | |||
Module, | |||
Node, | |||
NodeMixin, | |||
np.dtype, | |||
np.ndarray, | |||
np.number, | |||
np.bool_, | |||
OpDef, | |||
] | |||
NodeType = NamedTuple("NodeType", [("flatten", Callable), ("unflatten", Callable)]) | |||
def register_supported_type(type, flatten=None, unflatten=None): | |||
tp_info = (type.__module__, type.__qualname__) | |||
if flatten and unflatten: | |||
SUPPORTED_TYPE[type] = NodeType(flatten, unflatten) | |||
USER_REGISTERED_CONTAINER_TYPE.append(tp_info) | |||
else: | |||
SUPPORTED_LEAF_CLS.append(type) | |||
def _dict_flatten(inp): | |||
aux_data = [] | |||
results = [] | |||
for key, value in sorted(inp.items()): | |||
results.append(value) | |||
aux_data.append(key) | |||
return results, tuple(aux_data) | |||
USER_REGISTERED_LEAF_TYPE.append(tp_info) | |||
_register_supported_type(type, flatten, unflatten) | |||
def _dict_unflatten(inps, aux_data): | |||
return dict(zip(aux_data, inps)) | |||
def _register_supported_type(type, flatten=None, unflatten=None): | |||
if flatten and unflatten: | |||
SUPPORTED_TYPE[type] = NodeType(flatten, unflatten) | |||
else: | |||
SUPPORTED_LEAF_CLS.append(type) | |||
def _ordereddict_flatten(inp): | |||
def _dict_flatten(ordered, inp): | |||
aux_data = [] | |||
results = [] | |||
for key, value in inp.items(): | |||
dict_items = inp.items() if ordered else sorted(inp.items()) | |||
for key, value in dict_items: | |||
results.append(value) | |||
aux_data.append(key) | |||
return results, tuple(aux_data) | |||
def _ordereddict_unflatten(inps, aux_data): | |||
return OrderedDict(zip(aux_data, inps)) | |||
def _dict_unflatten(dict_type, inps, aux_data): | |||
return dict_type(zip(aux_data, inps)) | |||
def qparams_flatten(inp): | |||
@@ -99,33 +111,41 @@ def qparams_flatten(inp): | |||
return results, tuple(aux_data) | |||
def qparams_unflatten(inp, aux_data): | |||
obj = QParams.__new__(QParams) | |||
def qparams_unflatten(qparam_type, inp, aux_data): | |||
obj = qparam_type.__new__(qparam_type) | |||
for k, v in zip(aux_data, inp): | |||
setattr(obj, k, v) | |||
return obj | |||
register_supported_type(list, lambda x: (x, None), lambda x, aux_data: list(x)) | |||
register_supported_type(tuple, lambda x: (x, None), lambda x, aux_data: tuple(x)) | |||
register_supported_type(dict, _dict_flatten, _dict_unflatten) | |||
register_supported_type( | |||
collections.OrderedDict, _ordereddict_flatten, _ordereddict_unflatten | |||
_register_supported_type(list, lambda x: (x, None), lambda x, aux_data: list(x)) | |||
_register_supported_type(tuple, lambda x: (x, None), lambda x, aux_data: tuple(x)) | |||
_register_supported_type( | |||
dict, partial(_dict_flatten, False), partial(_dict_unflatten, dict) | |||
) | |||
_register_supported_type( | |||
defaultdict, partial(_dict_flatten, False), partial(_dict_unflatten, defaultdict) | |||
) | |||
register_supported_type( | |||
_register_supported_type( | |||
OrderedDict, partial(_dict_flatten, True), partial(_dict_unflatten, OrderedDict) | |||
) | |||
_register_supported_type( | |||
slice, | |||
lambda x: ([x.start, x.stop, x.step], None), | |||
lambda x, aux_data: slice(x[0], x[1], x[2]), | |||
) | |||
register_supported_type(QParams, qparams_flatten, qparams_unflatten) | |||
_register_supported_type(QParams, qparams_flatten, partial(qparams_unflatten, QParams)) | |||
_register_supported_type( | |||
LSQParams, qparams_flatten, partial(qparams_unflatten, LSQParams) | |||
) | |||
def _is_leaf(obj): | |||
if isinstance(obj, type): | |||
return issubclass(obj, tuple(SUPPORTED_LEAF_CLS)) or obj in SUPPORTED_LEAF_TYPE | |||
obj_type = obj if isinstance(obj, type) else type(obj) | |||
return ( | |||
isinstance(obj, tuple(SUPPORTED_LEAF_CLS)) or type(obj) in SUPPORTED_LEAF_TYPE | |||
issubclass(obj_type, tuple(SUPPORTED_LEAF_CLS)) | |||
or obj_type in SUPPORTED_LEAF_TYPE | |||
) | |||
@@ -5,30 +5,158 @@ | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from typing import Dict | |||
from importlib import import_module | |||
from typing import Dict, Tuple | |||
from ..core._imperative_rt import OpDef | |||
from ..core.ops import builtin | |||
from ..tensor import Tensor | |||
from ..version import __version__ | |||
from .utils import _convert_kwargs_to_args | |||
OPDEF_PARAM_LOADER = {} | |||
OPDEF_LOADER = {} | |||
FUNCTIONAL_LOADER = {} | |||
TENSORMETHOD_LOADER = {} | |||
MODULE_LOADER = {} | |||
def get_opdef_state(obj: OpDef) -> Dict: | |||
state = obj.__getstate__() | |||
state["type"] = type(obj) | |||
state["version"] = __version__ | |||
return state | |||
class _ModuleState: | |||
obj = None | |||
def __init__(self, module: Tuple, state: Dict, version: str): | |||
self.module = module | |||
self.state = state | |||
self.version = version | |||
def load_opdef_from_state(state: Dict) -> OpDef: | |||
assert "type" in state and issubclass(state["type"], OpDef) | |||
assert "version" in state | |||
opdef_type = state.pop("type") | |||
if opdef_type in OPDEF_PARAM_LOADER: | |||
loader = OPDEF_PARAM_LOADER[opdef_type] | |||
state = loader(state) | |||
state.pop("version") | |||
opdef_obj = opdef_type() | |||
opdef_obj.__setstate__(state) | |||
return opdef_obj | |||
@classmethod | |||
def get_module_state(cls, module): | |||
typem = (type(module).__module__, type(module).__qualname__) | |||
state = module.__dict__.copy() | |||
state.pop("_m_dump_modulestate", None) | |||
if hasattr(module, "_m_dump_modulestate"): | |||
assert isinstance(module._m_dump_modulestate, cls) | |||
module._m_dump_modulestate.__init__(typem, state, __version__) | |||
else: | |||
module.__dict__["_m_dump_modulestate"] = _ModuleState( | |||
typem, state, __version__ | |||
) | |||
return module._m_dump_modulestate | |||
def __getstate__(self): | |||
return {"module": self.module, "state": self.state, "version": self.version} | |||
def to_module(self): | |||
if self.obj is None: | |||
typem = getattr(import_module(self.module[0]), self.module[1]) | |||
m_obj = typem.__new__(typem) | |||
m_obj.__dict__.update(self.state) | |||
self.obj = m_obj | |||
return self.obj | |||
def register_opdef_loader(*opdefs): | |||
def callback(loader): | |||
for opdef in opdefs: | |||
assert opdef not in OPDEF_LOADER | |||
OPDEF_LOADER[opdef] = loader | |||
return loader | |||
return callback | |||
def register_functional_loader(*funcs): | |||
def callback(loader): | |||
for func in funcs: | |||
assert func not in FUNCTIONAL_LOADER | |||
FUNCTIONAL_LOADER[func] = loader | |||
return loader | |||
return callback | |||
def register_module_loader(*module_types): | |||
def callback(loader): | |||
for module_type in module_types: | |||
assert module_type not in MODULE_LOADER | |||
MODULE_LOADER[module_type] = loader | |||
return loader | |||
return callback | |||
def register_tensor_method_loader(*methods): | |||
def callback(loader): | |||
for method in methods: | |||
assert method not in TENSORMETHOD_LOADER | |||
TENSORMETHOD_LOADER[method] = loader | |||
return loader | |||
return callback | |||
def _replace_args_kwargs(expr, new_args, new_kwargs): | |||
if len(new_args) != len(expr.args) or set(new_kwargs.keys()) != set( | |||
expr.kwargs.keys() | |||
): | |||
expr.set_args_kwargs(*new_args, **new_kwargs) | |||
def load_functional(expr): | |||
func = ( | |||
(expr.func.__module__, expr.func.__qualname__) | |||
if callable(expr.func) | |||
else expr.func | |||
) | |||
assert isinstance(func, tuple) | |||
if func in FUNCTIONAL_LOADER: | |||
loader = FUNCTIONAL_LOADER[func] | |||
loader(expr) | |||
mname, fname = func | |||
f = import_module(mname) | |||
for i in fname.split("."): | |||
f = getattr(f, i) | |||
expr.func = f | |||
assert callable(expr.func) | |||
if not hasattr(expr, "version") or expr.version != __version__: | |||
args, kwargs = _convert_kwargs_to_args(expr.func, expr.args, expr.kwargs) | |||
_replace_args_kwargs(expr, args, kwargs) | |||
def load_call_module_expr(expr): | |||
m_type = expr.inputs[0].module_type | |||
if isinstance(m_type, type): | |||
m_type = (m_type.__module__, m_type.__qualname__) | |||
if m_type in MODULE_LOADER: | |||
MODULE_LOADER[m_type](expr) | |||
if isinstance(expr.inputs[0].module_type, tuple): | |||
mname, classname = expr.inputs[0].module_type | |||
expr.inputs[0].module_type = getattr(import_module(mname), classname) | |||
if not hasattr(expr, "version") or expr.version != __version__: | |||
fwd_func = getattr(expr.inputs[0].module_type, "forward") | |||
args, kwargs = _convert_kwargs_to_args(fwd_func, expr.args, expr.kwargs) | |||
_replace_args_kwargs(expr, args, kwargs) | |||
def load_call_tensor_method_expr(expr): | |||
if expr.method in TENSORMETHOD_LOADER: | |||
loader = TENSORMETHOD_LOADER[expr.method] | |||
loader(expr) | |||
if not hasattr(expr, "version") or expr.version != __version__: | |||
tmethod = ( | |||
getattr(expr.args[0], expr.method) | |||
if isinstance(expr.args[0], type) | |||
else getattr(Tensor, expr.method) | |||
) | |||
args, kwargs = _convert_kwargs_to_args(tmethod, expr.args, expr.kwargs) | |||
_replace_args_kwargs(expr, args, kwargs) | |||
def load_apply_expr(expr): | |||
opdef_type = type(expr.opdef) | |||
if opdef_type in OPDEF_LOADER: | |||
OPDEF_LOADER[opdef_type](expr) | |||
opdef_state = expr.opdef_state | |||
opdef_obj = opdef_state.pop("opdef_type")() | |||
opdef_obj.__setstate__(opdef_state) | |||
expr.opdef = opdef_obj |
@@ -14,6 +14,7 @@ import inspect | |||
import keyword | |||
import re | |||
import weakref | |||
from importlib import import_module | |||
from inspect import getcallargs, getmembers, isclass, ismethod | |||
from itertools import chain | |||
from types import FunctionType | |||
@@ -53,6 +54,7 @@ from ..quantization.observer import ( | |||
SyncMinMaxObserver, | |||
) | |||
from ..tensor import Tensor | |||
from ..version import __version__ | |||
from .expr import ( | |||
Apply, | |||
CallFunction, | |||
@@ -80,8 +82,27 @@ from .module_tracer import ( | |||
set_active_module_tracer, | |||
) | |||
from .node import ModuleNode, Node, NodeMixin, TensorNode | |||
from .pytree import ArgsIndex, tree_flatten | |||
from .utils import replace_container_with_module_container | |||
from .pytree import ( | |||
USER_REGISTERED_CONTAINER_TYPE, | |||
USER_REGISTERED_LEAF_TYPE, | |||
ArgsIndex, | |||
TreeDef, | |||
_register_supported_type, | |||
tree_flatten, | |||
) | |||
from .serialization import ( | |||
_ModuleState, | |||
load_apply_expr, | |||
load_call_module_expr, | |||
load_call_tensor_method_expr, | |||
load_functional, | |||
) | |||
from .utils import ( | |||
_check_builtin_module_attr, | |||
_check_obj_attr, | |||
_convert_kwargs_to_args, | |||
replace_container_with_module_container, | |||
) | |||
logger = get_logger(__name__) | |||
@@ -341,7 +362,7 @@ class NameSpace: | |||
def create_unique_name(self, name: str, node: Any = None) -> str: | |||
assert isinstance(name, str), "The name must be a string" | |||
if name in self._used_names and self._used_names[name] is node: | |||
if name in self._used_names and (self._used_names[name] is node): | |||
return name | |||
name = re.sub("[^0-9a-zA-Z_]+", "_", name) | |||
@@ -1067,6 +1088,7 @@ class InternalGraph: | |||
if node2value[n][1] == 0: | |||
node2value.pop(n) | |||
if values is not None: | |||
assert len(values) == len(expr.outputs) | |||
for n, v in zip(expr.outputs, values): | |||
if ref_count(n) > 0: | |||
node2value[n] = [v, ref_count(n)] | |||
@@ -1105,13 +1127,27 @@ class InternalGraph: | |||
return res | |||
def __getstate__(self): | |||
state = self.__dict__.copy() | |||
if "_top_graph" in state: | |||
state.pop("_top_graph") | |||
state = { | |||
"_exprs": self._exprs, | |||
"_inputs": self._inputs, | |||
"_outputs": self._outputs, | |||
"_watch_point": [], | |||
"_end_point": [], | |||
"_namespace": self._namespace, | |||
"_rst": collections.defaultdict(list), | |||
"_name": self._name, | |||
"_qualname": self._qualname, | |||
} | |||
if self._total_ids: | |||
state["_total_ids"] = self._total_ids | |||
_check_obj_attr(state) | |||
return state | |||
def __setstate__(self, state): | |||
old_version = False | |||
if "_module_name" in state: | |||
old_version = True | |||
state["_qualname"] = state.pop("_module_name") | |||
@@ -1144,6 +1180,25 @@ class InternalGraph: | |||
self._namespace = NameSpace(self._name, self._qualname) | |||
self._re_associate_name() | |||
def __copy__(self): | |||
cls = self.__class__ | |||
result = cls.__new__(cls) | |||
result.__dict__.update(self.__dict__) | |||
return result | |||
def __deepcopy__(self, memo): | |||
if id(self) in memo: | |||
return memo[id(self)] | |||
cls = self.__class__ | |||
result = cls.__new__(cls) | |||
state = {} | |||
memo[id(self)] = result | |||
for k, v in self.__dict__.items(): | |||
if not isinstance(v, weakref.ReferenceType): | |||
state[k] = copy.deepcopy(v, memo) | |||
result.__dict__.update(state) | |||
return result | |||
def _get_meth_name(obj, func): | |||
tp = obj if isinstance(obj, type) else type(obj) | |||
@@ -1157,9 +1212,7 @@ def _get_meth_name(obj, func): | |||
def _wrapped_function(orig_func): | |||
@functools.wraps(orig_func) | |||
def wrapped_fn(*args, **kwargs): | |||
method_func = wrapped_fn | |||
if "method_func" in kwargs: | |||
method_func = kwargs.pop("method_func") | |||
method_func = kwargs.pop("method_func", wrapped_fn) | |||
if is_tracing_module(): | |||
unset_module_tracing() | |||
inputs, tree_def = tree_flatten((args, kwargs)) | |||
@@ -1167,11 +1220,11 @@ def _wrapped_function(orig_func): | |||
if not NodeMixin.get(i, None): | |||
if isinstance(i, (RawTensor, NodeMixin)): | |||
NodeMixin.wrap_safe(i, Constant.make(i)) | |||
meth_name, arg_type = None, None | |||
if args: | |||
meth_name = _get_meth_name(args[0], method_func) | |||
arg_type = args[0] if isinstance(args[0], type) else type(args[0]) | |||
args, kwargs = _convert_kwargs_to_args(orig_func, args, kwargs) | |||
meth_name = _get_meth_name(args[0], method_func) | |||
arg_type = args[0] if isinstance(args[0], type) else type(args[0]) | |||
if meth_name and arg_type and issubclass(arg_type, RawTensor): | |||
inputs, tree_def = tree_flatten((args, kwargs)) | |||
self = inputs[0] | |||
if meth_name == "__new__": | |||
if all([not isinstance(i, RawTensor) for i in inputs]): | |||
@@ -1190,6 +1243,7 @@ def _wrapped_function(orig_func): | |||
call_node = CallMethod.make(NodeMixin.get(self), meth_name) | |||
call_node.add_inputs(inputs[1:]) | |||
else: | |||
inputs, tree_def = tree_flatten((args, kwargs)) | |||
call_node = CallFunction.make(orig_func) | |||
call_node.add_inputs(inputs) | |||
@@ -1228,9 +1282,11 @@ class TracedModuleBuilder(NodeMixin): | |||
"_record_wrapped_nodes", | |||
"_argdef_graph_map", | |||
"_argdef_outdef_map", | |||
"_check_qat_module", | |||
"nodes", | |||
"__class__", | |||
"__dict__", | |||
"_is_top", | |||
] | |||
def __init__(self, mod, is_top_module=False): | |||
@@ -1301,22 +1357,18 @@ class TracedModuleBuilder(NodeMixin): | |||
qat_module.weight_fake_quant.set_qparams(qparams) | |||
def build(self): | |||
if self._is_builtin or isinstance(self._mod, TracedModule): | |||
if module_tracer.is_builtin(self._mod) or isinstance( | |||
self._mod, TracedModule | |||
): | |||
mod_type = type(self._mod) | |||
else: | |||
assert isinstance(self._mod, (Observer, _FakeQuantize)) | |||
mod_type = ( | |||
Observer if isinstance(self._mod, Observer) else _FakeQuantize | |||
) | |||
if self._is_builtin: | |||
assert module_tracer.is_builtin(self._mod) | |||
mod_type = type(self._mod) | |||
for node in self.nodes: | |||
node.module_type = mod_type | |||
return self._mod | |||
else: | |||
is_qat = isinstance(self._mod, QATModule) | |||
is_qat = isinstance(self._mod, QATModule) or ( | |||
isinstance(self._mod, TracedModule) and self._mod.is_qat | |||
) | |||
traced_module = TracedModule( | |||
self._is_top, self._argdef_graph_map, self._argdef_outdef_map, is_qat | |||
) | |||
@@ -1338,15 +1390,18 @@ class TracedModuleBuilder(NodeMixin): | |||
traced_module.with_act = self._mod.with_act | |||
traced_module.with_weight = self._mod.with_weight | |||
if not hasattr(traced_module, "act_fake_quant"): | |||
traced_module.act_fakequant = None | |||
traced_module.act_fake_quant = None | |||
if not hasattr(traced_module, "act_observer"): | |||
traced_module.act_observer = None | |||
if not hasattr(traced_module, "weight_fake_quant"): | |||
traced_module.weight_fakequant = None | |||
traced_module.weight_fake_quant = None | |||
if not hasattr(traced_module, "weight_observer"): | |||
traced_module.weight_observer = None | |||
set_module_tracing() | |||
if self._is_top: | |||
traced_module._update_ref() | |||
return traced_module | |||
def _record_wrapped_nodes(self, node): | |||
@@ -1357,6 +1412,7 @@ class TracedModuleBuilder(NodeMixin): | |||
# prepare args and kwargs for inner graph | |||
if "method_func" in kwargs: | |||
kwargs.pop("method_func") | |||
args, kwargs = _convert_kwargs_to_args(self._mod.forward, args, kwargs, True) | |||
def mark_constant(x): | |||
node = NodeMixin.get(x, None) | |||
@@ -1372,11 +1428,7 @@ class TracedModuleBuilder(NodeMixin): | |||
callnode.arg_def = tree_def | |||
if ( | |||
self._is_builtin | |||
or tree_def in self._argdef_graph_map | |||
or isinstance(self._mod, TracedModule) | |||
): | |||
if self._is_builtin or tree_def in self._argdef_graph_map: | |||
unset_module_tracing() | |||
rst = self._mod(*args, **kwargs) | |||
outputs, out_def = tree_flatten(rst, is_leaf=_is_leaf) | |||
@@ -1385,33 +1437,7 @@ class TracedModuleBuilder(NodeMixin): | |||
self._body = None | |||
elif tree_def in self._argdef_graph_map: | |||
self._body = self._argdef_graph_map[tree_def] | |||
else: | |||
self._mod._is_top = False | |||
self._body = self._mod.argdef_graph_map[tree_def] | |||
module_qualname = NodeMixin.get(self).qualname | |||
if module_qualname != self._body.qualname: | |||
src_name, dst_name = self._body.qualname, module_qualname | |||
def replace_qualname(g): | |||
attr_name = get_suffix_name(src_name, g.qualname) | |||
if attr_name is not None: | |||
g._qualname = ( | |||
("%s.%s" % (dst_name, attr_name)) | |||
if attr_name | |||
else dst_name | |||
) | |||
assert get_suffix_name(dst_name, g.qualname) is not None | |||
for mod in self._mod.modules(): | |||
if not hasattr(mod, "argdef_graph_map"): | |||
continue | |||
for g in mod.argdef_graph_map.values(): | |||
replace_qualname(g) | |||
g._namespace.qualname = g.qualname | |||
for n in g.nodes(False): | |||
replace_qualname(n) | |||
else: | |||
self_node = None | |||
orig_self = NodeMixin.get(self) | |||
parent_graph = active_module_tracer().current_scope() | |||
module_qualname = orig_self._qualname | |||
@@ -1423,20 +1449,14 @@ class TracedModuleBuilder(NodeMixin): | |||
active_module_tracer().push_scope(self._body) | |||
# rebind self to new input node | |||
if self_node: | |||
NodeMixin.wrap_safe(self, self_node) | |||
active_module_tracer().current_scope()._add_input(self_node) | |||
else: | |||
NodeMixin.wrap_safe( | |||
self, | |||
self_node | |||
if self_node | |||
else Input.make( | |||
name="self", | |||
qualname=module_qualname, | |||
type=NodeMixin.get_wrapped_type(self), | |||
), | |||
) | |||
NodeMixin.wrap_safe( | |||
self, | |||
Input.make( | |||
name="self", | |||
qualname=module_qualname, | |||
type=NodeMixin.get_wrapped_type(self), | |||
), | |||
) | |||
origin_inp_node = [NodeMixin.get(i, None) for i in inputs[1:]] | |||
# prepare args and kwargs for inner graph | |||
@@ -1470,8 +1490,23 @@ class TracedModuleBuilder(NodeMixin): | |||
return x | |||
args = [self] | |||
for i, v in enumerate(inputs[1:]): | |||
args.append(wrap(v, idx2key[i + 1])) | |||
orig_traced_inputs = ( | |||
None | |||
if not isinstance(self._mod, TracedModule) | |||
else self._mod.argdef_graph_map[tree_def].inputs | |||
) | |||
ind = 1 | |||
for v in inputs[1:]: | |||
if isinstance(v, (RawTensor, NodeMixin)): | |||
args_name = ( | |||
orig_traced_inputs[ind]._name | |||
if orig_traced_inputs | |||
else idx2key[ind] | |||
) | |||
ind += 1 | |||
args.append(wrap(v, args_name)) | |||
else: | |||
args.append(v) | |||
args, kwargs = tree_def.unflatten(args) | |||
active_module_tracer().patcher.auto_patch( | |||
@@ -1514,7 +1549,6 @@ class TracedModuleBuilder(NodeMixin): | |||
attr = getattr(type(self._mod), name).__get__(self, type(self)) | |||
else: | |||
attr = getattr(self._mod, name) | |||
if ( | |||
isinstance(attr, FunctionType) | |||
and id(attr) in active_module_tracer().patcher.patched_fn_ids | |||
@@ -1568,7 +1602,7 @@ class TracedModuleBuilder(NodeMixin): | |||
wrapped = self.__getattr__(name) | |||
if isinstance(wrapped, TracedModuleBuilder): | |||
if not isinstance(mod_attr, (List, Dict)): | |||
if not isinstance(mod_attr, (List, Dict, QATModule)): | |||
assert mod_attr is wrapped._mod | |||
else: | |||
assert mod_attr is wrapped | |||
@@ -1977,8 +2011,6 @@ class TracedModule(Module): | |||
def graph(self) -> InternalGraph: | |||
"""Return the ``InternalGraph`` of this ``TracedModule``. | |||
""" | |||
if self._is_top: | |||
self._update_ref() | |||
assert len(self.argdef_graph_map) == 1 | |||
return list(self.argdef_graph_map.values())[0] | |||
@@ -2112,7 +2144,7 @@ class TracedModule(Module): | |||
if hasattr(obj, "argdef_graph_map") | |||
else None | |||
) | |||
if expr_graph is not None: | |||
if expr_graph is not None and not obj.is_qat: | |||
exprs = _flatten_subgraph(graph, expr_graph, expr, obj) | |||
if parent_graph is not None: | |||
@@ -2137,26 +2169,119 @@ class TracedModule(Module): | |||
) | |||
new_module.graph._re_associate_name() | |||
new_module.graph.compile() | |||
new_module._update_ref() | |||
new_module.graph._reset_ids() | |||
return new_module | |||
def __getstate__(self): | |||
d = self.__dict__ | |||
d = self.__dict__.copy() | |||
for k in Module.__dict__: | |||
d.pop(k, None) | |||
_check_obj_attr(d) | |||
for k in d: | |||
if module_tracer.is_builtin(d[k]): | |||
assert _check_builtin_module_attr( | |||
d[k] | |||
), "Module {} can not be serialized. ".format(type(d[k])) | |||
d[k] = _ModuleState.get_module_state(d[k]) | |||
dump_info = { | |||
"version": __version__, | |||
"register_type": USER_REGISTERED_LEAF_TYPE, | |||
"register_container_type": USER_REGISTERED_CONTAINER_TYPE, | |||
"register_mdule": USER_REGISTERED_MODULE, | |||
"register_function": USER_REGISTERED_FUNCTION, | |||
} | |||
d["dump_info"] = dump_info | |||
return d | |||
def __setstate__(self, state): | |||
for k, v in state.items(): | |||
if isinstance(v, _ModuleState): | |||
state[k] = v.to_module() | |||
self.__dict__.update(state) | |||
self._update_ref() | |||
for _, graph in self.argdef_graph_map.items(): | |||
for expr in graph._exprs: | |||
if isinstance(expr, CallFunction): | |||
load_functional(expr) | |||
if isinstance(expr, CallMethod): | |||
if expr.method == "__call__": | |||
load_call_module_expr(expr) | |||
else: | |||
load_call_tensor_method_expr(expr) | |||
if isinstance(expr, Apply): | |||
load_apply_expr(expr) | |||
for _, graph in self.argdef_graph_map.items(): | |||
ind = 0 | |||
while ind < len(graph._exprs): | |||
cur_expr = graph._exprs[ind] | |||
has_new_expr = False | |||
for i in cur_expr.inputs: | |||
if i.expr not in graph._exprs and not isinstance(i.expr, Input): | |||
graph._exprs.insert(ind, i.expr) | |||
has_new_expr = True | |||
if not has_new_expr: | |||
ind += 1 | |||
for expr in graph._exprs: | |||
for i in expr.inputs: | |||
if expr.inputs.count(i) != i.users.count(expr): | |||
add_or_del_count = expr.inputs.count(i) - i.users.count(expr) | |||
if add_or_del_count > 0: | |||
i.users.extend([expr] * add_or_del_count) | |||
else: | |||
[i.users.remove(expr) for i in range(-add_or_del_count)] | |||
for o in expr.outputs: | |||
if o.expr is not expr: | |||
assert o not in o.expr.outputs | |||
o.expr = expr | |||
for node in graph.nodes(False): | |||
# remove users of node which doesn't use node as input | |||
node.users = [e for e in node.users if node in e.inputs] | |||
for expr in graph._exprs: | |||
graph._namespace.auto_naming_for_outputs(expr) | |||
self._update_ref() | |||
for _, graph in self.argdef_graph_map.items(): | |||
graph._reset_ids() | |||
def __copy__(self): | |||
cls = self.__class__ | |||
result = cls.__new__(cls) | |||
result.__dict__.update(self.__dict__) | |||
return result | |||
def __deepcopy__(self, memo): | |||
cls = self.__class__ | |||
result = cls.__new__(cls) | |||
state = {} | |||
memo[id(self)] = result | |||
for k, v in self.__dict__.items(): | |||
if not isinstance(v, weakref.ReferenceType): | |||
state[k] = copy.deepcopy(v, memo) | |||
result.__dict__.update(state) | |||
result._update_ref() | |||
return result | |||
def cpp_apply_module_trace(opdef, *args): | |||
return Apply.apply_module_trace_hook(opdef, *args) | |||
USER_REGISTERED_MODULE = [] | |||
USER_REGISTERED_FUNCTION = [] | |||
def register_as_builtin(mod_cls: Type[Module]) -> None: | |||
r"""Registers class ``mod_cls`` (subclass of :class:`~.Module`) as builtin module. | |||
Args: | |||
mod_cls: the module class which will be treated as builtin module in tracing. | |||
""" | |||
USER_REGISTERED_MODULE.append((mod_cls.__module__, mod_cls.__qualname__)) | |||
module_tracer.register_as_builtin(mod_cls) | |||
@@ -2181,6 +2306,7 @@ def wrap(func: Callable): | |||
Args: | |||
func: the function of the global function to insert into the graph when it's called. | |||
""" | |||
USER_REGISTERED_FUNCTION.append((func.__module__, func.__qualname__)) | |||
assert callable(func), "func must be a callable" | |||
assert hasattr(func, "__code__") | |||
fn_name = func.__code__.co_name | |||
@@ -2247,6 +2373,8 @@ def trace_module( | |||
NodeMixin.wrap_safe( | |||
builder, Input.make(name="top", type=ModuleNode, qualname=net_name) | |||
) | |||
args, kwargs = _convert_kwargs_to_args(mod.forward, args, kwargs, True) | |||
inputs, _ = tree_flatten((args, kwargs)) | |||
for _, i in enumerate(inputs): | |||
# assert isinstance(i, Tensor), "not support " | |||
@@ -5,12 +5,17 @@ | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import collections | |||
import copy | |||
import inspect | |||
from collections.abc import MutableMapping, MutableSequence | |||
from typing import Dict, Iterable, List, Optional, Sequence | |||
from typing import Dict, Iterable, List, Optional, Sequence, Type | |||
from .. import get_logger | |||
from ..module import Module | |||
logger = get_logger(__name__) | |||
def replace_container_with_module_container(container): | |||
has_module = False | |||
@@ -52,6 +57,101 @@ def replace_container_with_module_container(container): | |||
return has_module, module_container | |||
def _convert_kwargs_to_args(func, args, kwargs, is_bounded=False): | |||
# is_bounded = True when func is a method and provided args don't include 'self' | |||
arg_specs = inspect.getfullargspec(func) | |||
arg_specs_args = arg_specs.args | |||
if is_bounded: | |||
arg_specs_args = arg_specs.args[1:] | |||
new_args = [] | |||
new_kwargs = {} | |||
new_args.extend(args) | |||
if set(arg_specs_args[0 : len(new_args)]) & set(kwargs.keys()): | |||
repeated_arg_name = set(arg_specs_args[0 : len(new_args)]) & set(kwargs.keys()) | |||
raise TypeError( | |||
"{} got multiple values for argument {}".format( | |||
func.__qualname__, ", ".join(repeated_arg_name) | |||
) | |||
) | |||
if len(new_args) < len(arg_specs.args): | |||
for ind in range(len(new_args), len(arg_specs_args)): | |||
arg_name = arg_specs_args[ind] | |||
if arg_name in kwargs: | |||
new_args.append(kwargs[arg_name]) | |||
else: | |||
index = ind - len(arg_specs_args) + len(arg_specs.defaults) | |||
assert index < len(arg_specs.defaults) and index >= 0 | |||
new_args.append(arg_specs.defaults[index]) | |||
for kwarg_name in arg_specs.kwonlyargs: | |||
if kwarg_name in kwargs: | |||
new_kwargs[kwarg_name] = kwargs[kwarg_name] | |||
else: | |||
assert kwarg_name in arg_specs.kwonlydefaults | |||
new_kwargs[kwarg_name] = arg_specs.kwonlydefaults[kwarg_name] | |||
for k, v in kwargs.items(): | |||
if k not in arg_specs.args and k not in arg_specs.kwonlyargs: | |||
if arg_specs.varkw is None: | |||
raise TypeError( | |||
"{} got an unexpected keyword argument {}".format( | |||
func.__qualname__, k | |||
) | |||
) | |||
new_kwargs[k] = v | |||
return tuple(new_args), new_kwargs | |||
def _check_obj_attr(obj): | |||
# check if all the attributes of a obj is serializable | |||
from .pytree import tree_flatten | |||
from .pytree import SUPPORTED_LEAF_CLS, SUPPORTED_LEAF_TYPE, TreeDef | |||
from .expr import Expr | |||
from .traced_module import TracedModule, InternalGraph, NameSpace | |||
def _check_leaf_type(leaf): | |||
leaf_type = leaf if isinstance(leaf, type) else type(leaf) | |||
traced_module_types = [Expr, TreeDef, TracedModule, InternalGraph, NameSpace] | |||
return ( | |||
issubclass(leaf_type, tuple(SUPPORTED_LEAF_CLS + traced_module_types)) | |||
or leaf_type in SUPPORTED_LEAF_TYPE | |||
) | |||
for _, v in obj.items(): | |||
leafs, _ = tree_flatten(v, is_leaf=lambda _: True) | |||
for leaf in leafs: | |||
assert _check_leaf_type( | |||
leaf | |||
), "Type {} is not supported by traced module".format( | |||
leaf if isinstance(leaf, type) else type(leaf) | |||
) | |||
def _check_builtin_module_attr(mod): | |||
from .pytree import _is_leaf as _check_leaf_type | |||
from .pytree import tree_flatten | |||
# check if all the attributes of a builtin module is serializable | |||
is_non_serializable_module = lambda m: isinstance( | |||
m, Module | |||
) and not _check_builtin_module_attr(m) | |||
for k, v in mod.__dict__.items(): | |||
if k == "_m_dump_modulestate": | |||
continue | |||
if is_non_serializable_module(v): | |||
return False | |||
elif not isinstance(v, Module): | |||
leafs, _ = tree_flatten(v, is_leaf=lambda _: True) | |||
for leaf in leafs: | |||
if not _check_leaf_type(leaf) or is_non_serializable_module(leaf): | |||
logger.warn( | |||
"Type {} is not supported by traced module".format( | |||
leaf if isinstance(leaf, type) else type(leaf) | |||
) | |||
) | |||
return False | |||
return True | |||
class _ModuleList(Module, MutableSequence): | |||
r"""A List-like container. | |||
@@ -15,7 +15,6 @@ import numpy as np | |||
import megengine as mge | |||
from megengine import Parameter, Tensor | |||
from megengine.core.ops import builtin | |||
from megengine.traced_module.serialization import get_opdef_state, load_opdef_from_state | |||
def test_tensor_serialization(): | |||
@@ -88,25 +87,3 @@ def test_compatibility(): | |||
test_old_tensor("tensor_v1_1.mge") | |||
test_old_tensor("tensor_v1_2.mge") | |||
def test_opdef_serialization(): | |||
with TemporaryFile() as f: | |||
x = builtin.Elemwise(mode="Add") | |||
pickle.dump(get_opdef_state(x), f) | |||
f.seek(0) | |||
load_x = load_opdef_from_state(pickle.load(f)) | |||
assert x == load_x | |||
with TemporaryFile() as f: | |||
x = builtin.Convolution(stride_h=9, compute_mode="float32") | |||
x.strategy = ( | |||
builtin.Convolution.Strategy.PROFILE | |||
| builtin.Convolution.Strategy.HEURISTIC | |||
| builtin.Convolution.Strategy.REPRODUCIBLE | |||
) | |||
pickle.dump(get_opdef_state(x), f) | |||
f.seek(0) | |||
load_x = load_opdef_from_state(pickle.load(f)) | |||
assert x.strategy == load_x.strategy | |||
assert x == load_x |
@@ -85,12 +85,12 @@ class NewModule(M.Module): | |||
return x | |||
def _check_expr_users(traced_module): | |||
def _check_expr_users(flattened_module): | |||
node_user = defaultdict(list) | |||
for expr in traced_module.graph._exprs: | |||
for expr in flattened_module.graph._exprs: | |||
for node in expr.inputs: | |||
node_user[node].append(expr) | |||
for node in traced_module.graph.nodes(): | |||
for node in flattened_module.graph.nodes(): | |||
node.users.sort(key=lambda m: m._id) | |||
node_user[node].sort(key=lambda m: m._id) | |||
assert node.users == node_user[node] | |||
@@ -8,6 +8,7 @@ import numpy as np | |||
import megengine as mge | |||
import megengine.functional as F | |||
import megengine.module as M | |||
import megengine.module.qat as QM | |||
import megengine.quantization as Q | |||
from megengine import Tensor | |||
from megengine.module.qat.module import QATModule | |||
@@ -28,10 +29,18 @@ def get_subattr(self: M.Module, name: str): | |||
return getattr(self, name) | |||
class MyConvBnRelu2d(M.ConvBnRelu2d): | |||
pass | |||
class MyQATConvBnRelu2d(QM.ConvBnRelu2d): | |||
pass | |||
class Myblcok(M.Module): | |||
def __init__(self,): | |||
super().__init__() | |||
self.conv0 = M.ConvBnRelu2d(3, 3, 3, 1, 1) | |||
self.conv0 = MyConvBnRelu2d(3, 3, 3, 1, 1) | |||
self.conv1 = M.ConvBn2d(3, 3, 1, 1, 0) | |||
self.conv2 = M.ConvBn2d(3, 3, 1, 1, 0) | |||
self.add = M.Elemwise("FUSE_ADD_RELU") | |||
@@ -106,7 +115,11 @@ def check_qparams(qparmsa: Q.QParams, qparmsb: Q.QParams): | |||
def build_observered_net(net: M.Module, observer_cls): | |||
qat_net = Q.quantize_qat(net, qconfig=get_observer_config(observer_cls)) | |||
qat_net = Q.quantize_qat( | |||
net, | |||
qconfig=get_observer_config(observer_cls), | |||
mapping={MyConvBnRelu2d: MyQATConvBnRelu2d}, | |||
) | |||
Q.enable_observer(qat_net) | |||
inp = Tensor(np.random.random(size=(5, 3, 32, 32))) | |||
qat_net(inp) | |||
@@ -134,6 +147,15 @@ def test_trace_qat(): | |||
check_qparams(weight_qparams, traced_weight_qparams) | |||
if act_qparams: | |||
check_qparams(act_qparams, traced_act_qparams) | |||
flatten_traced_net = traced_net.flatten() | |||
conv0_node = flatten_traced_net.graph.get_node_by_name( | |||
"MyModule_block0_conv0" | |||
).as_unique() | |||
conv0_out_node = flatten_traced_net.graph.get_node_by_name( | |||
"MyModule_block0_conv0_out" | |||
).as_unique() | |||
assert isinstance(conv0_node.owner, TracedModule) | |||
assert conv0_out_node.expr.inputs[0] is conv0_node | |||
_check_qat_module(build_observered_net(MyModule(), Q.MinMaxObserver)) | |||
_check_qat_module(build_observered_net(MyModule(), MyMinMaxObserver)) | |||
@@ -6,14 +6,59 @@ | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import pickle | |||
from collections import defaultdict | |||
from tempfile import TemporaryFile | |||
import numpy as np | |||
import megengine.functional as F | |||
import megengine.module as M | |||
import megengine.traced_module.serialization as S | |||
from megengine import Tensor | |||
from megengine.core._imperative_rt.core2 import apply | |||
from megengine.core.ops import builtin | |||
from megengine.core.ops.builtin import Elemwise | |||
from megengine.module import Module | |||
from megengine.traced_module import trace_module | |||
from megengine.traced_module.expr import CallMethod, Constant | |||
from megengine.traced_module.node import TensorNode | |||
from megengine.traced_module.serialization import ( | |||
register_functional_loader, | |||
register_module_loader, | |||
register_opdef_loader, | |||
register_tensor_method_loader, | |||
) | |||
from megengine.traced_module.utils import _convert_kwargs_to_args | |||
def _check_id(traced_module): | |||
_total_ids = traced_module.graph._total_ids | |||
node_ids = [n._id for n in traced_module.graph.nodes().as_list()] | |||
assert len(set(node_ids)) == len(node_ids) | |||
assert max(node_ids) + 1 == _total_ids[0] | |||
expr_ids = [n._id for n in traced_module.graph.exprs().as_list()] | |||
assert len(set(expr_ids)) == len(expr_ids) | |||
assert max(expr_ids) + 1 == _total_ids[1] | |||
def _check_name(flatened_module): | |||
node_names = [n._name for n in flatened_module.graph.nodes().as_list()] | |||
assert len(set(node_names)) == len(node_names) | |||
def _check_expr_users(traced_module): | |||
node_user = defaultdict(list) | |||
for expr in traced_module.graph._exprs: | |||
for node in expr.inputs: | |||
node_user[node].append(expr) | |||
if isinstance(expr, CallMethod) and expr.graph: | |||
_check_expr_users(expr.inputs[0].owner) | |||
for node in traced_module.graph.nodes(False): | |||
node.users.sort(key=lambda m: m._id) | |||
node_user[node].sort(key=lambda m: m._id) | |||
assert node.users == node_user[node] | |||
class MyBlock(Module): | |||
@@ -48,5 +93,274 @@ def test_dump_and_load(): | |||
traced_module = trace_module(module, x) | |||
np.testing.assert_array_equal(expect, traced_module(x)) | |||
obj = pickle.dumps(traced_module) | |||
pickle.loads(obj) | |||
new_tm = pickle.loads(obj) | |||
_check_id(new_tm) | |||
_check_expr_users(new_tm) | |||
traced_module.graph._reset_ids() | |||
old_nodes = traced_module.graph.nodes().as_list() | |||
new_nodes = new_tm.graph.nodes().as_list() | |||
old_exprs = traced_module.graph.exprs().as_list() | |||
new_exprs = new_tm.graph.exprs().as_list() | |||
assert len(old_nodes) == len(new_nodes) | |||
for i, j in zip(old_nodes, new_nodes): | |||
assert i._name == j._name | |||
assert i._qualname == j._qualname | |||
assert i._id == j._id | |||
assert len(old_exprs) == len(new_exprs) | |||
for i, j in zip(old_exprs, new_exprs): | |||
assert i._id == j._id | |||
np.testing.assert_array_equal(expect, traced_module(x)) | |||
def test_opdef_loader(): | |||
class MyModule1(Module): | |||
def forward(self, x, y): | |||
op = Elemwise("ADD") | |||
return apply(op, x, y)[0] | |||
m = MyModule1() | |||
x = Tensor(np.ones((20))) | |||
y = Tensor(np.ones((20))) | |||
traced_module = trace_module(m, x, y) | |||
orig_loader_dict = S.OPDEF_LOADER | |||
S.OPDEF_LOADER = {} | |||
@register_opdef_loader(Elemwise) | |||
def add_opdef_loader(expr): | |||
if expr.opdef_state["mode"] == "ADD": | |||
expr.opdef_state["mode"] = "MUL" | |||
node = expr.inputs[1] | |||
astype_expr = CallMethod(node, "astype") | |||
oup = TensorNode( | |||
astype_expr, | |||
shape=node.shape, | |||
dtype=expr.inputs[0].dtype, | |||
qparams=node.qparams, | |||
) | |||
astype_expr.set_args_kwargs(node, expr.inputs[0].dtype) | |||
astype_expr.return_val = (oup,) | |||
expr.inputs[1] = oup | |||
obj = pickle.dumps(traced_module) | |||
new_module = pickle.loads(obj) | |||
_check_id(new_module) | |||
_check_expr_users(new_module) | |||
_check_name(new_module.flatten()) | |||
assert ( | |||
isinstance(new_module.graph._exprs[0], CallMethod) | |||
and new_module.graph._exprs[1].opdef.mode == "MUL" | |||
and len(new_module.graph._exprs) == 2 | |||
) | |||
result = new_module(x, y) | |||
np.testing.assert_equal(result.numpy(), x.numpy()) | |||
S.OPDEF_LOADER = orig_loader_dict | |||
def test_functional_loader(): | |||
class MyModule2(Module): | |||
def forward(self, x, y): | |||
return F.conv2d(x, y) | |||
m = MyModule2() | |||
x = Tensor(np.random.random((1, 3, 32, 32))) | |||
y = Tensor(np.random.random((3, 3, 3, 3))) | |||
traced_module = trace_module(m, x, y) | |||
orig_loader_dict = S.FUNCTIONAL_LOADER | |||
S.FUNCTIONAL_LOADER = {} | |||
@register_functional_loader(("megengine.functional.nn", "conv2d")) | |||
def conv2df_loader(expr): | |||
# expr.func = ("megengine.functional.nn","conv2d") | |||
kwargs = expr.kwargs | |||
orig_weight = expr.named_args["weight"] | |||
astype_expr = CallMethod(orig_weight, "astype") | |||
oup = TensorNode( | |||
astype_expr, | |||
shape=orig_weight.shape, | |||
dtype=orig_weight.dtype, | |||
qparams=orig_weight.qparams, | |||
) | |||
astype_expr.set_args_kwargs(orig_weight, expr.named_args["inp"].dtype) | |||
astype_expr.return_val = (oup,) | |||
expr.set_arg("weight", oup) | |||
obj = pickle.dumps(traced_module) | |||
new_module = pickle.loads(obj) | |||
_check_expr_users(new_module) | |||
_check_id(new_module) | |||
result = new_module(x, y) | |||
gt = m(x, y) | |||
assert ( | |||
isinstance(new_module.graph._exprs[0], CallMethod) | |||
and len(new_module.graph._exprs) == 2 | |||
) | |||
np.testing.assert_equal(result.numpy(), gt.numpy()) | |||
S.FUNCTIONAL_LOADER = orig_loader_dict | |||
def test_tensor_method_loader(): | |||
class MyModule3(Module): | |||
def forward(self, x): | |||
return x + 1 | |||
m = MyModule3() | |||
x = Tensor(np.ones((20))) | |||
traced_module = trace_module(m, x) | |||
orig_loader_dict = S.TENSORMETHOD_LOADER | |||
S.TENSORMETHOD_LOADER = {} | |||
@register_tensor_method_loader("__add__") | |||
def add_loader(expr): | |||
args = list(expr.args) | |||
if not isinstance(args[1], TensorNode): | |||
args[1] = Tensor(args[1]) | |||
node = Constant(args[1], "const").outputs[0] | |||
astype_expr = CallMethod(node, "astype") | |||
oup = TensorNode( | |||
astype_expr, shape=node.shape, dtype=node.dtype, qparams=node.qparams, | |||
) | |||
astype_expr.set_args_kwargs(node, expr.inputs[0].dtype) | |||
astype_expr.return_val = (oup,) | |||
add_expr = CallMethod(oup, "__add__") | |||
add_expr.set_args_kwargs(oup, oup) | |||
oup1 = TensorNode( | |||
add_expr, shape=oup.shape, dtype=oup.dtype, qparams=node.qparams, | |||
) | |||
add_expr.return_val = oup1 | |||
args[1] = oup1 | |||
expr.set_args_kwargs(*args) | |||
obj = pickle.dumps(traced_module) | |||
new_module = pickle.loads(obj) | |||
_check_expr_users(new_module) | |||
_check_id(new_module) | |||
result = new_module(x) | |||
gt = m(x) | |||
assert ( | |||
isinstance(new_module.graph._exprs[0], Constant) | |||
and len(new_module.graph._exprs) == 4 | |||
) | |||
np.testing.assert_equal(result.numpy(), (x + 2).numpy()) | |||
S.TENSORMETHOD_LOADER = orig_loader_dict | |||
def test_module_loader(): | |||
class MyModule4(Module): | |||
def __init__(self): | |||
super().__init__() | |||
self.conv = M.Conv2d(3, 3, 3) | |||
def forward(self, x): | |||
return self.conv(x) | |||
m = MyModule4() | |||
x = Tensor(np.random.random((1, 3, 32, 32))) | |||
traced_module = trace_module(m, x) | |||
orig_loader_dict = S.MODULE_LOADER | |||
S.MODULE_LOADER = {} | |||
@register_module_loader(("megengine.module.conv", "Conv2d")) | |||
def conv2dm_loader(expr): | |||
module = expr.inputs[0].owner | |||
args = list(expr.args) | |||
orig_inp = args[1] | |||
astype_expr = CallMethod(orig_inp, "astype") | |||
oup = TensorNode( | |||
astype_expr, | |||
shape=orig_inp.shape, | |||
dtype=orig_inp.dtype, | |||
qparams=orig_inp.qparams, | |||
) | |||
astype_expr.set_args_kwargs(orig_inp, module.weight.dtype) | |||
astype_expr.return_val = (oup,) | |||
args[1] = oup | |||
expr.set_args_kwargs(*args) | |||
obj = pickle.dumps(traced_module) | |||
new_module = pickle.loads(obj) | |||
result = new_module(x) | |||
gt = m(x) | |||
assert ( | |||
isinstance(new_module.graph._exprs[1], CallMethod) | |||
and len(new_module.graph._exprs) == 3 | |||
) | |||
np.testing.assert_equal(result.numpy(), gt.numpy()) | |||
S.MODULE_LOADER = orig_loader_dict | |||
def test_shared_module(): | |||
class MyModule(M.Module): | |||
def __init__(self): | |||
super().__init__() | |||
self.a = M.Elemwise("ADD") | |||
self.b = self.a | |||
def forward(self, x, y): | |||
z = self.a(x, y) | |||
z = self.b(z, y) | |||
return z | |||
x = Tensor(1) | |||
y = Tensor(2) | |||
m = MyModule() | |||
tm = trace_module(m, x, y) | |||
obj = pickle.dumps(tm) | |||
load_tm = pickle.loads(obj) | |||
_check_expr_users(load_tm) | |||
_check_name(load_tm.flatten()) | |||
_check_id(load_tm) | |||
assert load_tm.a is load_tm.b | |||
def test_convert_kwargs_to_args(): | |||
def func(a, b, c=4, *, d, e=3, f=4): | |||
pass | |||
args = (1,) | |||
kwargs = {"b": 1, "d": 6} | |||
new_args, new_kwargs = _convert_kwargs_to_args(func, args, kwargs) | |||
assert new_args == (1, 1, 4) | |||
assert new_kwargs == {"d": 6, "e": 3, "f": 4} | |||
args = (1,) | |||
kwargs = {"d": 6} | |||
new_args, new_kwargs = _convert_kwargs_to_args(func, args, kwargs, is_bounded=True) | |||
assert new_args == (1, 4) | |||
assert new_kwargs == {"d": 6, "e": 3, "f": 4} | |||
def func1(a, b, c, d, e, *, f): | |||
pass | |||
args = () | |||
kwargs = {"a": 1, "b": 2, "c": 3, "d": 4, "e": 5, "f": 6} | |||
new_args, new_kwargs = _convert_kwargs_to_args(func1, args, kwargs) | |||
assert new_args == (1, 2, 3, 4, 5) | |||
assert new_kwargs == {"f": 6} | |||
def test_opdef_serialization(): | |||
with TemporaryFile() as f: | |||
x = builtin.Elemwise(mode="Add") | |||
pickle.dump(x, f) | |||
f.seek(0) | |||
load_x = pickle.load(f) | |||
assert x == load_x | |||
with TemporaryFile() as f: | |||
x = builtin.Convolution(stride_h=9, compute_mode="float32") | |||
x.strategy = ( | |||
builtin.Convolution.Strategy.PROFILE | |||
| builtin.Convolution.Strategy.HEURISTIC | |||
| builtin.Convolution.Strategy.REPRODUCIBLE | |||
) | |||
pickle.dump(x, f) | |||
f.seek(0) | |||
load_x = pickle.load(f) | |||
assert x.strategy == load_x.strategy | |||
assert x == load_x |