Browse Source

feat(traced_module): support traced module backward compatible serialization

GitOrigin-RevId: aaa9e51c74
release-1.7
Megvii Engine Team 3 years ago
parent
commit
7b19bc76fb
13 changed files with 1314 additions and 191 deletions
  1. +1
    -0
      imperative/python/megengine/traced_module/__init__.py
  2. +136
    -0
      imperative/python/megengine/traced_module/compat.py
  3. +276
    -27
      imperative/python/megengine/traced_module/expr.py
  4. +0
    -1
      imperative/python/megengine/traced_module/module_tracer.py
  5. +54
    -5
      imperative/python/megengine/traced_module/node.py
  6. +51
    -31
      imperative/python/megengine/traced_module/pytree.py
  7. +146
    -18
      imperative/python/megengine/traced_module/serialization.py
  8. +207
    -79
      imperative/python/megengine/traced_module/traced_module.py
  9. +101
    -1
      imperative/python/megengine/traced_module/utils.py
  10. +0
    -23
      imperative/python/test/unit/core/test_serialization.py
  11. +3
    -3
      imperative/python/test/unit/traced_module/test_modification.py
  12. +24
    -2
      imperative/python/test/unit/traced_module/test_qat_module.py
  13. +315
    -1
      imperative/python/test/unit/traced_module/test_serialization.py

+ 1
- 0
imperative/python/megengine/traced_module/__init__.py View File

@@ -7,6 +7,7 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.


from ..core._imperative_rt.core2 import set_cpp_apply_module_trace from ..core._imperative_rt.core2 import set_cpp_apply_module_trace
from . import compat
from .traced_module import ( from .traced_module import (
TracedModule, TracedModule,
_register_all_builtin_module, _register_all_builtin_module,


+ 136
- 0
imperative/python/megengine/traced_module/compat.py View File

@@ -0,0 +1,136 @@
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

import numpy as np

from .. import tensor
from ..core.ops.builtin import BatchNorm
from .expr import CallMethod, Constant
from .node import TensorNode
from .serialization import (
register_functional_loader,
register_module_loader,
register_opdef_loader,
register_tensor_method_loader,
)


"""
# Expr loaders examples

from ..core.ops.builtin import Elemwise

@register_opdef_loader(Elemwise)
def add_opdef_loader(expr):
if expr.opdef_state["mode"] == "ADD":
expr.opdef_state["mode"] == "MUL"
node = expr.inputs[1]
astype_expr = CallMethod(node, "astype")
oup = TensorNode(
astype_expr,
shape=node.shape,
dtype=expr.inputs[0].dtype,
qparams=node.qparams,
)

astype_expr.set_args_kwargs(node, expr.inputs[0].dtype)
astype_expr.return_val = (oup,)
expr.inputs[1] = oup


@register_functional_loader(("megengine.functional.nn", "conv2d"))
def conv2df_loader(expr):
# expr.func = ("megengine.functional.nn","conv2d")
kwargs = expr.kwargs
orig_weight = expr.named_args["weight"]

astype_expr = CallMethod(orig_weight, "astype")
oup = TensorNode(
astype_expr,
shape=orig_weight.shape,
dtype=orig_weight.dtype,
qparams=orig_weight.qparams,
)

astype_expr.set_args_kwargs(orig_weight, expr.named_args["inp"].dtype)
astype_expr.return_val = (oup,)

expr.set_arg("weight", oup)


@register_module_loader(("megengine.module.conv", "Conv2d"))
def conv2dm_loader(expr):
module = expr.inputs[0].owner
args = list(expr.args)
orig_inp = args[1]
astype_expr = CallMethod(orig_inp, "astype")
oup = TensorNode(
astype_expr,
shape=orig_inp.shape,
dtype=orig_inp.dtype,
qparams=orig_inp.qparams,
)
astype_expr.set_args_kwargs(orig_inp, module.weight.dtype)
astype_expr.return_val = (oup,)
args[1] = oup
expr.set_args_kwargs(*args)


@register_tensor_method_loader("__add__")
def add_loader(expr):
args = list(expr.args)
if not isinstance(args[1], TensorNode):
args[1] = tensor(args[1])
node = Constant(args[1], "const").outputs[0]

astype_expr = CallMethod(node, "astype")
oup = TensorNode(
astype_expr, shape=node.shape, dtype=node.dtype, qparams=node.qparams,
)

astype_expr.set_args_kwargs(node, expr.inputs[0].dtype)
astype_expr.return_val = (oup,)
args[1] = oup
expr.set_args_kwargs(*args)
"""


@register_module_loader(
("megengine.module.batchnorm", "BatchNorm1d"),
("megengine.module.batchnorm", "BatchNorm2d"),
("megengine.module.batchnorm", "SyncBatchNorm"),
)
def bn2d_module_loader(expr):
# mge 1.6
if not hasattr(expr, "version"):
module = expr.inputs[0].owner
if not hasattr(module, "param_dim"):
module.param_dim = "dim_1c11"


@register_module_loader(
("megengine.module.conv_bn", "ConvBn2d"),
("megengine.module.conv_bn", "ConvBnRelu2d"),
("megengine.module.qat.conv_bn", "ConvBn2d"),
("megengine.module.qat.conv_bn", "ConvBnRelu2d"),
)
def convbn2d_module_loader(expr):
# mge 1.6
if not hasattr(expr, "version"):
module = expr.inputs[0].owner
if not hasattr(module.bn, "param_dim"):
module.bn.param_dim = "dim_1c11"


@register_opdef_loader(BatchNorm)
def bn_opdef_loader(expr):
# mge 1.6
if not hasattr(expr, "version"):
output = expr.outputs[-1]
oup = TensorNode(expr, shape=(0,), dtype=None, qparams=output._qparams,)
expr.outputs.insert(4, oup)

+ 276
- 27
imperative/python/megengine/traced_module/expr.py View File

@@ -11,19 +11,28 @@ import collections
import copy import copy
import inspect import inspect
import re import re
from typing import Callable, Dict, List, Optional, Union
import weakref
from importlib import import_module
from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union


from ..core._imperative_rt import OpDef from ..core._imperative_rt import OpDef
from ..core._imperative_rt.core2 import Tensor as RawTensor from ..core._imperative_rt.core2 import Tensor as RawTensor
from ..core._imperative_rt.core2 import apply, set_module_tracing, unset_module_tracing
from ..core._imperative_rt.core2 import (
apply,
is_tracing_module,
set_module_tracing,
unset_module_tracing,
)
from ..core.ops.builtin import FakeQuant from ..core.ops.builtin import FakeQuant
from ..core.ops.special import Const from ..core.ops.special import Const
from ..module import Module from ..module import Module
from ..tensor import Parameter, Tensor from ..tensor import Parameter, Tensor
from ..version import __version__
from .module_tracer import active_module_tracer, module_tracer from .module_tracer import active_module_tracer, module_tracer
from .node import ModuleNode, Node, NodeMixin, TensorNode from .node import ModuleNode, Node, NodeMixin, TensorNode
from .pytree import ArgsIndex, TreeDef, _is_const_leaf, _is_leaf, tree_flatten from .pytree import ArgsIndex, TreeDef, _is_const_leaf, _is_leaf, tree_flatten
from .serialization import get_opdef_state, load_opdef_from_state
from .serialization import _ModuleState
from .utils import _check_builtin_module_attr, _check_obj_attr, _convert_kwargs_to_args




def rstrip(s: str, __chars: str): def rstrip(s: str, __chars: str):
@@ -112,6 +121,7 @@ class Expr:
node.users.append(self) node.users.append(self)
else: else:
assert node is None assert node is None
assert not isinstance(val, (Module, RawTensor))
assert _is_leaf(val) and _is_const_leaf(val) assert _is_leaf(val) and _is_const_leaf(val)
idx = len(self.inputs) + len(self.const_val) idx = len(self.inputs) + len(self.const_val)
self.const_val.append((idx, val)) self.const_val.append((idx, val))
@@ -132,14 +142,14 @@ class Expr:
current_graph._namespace.auto_naming_for_outputs(self) current_graph._namespace.auto_naming_for_outputs(self)


def unflatten_args(self, inputs): def unflatten_args(self, inputs):
if self.arg_def is not None:
inputs = list(inputs)
for idx, val in self.const_val:
inputs.insert(idx, val)
args, kwargs = self.arg_def.unflatten(inputs)
return args, kwargs
else:
return inputs, {}
assert self.arg_def is not None, "{} expr doesn't have args/kwargs".format(
type(self).__name__
)
inputs = list(inputs)
for idx, val in self.const_val:
inputs.insert(idx, val)
args, kwargs = self.arg_def.unflatten(inputs)
return args, kwargs


def replace_inputs(self, repl_dict: Dict[Node, Node]): def replace_inputs(self, repl_dict: Dict[Node, Node]):
r"""Replace the input Nodes of this Expr. r"""Replace the input Nodes of this Expr.
@@ -166,6 +176,39 @@ class Expr:
repl_node.users.append(self) repl_node.users.append(self)


@property @property
def _support_set_args_kwargs(self):
return False

def set_args_kwargs(self, *args, **kwargs):
r""" Set args and kwargs for Expr.
"""
assert (
self._support_set_args_kwargs
), "Doesn't support set args/kwargs for {} expr".format(type(self).__name__)
args, kwargs = _convert_kwargs_to_args(self._get_func(), args, kwargs)
inputs, arg_def = tree_flatten((args, kwargs))
orig_inputs = self.inputs
self.inputs = []
self.const_val = []
for val in inputs:
if isinstance(val, (TensorNode, ModuleNode)):
self.inputs.append(val)
else:
assert _is_leaf(val) and _is_const_leaf(val)
idx = len(self.inputs) + len(self.const_val)
self.const_val.append((idx, val))

for n in orig_inputs:
if n not in self.inputs:
n.users.remove(self)

for n in self.inputs:
if n not in orig_inputs:
n.users.append(self)

self.arg_def = arg_def

@property
def kwargs(self): def kwargs(self):
r"""Get the keyword arguments of the operation corresponding to this Expr.""" r"""Get the keyword arguments of the operation corresponding to this Expr."""
_, kwargs = self.unflatten_args(self.inputs) _, kwargs = self.unflatten_args(self.inputs)
@@ -177,6 +220,61 @@ class Expr:
args, _ = self.unflatten_args(self.inputs) args, _ = self.unflatten_args(self.inputs)
return args return args


def _get_func(self):
# get called function when the expr is interpreted
raise NotImplementedError

@property
def named_args(self):
func = self._get_func()
return inspect.getcallargs(func, *self.args, **self.kwargs)

def set_arg(self, name, val):
func = self._get_func()
if name in self.kwargs:
new_kwargs = self.kwargs
new_kwargs[name] = val
self.set_args_kwargs(*self.args, **new_kwargs)
else:
arg_spec = inspect.getfullargspec(func)
if name in arg_spec.args:
ind = arg_spec.args.index(name)
new_args = list(self.args)
new_args[ind] = val
self.set_args_kwargs(*new_args)
elif name == arg_spec.varargs:
assert arg_spec.varargs is not None
assert len(self.args) >= len(arg_spec.args)
val = (val,) if not isinstance(val, Sequence) else val
self.set_args_kwargs(*self.args[0 : len(arg_spec.args)], *val)
else:
assert (
arg_spec.varkw is not None
), "func {} does't have argument named {}".format(func, name)
new_kwargs = self.kwargs
new_kwargs[name] = val
self.set_args_kwargs(*self.args, **new_kwargs)

@property
def return_val(self):
return self.out_def.unflatten(self.outputs)

@return_val.setter
def return_val(self, new_outputs):
outputs, out_def = tree_flatten(
new_outputs, is_leaf=lambda x: isinstance(x, Node)
)
assert all(
isinstance(o, Node) for o in outputs
), "Return values of expr must be ModuleNode or TensorNode or Container with them"
assert all(
o.expr in (None, self) for o in outputs
), "Some nodes are produced by other expr, can not be output of expr {}".format(
self
)
self.outputs = outputs
self.out_def = out_def

@property @property
def top_graph(self): def top_graph(self):
r"""Get the parent graph of this Expr.""" r"""Get the parent graph of this Expr."""
@@ -184,12 +282,6 @@ class Expr:
return self._top_graph() return self._top_graph()
return None return None


def __getstate__(self):
state = self.__dict__.copy()
if "_top_graph" in state:
state.pop("_top_graph")
return state

@classmethod @classmethod
def _get_next_id(cls): def _get_next_id(cls):
return cls.__total_id return cls.__total_id
@@ -199,6 +291,23 @@ class Expr:
assert isinstance(id, int) assert isinstance(id, int)
cls.__total_id = id cls.__total_id = id


def __copy__(self):
cls = self.__class__
result = cls.__new__(cls)
result.__dict__.update(self.__dict__)
return result

def __deepcopy__(self, memo):
cls = self.__class__
result = cls.__new__(cls)
state = {}
memo[id(self)] = result
for k, v in self.__dict__.items():
if not isinstance(v, weakref.ReferenceType):
state[k] = copy.deepcopy(v, memo)
result.__dict__.update(state)
return result



# expr: None (i.e. fake expression which is used to mark input) # expr: None (i.e. fake expression which is used to mark input)
class Input(Expr): class Input(Expr):
@@ -229,6 +338,17 @@ class Input(Expr):
def __repr__(self): def __repr__(self):
return "%{}:\t{} = Input()".format(self._id, self.outputs[0]) return "%{}:\t{} = Input()".format(self._id, self.outputs[0])


def __getstate__(self):
state = {
"_id": self._id,
"_disable_remove": self._disable_remove,
"inputs": self.inputs,
"outputs": self.outputs,
"name": self.name,
}
_check_obj_attr(state)
return state



# expr: outputs = getattr(inputs[0], self.name) # expr: outputs = getattr(inputs[0], self.name)
class GetAttr(Expr): class GetAttr(Expr):
@@ -276,11 +396,23 @@ class GetAttr(Expr):
def __repr__(self): def __repr__(self):
out_type = "Tensor" out_type = "Tensor"
if isinstance(self.outputs[0], ModuleNode): if isinstance(self.outputs[0], ModuleNode):
out_type = self.outputs[0].module_type.__name__
m_type = self.outputs[0].module_type
out_type = m_type.__name__ if isinstance(m_type, type) else m_type[1]
return '%{}:\t{} = getattr({}, "{}") -> ({})'.format( return '%{}:\t{} = getattr({}, "{}") -> ({})'.format(
self._id, self.outputs[0], self.inputs[0], self.name, out_type self._id, self.outputs[0], self.inputs[0], self.name, out_type
) )


def __getstate__(self):
state = {
"_id": self._id,
"_disable_remove": self._disable_remove,
"inputs": self.inputs,
"outputs": self.outputs,
"name": self.name,
}
_check_obj_attr(state)
return state



# expr: outputs = inputs[0].__call__(*inputs[1:]) # expr: outputs = inputs[0].__call__(*inputs[1:])
class CallMethod(Expr): class CallMethod(Expr):
@@ -307,6 +439,7 @@ class CallMethod(Expr):
node, node,
] ]
self.const_val = [] self.const_val = []
self.arg_def = tree_flatten(((node,), {}))[1]
self.method = method self.method = method


