You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

checker.py 5.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. import traceback
  9. from typing import Sequence
  10. import numpy as np
  11. from ..core._imperative_rt.core2 import apply
  12. from ..core._imperative_rt.ops import ROIAlign, ROIPooling
  13. from ..core.ops.builtin import Copy
  14. from ..core.tensor.utils import isscalar, setscalar
  15. from ..tensor import Tensor
  16. from .tm_config import _exclude_from_trace
  17. class TracedModuleChecker:
  18. def __init__(self, tracer):
  19. self._active_node2values = []
  20. self.tracer = tracer
  21. self.node_without_tensor_info = {}
  22. def push_scope(self):
  23. self._active_node2values.append({})
  24. def pop_scope(self):
  25. self._active_node2values.pop()
  26. def current_node2values(self):
  27. return self._active_node2values[-1]
  28. def reset_checker(self):
  29. self._active_node2values = []
  30. def check_node_not_in_scope(self):
  31. if self.node_without_tensor_info:
  32. for node, info in self.node_without_tensor_info.items():
  33. for expr in info[0]._exprs:
  34. if node in expr.inputs or node in expr.outputs:
  35. traceback.print_list(info[1])
  36. raise ValueError(
  37. "node({}) not in the graph:\n{}".format(node, info[0])
  38. )
  39. return True
  40. else:
  41. return False
  42. def check_net_outputs(self, tm_res, gt_res):
  43. if isinstance(tm_res, Tensor):
  44. np.testing.assert_allclose(tm_res.numpy(), gt_res.numpy())
  45. elif isinstance(tm_res, Sequence):
  46. for i, j in zip(tm_res, gt_res):
  47. np.testing.assert_allclose(i.numpy(), j.numpy())
  48. else:
  49. for k in tm_res.__dict__.keys():
  50. np.testing.assert_allclose(
  51. getattr(tm_res, k).numpy(), getattr(gt_res, k).numpy()
  52. )
  53. def record_nodemixin(self, node, value):
  54. self.current_node2values()[node] = value
  55. def record_node2value(self, node, value):
  56. with _exclude_from_trace():
  57. self.current_node2values()[node] = apply(
  58. Copy(comp_node=value.device), value
  59. )[0]
  60. if isscalar(value):
  61. setscalar(self.current_node2values()[node])
  62. def check_apply_special_cases(self, opdef, num_outputs):
  63. indexs = list(range(num_outputs))
  64. if isinstance(opdef, ROIAlign) and opdef.mode == ROIAlign.Mode.AVERAGE:
  65. indexs.pop(-1)
  66. if isinstance(opdef, ROIPooling) and opdef.mode == ROIPooling.Mode.AVERAGE:
  67. indexs.pop(-1)
  68. return indexs
  69. def check_expr_results(self, expr_outputs, gt_outputs, indexs=None):
  70. expr_outputs = (
  71. (expr_outputs,) if not isinstance(expr_outputs, Sequence) else expr_outputs
  72. )
  73. gt_outputs = (
  74. (gt_outputs,) if not isinstance(gt_outputs, Sequence) else gt_outputs
  75. )
  76. if indexs is not None:
  77. for i in indexs:
  78. np.testing.assert_allclose(
  79. expr_outputs[i].numpy(), gt_outputs[i].numpy()
  80. )
  81. else:
  82. np.testing.assert_allclose(expr_outputs, gt_outputs)
  83. def get_node2value(self, inputs, start_idx=0):
  84. inp_values = []
  85. has_node_not_in_scope = False
  86. for i in range(start_idx, len(inputs)):
  87. try:
  88. inp_values.append(self.current_node2values()[inputs[i]])
  89. except:
  90. has_node_not_in_scope = True
  91. self.node_without_tensor_info[inputs[i]] = [
  92. self.tracer.current_scope(),
  93. traceback.extract_stack(),
  94. ]
  95. return inp_values, has_node_not_in_scope
  96. def check_expr_interpret(self, expr, gt_outputs):
  97. ori_in, has_node_not_in_scope = self.get_node2value(expr.inputs)
  98. if not has_node_not_in_scope:
  99. expr_res = expr.interpret(*ori_in)
  100. try:
  101. self.check_expr_results(expr_res, gt_outputs)
  102. except:
  103. raise ValueError("Error occurred when checking expr: {}".format(expr))
  104. def check_apply(self, expr, gt_outputs, opdef):
  105. ori_in, has_node_not_in_scope = self.get_node2value(expr.inputs)
  106. if not has_node_not_in_scope:
  107. expr_res = expr.interpret(*ori_in)
  108. indexs = self.check_apply_special_cases(opdef, len(gt_outputs))
  109. try:
  110. self.check_expr_results(expr_res, gt_outputs, indexs=indexs)
  111. except:
  112. raise ValueError("Error occurred when checking expr: {}".format(expr))
  113. def check_builtin_module(self, module, expr, gt_outputs):
  114. ori_in, has_node_not_in_scope = self.get_node2value(expr.inputs, start_idx=1)
  115. if not has_node_not_in_scope:
  116. ori_in.insert(0, module)
  117. expr_res = expr.interpret(*ori_in)
  118. try:
  119. self.check_expr_results(expr_res, gt_outputs)
  120. except:
  121. raise ValueError(
  122. "{}, Error occurred when checking expr: {}".format(expr)
  123. )