You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

matcher.py 6.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. from collections import OrderedDict, defaultdict
  2. from functools import partial
  3. from ...logger import get_logger
  4. from ..expr import (
  5. Expr,
  6. is_apply_def,
  7. is_call_function,
  8. is_call_module,
  9. is_call_tensor_method,
  10. is_constant,
  11. )
  12. from .pattern import (
  13. AnyPattern,
  14. ApplyDefPattern,
  15. CallPattern,
  16. ConstantPattern,
  17. ExprPattern,
  18. FunctionPattern,
  19. ModulePattern,
  20. OrPattern,
  21. TensorMethodPattern,
  22. VarPattern,
  23. )
  24. from .utils import register_obj
  25. logger = get_logger(__name__)
  26. class PatternMatcher:
  27. method_dict = {}
  28. register_visiter_func = partial(register_obj, _dict=method_dict)
  29. def __init__(self) -> None:
  30. self.matched_patterns = []
  31. self.matched_exprs = OrderedDict()
  32. def match(self, pattern: ExprPattern, expr: Expr) -> bool:
  33. self.matched_exprs.clear()
  34. self.matched_patterns.clear()
  35. pattern.check_users(False)
  36. res = self.visit_pattern(pattern, expr)
  37. if res and not self._check_users():
  38. self.clear_map(0)
  39. res = False
  40. self._clear_pattern_users()
  41. return res
  42. def clear_map(self, mark):
  43. for _ in range(len(self.matched_patterns) - mark):
  44. p = self.matched_patterns.pop()
  45. self.matched_exprs.pop(p)
  46. p._clear_users()
  47. def _clear_pattern_users(self):
  48. for p in self.matched_patterns:
  49. p._clear_users()
  50. def _check_users(self) -> bool:
  51. for pat, expr in self.matched_exprs.items():
  52. if pat._check_users:
  53. pattern_users = pat._users
  54. if len(expr.outputs) != 1:
  55. logger.warning(
  56. "only support single output, and the matching "
  57. "result may be wrong"
  58. )
  59. continue
  60. expr_users = expr.outputs[0].users
  61. if len(pattern_users) != len(expr_users):
  62. return False
  63. for pat, expr in zip(pattern_users, expr_users):
  64. if self.matched_exprs[pat] != expr:
  65. return False
  66. return True
  67. def visit_pattern(self, pattern: ExprPattern, expr: Expr) -> bool:
  68. if pattern in self.matched_exprs:
  69. if self.matched_exprs[pattern] is expr:
  70. if isinstance(pattern, (OrPattern)):
  71. assert self._visit_or_pattern(pattern, expr) == True
  72. return True
  73. else:
  74. return False
  75. else:
  76. mark = len(self.matched_patterns)
  77. visiter = self.method_dict.get(type(pattern))
  78. matched = visiter(self, pattern, expr)
  79. if matched:
  80. self.matched_patterns.append(pattern)
  81. self.matched_exprs[pattern] = expr
  82. else:
  83. self.clear_map(mark)
  84. return matched
  85. @register_visiter_func(OrPattern)
  86. def _visit_or_pattern(self, pattern: OrPattern, expr: Expr) -> bool:
  87. if self.visit_pattern(pattern.left, expr):
  88. if pattern._users:
  89. pattern.left._add_users(pattern._users[-1])
  90. return True
  91. if self.visit_pattern(pattern.right, expr):
  92. if pattern._users:
  93. pattern.right._add_users(pattern._users[-1])
  94. return True
  95. return False
  96. @register_visiter_func(CallPattern)
  97. def _visit_call_pattern(self, pattern: CallPattern, expr: Expr) -> bool:
  98. mark = len(self.matched_patterns)
  99. match_res = self.visit_pattern(pattern.op, expr)
  100. if not match_res:
  101. self.clear_map(mark)
  102. return False
  103. inputs = expr.inputs
  104. if isinstance(pattern.op, ModulePattern):
  105. inputs = inputs[1:]
  106. if (pattern._match_all_args and len(pattern.args) != len(inputs)) or (
  107. not pattern._match_all_args and len(pattern.args) > len(inputs)
  108. ):
  109. self.clear_map(mark)
  110. return False
  111. for i, pat in enumerate(pattern.args):
  112. pat._add_users(pattern)
  113. match_res = self.visit_pattern(pat, inputs[i].expr)
  114. if not match_res:
  115. pat._clear_users()
  116. self.clear_map(mark)
  117. return False
  118. return True
  119. @register_visiter_func(ModulePattern)
  120. def _visit_module_pattern(self, pattern: ModulePattern, expr: Expr) -> bool:
  121. if not is_call_module(expr, pattern.target):
  122. return False
  123. module = expr.inputs[0].owner
  124. for key, target in pattern.attrs.items():
  125. value = getattr(module, key, None)
  126. if target != value:
  127. return False
  128. return True
  129. @register_visiter_func(FunctionPattern)
  130. def _visit_function_pattern(self, pattern: FunctionPattern, expr: Expr) -> bool:
  131. if not is_call_function(expr, pattern.target):
  132. return False
  133. kwargs = expr.kwargs
  134. for key, target in pattern.params.items():
  135. value = kwargs.get(key, None)
  136. if target != value:
  137. return False
  138. return True
  139. @register_visiter_func(TensorMethodPattern)
  140. def _visit_tensor_method_pattern(
  141. self, pattern: TensorMethodPattern, expr: Expr
  142. ) -> bool:
  143. return is_call_tensor_method(expr, pattern.target)
  144. @register_visiter_func(ApplyDefPattern)
  145. def _visit_apply_pattern(self, pattern: ApplyDefPattern, expr: Expr) -> bool:
  146. return is_apply_def(expr, pattern.target)
  147. @register_visiter_func(ConstantPattern)
  148. def _visit_const_pattern(self, pattern: ConstantPattern, expr: Expr) -> bool:
  149. return is_constant(expr)
  150. @register_visiter_func(VarPattern)
  151. def _visit_var_pattern(self, pattern: VarPattern, expr: Expr) -> bool:
  152. return not is_constant(expr)
  153. @register_visiter_func(AnyPattern)
  154. def _visit_any_pattern(self, pattern: AnyPattern, expr: Expr) -> bool:
  155. return True