@classmethod @classmethod
@@ -342,6 +475,27 @@ class CallMethod(Expr):
outputs, _ = tree_flatten(outputs, is_leaf=lambda x: isinstance(x, RawTensor)) outputs, _ = tree_flatten(outputs, is_leaf=lambda x: isinstance(x, RawTensor))
return outputs return outputs


def _get_func(self):
if isinstance(self.args[0], type):
obj_type = self.args[0]
elif isinstance(self.args[0], ModuleNode):
obj_type = self.args[0].module_type
else:
assert isinstance(self.args[0], TensorNode)
obj_type = Tensor
meth = getattr(
obj_type, "forward" if issubclass(obj_type, Module) else self.method
)
return meth

@property
def _support_set_args_kwargs(self):
# only expr call tensor method or builtin module support modify args/kwargs
return (
isinstance(self.args[0], (TensorNode, type))
or self.args[0].module_type is not Module
)

def __repr__(self): def __repr__(self):
args = ", ".join(str(i) for i in self.args[1:]) args = ", ".join(str(i) for i in self.args[1:])
kwargs = ", ".join("{}={}".format(k, v) for k, v in self.kwargs.items()) kwargs = ", ".join("{}={}".format(k, v) for k, v in self.kwargs.items())
@@ -359,6 +513,21 @@ class CallMethod(Expr):
", ".join([args, kwargs]), ", ".join([args, kwargs]),
) )


def __getstate__(self):
state = {
"_id": self._id,
"_disable_remove": self._disable_remove,
"inputs": self.inputs,
"const_val": self.const_val,
"method": self.method,
"arg_def": self.arg_def,
"out_def": self.out_def,
"outputs": self.outputs,
"version": __version__,
}
_check_obj_attr(state)
return state



# expr: outputs = apply(self.opdef, *inputs) # expr: outputs = apply(self.opdef, *inputs)
class Apply(Expr): class Apply(Expr):
@@ -394,14 +563,32 @@ class Apply(Expr):
) )


def __getstate__(self): def __getstate__(self):
state = super().__getstate__()
state["opdef"] = get_opdef_state(state["opdef"])
opdef_state = self.opdef.__getstate__()
opdef_state["opdef_type"] = type(self.opdef)
state = {
"_id": self._id,
"_disable_remove": self._disable_remove,
"opdef_state": opdef_state,
"inputs": self.inputs,
"outputs": self.outputs,
"version": __version__,
}
_check_obj_attr(state)
return state return state


def __setstate__(self, state): def __setstate__(self, state):
state["opdef"] = load_opdef_from_state(state["opdef"])
for k, v in state.items():
setattr(self, k, v)
# compat with mge 1.6
if "opdef" in state and "opdef_state" not in state:
opdef_state = state.pop("opdef")
opdef_state["opdef_type"] = opdef_state.pop("type")
state["opdef_state"] = opdef_state
self.__dict__.update(state)
assert isinstance(state["opdef_state"], dict)
opdef_state = state["opdef_state"].copy()
opdef_type = opdef_state.pop("opdef_type")
opdef_obj = opdef_type()
opdef_obj.__setstate__(opdef_state)
setattr(self, "opdef", opdef_obj)


@classmethod @classmethod
def apply_module_trace_hook(cls, opdef, *inputs): def apply_module_trace_hook(cls, opdef, *inputs):
@@ -458,12 +645,24 @@ class CallFunction(Expr):


def interpret(self, *inputs): def interpret(self, *inputs):
args, kwargs = self.unflatten_args(inputs) args, kwargs = self.unflatten_args(inputs)
outputs = self.func(*args, **kwargs)
func = (
self.func
if not is_tracing_module()
else active_module_tracer().patcher.wrap_fn(self.func)
)
outputs = func(*args, **kwargs)
if outputs is None: if outputs is None:
return outputs return outputs
outputs, _ = tree_flatten(outputs, is_leaf=lambda x: isinstance(x, RawTensor)) outputs, _ = tree_flatten(outputs, is_leaf=lambda x: isinstance(x, RawTensor))
return outputs return outputs


def _get_func(self):
return self.func

@property
def _support_set_args_kwargs(self):
return True

def __repr__(self): def __repr__(self):
args = ", ".join(str(i) for i in self.args) args = ", ".join(str(i) for i in self.args)
kwargs = ", ".join("{}={}".format(k, v) for k, v in self.kwargs.items()) kwargs = ", ".join("{}={}".format(k, v) for k, v in self.kwargs.items())
@@ -477,6 +676,33 @@ class CallFunction(Expr):
", ".join([args, kwargs]), ", ".join([args, kwargs]),
) )


def __getstate__(self):
state = {
"_id": self._id,
"_disable_remove": self._disable_remove,
"func": (self.func.__module__, self.func.__qualname__),
"const_val": self.const_val,
"inputs": self.inputs,
"arg_def": self.arg_def,
"out_def": self.out_def,
"outputs": self.outputs,
"version": __version__,
}
_check_obj_attr(state)
return state

