@@ -10,6 +10,7 @@ from ..core._imperative_rt.core2 import set_cpp_apply_module_trace | |||||
from . import compat | from . import compat | ||||
from ._passes import optimize | from ._passes import optimize | ||||
from .pytree import register_supported_type | from .pytree import register_supported_type | ||||
from .tm_config import disable_default_checker, enable_expr_checker | |||||
from .traced_module import ( | from .traced_module import ( | ||||
TracedModule, | TracedModule, | ||||
_register_all_builtin_module, | _register_all_builtin_module, | ||||
@@ -29,4 +30,6 @@ __all__ = [ | |||||
"wrap", | "wrap", | ||||
"TracedModule", | "TracedModule", | ||||
"optimize", | "optimize", | ||||
"enable_expr_checker", | |||||
"disable_default_checker", | |||||
] | ] |
@@ -0,0 +1,142 @@ | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
import traceback | |||||
from typing import Sequence | |||||
import numpy as np | |||||
from ..core._imperative_rt.core2 import apply | |||||
from ..core._imperative_rt.ops import ROIAlign, ROIPooling | |||||
from ..core.ops.builtin import Copy | |||||
from ..core.tensor.utils import isscalar, setscalar | |||||
from ..tensor import Tensor | |||||
from .tm_config import _exclude_from_trace | |||||
class TracedModuleChecker: | |||||
def __init__(self, tracer): | |||||
self._active_node2values = [] | |||||
self.tracer = tracer | |||||
self.node_without_tensor_info = {} | |||||
def push_scope(self): | |||||
self._active_node2values.append({}) | |||||
def pop_scope(self): | |||||
self._active_node2values.pop() | |||||
def current_node2values(self): | |||||
return self._active_node2values[-1] | |||||
def reset_checker(self): | |||||
self._active_node2values = [] | |||||
def check_node_not_in_scope(self): | |||||
if self.node_without_tensor_info: | |||||
for node, info in self.node_without_tensor_info.items(): | |||||
for expr in info[0]._exprs: | |||||
if node in expr.inputs or node in expr.outputs: | |||||
traceback.print_list(info[1]) | |||||
raise ValueError( | |||||
"node({}) not in the graph:\n{}".format(node, info[0]) | |||||
) | |||||
return True | |||||
else: | |||||
return False | |||||
def check_net_outputs(self, tm_res, gt_res): | |||||
if isinstance(tm_res, Tensor): | |||||
np.testing.assert_allclose(tm_res.numpy(), gt_res.numpy()) | |||||
elif isinstance(tm_res, Sequence): | |||||
for i, j in zip(tm_res, gt_res): | |||||
np.testing.assert_allclose(i.numpy(), j.numpy()) | |||||
else: | |||||
for k in tm_res.__dict__.keys(): | |||||
np.testing.assert_allclose( | |||||
getattr(tm_res, k).numpy(), getattr(gt_res, k).numpy() | |||||
) | |||||
def record_nodemixin(self, node, value): | |||||
self.current_node2values()[node] = value | |||||
def record_node2value(self, node, value): | |||||
with _exclude_from_trace(): | |||||
self.current_node2values()[node] = apply( | |||||
Copy(comp_node=value.device), value | |||||
)[0] | |||||
if isscalar(value): | |||||
setscalar(self.current_node2values()[node]) | |||||
def check_apply_special_cases(self, opdef, num_outputs): | |||||
indexs = list(range(num_outputs)) | |||||
if isinstance(opdef, ROIAlign) and opdef.mode == ROIAlign.Mode.AVERAGE: | |||||
indexs.pop(-1) | |||||
if isinstance(opdef, ROIPooling) and opdef.mode == ROIPooling.Mode.AVERAGE: | |||||
indexs.pop(-1) | |||||
return indexs | |||||
def check_expr_results(self, expr_outputs, gt_outputs, indexs=None): | |||||
expr_outputs = ( | |||||
(expr_outputs,) if not isinstance(expr_outputs, Sequence) else expr_outputs | |||||
) | |||||
gt_outputs = ( | |||||
(gt_outputs,) if not isinstance(gt_outputs, Sequence) else gt_outputs | |||||
) | |||||
if indexs is not None: | |||||
for i in indexs: | |||||
np.testing.assert_allclose( | |||||
expr_outputs[i].numpy(), gt_outputs[i].numpy() | |||||
) | |||||
else: | |||||
np.testing.assert_allclose(expr_outputs, gt_outputs) | |||||
def get_node2value(self, inputs, start_idx=0): | |||||
inp_values = [] | |||||
has_node_not_in_scope = False | |||||
for i in range(start_idx, len(inputs)): | |||||
try: | |||||
inp_values.append(self.current_node2values()[inputs[i]]) | |||||
except: | |||||
has_node_not_in_scope = True | |||||
self.node_without_tensor_info[inputs[i]] = [ | |||||
self.tracer.current_scope(), | |||||
traceback.extract_stack(), | |||||
] | |||||
return inp_values, has_node_not_in_scope | |||||
def check_expr_interpret(self, expr, gt_outputs): | |||||
ori_in, has_node_not_in_scope = self.get_node2value(expr.inputs) | |||||
if not has_node_not_in_scope: | |||||
expr_res = expr.interpret(*ori_in) | |||||
try: | |||||
self.check_expr_results(expr_res, gt_outputs) | |||||
except: | |||||
raise ValueError("Error occurred when checking expr: {}".format(expr)) | |||||
def check_apply(self, expr, gt_outputs, opdef): | |||||
ori_in, has_node_not_in_scope = self.get_node2value(expr.inputs) | |||||
if not has_node_not_in_scope: | |||||
expr_res = expr.interpret(*ori_in) | |||||
indexs = self.check_apply_special_cases(opdef, len(gt_outputs)) | |||||
try: | |||||
self.check_expr_results(expr_res, gt_outputs, indexs=indexs) | |||||
except: | |||||
raise ValueError("Error occurred when checking expr: {}".format(expr)) | |||||
def check_builtin_module(self, module, expr, gt_outputs): | |||||
ori_in, has_node_not_in_scope = self.get_node2value(expr.inputs, start_idx=1) | |||||
if not has_node_not_in_scope: | |||||
ori_in.insert(0, module) | |||||
expr_res = expr.interpret(*ori_in) | |||||
try: | |||||
self.check_expr_results(expr_res, gt_outputs) | |||||
except: | |||||
raise ValueError( | |||||
"{}, Error occurred when checking expr: {}".format(expr) | |||||
) |
@@ -32,6 +32,7 @@ from .module_tracer import active_module_tracer, module_tracer | |||||
from .node import ModuleNode, Node, NodeMixin, TensorNode | from .node import ModuleNode, Node, NodeMixin, TensorNode | ||||
from .pytree import ArgsIndex, TreeDef, _is_const_leaf, _is_leaf, tree_flatten | from .pytree import ArgsIndex, TreeDef, _is_const_leaf, _is_leaf, tree_flatten | ||||
from .serialization import _ModuleState | from .serialization import _ModuleState | ||||
from .tm_config import _exclude_from_trace, _get_expr_checker | |||||
from .utils import _check_builtin_module_attr, _check_obj_attr, _convert_kwargs_to_args | from .utils import _check_builtin_module_attr, _check_obj_attr, _convert_kwargs_to_args | ||||
@@ -611,6 +612,8 @@ class Apply(Expr): | |||||
inp_nodes = [NodeMixin.get(inputs[0])] | inp_nodes = [NodeMixin.get(inputs[0])] | ||||
for i in inputs[1:]: | for i in inputs[1:]: | ||||
node = Constant.make(i) | node = Constant.make(i) | ||||
if _get_expr_checker(): | |||||
active_module_tracer().checker.record_node2value(node, Tensor(i)) | |||||
inp_nodes.append(node) | inp_nodes.append(node) | ||||
apply_node = cls.make(opdef) | apply_node = cls.make(opdef) | ||||
for n in inp_nodes: | for n in inp_nodes: | ||||
@@ -624,11 +627,17 @@ class Apply(Expr): | |||||
unset_module_tracing() | unset_module_tracing() | ||||
outputs = apply(opdef, *inputs) | outputs = apply(opdef, *inputs) | ||||
outputs = list(map(Tensor, outputs)) | |||||
set_module_tracing() | set_module_tracing() | ||||
apply_node.add_outputs(outputs) | apply_node.add_outputs(outputs) | ||||
for n, v in zip(apply_node.outputs, outputs): | for n, v in zip(apply_node.outputs, outputs): | ||||
NodeMixin.wrap_safe(v, n) | NodeMixin.wrap_safe(v, n) | ||||
if _get_expr_checker(): | |||||
with _exclude_from_trace(): | |||||
active_module_tracer().checker.check_apply(apply_node, outputs, opdef) | |||||
return list(outputs) | return list(outputs) | ||||
@@ -12,6 +12,7 @@ from .. import functional as F | |||||
from ..core.tensor.array_method import ArrayMethodMixin | from ..core.tensor.array_method import ArrayMethodMixin | ||||
from ..module import Module | from ..module import Module | ||||
from ..module.qat import QATModule | from ..module.qat import QATModule | ||||
from .checker import TracedModuleChecker | |||||
_active_module_tracer = None | _active_module_tracer = None | ||||
@@ -128,6 +129,7 @@ class module_tracer: | |||||
def __init__(self, wrap_fn): | def __init__(self, wrap_fn): | ||||
self._active_scopes = [] | self._active_scopes = [] | ||||
self.checker = TracedModuleChecker(self) | |||||
self.patcher = Patcher(wrap_fn) | self.patcher = Patcher(wrap_fn) | ||||
@classmethod | @classmethod | ||||
@@ -142,9 +144,11 @@ class module_tracer: | |||||
def push_scope(self, scope): | def push_scope(self, scope): | ||||
self._active_scopes.append(scope) | self._active_scopes.append(scope) | ||||
self.checker.push_scope() | |||||
def pop_scope(self): | def pop_scope(self): | ||||
self._active_scopes.pop() | self._active_scopes.pop() | ||||
self.checker.pop_scope() | |||||
def current_scope(self): | def current_scope(self): | ||||
if self._active_scopes: | if self._active_scopes: | ||||
@@ -18,6 +18,8 @@ from ..core._imperative_rt.core2 import Tensor as RawTensor | |||||
from ..module import Module | from ..module import Module | ||||
from ..quantization.utils import QParams | from ..quantization.utils import QParams | ||||
from ..tensor import Tensor | from ..tensor import Tensor | ||||
from .module_tracer import active_module_tracer | |||||
from .tm_config import _get_expr_checker | |||||
from .utils import _check_obj_attr | from .utils import _check_obj_attr | ||||
logger = get_logger(__name__) | logger = get_logger(__name__) | ||||
@@ -343,6 +345,11 @@ class NodeMixin(abc.ABC): | |||||
if isinstance(value, NodeMixin): | if isinstance(value, NodeMixin): | ||||
value._record_wrapped_nodes(node) | value._record_wrapped_nodes(node) | ||||
setattr(value, "_NodeMixin__node", node) | setattr(value, "_NodeMixin__node", node) | ||||
if _get_expr_checker(): | |||||
if isinstance(value, RawTensor): | |||||
active_module_tracer().checker.record_node2value(node, value) | |||||
if isinstance(value, NodeMixin): | |||||
active_module_tracer().checker.record_nodemixin(node, value) | |||||
else: | else: | ||||
assert callable(node) | assert callable(node) | ||||
n = node() | n = node() | ||||
@@ -352,6 +359,11 @@ class NodeMixin(abc.ABC): | |||||
if isinstance(value, NodeMixin): | if isinstance(value, NodeMixin): | ||||
value._record_wrapped_nodes(n) | value._record_wrapped_nodes(n) | ||||
setattr(value, "_NodeMixin__node", n) | setattr(value, "_NodeMixin__node", n) | ||||
if _get_expr_checker(): | |||||
if isinstance(value, RawTensor): | |||||
active_module_tracer().checker.record_node2value(n, value) | |||||
if isinstance(value, NodeMixin): | |||||
active_module_tracer().checker.record_nodemixin(n, value) | |||||
@classmethod | @classmethod | ||||
def wrap_safe(cls, value, node): | def wrap_safe(cls, value, node): | ||||
@@ -359,6 +371,11 @@ class NodeMixin(abc.ABC): | |||||
if isinstance(value, RawTensor): | if isinstance(value, RawTensor): | ||||
cls._record_tensornode_property(node, value) | cls._record_tensornode_property(node, value) | ||||
setattr(value, "_NodeMixin__node", node) | setattr(value, "_NodeMixin__node", node) | ||||
if _get_expr_checker(): | |||||
if isinstance(value, RawTensor): | |||||
active_module_tracer().checker.record_node2value(node, value) | |||||
if isinstance(value, NodeMixin): | |||||
active_module_tracer().checker.record_nodemixin(node, value) | |||||
if isinstance(value, NodeMixin): | if isinstance(value, NodeMixin): | ||||
value._record_wrapped_nodes(node) | value._record_wrapped_nodes(node) | ||||
@@ -212,7 +212,11 @@ def tree_flatten( | |||||
to reconstruct the pytree. | to reconstruct the pytree. | ||||
""" | """ | ||||
if type(values) not in SUPPORTED_TYPE: | if type(values) not in SUPPORTED_TYPE: | ||||
assert is_leaf(values), values | |||||
assert is_leaf( | |||||
values | |||||
), 'doesn\'t support {} type, MUST use "register_supported_type" method to register self-defined type'.format( | |||||
values | |||||
) | |||||
node = LeafDef(leaf_type(values)) | node = LeafDef(leaf_type(values)) | ||||
if is_const_leaf(values): | if is_const_leaf(values): | ||||
node.const_val = values | node.const_val = values | ||||
@@ -0,0 +1,55 @@ | |||||
import contextlib | |||||
from ..core._imperative_rt.core2 import ( | |||||
is_tracing_module, | |||||
set_module_tracing, | |||||
unset_module_tracing, | |||||
) | |||||
_enable_expr_checker = False | |||||
_enable_default_checker = True | |||||
def _get_expr_checker(): | |||||
return _enable_expr_checker | |||||
def _get_default_checker(): | |||||
return _enable_default_checker | |||||
def enable_expr_checker(): | |||||
r"""Call this function to check the result of each expr during tracing.""" | |||||
global _enable_expr_checker | |||||
_enable_expr_checker = True | |||||
_enable_default_checker = False | |||||
def disable_default_checker(): | |||||
r"""Call this function to disable checking the final output of the model after tracing.""" | |||||
global _enable_default_checker | |||||
_enable_default_checker = False | |||||
_enable_graph_surgery_mode = False | |||||
def _graph_surgery_mode(): | |||||
return _enable_graph_surgery_mode | |||||
def _set_graph_surgery_mode(mode: bool): | |||||
global _enable_graph_surgery_mode | |||||
pre_mode = _enable_graph_surgery_mode | |||||
_enable_graph_surgery_mode = mode | |||||
return pre_mode | |||||
@contextlib.contextmanager | |||||
def _exclude_from_trace(): | |||||
is_tracing = is_tracing_module() | |||||
if is_tracing: | |||||
unset_module_tracing() | |||||
yield | |||||
if is_tracing: | |||||
set_module_tracing() |
@@ -36,11 +36,14 @@ from .. import get_logger | |||||
from .. import module as M | from .. import module as M | ||||
from ..core._imperative_rt.core2 import Tensor as RawTensor | from ..core._imperative_rt.core2 import Tensor as RawTensor | ||||
from ..core._imperative_rt.core2 import ( | from ..core._imperative_rt.core2 import ( | ||||
apply, | |||||
is_tracing_module, | is_tracing_module, | ||||
set_module_tracing, | set_module_tracing, | ||||
unset_module_tracing, | unset_module_tracing, | ||||
) | ) | ||||
from ..core._trace_option import set_symbolic_shape | from ..core._trace_option import set_symbolic_shape | ||||
from ..core.ops.builtin import Copy | |||||
from ..core.tensor.utils import isscalar, setscalar | |||||
from ..module import Module | from ..module import Module | ||||
from ..module import external as MExternal | from ..module import external as MExternal | ||||
from ..module.qat import QATModule | from ..module.qat import QATModule | ||||
@@ -98,6 +101,13 @@ from .serialization import ( | |||||
load_call_tensor_method_expr, | load_call_tensor_method_expr, | ||||
load_functional, | load_functional, | ||||
) | ) | ||||
from .tm_config import ( | |||||
_exclude_from_trace, | |||||
_get_default_checker, | |||||
_get_expr_checker, | |||||
_graph_surgery_mode, | |||||
_set_graph_surgery_mode, | |||||
) | |||||
from .utils import ( | from .utils import ( | ||||
_check_builtin_module_attr, | _check_builtin_module_attr, | ||||
_check_obj_attr, | _check_obj_attr, | ||||
@@ -117,26 +127,14 @@ def _is_builtin_name(name: str) -> bool: | |||||
def _is_leaf(node): | def _is_leaf(node): | ||||
assert isinstance(node, RawTensor), "doesn't support {} in return values".format( | |||||
assert isinstance( | |||||
node, RawTensor | |||||
), 'doesn\'t support {} in return values, MUST use Tensor or use "register_supported_type" method to register self-defined type'.format( | |||||
type(node) | type(node) | ||||
) | ) | ||||
return isinstance(node, RawTensor) | return isinstance(node, RawTensor) | ||||
_enable_graph_surgery_mode = False | |||||
def _graph_surgery_mode(): | |||||
return _enable_graph_surgery_mode | |||||
def _set_graph_surgery_mode(mode: bool): | |||||
global _enable_graph_surgery_mode | |||||
pre_mode = _enable_graph_surgery_mode | |||||
_enable_graph_surgery_mode = mode | |||||
return pre_mode | |||||
def _node_to_tensor(*args, **kwargs): | def _node_to_tensor(*args, **kwargs): | ||||
tensors = [] | tensors = [] | ||||
nodes, tree_def = tree_flatten((args, kwargs)) | nodes, tree_def = tree_flatten((args, kwargs)) | ||||
@@ -1295,7 +1293,12 @@ def _wrapped_function(orig_func): | |||||
return orig_func(*args, **kwargs) | return orig_func(*args, **kwargs) | ||||
if isinstance(args[1], RawTensor): | if isinstance(args[1], RawTensor): | ||||
node = NodeMixin.get(inputs[1]) | node = NodeMixin.get(inputs[1]) | ||||
inputs[1] = copy.copy(inputs[1]) | |||||
is_scalar = isscalar(inputs[1]) | |||||
inputs[1] = apply( | |||||
Copy(comp_node=inputs[1].device), Tensor(inputs[1]) | |||||
)[0] | |||||
if is_scalar: | |||||
setscalar(inputs[1]) | |||||
# copy inputs[1] to avoid tensor and Tensor(tensor) share same m_tensor, | # copy inputs[1] to avoid tensor and Tensor(tensor) share same m_tensor, | ||||
# which will cause they have same _NodeMixin__node in tracing. | # which will cause they have same _NodeMixin__node in tracing. | ||||
NodeMixin.wrap_safe(inputs[1], node) | NodeMixin.wrap_safe(inputs[1], node) | ||||
@@ -1319,6 +1322,13 @@ def _wrapped_function(orig_func): | |||||
else: | else: | ||||
outputs = None | outputs = None | ||||
call_node.add_outputs(outputs) | call_node.add_outputs(outputs) | ||||
if _get_expr_checker(): | |||||
with _exclude_from_trace(): | |||||
active_module_tracer().checker.check_expr_interpret( | |||||
call_node, outputs | |||||
) | |||||
set_module_tracing() | set_module_tracing() | ||||
return rst | return rst | ||||
return orig_func(*args, **kwargs) | return orig_func(*args, **kwargs) | ||||
@@ -1500,6 +1510,12 @@ class TracedModuleBuilder(NodeMixin): | |||||
unset_module_tracing() | unset_module_tracing() | ||||
rst = self._mod(*args, **kwargs) | rst = self._mod(*args, **kwargs) | ||||
outputs, out_def = tree_flatten(rst, is_leaf=_is_leaf) | outputs, out_def = tree_flatten(rst, is_leaf=_is_leaf) | ||||
if _get_expr_checker(): | |||||
with _exclude_from_trace(): | |||||
tmp = self.build() | |||||
active_module_tracer().checker.check_builtin_module( | |||||
tmp, callnode, outputs | |||||
) | |||||
set_module_tracing() | set_module_tracing() | ||||
if self._is_builtin: | if self._is_builtin: | ||||
self._body = None | self._body = None | ||||
@@ -1674,7 +1690,9 @@ class TracedModuleBuilder(NodeMixin): | |||||
if not isinstance(mod_attr, (List, Dict, QATModule)): | if not isinstance(mod_attr, (List, Dict, QATModule)): | ||||
assert mod_attr is wrapped._mod | assert mod_attr is wrapped._mod | ||||
else: | else: | ||||
assert mod_attr is wrapped | |||||
assert ( | |||||
mod_attr is wrapped | |||||
), "TracedModule do not support modify attributes, please check your code." | |||||
if isinstance(wrapped, (NodeMixin, RawTensor)): | if isinstance(wrapped, (NodeMixin, RawTensor)): | ||||
NodeMixin.wrap( | NodeMixin.wrap( | ||||
@@ -2469,11 +2487,23 @@ def trace_module( | |||||
qualname="{}.[{}]".format(net_name, "arg_{}".format(_)), | qualname="{}.[{}]".format(net_name, "arg_{}".format(_)), | ||||
), | ), | ||||
) | ) | ||||
builder(*args, **kwargs) | |||||
rst = builder(*copy.deepcopy(args), **copy.deepcopy(kwargs)) | |||||
active_module_tracer().pop_scope() | active_module_tracer().pop_scope() | ||||
traced_mod = builder.build() | traced_mod = builder.build() | ||||
traced_mod.argspec = forward_argspec | traced_mod.argspec = forward_argspec | ||||
traced_mod.graph._reset_ids() | traced_mod.graph._reset_ids() | ||||
has_expr_not_check = False | |||||
if _get_expr_checker(): | |||||
has_expr_not_check = ( | |||||
active_module_tracer().checker.check_node_not_in_scope() | |||||
) | |||||
if _get_default_checker() or has_expr_not_check: | |||||
with _exclude_from_trace(): | |||||
tm_res = traced_mod(*args, **kwargs) | |||||
tm_res, _ = tree_flatten(tm_res, is_leaf=_is_leaf) | |||||
rst, _ = tree_flatten(rst, is_leaf=_is_leaf) | |||||
active_module_tracer().checker.check_net_outputs(tm_res, rst) | |||||
return traced_mod | return traced_mod | ||||
finally: | finally: | ||||
set_symbolic_shape(use_sym_shape) | set_symbolic_shape(use_sym_shape) | ||||
@@ -5,16 +5,15 @@ | |||||
# Unless required by applicable law or agreed to in writing, | # Unless required by applicable law or agreed to in writing, | ||||
# software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
import collections | |||||
import copy | import copy | ||||
import inspect | import inspect | ||||
from collections.abc import MutableMapping, MutableSequence | from collections.abc import MutableMapping, MutableSequence | ||||
from inspect import FullArgSpec | from inspect import FullArgSpec | ||||
from typing import Callable, Dict, Iterable, List, Optional, Sequence, Type, Union | |||||
from typing import Callable, Dict, Iterable, List, Optional, Sequence, Union | |||||
from .. import get_logger | from .. import get_logger | ||||
from ..module import Module | from ..module import Module | ||||
from ..tensor import Parameter, Tensor | |||||
from ..tensor import Tensor | |||||
logger = get_logger(__name__) | logger = get_logger(__name__) | ||||
@@ -109,6 +109,7 @@ def build_observered_net(net: M.Module, observer_cls): | |||||
) | ) | ||||
Q.enable_observer(qat_net) | Q.enable_observer(qat_net) | ||||
inp = Tensor(np.random.random(size=(5, 3, 32, 32))) | inp = Tensor(np.random.random(size=(5, 3, 32, 32))) | ||||
qat_net.eval() | |||||
qat_net(inp) | qat_net(inp) | ||||
Q.disable_observer(qat_net) | Q.disable_observer(qat_net) | ||||
return qat_net | return qat_net | ||||
@@ -116,6 +117,7 @@ def build_observered_net(net: M.Module, observer_cls): | |||||
def build_fakequanted_net(net: QATModule, fakequant_cls): | def build_fakequanted_net(net: QATModule, fakequant_cls): | ||||
qat_net = Q.reset_qconfig(net, get_lsq_config(fakequant_cls)) | qat_net = Q.reset_qconfig(net, get_lsq_config(fakequant_cls)) | ||||
qat_net.eval() | |||||
return qat_net | return qat_net | ||||
@@ -162,6 +164,7 @@ def test_load_param(): | |||||
def _check_module(build_func: Callable): | def _check_module(build_func: Callable): | ||||
net = build_func() | net = build_func() | ||||
net.eval() | |||||
buffer = io.BytesIO() | buffer = io.BytesIO() | ||||
mge.save(net.state_dict(), buffer) | mge.save(net.state_dict(), buffer) | ||||
buffer.seek(0) | buffer.seek(0) | ||||
@@ -185,6 +188,7 @@ def test_load_param(): | |||||
def test_qualname(): | def test_qualname(): | ||||
def _check_qualname(net): | def _check_qualname(net): | ||||
inp = Tensor(np.random.random(size=(5, 3, 32, 32))) | inp = Tensor(np.random.random(size=(5, 3, 32, 32))) | ||||
net.eval() | |||||
traced_net = trace_module(net, inp) | traced_net = trace_module(net, inp) | ||||
base_qualname = traced_net.graph.qualname | base_qualname = traced_net.graph.qualname | ||||
for node in traced_net.graph.nodes(): | for node in traced_net.graph.nodes(): | ||||