@@ -10,6 +10,7 @@ from ..core._imperative_rt.core2 import set_cpp_apply_module_trace | |||
from . import compat | |||
from ._passes import optimize | |||
from .pytree import register_supported_type | |||
from .tm_config import disable_default_checker, enable_expr_checker | |||
from .traced_module import ( | |||
TracedModule, | |||
_register_all_builtin_module, | |||
@@ -29,4 +30,6 @@ __all__ = [ | |||
"wrap", | |||
"TracedModule", | |||
"optimize", | |||
"enable_expr_checker", | |||
"disable_default_checker", | |||
] |
@@ -0,0 +1,142 @@ | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import traceback | |||
from typing import Sequence | |||
import numpy as np | |||
from ..core._imperative_rt.core2 import apply | |||
from ..core._imperative_rt.ops import ROIAlign, ROIPooling | |||
from ..core.ops.builtin import Copy | |||
from ..core.tensor.utils import isscalar, setscalar | |||
from ..tensor import Tensor | |||
from .tm_config import _exclude_from_trace | |||
class TracedModuleChecker: | |||
def __init__(self, tracer): | |||
self._active_node2values = [] | |||
self.tracer = tracer | |||
self.node_without_tensor_info = {} | |||
def push_scope(self): | |||
self._active_node2values.append({}) | |||
def pop_scope(self): | |||
self._active_node2values.pop() | |||
def current_node2values(self): | |||
return self._active_node2values[-1] | |||
def reset_checker(self): | |||
self._active_node2values = [] | |||
def check_node_not_in_scope(self): | |||
if self.node_without_tensor_info: | |||
for node, info in self.node_without_tensor_info.items(): | |||
for expr in info[0]._exprs: | |||
if node in expr.inputs or node in expr.outputs: | |||
traceback.print_list(info[1]) | |||
raise ValueError( | |||
"node({}) not in the graph:\n{}".format(node, info[0]) | |||
) | |||
return True | |||
else: | |||
return False | |||
def check_net_outputs(self, tm_res, gt_res): | |||
if isinstance(tm_res, Tensor): | |||
np.testing.assert_allclose(tm_res.numpy(), gt_res.numpy()) | |||
elif isinstance(tm_res, Sequence): | |||
for i, j in zip(tm_res, gt_res): | |||
np.testing.assert_allclose(i.numpy(), j.numpy()) | |||
else: | |||
for k in tm_res.__dict__.keys(): | |||
np.testing.assert_allclose( | |||
getattr(tm_res, k).numpy(), getattr(gt_res, k).numpy() | |||
) | |||
def record_nodemixin(self, node, value): | |||
self.current_node2values()[node] = value | |||
def record_node2value(self, node, value): | |||
with _exclude_from_trace(): | |||
self.current_node2values()[node] = apply( | |||
Copy(comp_node=value.device), value | |||
)[0] | |||
if isscalar(value): | |||
setscalar(self.current_node2values()[node]) | |||
def check_apply_special_cases(self, opdef, num_outputs): | |||
indexs = list(range(num_outputs)) | |||
if isinstance(opdef, ROIAlign) and opdef.mode == ROIAlign.Mode.AVERAGE: | |||
indexs.pop(-1) | |||
if isinstance(opdef, ROIPooling) and opdef.mode == ROIPooling.Mode.AVERAGE: | |||
indexs.pop(-1) | |||
return indexs | |||
def check_expr_results(self, expr_outputs, gt_outputs, indexs=None): | |||
expr_outputs = ( | |||
(expr_outputs,) if not isinstance(expr_outputs, Sequence) else expr_outputs | |||
) | |||
gt_outputs = ( | |||
(gt_outputs,) if not isinstance(gt_outputs, Sequence) else gt_outputs | |||
) | |||
if indexs is not None: | |||
for i in indexs: | |||
np.testing.assert_allclose( | |||
expr_outputs[i].numpy(), gt_outputs[i].numpy() | |||
) | |||
else: | |||
np.testing.assert_allclose(expr_outputs, gt_outputs) | |||
def get_node2value(self, inputs, start_idx=0): | |||
inp_values = [] | |||
has_node_not_in_scope = False | |||
for i in range(start_idx, len(inputs)): | |||
try: | |||
inp_values.append(self.current_node2values()[inputs[i]]) | |||
except: | |||
has_node_not_in_scope = True | |||
self.node_without_tensor_info[inputs[i]] = [ | |||
self.tracer.current_scope(), | |||
traceback.extract_stack(), | |||
] | |||
return inp_values, has_node_not_in_scope | |||
def check_expr_interpret(self, expr, gt_outputs): | |||
ori_in, has_node_not_in_scope = self.get_node2value(expr.inputs) | |||
if not has_node_not_in_scope: | |||
expr_res = expr.interpret(*ori_in) | |||
try: | |||
self.check_expr_results(expr_res, gt_outputs) | |||
except: | |||
raise ValueError("Error occurred when checking expr: {}".format(expr)) | |||
def check_apply(self, expr, gt_outputs, opdef): | |||
ori_in, has_node_not_in_scope = self.get_node2value(expr.inputs) | |||
if not has_node_not_in_scope: | |||
expr_res = expr.interpret(*ori_in) | |||
indexs = self.check_apply_special_cases(opdef, len(gt_outputs)) | |||
try: | |||
self.check_expr_results(expr_res, gt_outputs, indexs=indexs) | |||
except: | |||
raise ValueError("Error occurred when checking expr: {}".format(expr)) | |||
def check_builtin_module(self, module, expr, gt_outputs): | |||
ori_in, has_node_not_in_scope = self.get_node2value(expr.inputs, start_idx=1) | |||
if not has_node_not_in_scope: | |||
ori_in.insert(0, module) | |||
expr_res = expr.interpret(*ori_in) | |||
try: | |||
self.check_expr_results(expr_res, gt_outputs) | |||
except: | |||
raise ValueError( | |||
"{}, Error occurred when checking expr: {}".format(expr) | |||
) |
@@ -32,6 +32,7 @@ from .module_tracer import active_module_tracer, module_tracer | |||
from .node import ModuleNode, Node, NodeMixin, TensorNode | |||
from .pytree import ArgsIndex, TreeDef, _is_const_leaf, _is_leaf, tree_flatten | |||
from .serialization import _ModuleState | |||
from .tm_config import _exclude_from_trace, _get_expr_checker | |||
from .utils import _check_builtin_module_attr, _check_obj_attr, _convert_kwargs_to_args | |||
@@ -611,6 +612,8 @@ class Apply(Expr): | |||
inp_nodes = [NodeMixin.get(inputs[0])] | |||
for i in inputs[1:]: | |||
node = Constant.make(i) | |||
if _get_expr_checker(): | |||
active_module_tracer().checker.record_node2value(node, Tensor(i)) | |||
inp_nodes.append(node) | |||
apply_node = cls.make(opdef) | |||
for n in inp_nodes: | |||
@@ -624,11 +627,17 @@ class Apply(Expr): | |||
unset_module_tracing() | |||
outputs = apply(opdef, *inputs) | |||
outputs = list(map(Tensor, outputs)) | |||
set_module_tracing() | |||
apply_node.add_outputs(outputs) | |||
for n, v in zip(apply_node.outputs, outputs): | |||
NodeMixin.wrap_safe(v, n) | |||
if _get_expr_checker(): | |||
with _exclude_from_trace(): | |||
active_module_tracer().checker.check_apply(apply_node, outputs, opdef) | |||
return list(outputs) | |||
@@ -12,6 +12,7 @@ from .. import functional as F | |||
from ..core.tensor.array_method import ArrayMethodMixin | |||
from ..module import Module | |||
from ..module.qat import QATModule | |||
from .checker import TracedModuleChecker | |||
_active_module_tracer = None | |||
@@ -128,6 +129,7 @@ class module_tracer: | |||
def __init__(self, wrap_fn): | |||
self._active_scopes = [] | |||
self.checker = TracedModuleChecker(self) | |||
self.patcher = Patcher(wrap_fn) | |||
@classmethod | |||
@@ -142,9 +144,11 @@ class module_tracer: | |||
def push_scope(self, scope): | |||
self._active_scopes.append(scope) | |||
self.checker.push_scope() | |||
def pop_scope(self): | |||
self._active_scopes.pop() | |||
self.checker.pop_scope() | |||
def current_scope(self): | |||
if self._active_scopes: | |||
@@ -18,6 +18,8 @@ from ..core._imperative_rt.core2 import Tensor as RawTensor | |||
from ..module import Module | |||
from ..quantization.utils import QParams | |||
from ..tensor import Tensor | |||
from .module_tracer import active_module_tracer | |||
from .tm_config import _get_expr_checker | |||
from .utils import _check_obj_attr | |||
logger = get_logger(__name__) | |||
@@ -343,6 +345,11 @@ class NodeMixin(abc.ABC): | |||
if isinstance(value, NodeMixin): | |||
value._record_wrapped_nodes(node) | |||
setattr(value, "_NodeMixin__node", node) | |||
if _get_expr_checker(): | |||
if isinstance(value, RawTensor): | |||
active_module_tracer().checker.record_node2value(node, value) | |||
if isinstance(value, NodeMixin): | |||
active_module_tracer().checker.record_nodemixin(node, value) | |||
else: | |||
assert callable(node) | |||
n = node() | |||
@@ -352,6 +359,11 @@ class NodeMixin(abc.ABC): | |||
if isinstance(value, NodeMixin): | |||
value._record_wrapped_nodes(n) | |||
setattr(value, "_NodeMixin__node", n) | |||
if _get_expr_checker(): | |||
if isinstance(value, RawTensor): | |||
active_module_tracer().checker.record_node2value(n, value) | |||
if isinstance(value, NodeMixin): | |||
active_module_tracer().checker.record_nodemixin(n, value) | |||
@classmethod | |||
def wrap_safe(cls, value, node): | |||
@@ -359,6 +371,11 @@ class NodeMixin(abc.ABC): | |||
if isinstance(value, RawTensor): | |||
cls._record_tensornode_property(node, value) | |||
setattr(value, "_NodeMixin__node", node) | |||
if _get_expr_checker(): | |||
if isinstance(value, RawTensor): | |||
active_module_tracer().checker.record_node2value(node, value) | |||
if isinstance(value, NodeMixin): | |||
active_module_tracer().checker.record_nodemixin(node, value) | |||
if isinstance(value, NodeMixin): | |||
value._record_wrapped_nodes(node) | |||
@@ -212,7 +212,11 @@ def tree_flatten( | |||
to reconstruct the pytree. | |||
""" | |||
if type(values) not in SUPPORTED_TYPE: | |||
assert is_leaf(values), values | |||
assert is_leaf( | |||
values | |||
), 'doesn\'t support {} type, MUST use "register_supported_type" method to register self-defined type'.format( | |||
values | |||
) | |||
node = LeafDef(leaf_type(values)) | |||
if is_const_leaf(values): | |||
node.const_val = values | |||
@@ -0,0 +1,55 @@ | |||
import contextlib | |||
from ..core._imperative_rt.core2 import ( | |||
is_tracing_module, | |||
set_module_tracing, | |||
unset_module_tracing, | |||
) | |||
_enable_expr_checker = False | |||
_enable_default_checker = True | |||
def _get_expr_checker(): | |||
return _enable_expr_checker | |||
def _get_default_checker(): | |||
return _enable_default_checker | |||
def enable_expr_checker(): | |||
r"""Call this function to check the result of each expr during tracing.""" | |||
global _enable_expr_checker | |||
_enable_expr_checker = True | |||
_enable_default_checker = False | |||
def disable_default_checker(): | |||
r"""Call this function to disable checking the final output of the model after tracing.""" | |||
global _enable_default_checker | |||
_enable_default_checker = False | |||
_enable_graph_surgery_mode = False | |||
def _graph_surgery_mode(): | |||
return _enable_graph_surgery_mode | |||
def _set_graph_surgery_mode(mode: bool): | |||
global _enable_graph_surgery_mode | |||
pre_mode = _enable_graph_surgery_mode | |||
_enable_graph_surgery_mode = mode | |||
return pre_mode | |||
@contextlib.contextmanager | |||
def _exclude_from_trace(): | |||
is_tracing = is_tracing_module() | |||
if is_tracing: | |||
unset_module_tracing() | |||
yield | |||
if is_tracing: | |||
set_module_tracing() |
@@ -36,11 +36,14 @@ from .. import get_logger | |||
from .. import module as M | |||
from ..core._imperative_rt.core2 import Tensor as RawTensor | |||
from ..core._imperative_rt.core2 import ( | |||
apply, | |||
is_tracing_module, | |||
set_module_tracing, | |||
unset_module_tracing, | |||
) | |||
from ..core._trace_option import set_symbolic_shape | |||
from ..core.ops.builtin import Copy | |||
from ..core.tensor.utils import isscalar, setscalar | |||
from ..module import Module | |||
from ..module import external as MExternal | |||
from ..module.qat import QATModule | |||
@@ -98,6 +101,13 @@ from .serialization import ( | |||
load_call_tensor_method_expr, | |||
load_functional, | |||
) | |||
from .tm_config import ( | |||
_exclude_from_trace, | |||
_get_default_checker, | |||
_get_expr_checker, | |||
_graph_surgery_mode, | |||
_set_graph_surgery_mode, | |||
) | |||
from .utils import ( | |||
_check_builtin_module_attr, | |||
_check_obj_attr, | |||
@@ -117,26 +127,14 @@ def _is_builtin_name(name: str) -> bool: | |||
def _is_leaf(node): | |||
assert isinstance(node, RawTensor), "doesn't support {} in return values".format( | |||
assert isinstance( | |||
node, RawTensor | |||
), 'doesn\'t support {} in return values, MUST use Tensor or use "register_supported_type" method to register self-defined type'.format( | |||
type(node) | |||
) | |||
return isinstance(node, RawTensor) | |||
_enable_graph_surgery_mode = False | |||
def _graph_surgery_mode(): | |||
return _enable_graph_surgery_mode | |||
def _set_graph_surgery_mode(mode: bool): | |||
global _enable_graph_surgery_mode | |||
pre_mode = _enable_graph_surgery_mode | |||
_enable_graph_surgery_mode = mode | |||
return pre_mode | |||
def _node_to_tensor(*args, **kwargs): | |||
tensors = [] | |||
nodes, tree_def = tree_flatten((args, kwargs)) | |||
@@ -1295,7 +1293,12 @@ def _wrapped_function(orig_func): | |||
return orig_func(*args, **kwargs) | |||
if isinstance(args[1], RawTensor): | |||
node = NodeMixin.get(inputs[1]) | |||
inputs[1] = copy.copy(inputs[1]) | |||
is_scalar = isscalar(inputs[1]) | |||
inputs[1] = apply( | |||
Copy(comp_node=inputs[1].device), Tensor(inputs[1]) | |||
)[0] | |||
if is_scalar: | |||
setscalar(inputs[1]) | |||
# copy inputs[1] to avoid tensor and Tensor(tensor) share same m_tensor, | |||
# which will cause they have same _NodeMixin__node in tracing. | |||
NodeMixin.wrap_safe(inputs[1], node) | |||
@@ -1319,6 +1322,13 @@ def _wrapped_function(orig_func): | |||
else: | |||
outputs = None | |||
call_node.add_outputs(outputs) | |||
if _get_expr_checker(): | |||
with _exclude_from_trace(): | |||
active_module_tracer().checker.check_expr_interpret( | |||
call_node, outputs | |||
) | |||
set_module_tracing() | |||
return rst | |||
return orig_func(*args, **kwargs) | |||
@@ -1500,6 +1510,12 @@ class TracedModuleBuilder(NodeMixin): | |||
unset_module_tracing() | |||
rst = self._mod(*args, **kwargs) | |||
outputs, out_def = tree_flatten(rst, is_leaf=_is_leaf) | |||
if _get_expr_checker(): | |||
with _exclude_from_trace(): | |||
tmp = self.build() | |||
active_module_tracer().checker.check_builtin_module( | |||
tmp, callnode, outputs | |||
) | |||
set_module_tracing() | |||
if self._is_builtin: | |||
self._body = None | |||
@@ -1674,7 +1690,9 @@ class TracedModuleBuilder(NodeMixin): | |||
if not isinstance(mod_attr, (List, Dict, QATModule)): | |||
assert mod_attr is wrapped._mod | |||
else: | |||
assert mod_attr is wrapped | |||
assert ( | |||
mod_attr is wrapped | |||
), "TracedModule do not support modify attributes, please check your code." | |||
if isinstance(wrapped, (NodeMixin, RawTensor)): | |||
NodeMixin.wrap( | |||
@@ -2469,11 +2487,23 @@ def trace_module( | |||
qualname="{}.[{}]".format(net_name, "arg_{}".format(_)), | |||
), | |||
) | |||
builder(*args, **kwargs) | |||
rst = builder(*copy.deepcopy(args), **copy.deepcopy(kwargs)) | |||
active_module_tracer().pop_scope() | |||
traced_mod = builder.build() | |||
traced_mod.argspec = forward_argspec | |||
traced_mod.graph._reset_ids() | |||
has_expr_not_check = False | |||
if _get_expr_checker(): | |||
has_expr_not_check = ( | |||
active_module_tracer().checker.check_node_not_in_scope() | |||
) | |||
if _get_default_checker() or has_expr_not_check: | |||
with _exclude_from_trace(): | |||
tm_res = traced_mod(*args, **kwargs) | |||
tm_res, _ = tree_flatten(tm_res, is_leaf=_is_leaf) | |||
rst, _ = tree_flatten(rst, is_leaf=_is_leaf) | |||
active_module_tracer().checker.check_net_outputs(tm_res, rst) | |||
return traced_mod | |||
finally: | |||
set_symbolic_shape(use_sym_shape) | |||
@@ -5,16 +5,15 @@ | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import collections | |||
import copy | |||
import inspect | |||
from collections.abc import MutableMapping, MutableSequence | |||
from inspect import FullArgSpec | |||
from typing import Callable, Dict, Iterable, List, Optional, Sequence, Type, Union | |||
from typing import Callable, Dict, Iterable, List, Optional, Sequence, Union | |||
from .. import get_logger | |||
from ..module import Module | |||
from ..tensor import Parameter, Tensor | |||
from ..tensor import Tensor | |||
logger = get_logger(__name__) | |||
@@ -109,6 +109,7 @@ def build_observered_net(net: M.Module, observer_cls): | |||
) | |||
Q.enable_observer(qat_net) | |||
inp = Tensor(np.random.random(size=(5, 3, 32, 32))) | |||
qat_net.eval() | |||
qat_net(inp) | |||
Q.disable_observer(qat_net) | |||
return qat_net | |||
@@ -116,6 +117,7 @@ def build_observered_net(net: M.Module, observer_cls): | |||
def build_fakequanted_net(net: QATModule, fakequant_cls): | |||
qat_net = Q.reset_qconfig(net, get_lsq_config(fakequant_cls)) | |||
qat_net.eval() | |||
return qat_net | |||
@@ -162,6 +164,7 @@ def test_load_param(): | |||
def _check_module(build_func: Callable): | |||
net = build_func() | |||
net.eval() | |||
buffer = io.BytesIO() | |||
mge.save(net.state_dict(), buffer) | |||
buffer.seek(0) | |||
@@ -185,6 +188,7 @@ def test_load_param(): | |||
def test_qualname(): | |||
def _check_qualname(net): | |||
inp = Tensor(np.random.random(size=(5, 3, 32, 32))) | |||
net.eval() | |||
traced_net = trace_module(net, inp) | |||
base_qualname = traced_net.graph.qualname | |||
for node in traced_net.graph.nodes(): | |||