def __setstate__(self, state):
self.__dict__.update(state)
try:
if isinstance(self.func, tuple):
mname, fname = self.func
f = import_module(mname)
for i in fname.split("."):
f = getattr(f, i)
self.func = f
except Exception:
pass



# expr outputs = self.value # expr outputs = self.value
class Constant(Expr): class Constant(Expr):
@@ -496,6 +722,13 @@ class Constant(Expr):
assert isinstance(c, (RawTensor, Module)) assert isinstance(c, (RawTensor, Module))
if isinstance(c, Module): if isinstance(c, Module):
assert module_tracer.is_builtin(c) or c.is_qat assert module_tracer.is_builtin(c) or c.is_qat
if isinstance(c, RawTensor):
if is_tracing_module():
unset_module_tracing()
c = Tensor(c)
set_module_tracing()
else:
c = Tensor(c)
self.value = c self.value = c
self.name = name self.name = name
self.inputs = [] self.inputs = []
@@ -530,9 +763,25 @@ class Constant(Expr):
) )


def __getstate__(self): def __getstate__(self):
state = self.__dict__.copy()
if "_top_graph" in state:
state.pop("_top_graph")
state = {
"_id": self._id,
"_disable_remove": self._disable_remove,
"value": self.value,
"name": self.name,
"inputs": self.inputs,
"outputs": self.outputs,
}
_check_obj_attr(state)
if isinstance(self.value, RawTensor): if isinstance(self.value, RawTensor):
state["value"] = Tensor(self.value) state["value"] = Tensor(self.value)
if isinstance(self.value, Module) and module_tracer.is_builtin(self.value):
_check_builtin_module_attr(self.value)
state["value"] = _ModuleState.get_module_state(self.value)

return state return state

def __setstate__(self, state):
for k, v in state.items():
if isinstance(v, _ModuleState):
state[k] = v.to_module()
self.__dict__.update(state)

+ 0
- 1
imperative/python/megengine/traced_module/module_tracer.py View File

@@ -72,7 +72,6 @@ BUILTIN_ARRAY_METHOD = [
"astype", "astype",
"reshape", "reshape",
"_broadcast", "_broadcast",
"transpose",
"flatten", "flatten",
"sum", "sum",
"prod", "prod",


+ 54
- 5
imperative/python/megengine/traced_module/node.py View File

@@ -6,7 +6,9 @@
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import abc import abc
import copy
import weakref import weakref
from importlib import import_module
from typing import Any, Dict, List, Tuple, Type from typing import Any, Dict, List, Tuple, Type


import numpy import numpy
@@ -14,7 +16,9 @@ import numpy
from .. import get_logger from .. import get_logger
from ..core._imperative_rt.core2 import Tensor as RawTensor from ..core._imperative_rt.core2 import Tensor as RawTensor
from ..module import Module from ..module import Module
from ..quantization.utils import QParams
from ..tensor import Tensor from ..tensor import Tensor
from .utils import _check_obj_attr


logger = get_logger(__name__) logger = get_logger(__name__)


@@ -145,6 +149,23 @@ class Node:
assert isinstance(id, int) assert isinstance(id, int)
cls.__total_id = id cls.__total_id = id


def __copy__(self):
cls = self.__class__
result = cls.__new__(cls)
result.__dict__.update(self.__dict__)
return result

def __deepcopy__(self, memo):
cls = self.__class__
result = cls.__new__(cls)
state = {}
memo[id(self)] = result
for k, v in self.__dict__.items():
if not isinstance(v, weakref.ReferenceType) and k != "actual_node":
state[k] = copy.deepcopy(v, memo)
result.__dict__.update(state)
return result



class ModuleNode(Node): class ModuleNode(Node):
r"""``ModuleNode`` represents the Module objects.""" r"""``ModuleNode`` represents the Module objects."""
@@ -157,19 +178,28 @@ class ModuleNode(Node):
super().__init__(expr, name, qualname) super().__init__(expr, name, qualname)


def __getstate__(self): def __getstate__(self):
return {
state = {
"expr": self.expr, "expr": self.expr,
"users": self.users, "users": self.users,
"_id": self._id, "_id": self._id,
"_name": self._name, "_name": self._name,
"_qualname": self._qualname, "_qualname": self._qualname,
"module_type": self.module_type,
"module_type": (self.module_type.__module__, self.module_type.__qualname__),
} }
_check_obj_attr(state)
return state


def __setstate__(self, state): def __setstate__(self, state):
if "_orig_name" in state: if "_orig_name" in state:
state["_qualname"] = state.pop("_orig_name") state["_qualname"] = state.pop("_orig_name")
self.__dict__.update(state) self.__dict__.update(state)
try:
if isinstance(self.module_type, tuple):
mname, classname = self.module_type
mtype = getattr(import_module(mname), classname)
self.module_type = mtype
except Exception:
pass


@property @property
def owner(self): def owner(self):
@@ -185,12 +215,26 @@ class TensorNode(Node):


_shape = None # type: Tuple[int] _shape = None # type: Tuple[int]
_dtype = None # type: numpy.dtype _dtype = None # type: numpy.dtype
_qparams = None
_qparams = None # type: QParams
_device = None _device = None
_value = None # type: Tensor _value = None # type: Tensor


def __init__(
self,
expr: "Expr",
name: str = None,
qualname: str = None,
shape: Tuple[int] = None,
dtype: numpy.dtype = None,
qparams: QParams = None,
):
super().__init__(expr, name, qualname)
self._shape = shape
self._dtype = shape
self._qparams = qparams

def __getstate__(self): def __getstate__(self):
return {
state = {
"expr": self.expr, "expr": self.expr,
"users": self.users, "users": self.users,
"_id": self._id, "_id": self._id,
@@ -201,6 +245,8 @@ class TensorNode(Node):
"_name": self._name, "_name": self._name,
"_qualname": self._qualname, "_qualname": self._qualname,
} }
_check_obj_attr(state)
return state


def __setstate__(self, state): def __setstate__(self, state):
if "_orig_name" in state: if "_orig_name" in state:
@@ -276,7 +322,10 @@ class NodeMixin(abc.ABC):
assert isinstance(node, TensorNode) assert isinstance(node, TensorNode)
assert isinstance(value, RawTensor) assert isinstance(value, RawTensor)
if isinstance(value, RawTensor): if isinstance(value, RawTensor):
node._dtype = value.dtype
try:
node._dtype = value.dtype
except RuntimeError:
node._dtype = None
node._shape = ( node._shape = (
value._tuple_shape if isinstance(value, Tensor) else value.shape value._tuple_shape if isinstance(value, Tensor) else value.shape
) )


+ 51
- 31
imperative/python/megengine/traced_module/pytree.py View File

@@ -7,15 +7,18 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.


import collections import collections
from collections import OrderedDict
from collections import OrderedDict, defaultdict
from functools import partial
from typing import Callable, NamedTuple from typing import Callable, NamedTuple


import numpy as np import numpy as np


from ..core._imperative_rt import OpDef
from ..core._imperative_rt.common import CompNode from ..core._imperative_rt.common import CompNode
from ..core._imperative_rt.core2 import Tensor as RawTensor from ..core._imperative_rt.core2 import Tensor as RawTensor
from ..core._wrap import Device from ..core._wrap import Device
from ..core.tensor.dtype import QuantDtypeMeta from ..core.tensor.dtype import QuantDtypeMeta
from ..distributed import Group
from ..module import Module from ..module import Module
from ..quantization.utils import LSQParams, QParams, QuantMode from ..quantization.utils import LSQParams, QParams, QuantMode
from ..tensor import Parameter, Tensor from ..tensor import Parameter, Tensor
@@ -49,45 +52,54 @@ SUPPORTED_LEAF_TYPE = {
type(Ellipsis), type(Ellipsis),
QuantMode, QuantMode,
ArgsIndex, ArgsIndex,
Group,
} }


USER_REGISTERED_LEAF_TYPE = []
USER_REGISTERED_CONTAINER_TYPE = []
# if isinstance(object, SUPPORTED_LEAF_CLS) or issubclass(obj, SUPPORTED_LEAF_CLS) is True, the object could be threated as leaf node of pytree # if isinstance(object, SUPPORTED_LEAF_CLS) or issubclass(obj, SUPPORTED_LEAF_CLS) is True, the object could be threated as leaf node of pytree
SUPPORTED_LEAF_CLS = [Module, Node, NodeMixin, np.dtype, np.ndarray, np.number]
SUPPORTED_LEAF_CLS = [
Module,
Node,
NodeMixin,
np.dtype,
np.ndarray,
np.number,
np.bool_,
OpDef,
]


NodeType = NamedTuple("NodeType", [("flatten", Callable), ("unflatten", Callable)]) NodeType = NamedTuple("NodeType", [("flatten", Callable), ("unflatten", Callable)])




def register_supported_type(type, flatten=None, unflatten=None): def register_supported_type(type, flatten=None, unflatten=None):
tp_info = (type.__module__, type.__qualname__)
if flatten and unflatten: if flatten and unflatten:
SUPPORTED_TYPE[type] = NodeType(flatten, unflatten)
USER_REGISTERED_CONTAINER_TYPE.append(tp_info)
else: else:
SUPPORTED_LEAF_CLS.append(type)


def _dict_flatten(inp):
aux_data = []
results = []
for key, value in sorted(inp.items()):
results.append(value)
aux_data.append(key)
return results, tuple(aux_data)
USER_REGISTERED_LEAF_TYPE.append(tp_info)
_register_supported_type(type, flatten, unflatten)




def _dict_unflatten(inps, aux_data):
return dict(zip(aux_data, inps))
def _register_supported_type(type, flatten=None, unflatten=None):
if flatten and unflatten:
SUPPORTED_TYPE[type] = NodeType(flatten, unflatten)
else:
SUPPORTED_LEAF_CLS.append(type)




def _ordereddict_flatten(inp):
def _dict_flatten(ordered, inp):
aux_data = [] aux_data = []
results = [] results = []
for key, value in inp.items():
dict_items = inp.items() if ordered else sorted(inp.items())
for key, value in dict_items:
results.append(value) results.append(value)
aux_data.append(key) aux_data.append(key)
return results, tuple(aux_data) return results, tuple(aux_data)




