|
- # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
- #
- # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
- #
- # Unless required by applicable law or agreed to in writing,
- # software distributed under the License is distributed on an
- # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- import traceback
- from typing import Sequence
-
- import numpy as np
-
- from ..core._imperative_rt.core2 import apply
- from ..core._imperative_rt.ops import ROIAlign, ROIPooling
- from ..core.ops.builtin import Copy
- from ..tensor import Tensor
- from .tm_config import _exclude_from_trace
-
-
- class TracedModuleChecker:
- def __init__(self, tracer):
- self._active_node2values = []
- self.tracer = tracer
-
- self.node_without_tensor_info = {}
-
- def push_scope(self):
- self._active_node2values.append({})
-
- def pop_scope(self):
- self._active_node2values.pop()
-
- def current_node2values(self):
- return self._active_node2values[-1]
-
- def reset_checker(self):
- self._active_node2values = []
-
- def check_node_not_in_scope(self):
- if self.node_without_tensor_info:
- for node, info in self.node_without_tensor_info.items():
- for expr in info[0]._exprs:
- if node in expr.inputs or node in expr.outputs:
- traceback.print_list(info[1])
- raise ValueError(
- "node({}) not in the graph:\n{}".format(node, info[0])
- )
- return True
- else:
- return False
-
- def check_net_outputs(self, tm_res, gt_res):
- if isinstance(tm_res, Tensor):
- np.testing.assert_allclose(tm_res.numpy(), gt_res.numpy())
- elif isinstance(tm_res, Sequence):
- for i, j in zip(tm_res, gt_res):
- np.testing.assert_allclose(i.numpy(), j.numpy())
- else:
- for k in tm_res.__dict__.keys():
- np.testing.assert_allclose(
- getattr(tm_res, k).numpy(), getattr(gt_res, k).numpy()
- )
-
- def record_nodemixin(self, node, value):
- self.current_node2values()[node] = value
-
- def record_node2value(self, node, value):
- with _exclude_from_trace():
- self.current_node2values()[node] = apply(
- Copy(comp_node=value.device), value
- )[0]
-
- def check_apply_special_cases(self, opdef, num_outputs):
- indexs = list(range(num_outputs))
- if isinstance(opdef, ROIAlign) and opdef.mode == ROIAlign.Mode.AVERAGE:
- indexs.pop(-1)
- if isinstance(opdef, ROIPooling) and opdef.mode == ROIPooling.Mode.AVERAGE:
- indexs.pop(-1)
- return indexs
-
- def check_expr_results(self, expr_outputs, gt_outputs, indexs=None):
- expr_outputs = (
- (expr_outputs,) if not isinstance(expr_outputs, Sequence) else expr_outputs
- )
- gt_outputs = (
- (gt_outputs,) if not isinstance(gt_outputs, Sequence) else gt_outputs
- )
- if indexs is not None:
- for i in indexs:
- np.testing.assert_allclose(
- expr_outputs[i].numpy(), gt_outputs[i].numpy()
- )
- else:
- np.testing.assert_allclose(expr_outputs, gt_outputs)
-
- def get_node2value(self, inputs, start_idx=0):
- inp_values = []
- has_node_not_in_scope = False
- for i in range(start_idx, len(inputs)):
- try:
- inp_values.append(self.current_node2values()[inputs[i]])
- except:
- has_node_not_in_scope = True
- self.node_without_tensor_info[inputs[i]] = [
- self.tracer.current_scope(),
- traceback.extract_stack(),
- ]
- return inp_values, has_node_not_in_scope
-
- def check_expr_interpret(self, expr, gt_outputs):
- ori_in, has_node_not_in_scope = self.get_node2value(expr.inputs)
- if not has_node_not_in_scope:
- expr_res = expr.interpret(*ori_in)
- try:
- self.check_expr_results(expr_res, gt_outputs)
- except:
- raise ValueError("Error occurred when checking expr: {}".format(expr))
-
- def check_apply(self, expr, gt_outputs, opdef):
- ori_in, has_node_not_in_scope = self.get_node2value(expr.inputs)
- if not has_node_not_in_scope:
- expr_res = expr.interpret(*ori_in)
- indexs = self.check_apply_special_cases(opdef, len(gt_outputs))
- try:
- self.check_expr_results(expr_res, gt_outputs, indexs=indexs)
- except:
- raise ValueError("Error occurred when checking expr: {}".format(expr))
-
- def check_builtin_module(self, module, expr, gt_outputs):
- ori_in, has_node_not_in_scope = self.get_node2value(expr.inputs, start_idx=1)
- if not has_node_not_in_scope:
- ori_in.insert(0, module)
- expr_res = expr.interpret(*ori_in)
- try:
- self.check_expr_results(expr_res, gt_outputs)
- except:
- raise ValueError(
- "{}, Error occurred when checking expr: {}".format(expr)
- )
|