Browse Source

feat(imperative): add TracedModule checker

GitOrigin-RevId: 12de7b278e
tags/v1.8.0
Megvii Engine Team 3 years ago
parent
commit
a0b3a3c0b3
10 changed files with 289 additions and 22 deletions
  1. +3
    -0
      imperative/python/megengine/traced_module/__init__.py
  2. +142
    -0
      imperative/python/megengine/traced_module/checker.py
  3. +9
    -0
      imperative/python/megengine/traced_module/expr.py
  4. +4
    -0
      imperative/python/megengine/traced_module/module_tracer.py
  5. +17
    -0
      imperative/python/megengine/traced_module/node.py
  6. +5
    -1
      imperative/python/megengine/traced_module/pytree.py
  7. +55
    -0
      imperative/python/megengine/traced_module/tm_config.py
  8. +48
    -18
      imperative/python/megengine/traced_module/traced_module.py
  9. +2
    -3
      imperative/python/megengine/traced_module/utils.py
  10. +4
    -0
      imperative/python/test/unit/traced_module/test_qat_module.py

+ 3
- 0
imperative/python/megengine/traced_module/__init__.py View File

@@ -10,6 +10,7 @@ from ..core._imperative_rt.core2 import set_cpp_apply_module_trace
from . import compat
from ._passes import optimize
from .pytree import register_supported_type
from .tm_config import disable_default_checker, enable_expr_checker
from .traced_module import (
TracedModule,
_register_all_builtin_module,
@@ -29,4 +30,6 @@ __all__ = [
"wrap",
"TracedModule",
"optimize",
"enable_expr_checker",
"disable_default_checker",
]

+ 142
- 0
imperative/python/megengine/traced_module/checker.py View File

@@ -0,0 +1,142 @@
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import traceback
from typing import Sequence

import numpy as np

from ..core._imperative_rt.core2 import apply
from ..core._imperative_rt.ops import ROIAlign, ROIPooling
from ..core.ops.builtin import Copy
from ..core.tensor.utils import isscalar, setscalar
from ..tensor import Tensor
from .tm_config import _exclude_from_trace


class TracedModuleChecker:
def __init__(self, tracer):
self._active_node2values = []
self.tracer = tracer

self.node_without_tensor_info = {}

def push_scope(self):
self._active_node2values.append({})

def pop_scope(self):
self._active_node2values.pop()

def current_node2values(self):
return self._active_node2values[-1]

def reset_checker(self):
self._active_node2values = []

def check_node_not_in_scope(self):
if self.node_without_tensor_info:
for node, info in self.node_without_tensor_info.items():
for expr in info[0]._exprs:
if node in expr.inputs or node in expr.outputs:
traceback.print_list(info[1])
raise ValueError(
"node({}) not in the graph:\n{}".format(node, info[0])
)
return True
else:
return False

def check_net_outputs(self, tm_res, gt_res):
if isinstance(tm_res, Tensor):
np.testing.assert_allclose(tm_res.numpy(), gt_res.numpy())
elif isinstance(tm_res, Sequence):
for i, j in zip(tm_res, gt_res):
np.testing.assert_allclose(i.numpy(), j.numpy())
else:
for k in tm_res.__dict__.keys():
np.testing.assert_allclose(
getattr(tm_res, k).numpy(), getattr(gt_res, k).numpy()
)

def record_nodemixin(self, node, value):
self.current_node2values()[node] = value

def record_node2value(self, node, value):
with _exclude_from_trace():
self.current_node2values()[node] = apply(
Copy(comp_node=value.device), value
)[0]
if isscalar(value):
setscalar(self.current_node2values()[node])

def check_apply_special_cases(self, opdef, num_outputs):
indexs = list(range(num_outputs))
if isinstance(opdef, ROIAlign) and opdef.mode == ROIAlign.Mode.AVERAGE:
indexs.pop(-1)
if isinstance(opdef, ROIPooling) and opdef.mode == ROIPooling.Mode.AVERAGE:
indexs.pop(-1)
return indexs

def check_expr_results(self, expr_outputs, gt_outputs, indexs=None):
expr_outputs = (
(expr_outputs,) if not isinstance(expr_outputs, Sequence) else expr_outputs
)
gt_outputs = (
(gt_outputs,) if not isinstance(gt_outputs, Sequence) else gt_outputs
)
if indexs is not None:
for i in indexs:
np.testing.assert_allclose(
expr_outputs[i].numpy(), gt_outputs[i].numpy()
)
else:
np.testing.assert_allclose(expr_outputs, gt_outputs)

def get_node2value(self, inputs, start_idx=0):
inp_values = []
has_node_not_in_scope = False
for i in range(start_idx, len(inputs)):
try:
inp_values.append(self.current_node2values()[inputs[i]])
except:
has_node_not_in_scope = True
self.node_without_tensor_info[inputs[i]] = [
self.tracer.current_scope(),
traceback.extract_stack(),
]
return inp_values, has_node_not_in_scope

def check_expr_interpret(self, expr, gt_outputs):
ori_in, has_node_not_in_scope = self.get_node2value(expr.inputs)
if not has_node_not_in_scope:
expr_res = expr.interpret(*ori_in)
try:
self.check_expr_results(expr_res, gt_outputs)
except:
raise ValueError("Error occurred when checking expr: {}".format(expr))

def check_apply(self, expr, gt_outputs, opdef):
ori_in, has_node_not_in_scope = self.get_node2value(expr.inputs)
if not has_node_not_in_scope:
expr_res = expr.interpret(*ori_in)
indexs = self.check_apply_special_cases(opdef, len(gt_outputs))
try:
self.check_expr_results(expr_res, gt_outputs, indexs=indexs)
except:
raise ValueError("Error occurred when checking expr: {}".format(expr))

def check_builtin_module(self, module, expr, gt_outputs):
ori_in, has_node_not_in_scope = self.get_node2value(expr.inputs, start_idx=1)
if not has_node_not_in_scope:
ori_in.insert(0, module)
expr_res = expr.interpret(*ori_in)
try:
self.check_expr_results(expr_res, gt_outputs)
except:
raise ValueError(
"{}, Error occurred when checking expr: {}".format(expr)
)

+ 9
- 0
imperative/python/megengine/traced_module/expr.py View File

@@ -32,6 +32,7 @@ from .module_tracer import active_module_tracer, module_tracer
from .node import ModuleNode, Node, NodeMixin, TensorNode
from .pytree import ArgsIndex, TreeDef, _is_const_leaf, _is_leaf, tree_flatten
from .serialization import _ModuleState
from .tm_config import _exclude_from_trace, _get_expr_checker
from .utils import _check_builtin_module_attr, _check_obj_attr, _convert_kwargs_to_args


@@ -611,6 +612,8 @@ class Apply(Expr):
inp_nodes = [NodeMixin.get(inputs[0])]
for i in inputs[1:]:
node = Constant.make(i)
if _get_expr_checker():
active_module_tracer().checker.record_node2value(node, Tensor(i))
inp_nodes.append(node)
apply_node = cls.make(opdef)
for n in inp_nodes:
@@ -624,11 +627,17 @@ class Apply(Expr):

unset_module_tracing()
outputs = apply(opdef, *inputs)
outputs = list(map(Tensor, outputs))
set_module_tracing()

apply_node.add_outputs(outputs)
for n, v in zip(apply_node.outputs, outputs):
NodeMixin.wrap_safe(v, n)

if _get_expr_checker():
with _exclude_from_trace():
active_module_tracer().checker.check_apply(apply_node, outputs, opdef)

return list(outputs)




+ 4
- 0
imperative/python/megengine/traced_module/module_tracer.py View File

@@ -12,6 +12,7 @@ from .. import functional as F
from ..core.tensor.array_method import ArrayMethodMixin
from ..module import Module
from ..module.qat import QATModule
from .checker import TracedModuleChecker

_active_module_tracer = None

@@ -128,6 +129,7 @@ class module_tracer:

def __init__(self, wrap_fn):
self._active_scopes = []
self.checker = TracedModuleChecker(self)
self.patcher = Patcher(wrap_fn)

@classmethod
@@ -142,9 +144,11 @@ class module_tracer:

def push_scope(self, scope):
self._active_scopes.append(scope)
self.checker.push_scope()

def pop_scope(self):
self._active_scopes.pop()
self.checker.pop_scope()

def current_scope(self):
if self._active_scopes:


+ 17
- 0
imperative/python/megengine/traced_module/node.py View File

@@ -18,6 +18,8 @@ from ..core._imperative_rt.core2 import Tensor as RawTensor
from ..module import Module
from ..quantization.utils import QParams
from ..tensor import Tensor
from .module_tracer import active_module_tracer
from .tm_config import _get_expr_checker
from .utils import _check_obj_attr

logger = get_logger(__name__)
@@ -343,6 +345,11 @@ class NodeMixin(abc.ABC):
if isinstance(value, NodeMixin):
value._record_wrapped_nodes(node)
setattr(value, "_NodeMixin__node", node)
if _get_expr_checker():
if isinstance(value, RawTensor):
active_module_tracer().checker.record_node2value(node, value)
if isinstance(value, NodeMixin):
active_module_tracer().checker.record_nodemixin(node, value)
else:
assert callable(node)
n = node()
@@ -352,6 +359,11 @@ class NodeMixin(abc.ABC):
if isinstance(value, NodeMixin):
value._record_wrapped_nodes(n)
setattr(value, "_NodeMixin__node", n)
if _get_expr_checker():
if isinstance(value, RawTensor):
active_module_tracer().checker.record_node2value(n, value)
if isinstance(value, NodeMixin):
active_module_tracer().checker.record_nodemixin(n, value)

@classmethod
def wrap_safe(cls, value, node):
@@ -359,6 +371,11 @@ class NodeMixin(abc.ABC):
if isinstance(value, RawTensor):
cls._record_tensornode_property(node, value)
setattr(value, "_NodeMixin__node", node)
if _get_expr_checker():
if isinstance(value, RawTensor):
active_module_tracer().checker.record_node2value(node, value)
if isinstance(value, NodeMixin):
active_module_tracer().checker.record_nodemixin(node, value)
if isinstance(value, NodeMixin):
value._record_wrapped_nodes(node)



+ 5
- 1
imperative/python/megengine/traced_module/pytree.py View File

@@ -212,7 +212,11 @@ def tree_flatten(
to reconstruct the pytree.
"""
if type(values) not in SUPPORTED_TYPE:
assert is_leaf(values), values
assert is_leaf(
values
), 'doesn\'t support {} type, MUST use "register_supported_type" method to register self-defined type'.format(
values
)
node = LeafDef(leaf_type(values))
if is_const_leaf(values):
node.const_val = values


+ 55
- 0
imperative/python/megengine/traced_module/tm_config.py View File

@@ -0,0 +1,55 @@
import contextlib

from ..core._imperative_rt.core2 import (
is_tracing_module,
set_module_tracing,
unset_module_tracing,
)

_enable_expr_checker = False
_enable_default_checker = True


def _get_expr_checker():
return _enable_expr_checker


def _get_default_checker():
return _enable_default_checker


def enable_expr_checker():
r"""Call this function to check the result of each expr during tracing."""
global _enable_expr_checker
_enable_expr_checker = True
_enable_default_checker = False


def disable_default_checker():
r"""Call this function to disable checking the final output of the model after tracing."""
global _enable_default_checker
_enable_default_checker = False


_enable_graph_surgery_mode = False


def _graph_surgery_mode():
return _enable_graph_surgery_mode


def _set_graph_surgery_mode(mode: bool):
global _enable_graph_surgery_mode
pre_mode = _enable_graph_surgery_mode
_enable_graph_surgery_mode = mode
return pre_mode


@contextlib.contextmanager
def _exclude_from_trace():
is_tracing = is_tracing_module()
if is_tracing:
unset_module_tracing()
yield
if is_tracing:
set_module_tracing()

+ 48
- 18
imperative/python/megengine/traced_module/traced_module.py View File

@@ -36,11 +36,14 @@ from .. import get_logger
from .. import module as M
from ..core._imperative_rt.core2 import Tensor as RawTensor
from ..core._imperative_rt.core2 import (
apply,
is_tracing_module,
set_module_tracing,
unset_module_tracing,
)
from ..core._trace_option import set_symbolic_shape
from ..core.ops.builtin import Copy
from ..core.tensor.utils import isscalar, setscalar
from ..module import Module
from ..module import external as MExternal
from ..module.qat import QATModule
@@ -98,6 +101,13 @@ from .serialization import (
load_call_tensor_method_expr,
load_functional,
)
from .tm_config import (
_exclude_from_trace,
_get_default_checker,
_get_expr_checker,
_graph_surgery_mode,
_set_graph_surgery_mode,
)
from .utils import (
_check_builtin_module_attr,
_check_obj_attr,
@@ -117,26 +127,14 @@ def _is_builtin_name(name: str) -> bool:


def _is_leaf(node):
assert isinstance(node, RawTensor), "doesn't support {} in return values".format(
assert isinstance(
node, RawTensor
), 'doesn\'t support {} in return values, MUST use Tensor or use "register_supported_type" method to register self-defined type'.format(
type(node)
)
return isinstance(node, RawTensor)


_enable_graph_surgery_mode = False


def _graph_surgery_mode():
return _enable_graph_surgery_mode


def _set_graph_surgery_mode(mode: bool):
global _enable_graph_surgery_mode
pre_mode = _enable_graph_surgery_mode
_enable_graph_surgery_mode = mode
return pre_mode


def _node_to_tensor(*args, **kwargs):
tensors = []
nodes, tree_def = tree_flatten((args, kwargs))
@@ -1295,7 +1293,12 @@ def _wrapped_function(orig_func):
return orig_func(*args, **kwargs)
if isinstance(args[1], RawTensor):
node = NodeMixin.get(inputs[1])
inputs[1] = copy.copy(inputs[1])
is_scalar = isscalar(inputs[1])
inputs[1] = apply(
Copy(comp_node=inputs[1].device), Tensor(inputs[1])
)[0]
if is_scalar:
setscalar(inputs[1])
# copy inputs[1] to avoid tensor and Tensor(tensor) share same m_tensor,
# which will cause they have same _NodeMixin__node in tracing.
NodeMixin.wrap_safe(inputs[1], node)
@@ -1319,6 +1322,13 @@ def _wrapped_function(orig_func):
else:
outputs = None
call_node.add_outputs(outputs)

if _get_expr_checker():
with _exclude_from_trace():
active_module_tracer().checker.check_expr_interpret(
call_node, outputs
)

set_module_tracing()
return rst
return orig_func(*args, **kwargs)
@@ -1500,6 +1510,12 @@ class TracedModuleBuilder(NodeMixin):
unset_module_tracing()
rst = self._mod(*args, **kwargs)
outputs, out_def = tree_flatten(rst, is_leaf=_is_leaf)
if _get_expr_checker():
with _exclude_from_trace():
tmp = self.build()
active_module_tracer().checker.check_builtin_module(
tmp, callnode, outputs
)
set_module_tracing()
if self._is_builtin:
self._body = None
@@ -1674,7 +1690,9 @@ class TracedModuleBuilder(NodeMixin):
if not isinstance(mod_attr, (List, Dict, QATModule)):
assert mod_attr is wrapped._mod
else:
assert mod_attr is wrapped
assert (
mod_attr is wrapped
), "TracedModule do not support modify attributes, please check your code."

if isinstance(wrapped, (NodeMixin, RawTensor)):
NodeMixin.wrap(
@@ -2469,11 +2487,23 @@ def trace_module(
qualname="{}.[{}]".format(net_name, "arg_{}".format(_)),
),
)
builder(*args, **kwargs)
rst = builder(*copy.deepcopy(args), **copy.deepcopy(kwargs))
active_module_tracer().pop_scope()
traced_mod = builder.build()
traced_mod.argspec = forward_argspec
traced_mod.graph._reset_ids()

has_expr_not_check = False
if _get_expr_checker():
has_expr_not_check = (
active_module_tracer().checker.check_node_not_in_scope()
)
if _get_default_checker() or has_expr_not_check:
with _exclude_from_trace():
tm_res = traced_mod(*args, **kwargs)
tm_res, _ = tree_flatten(tm_res, is_leaf=_is_leaf)
rst, _ = tree_flatten(rst, is_leaf=_is_leaf)
active_module_tracer().checker.check_net_outputs(tm_res, rst)
return traced_mod
finally:
set_symbolic_shape(use_sym_shape)


+ 2
- 3
imperative/python/megengine/traced_module/utils.py View File

@@ -5,16 +5,15 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import collections
import copy
import inspect
from collections.abc import MutableMapping, MutableSequence
from inspect import FullArgSpec
from typing import Callable, Dict, Iterable, List, Optional, Sequence, Type, Union
from typing import Callable, Dict, Iterable, List, Optional, Sequence, Union

from .. import get_logger
from ..module import Module
from ..tensor import Parameter, Tensor
from ..tensor import Tensor

logger = get_logger(__name__)



+ 4
- 0
imperative/python/test/unit/traced_module/test_qat_module.py View File

@@ -109,6 +109,7 @@ def build_observered_net(net: M.Module, observer_cls):
)
Q.enable_observer(qat_net)
inp = Tensor(np.random.random(size=(5, 3, 32, 32)))
qat_net.eval()
qat_net(inp)
Q.disable_observer(qat_net)
return qat_net
@@ -116,6 +117,7 @@ def build_observered_net(net: M.Module, observer_cls):

def build_fakequanted_net(net: QATModule, fakequant_cls):
qat_net = Q.reset_qconfig(net, get_lsq_config(fakequant_cls))
qat_net.eval()
return qat_net


@@ -162,6 +164,7 @@ def test_load_param():

def _check_module(build_func: Callable):
net = build_func()
net.eval()
buffer = io.BytesIO()
mge.save(net.state_dict(), buffer)
buffer.seek(0)
@@ -185,6 +188,7 @@ def test_load_param():
def test_qualname():
def _check_qualname(net):
inp = Tensor(np.random.random(size=(5, 3, 32, 32)))
net.eval()
traced_net = trace_module(net, inp)
base_qualname = traced_net.graph.qualname
for node in traced_net.graph.nodes():


Loading…
Cancel
Save