def _ordereddict_unflatten(inps, aux_data):
return OrderedDict(zip(aux_data, inps))
def _dict_unflatten(dict_type, inps, aux_data):
return dict_type(zip(aux_data, inps))




def qparams_flatten(inp): def qparams_flatten(inp):
@@ -99,33 +111,41 @@ def qparams_flatten(inp):
return results, tuple(aux_data) return results, tuple(aux_data)




def qparams_unflatten(inp, aux_data):
obj = QParams.__new__(QParams)
def qparams_unflatten(qparam_type, inp, aux_data):
obj = qparam_type.__new__(qparam_type)
for k, v in zip(aux_data, inp): for k, v in zip(aux_data, inp):
setattr(obj, k, v) setattr(obj, k, v)
return obj return obj




register_supported_type(list, lambda x: (x, None), lambda x, aux_data: list(x))
register_supported_type(tuple, lambda x: (x, None), lambda x, aux_data: tuple(x))
register_supported_type(dict, _dict_flatten, _dict_unflatten)
register_supported_type(
collections.OrderedDict, _ordereddict_flatten, _ordereddict_unflatten
_register_supported_type(list, lambda x: (x, None), lambda x, aux_data: list(x))
_register_supported_type(tuple, lambda x: (x, None), lambda x, aux_data: tuple(x))
_register_supported_type(
dict, partial(_dict_flatten, False), partial(_dict_unflatten, dict)
)
_register_supported_type(
defaultdict, partial(_dict_flatten, False), partial(_dict_unflatten, defaultdict)
) )
register_supported_type(
_register_supported_type(
OrderedDict, partial(_dict_flatten, True), partial(_dict_unflatten, OrderedDict)
)
_register_supported_type(
slice, slice,
lambda x: ([x.start, x.stop, x.step], None), lambda x: ([x.start, x.stop, x.step], None),
lambda x, aux_data: slice(x[0], x[1], x[2]), lambda x, aux_data: slice(x[0], x[1], x[2]),
) )


register_supported_type(QParams, qparams_flatten, qparams_unflatten)
_register_supported_type(QParams, qparams_flatten, partial(qparams_unflatten, QParams))
_register_supported_type(
LSQParams, qparams_flatten, partial(qparams_unflatten, LSQParams)
)




def _is_leaf(obj): def _is_leaf(obj):
if isinstance(obj, type):
return issubclass(obj, tuple(SUPPORTED_LEAF_CLS)) or obj in SUPPORTED_LEAF_TYPE
obj_type = obj if isinstance(obj, type) else type(obj)
return ( return (
isinstance(obj, tuple(SUPPORTED_LEAF_CLS)) or type(obj) in SUPPORTED_LEAF_TYPE
issubclass(obj_type, tuple(SUPPORTED_LEAF_CLS))
or obj_type in SUPPORTED_LEAF_TYPE
) )






+ 146
- 18
imperative/python/megengine/traced_module/serialization.py View File

@@ -5,30 +5,158 @@
# Unless required by applicable law or agreed to in writing, # Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from typing import Dict
from importlib import import_module
from typing import Dict, Tuple


from ..core._imperative_rt import OpDef from ..core._imperative_rt import OpDef
from ..core.ops import builtin from ..core.ops import builtin
from ..tensor import Tensor
from ..version import __version__ from ..version import __version__
from .utils import _convert_kwargs_to_args


OPDEF_PARAM_LOADER = {}
OPDEF_LOADER = {}
FUNCTIONAL_LOADER = {}
TENSORMETHOD_LOADER = {}
MODULE_LOADER = {}




def get_opdef_state(obj: OpDef) -> Dict:
state = obj.__getstate__()
state["type"] = type(obj)
state["version"] = __version__
return state
class _ModuleState:
obj = None


def __init__(self, module: Tuple, state: Dict, version: str):
self.module = module
self.state = state
self.version = version


def load_opdef_from_state(state: Dict) -> OpDef:
assert "type" in state and issubclass(state["type"], OpDef)
assert "version" in state
opdef_type = state.pop("type")
if opdef_type in OPDEF_PARAM_LOADER:
loader = OPDEF_PARAM_LOADER[opdef_type]
state = loader(state)
state.pop("version")
opdef_obj = opdef_type()
opdef_obj.__setstate__(state)
return opdef_obj
@classmethod
def get_module_state(cls, module):
typem = (type(module).__module__, type(module).__qualname__)
state = module.__dict__.copy()
state.pop("_m_dump_modulestate", None)
if hasattr(module, "_m_dump_modulestate"):
assert isinstance(module._m_dump_modulestate, cls)
module._m_dump_modulestate.__init__(typem, state, __version__)
else:
module.__dict__["_m_dump_modulestate"] = _ModuleState(
typem, state, __version__
)

return module._m_dump_modulestate

def __getstate__(self):
return {"module": self.module, "state": self.state, "version": self.version}

def to_module(self):
if self.obj is None:
typem = getattr(import_module(self.module[0]), self.module[1])
m_obj = typem.__new__(typem)
m_obj.__dict__.update(self.state)
self.obj = m_obj
return self.obj


def register_opdef_loader(*opdefs):
def callback(loader):
for opdef in opdefs:
assert opdef not in OPDEF_LOADER
OPDEF_LOADER[opdef] = loader
return loader

return callback


def register_functional_loader(*funcs):
def callback(loader):
for func in funcs:
assert func not in FUNCTIONAL_LOADER
FUNCTIONAL_LOADER[func] = loader
return loader

return callback


def register_module_loader(*module_types):
def callback(loader):
for module_type in module_types:
assert module_type not in MODULE_LOADER
MODULE_LOADER[module_type] = loader
return loader

return callback


def register_tensor_method_loader(*methods):
def callback(loader):
for method in methods:
assert method not in TENSORMETHOD_LOADER
TENSORMETHOD_LOADER[method] = loader
return loader

return callback


def _replace_args_kwargs(expr, new_args, new_kwargs):
if len(new_args) != len(expr.args) or set(new_kwargs.keys()) != set(
expr.kwargs.keys()
):
expr.set_args_kwargs(*new_args, **new_kwargs)


def load_functional(expr):
func = (
(expr.func.__module__, expr.func.__qualname__)
if callable(expr.func)
else expr.func
)
assert isinstance(func, tuple)
if func in FUNCTIONAL_LOADER:
loader = FUNCTIONAL_LOADER[func]
loader(expr)
mname, fname = func
f = import_module(mname)
for i in fname.split("."):
f = getattr(f, i)
expr.func = f
assert callable(expr.func)
if not hasattr(expr, "version") or expr.version != __version__:
args, kwargs = _convert_kwargs_to_args(expr.func, expr.args, expr.kwargs)
_replace_args_kwargs(expr, args, kwargs)


def load_call_module_expr(expr):
m_type = expr.inputs[0].module_type
if isinstance(m_type, type):
m_type = (m_type.__module__, m_type.__qualname__)
if m_type in MODULE_LOADER:
MODULE_LOADER[m_type](expr)
if isinstance(expr.inputs[0].module_type, tuple):
mname, classname = expr.inputs[0].module_type
expr.inputs[0].module_type = getattr(import_module(mname), classname)
if not hasattr(expr, "version") or expr.version != __version__:
fwd_func = getattr(expr.inputs[0].module_type, "forward")
args, kwargs = _convert_kwargs_to_args(fwd_func, expr.args, expr.kwargs)
_replace_args_kwargs(expr, args, kwargs)


def load_call_tensor_method_expr(expr):
if expr.method in TENSORMETHOD_LOADER:
loader = TENSORMETHOD_LOADER[expr.method]
loader(expr)
if not hasattr(expr, "version") or expr.version != __version__:
tmethod = (
getattr(expr.args[0], expr.method)
if isinstance(expr.args[0], type)
else getattr(Tensor, expr.method)
)
args, kwargs = _convert_kwargs_to_args(tmethod, expr.args, expr.kwargs)
_replace_args_kwargs(expr, args, kwargs)


def load_apply_expr(expr):
opdef_type = type(expr.opdef)
if opdef_type in OPDEF_LOADER:
OPDEF_LOADER[opdef_type](expr)
opdef_state = expr.opdef_state
opdef_obj = opdef_state.pop("opdef_type")()
opdef_obj.__setstate__(opdef_state)
expr.opdef = opdef_obj

+ 207
- 79
imperative/python/megengine/traced_module/traced_module.py View File

