You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

checker.py 4.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. import traceback
  2. from typing import Sequence
  3. import numpy as np
  4. from ..core._imperative_rt.core2 import apply
  5. from ..core._imperative_rt.ops import ROIAlign, ROIPooling
  6. from ..core.ops.builtin import Copy
  7. from ..tensor import Tensor
  8. from .tm_config import _exclude_from_trace
  9. class TracedModuleChecker:
  10. def __init__(self, tracer):
  11. self._active_node2values = []
  12. self.tracer = tracer
  13. self.node_without_tensor_info = {}
  14. def push_scope(self):
  15. self._active_node2values.append({})
  16. def pop_scope(self):
  17. self._active_node2values.pop()
  18. def current_node2values(self):
  19. return self._active_node2values[-1]
  20. def reset_checker(self):
  21. self._active_node2values = []
  22. def check_node_not_in_scope(self):
  23. if self.node_without_tensor_info:
  24. for node, info in self.node_without_tensor_info.items():
  25. for expr in info[0]._exprs:
  26. if node in expr.inputs or node in expr.outputs:
  27. traceback.print_list(info[1])
  28. raise ValueError(
  29. "node({}) not in the graph:\n{}".format(node, info[0])
  30. )
  31. return True
  32. else:
  33. return False
  34. def check_net_outputs(self, tm_res, gt_res):
  35. if isinstance(tm_res, Tensor):
  36. np.testing.assert_allclose(tm_res.numpy(), gt_res.numpy())
  37. elif isinstance(tm_res, Sequence):
  38. for i, j in zip(tm_res, gt_res):
  39. np.testing.assert_allclose(i.numpy(), j.numpy())
  40. else:
  41. for k in tm_res.__dict__.keys():
  42. np.testing.assert_allclose(
  43. getattr(tm_res, k).numpy(), getattr(gt_res, k).numpy()
  44. )
  45. def record_nodemixin(self, node, value):
  46. self.current_node2values()[node] = value
  47. def record_node2value(self, node, value):
  48. with _exclude_from_trace():
  49. self.current_node2values()[node] = apply(
  50. Copy(comp_node=value.device), value
  51. )[0]
  52. def check_apply_special_cases(self, opdef, num_outputs):
  53. indexs = list(range(num_outputs))
  54. if isinstance(opdef, ROIAlign) and opdef.mode == ROIAlign.Mode.AVERAGE:
  55. indexs.pop(-1)
  56. if isinstance(opdef, ROIPooling) and opdef.mode == ROIPooling.Mode.AVERAGE:
  57. indexs.pop(-1)
  58. return indexs
  59. def check_expr_results(self, expr_outputs, gt_outputs, indexs=None):
  60. expr_outputs = (
  61. (expr_outputs,) if not isinstance(expr_outputs, Sequence) else expr_outputs
  62. )
  63. gt_outputs = (
  64. (gt_outputs,) if not isinstance(gt_outputs, Sequence) else gt_outputs
  65. )
  66. if indexs is not None:
  67. for i in indexs:
  68. np.testing.assert_allclose(
  69. expr_outputs[i].numpy(), gt_outputs[i].numpy()
  70. )
  71. else:
  72. np.testing.assert_allclose(expr_outputs, gt_outputs)
  73. def get_node2value(self, inputs, start_idx=0):
  74. inp_values = []
  75. has_node_not_in_scope = False
  76. for i in range(start_idx, len(inputs)):
  77. try:
  78. inp_values.append(self.current_node2values()[inputs[i]])
  79. except:
  80. has_node_not_in_scope = True
  81. self.node_without_tensor_info[inputs[i]] = [
  82. self.tracer.current_scope(),
  83. traceback.extract_stack(),
  84. ]
  85. return inp_values, has_node_not_in_scope
  86. def check_expr_interpret(self, expr, gt_outputs):
  87. ori_in, has_node_not_in_scope = self.get_node2value(expr.inputs)
  88. if not has_node_not_in_scope:
  89. expr_res = expr.interpret(*ori_in)
  90. try:
  91. self.check_expr_results(expr_res, gt_outputs)
  92. except:
  93. raise ValueError("Error occurred when checking expr: {}".format(expr))
  94. def check_apply(self, expr, gt_outputs, opdef):
  95. ori_in, has_node_not_in_scope = self.get_node2value(expr.inputs)
  96. if not has_node_not_in_scope:
  97. expr_res = expr.interpret(*ori_in)
  98. indexs = self.check_apply_special_cases(opdef, len(gt_outputs))
  99. try:
  100. self.check_expr_results(expr_res, gt_outputs, indexs=indexs)
  101. except:
  102. raise ValueError("Error occurred when checking expr: {}".format(expr))
  103. def check_builtin_module(self, module, expr, gt_outputs):
  104. ori_in, has_node_not_in_scope = self.get_node2value(expr.inputs, start_idx=1)
  105. if not has_node_not_in_scope:
  106. ori_in.insert(0, module)
  107. expr_res = expr.interpret(*ori_in)
  108. try:
  109. self.check_expr_results(expr_res, gt_outputs)
  110. except:
  111. raise ValueError(
  112. "{}, Error occurred when checking expr: {}".format(expr)
  113. )