You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

checker.py 5.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. import traceback
  9. from typing import Sequence
  10. import numpy as np
  11. from ..core._imperative_rt.core2 import apply
  12. from ..core._imperative_rt.ops import ROIAlign, ROIPooling
  13. from ..core.ops.builtin import Copy
  14. from ..tensor import Tensor
  15. from .tm_config import _exclude_from_trace
  16. class TracedModuleChecker:
  17. def __init__(self, tracer):
  18. self._active_node2values = []
  19. self.tracer = tracer
  20. self.node_without_tensor_info = {}
  21. def push_scope(self):
  22. self._active_node2values.append({})
  23. def pop_scope(self):
  24. self._active_node2values.pop()
  25. def current_node2values(self):
  26. return self._active_node2values[-1]
  27. def reset_checker(self):
  28. self._active_node2values = []
  29. def check_node_not_in_scope(self):
  30. if self.node_without_tensor_info:
  31. for node, info in self.node_without_tensor_info.items():
  32. for expr in info[0]._exprs:
  33. if node in expr.inputs or node in expr.outputs:
  34. traceback.print_list(info[1])
  35. raise ValueError(
  36. "node({}) not in the graph:\n{}".format(node, info[0])
  37. )
  38. return True
  39. else:
  40. return False
  41. def check_net_outputs(self, tm_res, gt_res):
  42. if isinstance(tm_res, Tensor):
  43. np.testing.assert_allclose(tm_res.numpy(), gt_res.numpy())
  44. elif isinstance(tm_res, Sequence):
  45. for i, j in zip(tm_res, gt_res):
  46. np.testing.assert_allclose(i.numpy(), j.numpy())
  47. else:
  48. for k in tm_res.__dict__.keys():
  49. np.testing.assert_allclose(
  50. getattr(tm_res, k).numpy(), getattr(gt_res, k).numpy()
  51. )
  52. def record_nodemixin(self, node, value):
  53. self.current_node2values()[node] = value
  54. def record_node2value(self, node, value):
  55. with _exclude_from_trace():
  56. self.current_node2values()[node] = apply(
  57. Copy(comp_node=value.device), value
  58. )[0]
  59. def check_apply_special_cases(self, opdef, num_outputs):
  60. indexs = list(range(num_outputs))
  61. if isinstance(opdef, ROIAlign) and opdef.mode == ROIAlign.Mode.AVERAGE:
  62. indexs.pop(-1)
  63. if isinstance(opdef, ROIPooling) and opdef.mode == ROIPooling.Mode.AVERAGE:
  64. indexs.pop(-1)
  65. return indexs
  66. def check_expr_results(self, expr_outputs, gt_outputs, indexs=None):
  67. expr_outputs = (
  68. (expr_outputs,) if not isinstance(expr_outputs, Sequence) else expr_outputs
  69. )
  70. gt_outputs = (
  71. (gt_outputs,) if not isinstance(gt_outputs, Sequence) else gt_outputs
  72. )
  73. if indexs is not None:
  74. for i in indexs:
  75. np.testing.assert_allclose(
  76. expr_outputs[i].numpy(), gt_outputs[i].numpy()
  77. )
  78. else:
  79. np.testing.assert_allclose(expr_outputs, gt_outputs)
  80. def get_node2value(self, inputs, start_idx=0):
  81. inp_values = []
  82. has_node_not_in_scope = False
  83. for i in range(start_idx, len(inputs)):
  84. try:
  85. inp_values.append(self.current_node2values()[inputs[i]])
  86. except:
  87. has_node_not_in_scope = True
  88. self.node_without_tensor_info[inputs[i]] = [
  89. self.tracer.current_scope(),
  90. traceback.extract_stack(),
  91. ]
  92. return inp_values, has_node_not_in_scope
  93. def check_expr_interpret(self, expr, gt_outputs):
  94. ori_in, has_node_not_in_scope = self.get_node2value(expr.inputs)
  95. if not has_node_not_in_scope:
  96. expr_res = expr.interpret(*ori_in)
  97. try:
  98. self.check_expr_results(expr_res, gt_outputs)
  99. except:
  100. raise ValueError("Error occurred when checking expr: {}".format(expr))
  101. def check_apply(self, expr, gt_outputs, opdef):
  102. ori_in, has_node_not_in_scope = self.get_node2value(expr.inputs)
  103. if not has_node_not_in_scope:
  104. expr_res = expr.interpret(*ori_in)
  105. indexs = self.check_apply_special_cases(opdef, len(gt_outputs))
  106. try:
  107. self.check_expr_results(expr_res, gt_outputs, indexs=indexs)
  108. except:
  109. raise ValueError("Error occurred when checking expr: {}".format(expr))
  110. def check_builtin_module(self, module, expr, gt_outputs):
  111. ori_in, has_node_not_in_scope = self.get_node2value(expr.inputs, start_idx=1)
  112. if not has_node_not_in_scope:
  113. ori_in.insert(0, module)
  114. expr_res = expr.interpret(*ori_in)
  115. try:
  116. self.check_expr_results(expr_res, gt_outputs)
  117. except:
  118. raise ValueError(
  119. "{}, Error occurred when checking expr: {}".format(expr)
  120. )