@@ -14,6 +14,7 @@ import inspect
import keyword import keyword
import re import re
import weakref import weakref
from importlib import import_module
from inspect import getcallargs, getmembers, isclass, ismethod from inspect import getcallargs, getmembers, isclass, ismethod
from itertools import chain from itertools import chain
from types import FunctionType from types import FunctionType
@@ -53,6 +54,7 @@ from ..quantization.observer import (
SyncMinMaxObserver, SyncMinMaxObserver,
) )
from ..tensor import Tensor from ..tensor import Tensor
from ..version import __version__
from .expr import ( from .expr import (
Apply, Apply,
CallFunction, CallFunction,
@@ -80,8 +82,27 @@ from .module_tracer import (
set_active_module_tracer, set_active_module_tracer,
) )
from .node import ModuleNode, Node, NodeMixin, TensorNode from .node import ModuleNode, Node, NodeMixin, TensorNode
from .pytree import ArgsIndex, tree_flatten
from .utils import replace_container_with_module_container
from .pytree import (
USER_REGISTERED_CONTAINER_TYPE,
USER_REGISTERED_LEAF_TYPE,
ArgsIndex,
TreeDef,
_register_supported_type,
tree_flatten,
)
from .serialization import (
_ModuleState,
load_apply_expr,
load_call_module_expr,
load_call_tensor_method_expr,
load_functional,
)
from .utils import (
_check_builtin_module_attr,
_check_obj_attr,
_convert_kwargs_to_args,
replace_container_with_module_container,
)


logger = get_logger(__name__) logger = get_logger(__name__)


@@ -341,7 +362,7 @@ class NameSpace:
def create_unique_name(self, name: str, node: Any = None) -> str: def create_unique_name(self, name: str, node: Any = None) -> str:
assert isinstance(name, str), "The name must be a string" assert isinstance(name, str), "The name must be a string"


if name in self._used_names and self._used_names[name] is node:
if name in self._used_names and (self._used_names[name] is node):
return name return name


name = re.sub("[^0-9a-zA-Z_]+", "_", name) name = re.sub("[^0-9a-zA-Z_]+", "_", name)
@@ -1067,6 +1088,7 @@ class InternalGraph:
if node2value[n][1] == 0: if node2value[n][1] == 0:
node2value.pop(n) node2value.pop(n)
if values is not None: if values is not None:
assert len(values) == len(expr.outputs)
for n, v in zip(expr.outputs, values): for n, v in zip(expr.outputs, values):
if ref_count(n) > 0: if ref_count(n) > 0:
node2value[n] = [v, ref_count(n)] node2value[n] = [v, ref_count(n)]
@@ -1105,13 +1127,27 @@ class InternalGraph:
return res return res


def __getstate__(self): def __getstate__(self):
state = self.__dict__.copy()
if "_top_graph" in state:
state.pop("_top_graph")
state = {
"_exprs": self._exprs,
"_inputs": self._inputs,
"_outputs": self._outputs,
"_watch_point": [],
"_end_point": [],
"_namespace": self._namespace,
"_rst": collections.defaultdict(list),
"_name": self._name,
"_qualname": self._qualname,
}
if self._total_ids:
state["_total_ids"] = self._total_ids

_check_obj_attr(state)

return state return state


def __setstate__(self, state): def __setstate__(self, state):
old_version = False old_version = False

if "_module_name" in state: if "_module_name" in state:
old_version = True old_version = True
state["_qualname"] = state.pop("_module_name") state["_qualname"] = state.pop("_module_name")
@@ -1144,6 +1180,25 @@ class InternalGraph:
self._namespace = NameSpace(self._name, self._qualname) self._namespace = NameSpace(self._name, self._qualname)
self._re_associate_name() self._re_associate_name()


def __copy__(self):
cls = self.__class__
result = cls.__new__(cls)
result.__dict__.update(self.__dict__)
return result

def __deepcopy__(self, memo):
if id(self) in memo:
return memo[id(self)]
cls = self.__class__
result = cls.__new__(cls)
state = {}
memo[id(self)] = result
for k, v in self.__dict__.items():
if not isinstance(v, weakref.ReferenceType):
state[k] = copy.deepcopy(v, memo)
result.__dict__.update(state)
return result



def _get_meth_name(obj, func): def _get_meth_name(obj, func):
tp = obj if isinstance(obj, type) else type(obj) tp = obj if isinstance(obj, type) else type(obj)
@@ -1157,9 +1212,7 @@ def _get_meth_name(obj, func):
def _wrapped_function(orig_func): def _wrapped_function(orig_func):
@functools.wraps(orig_func) @functools.wraps(orig_func)
def wrapped_fn(*args, **kwargs): def wrapped_fn(*args, **kwargs):
method_func = wrapped_fn
if "method_func" in kwargs:
method_func = kwargs.pop("method_func")
method_func = kwargs.pop("method_func", wrapped_fn)
if is_tracing_module(): if is_tracing_module():
unset_module_tracing() unset_module_tracing()
inputs, tree_def = tree_flatten((args, kwargs)) inputs, tree_def = tree_flatten((args, kwargs))
@@ -1167,11 +1220,11 @@ def _wrapped_function(orig_func):
if not NodeMixin.get(i, None): if not NodeMixin.get(i, None):
if isinstance(i, (RawTensor, NodeMixin)): if isinstance(i, (RawTensor, NodeMixin)):
NodeMixin.wrap_safe(i, Constant.make(i)) NodeMixin.wrap_safe(i, Constant.make(i))
meth_name, arg_type = None, None
if args:
meth_name = _get_meth_name(args[0], method_func)
arg_type = args[0] if isinstance(args[0], type) else type(args[0])
args, kwargs = _convert_kwargs_to_args(orig_func, args, kwargs)
meth_name = _get_meth_name(args[0], method_func)
arg_type = args[0] if isinstance(args[0], type) else type(args[0])
if meth_name and arg_type and issubclass(arg_type, RawTensor): if meth_name and arg_type and issubclass(arg_type, RawTensor):
inputs, tree_def = tree_flatten((args, kwargs))
self = inputs[0] self = inputs[0]
if meth_name == "__new__": if meth_name == "__new__":
if all([not isinstance(i, RawTensor) for i in inputs]): if all([not isinstance(i, RawTensor) for i in inputs]):
@@ -1190,6 +1243,7 @@ def _wrapped_function(orig_func):
call_node = CallMethod.make(NodeMixin.get(self), meth_name) call_node = CallMethod.make(NodeMixin.get(self), meth_name)
call_node.add_inputs(inputs[1:]) call_node.add_inputs(inputs[1:])
else: else:
inputs, tree_def = tree_flatten((args, kwargs))
call_node = CallFunction.make(orig_func) call_node = CallFunction.make(orig_func)
call_node.add_inputs(inputs) call_node.add_inputs(inputs)


@@ -1228,9 +1282,11 @@ class TracedModuleBuilder(NodeMixin):
"_record_wrapped_nodes", "_record_wrapped_nodes",
"_argdef_graph_map", "_argdef_graph_map",
"_argdef_outdef_map", "_argdef_outdef_map",
"_check_qat_module",
"nodes", "nodes",
"__class__", "__class__",
"__dict__", "__dict__",
"_is_top",
] ]


def __init__(self, mod, is_top_module=False): def __init__(self, mod, is_top_module=False):
@@ -1301,22 +1357,18 @@ class TracedModuleBuilder(NodeMixin):
qat_module.weight_fake_quant.set_qparams(qparams) qat_module.weight_fake_quant.set_qparams(qparams)


def build(self): def build(self):
if self._is_builtin or isinstance(self._mod, TracedModule):
if module_tracer.is_builtin(self._mod) or isinstance(
self._mod, TracedModule
):
mod_type = type(self._mod)
else:
assert isinstance(self._mod, (Observer, _FakeQuantize))
mod_type = (
Observer if isinstance(self._mod, Observer) else _FakeQuantize
)
if self._is_builtin:
assert module_tracer.is_builtin(self._mod)
mod_type = type(self._mod)

for node in self.nodes: for node in self.nodes:
node.module_type = mod_type node.module_type = mod_type


return self._mod return self._mod
else: else:
is_qat = isinstance(self._mod, QATModule)
is_qat = isinstance(self._mod, QATModule) or (
isinstance(self._mod, TracedModule) and self._mod.is_qat
)
traced_module = TracedModule( traced_module = TracedModule(
self._is_top, self._argdef_graph_map, self._argdef_outdef_map, is_qat self._is_top, self._argdef_graph_map, self._argdef_outdef_map, is_qat
) )
@@ -1338,15 +1390,18 @@ class TracedModuleBuilder(NodeMixin):
traced_module.with_act = self._mod.with_act traced_module.with_act = self._mod.with_act
traced_module.with_weight = self._mod.with_weight traced_module.with_weight = self._mod.with_weight
if not hasattr(traced_module, "act_fake_quant"): if not hasattr(traced_module, "act_fake_quant"):
traced_module.act_fakequant = None
traced_module.act_fake_quant = None
if not hasattr(traced_module, "act_observer"): if not hasattr(traced_module, "act_observer"):
traced_module.act_observer = None traced_module.act_observer = None
if not hasattr(traced_module, "weight_fake_quant"): if not hasattr(traced_module, "weight_fake_quant"):
traced_module.weight_fakequant = None
traced_module.weight_fake_quant = None
if not hasattr(traced_module, "weight_observer"): if not hasattr(traced_module, "weight_observer"):
traced_module.weight_observer = None traced_module.weight_observer = None
set_module_tracing() set_module_tracing()


if self._is_top:
traced_module._update_ref()

return traced_module return traced_module


def _record_wrapped_nodes(self, node): def _record_wrapped_nodes(self, node):
@@ -1357,6 +1412,7 @@ class TracedModuleBuilder(NodeMixin):
# prepare args and kwargs for inner graph # prepare args and kwargs for inner graph
if "method_func" in kwargs: if "method_func" in kwargs:
kwargs.pop("method_func") kwargs.pop("method_func")
args, kwargs = _convert_kwargs_to_args(self._mod.forward, args, kwargs, True)


def mark_constant(x): def mark_constant(x):
node = NodeMixin.get(x, None) node = NodeMixin.get(x, None)
@@ -1372,11 +1428,7 @@ class TracedModuleBuilder(NodeMixin):


callnode.arg_def = tree_def callnode.arg_def = tree_def


if (
self._is_builtin
or tree_def in self._argdef_graph_map
or isinstance(self._mod, TracedModule)
):
if self._is_builtin or tree_def in self._argdef_graph_map:
unset_module_tracing() unset_module_tracing()
rst = self._mod(*args, **kwargs) rst = self._mod(*args, **kwargs)
outputs, out_def = tree_flatten(rst, is_leaf=_is_leaf) outputs, out_def = tree_flatten(rst, is_leaf=_is_leaf)
@@ -1385,33 +1437,7 @@ class TracedModuleBuilder(NodeMixin):
self._body = None self._body = None
elif tree_def in self._argdef_graph_map: elif tree_def in self._argdef_graph_map:
self._body = self._argdef_graph_map[tree_def] self._body = self._argdef_graph_map[tree_def]
else:
self._mod._is_top = False
self._body = self._mod.argdef_graph_map[tree_def]
module_qualname = NodeMixin.get(self).qualname
if module_qualname != self._body.qualname:
src_name, dst_name = self._body.qualname, module_qualname

def replace_qualname(g):
attr_name = get_suffix_name(src_name, g.qualname)
if attr_name is not None:
g._qualname = (
("%s.%s" % (dst_name, attr_name))
if attr_name
else dst_name
)
assert get_suffix_name(dst_name, g.qualname) is not None

for mod in self._mod.modules():
if not hasattr(mod, "argdef_graph_map"):
continue
for g in mod.argdef_graph_map.values():
replace_qualname(g)
g._namespace.qualname = g.qualname
for n in g.nodes(False):
replace_qualname(n)
else: else:
self_node = None
orig_self = NodeMixin.get(self) orig_self = NodeMixin.get(self)
parent_graph = active_module_tracer().current_scope() parent_graph = active_module_tracer().current_scope()
module_qualname = orig_self._qualname module_qualname = orig_self._qualname
@@ -1423,20 +1449,14 @@ class TracedModuleBuilder(NodeMixin):
active_module_tracer().push_scope(self._body) active_module_tracer().push_scope(self._body)
# rebind self to new input node # rebind self to new input node


if self_node:
NodeMixin.wrap_safe(self, self_node)
active_module_tracer().current_scope()._add_input(self_node)
else:
NodeMixin.wrap_safe(
self,
self_node
if self_node
else Input.make(
name="self",
qualname=module_qualname,
type=NodeMixin.get_wrapped_type(self),
),
)
NodeMixin.wrap_safe(
self,
Input.make(
name="self",
qualname=module_qualname,
type=NodeMixin.get_wrapped_type(self),
),
)


origin_inp_node = [NodeMixin.get(i, None) for i in inputs[1:]] origin_inp_node = [NodeMixin.get(i, None) for i in inputs[1:]]
# prepare args and kwargs for inner graph # prepare args and kwargs for inner graph
@@ -1470,8 +1490,23 @@ class TracedModuleBuilder(NodeMixin):
return x return x


args = [self] args = [self]
for i, v in enumerate(inputs[1:]):
args.append(wrap(v, idx2key[i + 1]))
orig_traced_inputs = (
None
if not isinstance(self._mod, TracedModule)
else self._mod.argdef_graph_map[tree_def].inputs
)
ind = 1
for v in inputs[1:]:
if isinstance(v, (RawTensor, NodeMixin)):
args_name = (
orig_traced_inputs[ind]._name
if orig_traced_inputs
else idx2key[ind]
)
ind += 1
args.append(wrap(v, args_name))
else:
args.append(v)


args, kwargs = tree_def.unflatten(args) args, kwargs = tree_def.unflatten(args)
active_module_tracer().patcher.auto_patch( active_module_tracer().patcher.auto_patch(
@@ -1514,7 +1549,6 @@ class TracedModuleBuilder(NodeMixin):
attr = getattr(type(self._mod), name).__get__(self, type(self)) attr = getattr(type(self._mod), name).__get__(self, type(self))
else: else:
attr = getattr(self._mod, name) attr = getattr(self._mod, name)

if ( if (
isinstance(attr, FunctionType) isinstance(attr, FunctionType)
and id(attr) in active_module_tracer().patcher.patched_fn_ids and id(attr) in active_module_tracer().patcher.patched_fn_ids
@@ -1568,7 +1602,7 @@ class TracedModuleBuilder(NodeMixin):
wrapped = self.__getattr__(name) wrapped = self.__getattr__(name)


if isinstance(wrapped, TracedModuleBuilder): if isinstance(wrapped, TracedModuleBuilder):
if not isinstance(mod_attr, (List, Dict)):
if not isinstance(mod_attr, (List, Dict, QATModule)):
assert mod_attr is wrapped._mod assert mod_attr is wrapped._mod
else: else:
assert mod_attr is wrapped assert mod_attr is wrapped
@@ -1977,8 +2011,6 @@ class TracedModule(Module):
def graph(self) -> InternalGraph: def graph(self) -> InternalGraph:
"""Return the ``InternalGraph`` of this ``TracedModule``. """Return the ``InternalGraph`` of this ``TracedModule``.
""" """
if self._is_top:
self._update_ref()
assert len(self.argdef_graph_map) == 1 assert len(self.argdef_graph_map) == 1
return list(self.argdef_graph_map.values())[0] return list(self.argdef_graph_map.values())[0]


@@ -2112,7 +2144,7 @@ class TracedModule(Module):
if hasattr(obj, "argdef_graph_map") if hasattr(obj, "argdef_graph_map")
else None else None
) )
if expr_graph is not None:
if expr_graph is not None and not obj.is_qat:
exprs = _flatten_subgraph(graph, expr_graph, expr, obj) exprs = _flatten_subgraph(graph, expr_graph, expr, obj)


if parent_graph is not None: if parent_graph is not None:
@@ -2137,26 +2169,119 @@ class TracedModule(Module):
) )
new_module.graph._re_associate_name() new_module.graph._re_associate_name()
new_module.graph.compile() new_module.graph.compile()
new_module._update_ref()
new_module.graph._reset_ids() new_module.graph._reset_ids()
return new_module return new_module


def __getstate__(self): def __getstate__(self):
d = self.__dict__
d = self.__dict__.copy()
for k in Module.__dict__: for k in Module.__dict__:
d.pop(k, None) d.pop(k, None)
_check_obj_attr(d)
for k in d:
if module_tracer.is_builtin(d[k]):
assert _check_builtin_module_attr(
d[k]
), "Module {} can not be serialized. ".format(type(d[k]))
d[k] = _ModuleState.get_module_state(d[k])
dump_info = {
"version": __version__,
"register_type": USER_REGISTERED_LEAF_TYPE,
"register_container_type": USER_REGISTERED_CONTAINER_TYPE,
"register_mdule": USER_REGISTERED_MODULE,
"register_function": USER_REGISTERED_FUNCTION,
}
d["dump_info"] = dump_info
return d return d


def __setstate__(self, state):

for k, v in state.items():
if isinstance(v, _ModuleState):
state[k] = v.to_module()
self.__dict__.update(state)
self._update_ref()

for _, graph in self.argdef_graph_map.items():
for expr in graph._exprs:
if isinstance(expr, CallFunction):
load_functional(expr)
if isinstance(expr, CallMethod):
if expr.method == "__call__":
load_call_module_expr(expr)
else:
load_call_tensor_method_expr(expr)
if isinstance(expr, Apply):
load_apply_expr(expr)

for _, graph in self.argdef_graph_map.items():
ind = 0
while ind < len(graph._exprs):
cur_expr = graph._exprs[ind]
has_new_expr = False
for i in cur_expr.inputs:
if i.expr not in graph._exprs and not isinstance(i.expr, Input):
graph._exprs.insert(ind, i.expr)
has_new_expr = True
if not has_new_expr:
ind += 1
for expr in graph._exprs:
for i in expr.inputs:
if expr.inputs.count(i) != i.users.count(expr):
add_or_del_count = expr.inputs.count(i) - i.users.count(expr)
if add_or_del_count > 0:
i.users.extend([expr] * add_or_del_count)
else:
[i.users.remove(expr) for i in range(-add_or_del_count)]

for o in expr.outputs:
if o.expr is not expr:
assert o not in o.expr.outputs
o.expr = expr
for node in graph.nodes(False):
# remove users of node which doesn't use node as input
node.users = [e for e in node.users if node in e.inputs]

for expr in graph._exprs:
graph._namespace.auto_naming_for_outputs(expr)
self._update_ref()
for _, graph in self.argdef_graph_map.items():
graph._reset_ids()

def __copy__(self):
cls = self.__class__
result = cls.__new__(cls)
result.__dict__.update(self.__dict__)
return result

def __deepcopy__(self, memo):
cls = self.__class__
result = cls.__new__(cls)
state = {}
memo[id(self)] = result
for k, v in self.__dict__.items():
if not isinstance(v, weakref.ReferenceType):
state[k] = copy.deepcopy(v, memo)
result.__dict__.update(state)
result._update_ref()
return result



def cpp_apply_module_trace(opdef, *args): def cpp_apply_module_trace(opdef, *args):
return Apply.apply_module_trace_hook(opdef, *args) return Apply.apply_module_trace_hook(opdef, *args)




USER_REGISTERED_MODULE = []
USER_REGISTERED_FUNCTION = []


def register_as_builtin(mod_cls: Type[Module]) -> None: def register_as_builtin(mod_cls: Type[Module]) -> None:
r"""Registers class ``mod_cls`` (subclass of :class:`~.Module`) as builtin module. r"""Registers class ``mod_cls`` (subclass of :class:`~.Module`) as builtin module.


Args: Args:
mod_cls: the module class which will be treated as builtin module in tracing. mod_cls: the module class which will be treated as builtin module in tracing.
""" """
USER_REGISTERED_MODULE.append((mod_cls.__module__, mod_cls.__qualname__))
module_tracer.register_as_builtin(mod_cls) module_tracer.register_as_builtin(mod_cls)




@@ -2181,6 +2306,7 @@ def wrap(func: Callable):
Args: Args:
func: the function of the global function to insert into the graph when it's called. func: the function of the global function to insert into the graph when it's called.
""" """
USER_REGISTERED_FUNCTION.append((func.__module__, func.__qualname__))
assert callable(func), "func must be a callable" assert callable(func), "func must be a callable"
assert hasattr(func, "__code__") assert hasattr(func, "__code__")
fn_name = func.__code__.co_name fn_name = func.__code__.co_name
@@ -2247,6 +2373,8 @@ def trace_module(
NodeMixin.wrap_safe( NodeMixin.wrap_safe(
builder, Input.make(name="top", type=ModuleNode, qualname=net_name) builder, Input.make(name="top", type=ModuleNode, qualname=net_name)
) )
args, kwargs = _convert_kwargs_to_args(mod.forward, args, kwargs, True)

inputs, _ = tree_flatten((args, kwargs)) inputs, _ = tree_flatten((args, kwargs))
for _, i in enumerate(inputs): for _, i in enumerate(inputs):
# assert isinstance(i, Tensor), "not support " # assert isinstance(i, Tensor), "not support "


+ 101
- 1
imperative/python/megengine/traced_module/utils.py View File

@@ -5,12 +5,17 @@
# Unless required by applicable law or agreed to in writing, # Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import collections
import copy import copy
import inspect
from collections.abc import MutableMapping, MutableSequence from collections.abc import MutableMapping, MutableSequence
from typing import Dict, Iterable, List, Optional, Sequence
from typing import Dict, Iterable, List, Optional, Sequence, Type


from .. import get_logger
from ..module import Module from ..module import Module


logger = get_logger(__name__)



def replace_container_with_module_container(container): def replace_container_with_module_container(container):
has_module = False has_module = False
@@ -52,6 +57,101 @@ def replace_container_with_module_container(container):
return has_module, module_container return has_module, module_container




def _convert_kwargs_to_args(func, args, kwargs, is_bounded=False):
# is_bounded = True when func is a method and provided args don't include 'self'
arg_specs = inspect.getfullargspec(func)
arg_specs_args = arg_specs.args
if is_bounded:
arg_specs_args = arg_specs.args[1:]
new_args = []
new_kwargs = {}
new_args.extend(args)
if set(arg_specs_args[0 : len(new_args)]) & set(kwargs.keys()):
repeated_arg_name = set(arg_specs_args[0 : len(new_args)]) & set(kwargs.keys())
raise TypeError(
"{} got multiple values for argument {}".format(
func.__qualname__, ", ".join(repeated_arg_name)
)
)
if len(new_args) < len(arg_specs.args):
for ind in range(len(new_args), len(arg_specs_args)):
arg_name = arg_specs_args[ind]
if arg_name in kwargs:
new_args.append(kwargs[arg_name])
else:
index = ind - len(arg_specs_args) + len(arg_specs.defaults)
assert index < len(arg_specs.defaults) and index >= 0
new_args.append(arg_specs.defaults[index])

for kwarg_name in arg_specs.kwonlyargs:
if kwarg_name in kwargs:
new_kwargs[kwarg_name] = kwargs[kwarg_name]
else:
assert kwarg_name in arg_specs.kwonlydefaults
new_kwargs[kwarg_name] = arg_specs.kwonlydefaults[kwarg_name]
for k, v in kwargs.items():
if k not in arg_specs.args and k not in arg_specs.kwonlyargs:
if arg_specs.varkw is None:
raise TypeError(
"{} got an unexpected keyword argument {}".format(
func.__qualname__, k
)
)
new_kwargs[k] = v
return tuple(new_args), new_kwargs


def _check_obj_attr(obj):
# check if all the attributes of a obj is serializable
from .pytree import tree_flatten
from .pytree import SUPPORTED_LEAF_CLS, SUPPORTED_LEAF_TYPE, TreeDef
from .expr import Expr
from .traced_module import TracedModule, InternalGraph, NameSpace

def _check_leaf_type(leaf):
leaf_type = leaf if isinstance(leaf, type) else type(leaf)
traced_module_types = [Expr, TreeDef, TracedModule, InternalGraph, NameSpace]
return (
issubclass(leaf_type, tuple(SUPPORTED_LEAF_CLS + traced_module_types))
or leaf_type in SUPPORTED_LEAF_TYPE
)

for _, v in obj.items():
leafs, _ = tree_flatten(v, is_leaf=lambda _: True)
for leaf in leafs:
assert _check_leaf_type(
leaf
), "Type {} is not supported by traced module".format(
leaf if isinstance(leaf, type) else type(leaf)
)


def _check_builtin_module_attr(mod):
from .pytree import _is_leaf as _check_leaf_type
from .pytree import tree_flatten

# check if all the attributes of a builtin module is serializable
is_non_serializable_module = lambda m: isinstance(
m, Module
) and not _check_builtin_module_attr(m)
for k, v in mod.__dict__.items():
if k == "_m_dump_modulestate":
continue
if is_non_serializable_module(v):
return False
elif not isinstance(v, Module):
leafs, _ = tree_flatten(v, is_leaf=lambda _: True)
for leaf in leafs:
if not _check_leaf_type(leaf) or is_non_serializable_module(leaf):
logger.warn(
"Type {} is not supported by traced module".format(
leaf if isinstance(leaf, type) else type(leaf)
)
)
return False
return True


class _ModuleList(Module, MutableSequence): class _ModuleList(Module, MutableSequence):
r"""A List-like container. r"""A List-like container.




+ 0
- 23
imperative/python/test/unit/core/test_serialization.py View File

@@ -15,7 +15,6 @@ import numpy as np
import megengine as mge import megengine as mge
from megengine import Parameter, Tensor from megengine import Parameter, Tensor
from megengine.core.ops import builtin from megengine.core.ops import builtin
from megengine.traced_module.serialization import get_opdef_state, load_opdef_from_state




def test_tensor_serialization(): def test_tensor_serialization():
@@ -88,25 +87,3 @@ def test_compatibility():


test_old_tensor("tensor_v1_1.mge") test_old_tensor("tensor_v1_1.mge")
test_old_tensor("tensor_v1_2.mge") test_old_tensor("tensor_v1_2.mge")


def test_opdef_serialization():
with TemporaryFile() as f:
x = builtin.Elemwise(mode="Add")
pickle.dump(get_opdef_state(x), f)
f.seek(0)
load_x = load_opdef_from_state(pickle.load(f))
assert x == load_x

with TemporaryFile() as f:
x = builtin.Convolution(stride_h=9, compute_mode="float32")
x.strategy = (
builtin.Convolution.Strategy.PROFILE
| builtin.Convolution.Strategy.HEURISTIC
| builtin.Convolution.Strategy.REPRODUCIBLE
)
pickle.dump(get_opdef_state(x), f)
f.seek(0)
load_x = load_opdef_from_state(pickle.load(f))
assert x.strategy == load_x.strategy
assert x == load_x

+ 3
- 3
imperative/python/test/unit/traced_module/test_modification.py View File

@@ -85,12 +85,12 @@ class NewModule(M.Module):
return x return x




def _check_expr_users(traced_module):
def _check_expr_users(flattened_module):
node_user = defaultdict(list) node_user = defaultdict(list)
for expr in traced_module.graph._exprs:
for expr in flattened_module.graph._exprs:
for node in expr.inputs: for node in expr.inputs:
node_user[node].append(expr) node_user[node].append(expr)
for node in traced_module.graph.nodes():
for node in flattened_module.graph.nodes():
node.users.sort(key=lambda m: m._id) node.users.sort(key=lambda m: m._id)
node_user[node].sort(key=lambda m: m._id) node_user[node].sort(key=lambda m: m._id)
assert node.users == node_user[node] assert node.users == node_user[node]


+ 24
- 2
imperative/python/test/unit/traced_module/test_qat_module.py View File

@@ -8,6 +8,7 @@ import numpy as np
import megengine as mge import megengine as mge
import megengine.functional as F import megengine.functional as F
import megengine.module as M import megengine.module as M
import megengine.module.qat as QM
import megengine.quantization as Q import megengine.quantization as Q
from megengine import Tensor from megengine import Tensor
from megengine.module.qat.module import QATModule from megengine.module.qat.module import QATModule
@@ -28,10 +29,18 @@ def get_subattr(self: M.Module, name: str):
return getattr(self, name) return getattr(self, name)




class MyConvBnRelu2d(M.ConvBnRelu2d):
pass


class MyQATConvBnRelu2d(QM.ConvBnRelu2d):
pass


class Myblcok(M.Module): class Myblcok(M.Module):
def __init__(self,): def __init__(self,):
super().__init__() super().__init__()
self.conv0 = M.ConvBnRelu2d(3, 3, 3, 1, 1)
self.conv0 = MyConvBnRelu2d(3, 3, 3, 1, 1)
self.conv1 = M.ConvBn2d(3, 3, 1, 1, 0) self.conv1 = M.ConvBn2d(3, 3, 1, 1, 0)
self.conv2 = M.ConvBn2d(3, 3, 1, 1, 0) self.conv2 = M.ConvBn2d(3, 3, 1, 1, 0)
self.add = M.Elemwise("FUSE_ADD_RELU") self.add = M.Elemwise("FUSE_ADD_RELU")
@@ -106,7 +115,11 @@ def check_qparams(qparmsa: Q.QParams, qparmsb: Q.QParams):




def build_observered_net(net: M.Module, observer_cls): def build_observered_net(net: M.Module, observer_cls):
qat_net = Q.quantize_qat(net, qconfig=get_observer_config(observer_cls))
qat_net = Q.quantize_qat(
net,
qconfig=get_observer_config(observer_cls),
mapping={MyConvBnRelu2d: MyQATConvBnRelu2d},
)
Q.enable_observer(qat_net) Q.enable_observer(qat_net)
inp = Tensor(np.random.random(size=(5, 3, 32, 32))) inp = Tensor(np.random.random(size=(5, 3, 32, 32)))
qat_net(inp) qat_net(inp)
@@ -134,6 +147,15 @@ def test_trace_qat():
check_qparams(weight_qparams, traced_weight_qparams) check_qparams(weight_qparams, traced_weight_qparams)
if act_qparams: if act_qparams:
check_qparams(act_qparams, traced_act_qparams) check_qparams(act_qparams, traced_act_qparams)
flatten_traced_net = traced_net.flatten()
conv0_node = flatten_traced_net.graph.get_node_by_name(
"MyModule_block0_conv0"
).as_unique()
conv0_out_node = flatten_traced_net.graph.get_node_by_name(
"MyModule_block0_conv0_out"
).as_unique()
assert isinstance(conv0_node.owner, TracedModule)
assert conv0_out_node.expr.inputs[0] is conv0_node


_check_qat_module(build_observered_net(MyModule(), Q.MinMaxObserver)) _check_qat_module(build_observered_net(MyModule(), Q.MinMaxObserver))
_check_qat_module(build_observered_net(MyModule(), MyMinMaxObserver)) _check_qat_module(build_observered_net(MyModule(), MyMinMaxObserver))


+ 315
- 1
imperative/python/test/unit/traced_module/test_serialization.py View File

@@ -6,14 +6,59 @@
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import pickle import pickle
from collections import defaultdict
from tempfile import TemporaryFile


import numpy as np import numpy as np


import megengine.functional as F import megengine.functional as F
import megengine.module as M import megengine.module as M
import megengine.traced_module.serialization as S
from megengine import Tensor from megengine import Tensor
from megengine.core._imperative_rt.core2 import apply
from megengine.core.ops import builtin
from megengine.core.ops.builtin import Elemwise
from megengine.module import Module from megengine.module import Module
from megengine.traced_module import trace_module from megengine.traced_module import trace_module
from megengine.traced_module.expr import CallMethod, Constant
from megengine.traced_module.node import TensorNode
from megengine.traced_module.serialization import (
register_functional_loader,
register_module_loader,
register_opdef_loader,
register_tensor_method_loader,
)
from megengine.traced_module.utils import _convert_kwargs_to_args


def _check_id(traced_module):
_total_ids = traced_module.graph._total_ids
node_ids = [n._id for n in traced_module.graph.nodes().as_list()]
assert len(set(node_ids)) == len(node_ids)
assert max(node_ids) + 1 == _total_ids[0]

expr_ids = [n._id for n in traced_module.graph.exprs().as_list()]
assert len(set(expr_ids)) == len(expr_ids)
assert max(expr_ids) + 1 == _total_ids[1]


def _check_name(flatened_module):
node_names = [n._name for n in flatened_module.graph.nodes().as_list()]
assert len(set(node_names)) == len(node_names)


def _check_expr_users(traced_module):
node_user = defaultdict(list)
for expr in traced_module.graph._exprs:
for node in expr.inputs:
node_user[node].append(expr)
if isinstance(expr, CallMethod) and expr.graph:
_check_expr_users(expr.inputs[0].owner)

for node in traced_module.graph.nodes(False):
node.users.sort(key=lambda m: m._id)
node_user[node].sort(key=lambda m: m._id)
assert node.users == node_user[node]




class MyBlock(Module): class MyBlock(Module):
@@ -48,5 +93,274 @@ def test_dump_and_load():
traced_module = trace_module(module, x) traced_module = trace_module(module, x)
np.testing.assert_array_equal(expect, traced_module(x)) np.testing.assert_array_equal(expect, traced_module(x))
obj = pickle.dumps(traced_module) obj = pickle.dumps(traced_module)
pickle.loads(obj)
new_tm = pickle.loads(obj)
_check_id(new_tm)
_check_expr_users(new_tm)
traced_module.graph._reset_ids()
old_nodes = traced_module.graph.nodes().as_list()
new_nodes = new_tm.graph.nodes().as_list()
old_exprs = traced_module.graph.exprs().as_list()
new_exprs = new_tm.graph.exprs().as_list()
assert len(old_nodes) == len(new_nodes)
for i, j in zip(old_nodes, new_nodes):
assert i._name == j._name
assert i._qualname == j._qualname
assert i._id == j._id
assert len(old_exprs) == len(new_exprs)
for i, j in zip(old_exprs, new_exprs):
assert i._id == j._id

np.testing.assert_array_equal(expect, traced_module(x)) np.testing.assert_array_equal(expect, traced_module(x))


def test_opdef_loader():
class MyModule1(Module):
def forward(self, x, y):
op = Elemwise("ADD")
return apply(op, x, y)[0]

m = MyModule1()
x = Tensor(np.ones((20)))
y = Tensor(np.ones((20)))
traced_module = trace_module(m, x, y)
orig_loader_dict = S.OPDEF_LOADER
S.OPDEF_LOADER = {}

@register_opdef_loader(Elemwise)
def add_opdef_loader(expr):
if expr.opdef_state["mode"] == "ADD":
expr.opdef_state["mode"] = "MUL"
node = expr.inputs[1]
astype_expr = CallMethod(node, "astype")
oup = TensorNode(
astype_expr,
shape=node.shape,
dtype=expr.inputs[0].dtype,
qparams=node.qparams,
)
astype_expr.set_args_kwargs(node, expr.inputs[0].dtype)
astype_expr.return_val = (oup,)
expr.inputs[1] = oup

obj = pickle.dumps(traced_module)
new_module = pickle.loads(obj)
_check_id(new_module)
_check_expr_users(new_module)
_check_name(new_module.flatten())
assert (
isinstance(new_module.graph._exprs[0], CallMethod)
and new_module.graph._exprs[1].opdef.mode == "MUL"
and len(new_module.graph._exprs) == 2
)
result = new_module(x, y)
np.testing.assert_equal(result.numpy(), x.numpy())
S.OPDEF_LOADER = orig_loader_dict


def test_functional_loader():
class MyModule2(Module):
def forward(self, x, y):
return F.conv2d(x, y)

m = MyModule2()
x = Tensor(np.random.random((1, 3, 32, 32)))
y = Tensor(np.random.random((3, 3, 3, 3)))
traced_module = trace_module(m, x, y)
orig_loader_dict = S.FUNCTIONAL_LOADER
S.FUNCTIONAL_LOADER = {}

@register_functional_loader(("megengine.functional.nn", "conv2d"))
def conv2df_loader(expr):
# expr.func = ("megengine.functional.nn","conv2d")
kwargs = expr.kwargs
orig_weight = expr.named_args["weight"]

astype_expr = CallMethod(orig_weight, "astype")
oup = TensorNode(
astype_expr,
shape=orig_weight.shape,
dtype=orig_weight.dtype,
qparams=orig_weight.qparams,
)

astype_expr.set_args_kwargs(orig_weight, expr.named_args["inp"].dtype)
astype_expr.return_val = (oup,)

expr.set_arg("weight", oup)

obj = pickle.dumps(traced_module)
new_module = pickle.loads(obj)
_check_expr_users(new_module)
_check_id(new_module)
result = new_module(x, y)
gt = m(x, y)
assert (
isinstance(new_module.graph._exprs[0], CallMethod)
and len(new_module.graph._exprs) == 2
)
np.testing.assert_equal(result.numpy(), gt.numpy())
S.FUNCTIONAL_LOADER = orig_loader_dict


def test_tensor_method_loader():
class MyModule3(Module):
def forward(self, x):
return x + 1

m = MyModule3()
x = Tensor(np.ones((20)))
traced_module = trace_module(m, x)
orig_loader_dict = S.TENSORMETHOD_LOADER
S.TENSORMETHOD_LOADER = {}

@register_tensor_method_loader("__add__")
def add_loader(expr):
args = list(expr.args)
if not isinstance(args[1], TensorNode):
args[1] = Tensor(args[1])
node = Constant(args[1], "const").outputs[0]

astype_expr = CallMethod(node, "astype")
oup = TensorNode(
astype_expr, shape=node.shape, dtype=node.dtype, qparams=node.qparams,
)
astype_expr.set_args_kwargs(node, expr.inputs[0].dtype)
astype_expr.return_val = (oup,)

add_expr = CallMethod(oup, "__add__")
add_expr.set_args_kwargs(oup, oup)
oup1 = TensorNode(
add_expr, shape=oup.shape, dtype=oup.dtype, qparams=node.qparams,
)
add_expr.return_val = oup1
args[1] = oup1
expr.set_args_kwargs(*args)

obj = pickle.dumps(traced_module)
new_module = pickle.loads(obj)
_check_expr_users(new_module)
_check_id(new_module)
result = new_module(x)
gt = m(x)
assert (
isinstance(new_module.graph._exprs[0], Constant)
and len(new_module.graph._exprs) == 4
)
np.testing.assert_equal(result.numpy(), (x + 2).numpy())
S.TENSORMETHOD_LOADER = orig_loader_dict


def test_module_loader():
class MyModule4(Module):
def __init__(self):
super().__init__()
self.conv = M.Conv2d(3, 3, 3)

def forward(self, x):
return self.conv(x)

m = MyModule4()
x = Tensor(np.random.random((1, 3, 32, 32)))
traced_module = trace_module(m, x)
orig_loader_dict = S.MODULE_LOADER
S.MODULE_LOADER = {}

@register_module_loader(("megengine.module.conv", "Conv2d"))
def conv2dm_loader(expr):
module = expr.inputs[0].owner
args = list(expr.args)
orig_inp = args[1]
astype_expr = CallMethod(orig_inp, "astype")
oup = TensorNode(
astype_expr,
shape=orig_inp.shape,
dtype=orig_inp.dtype,
qparams=orig_inp.qparams,
)
astype_expr.set_args_kwargs(orig_inp, module.weight.dtype)
astype_expr.return_val = (oup,)
args[1] = oup
expr.set_args_kwargs(*args)

obj = pickle.dumps(traced_module)
new_module = pickle.loads(obj)
result = new_module(x)
gt = m(x)
assert (
isinstance(new_module.graph._exprs[1], CallMethod)
and len(new_module.graph._exprs) == 3
)
np.testing.assert_equal(result.numpy(), gt.numpy())
S.MODULE_LOADER = orig_loader_dict


def test_shared_module():
class MyModule(M.Module):
def __init__(self):
super().__init__()
self.a = M.Elemwise("ADD")
self.b = self.a

def forward(self, x, y):
z = self.a(x, y)
z = self.b(z, y)
return z

x = Tensor(1)
y = Tensor(2)
m = MyModule()
tm = trace_module(m, x, y)
obj = pickle.dumps(tm)
load_tm = pickle.loads(obj)
_check_expr_users(load_tm)
_check_name(load_tm.flatten())
_check_id(load_tm)
assert load_tm.a is load_tm.b


def test_convert_kwargs_to_args():
def func(a, b, c=4, *, d, e=3, f=4):
pass

args = (1,)
kwargs = {"b": 1, "d": 6}
new_args, new_kwargs = _convert_kwargs_to_args(func, args, kwargs)
assert new_args == (1, 1, 4)
assert new_kwargs == {"d": 6, "e": 3, "f": 4}

args = (1,)
kwargs = {"d": 6}
new_args, new_kwargs = _convert_kwargs_to_args(func, args, kwargs, is_bounded=True)
assert new_args == (1, 4)
assert new_kwargs == {"d": 6, "e": 3, "f": 4}

def func1(a, b, c, d, e, *, f):
pass

args = ()
kwargs = {"a": 1, "b": 2, "c": 3, "d": 4, "e": 5, "f": 6}
new_args, new_kwargs = _convert_kwargs_to_args(func1, args, kwargs)
assert new_args == (1, 2, 3, 4, 5)
assert new_kwargs == {"f": 6}


def test_opdef_serialization():
with TemporaryFile() as f:
x = builtin.Elemwise(mode="Add")
pickle.dump(x, f)
f.seek(0)
load_x = pickle.load(f)
assert x == load_x

with TemporaryFile() as f:
x = builtin.Convolution(stride_h=9, compute_mode="float32")
x.strategy = (
builtin.Convolution.Strategy.PROFILE
| builtin.Convolution.Strategy.HEURISTIC
| builtin.Convolution.Strategy.REPRODUCIBLE
)
pickle.dump(x, f)
f.seek(0)
load_x = pickle.load(f)
assert x.strategy == load_x.strategy
assert x == load_x

Loading…
Cancel
Save