You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

matcher.py 6.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. from collections import OrderedDict, defaultdict
  9. from functools import partial
  10. from ...logger import get_logger
  11. from ..expr import (
  12. Expr,
  13. is_apply_def,
  14. is_call_function,
  15. is_call_module,
  16. is_call_tensor_method,
  17. is_constant,
  18. )
  19. from .pattern import (
  20. AnyPattern,
  21. ApplyDefPattern,
  22. CallPattern,
  23. ConstantPattern,
  24. ExprPattern,
  25. FunctionPattern,
  26. ModulePattern,
  27. OrPattern,
  28. TensorMethodPattern,
  29. VarPattern,
  30. )
  31. from .utils import register_obj
  32. logger = get_logger(__name__)
  33. class PatternMatcher:
  34. method_dict = {}
  35. register_visiter_func = partial(register_obj, _dict=method_dict)
  36. def __init__(self) -> None:
  37. self.matched_patterns = []
  38. self.matched_exprs = OrderedDict()
  39. def match(self, pattern: ExprPattern, expr: Expr) -> bool:
  40. self.matched_exprs.clear()
  41. self.matched_patterns.clear()
  42. pattern.check_users(False)
  43. res = self.visit_pattern(pattern, expr)
  44. if res and not self._check_users():
  45. self.clear_map(0)
  46. res = False
  47. self._clear_pattern_users()
  48. return res
  49. def clear_map(self, mark):
  50. for _ in range(len(self.matched_patterns) - mark):
  51. p = self.matched_patterns.pop()
  52. self.matched_exprs.pop(p)
  53. p._clear_users()
  54. def _clear_pattern_users(self):
  55. for p in self.matched_patterns:
  56. p._clear_users()
  57. def _check_users(self) -> bool:
  58. for pat, expr in self.matched_exprs.items():
  59. if pat._check_users:
  60. pattern_users = pat._users
  61. if len(expr.outputs) != 1:
  62. logger.warning(
  63. "only support single output, and the matching "
  64. "result may be wrong"
  65. )
  66. continue
  67. expr_users = expr.outputs[0].users
  68. if len(pattern_users) != len(expr_users):
  69. return False
  70. for pat, expr in zip(pattern_users, expr_users):
  71. if self.matched_exprs[pat] != expr:
  72. return False
  73. return True
  74. def visit_pattern(self, pattern: ExprPattern, expr: Expr) -> bool:
  75. if pattern in self.matched_exprs:
  76. if self.matched_exprs[pattern] is expr:
  77. if isinstance(pattern, (OrPattern)):
  78. assert self._visit_or_pattern(pattern, expr) == True
  79. return True
  80. else:
  81. return False
  82. else:
  83. mark = len(self.matched_patterns)
  84. visiter = self.method_dict.get(type(pattern))
  85. matched = visiter(self, pattern, expr)
  86. if matched:
  87. self.matched_patterns.append(pattern)
  88. self.matched_exprs[pattern] = expr
  89. else:
  90. self.clear_map(mark)
  91. return matched
  92. @register_visiter_func(OrPattern)
  93. def _visit_or_pattern(self, pattern: OrPattern, expr: Expr) -> bool:
  94. if self.visit_pattern(pattern.left, expr):
  95. if pattern._users:
  96. pattern.left._add_users(pattern._users[-1])
  97. return True
  98. if self.visit_pattern(pattern.right, expr):
  99. if pattern._users:
  100. pattern.right._add_users(pattern._users[-1])
  101. return True
  102. return False
  103. @register_visiter_func(CallPattern)
  104. def _visit_call_pattern(self, pattern: CallPattern, expr: Expr) -> bool:
  105. mark = len(self.matched_patterns)
  106. match_res = self.visit_pattern(pattern.op, expr)
  107. if not match_res:
  108. self.clear_map(mark)
  109. return False
  110. inputs = expr.inputs
  111. if isinstance(pattern.op, ModulePattern):
  112. inputs = inputs[1:]
  113. if (pattern._match_all_args and len(pattern.args) != len(inputs)) or (
  114. not pattern._match_all_args and len(pattern.args) > len(inputs)
  115. ):
  116. self.clear_map(mark)
  117. return False
  118. for i, pat in enumerate(pattern.args):
  119. pat._add_users(pattern)
  120. match_res = self.visit_pattern(pat, inputs[i].expr)
  121. if not match_res:
  122. pat._clear_users()
  123. self.clear_map(mark)
  124. return False
  125. return True
  126. @register_visiter_func(ModulePattern)
  127. def _visit_module_pattern(self, pattern: ModulePattern, expr: Expr) -> bool:
  128. if not is_call_module(expr, pattern.target):
  129. return False
  130. module = expr.inputs[0].owner
  131. for key, target in pattern.attrs.items():
  132. value = getattr(module, key, None)
  133. if target != value:
  134. return False
  135. return True
  136. @register_visiter_func(FunctionPattern)
  137. def _visit_function_pattern(self, pattern: FunctionPattern, expr: Expr) -> bool:
  138. if not is_call_function(expr, pattern.target):
  139. return False
  140. kwargs = expr.kwargs
  141. for key, target in pattern.params.items():
  142. value = kwargs.get(key, None)
  143. if target != value:
  144. return False
  145. return True
  146. @register_visiter_func(TensorMethodPattern)
  147. def _visit_tensor_method_pattern(
  148. self, pattern: TensorMethodPattern, expr: Expr
  149. ) -> bool:
  150. return is_call_tensor_method(expr, pattern.target)
  151. @register_visiter_func(ApplyDefPattern)
  152. def _visit_apply_pattern(self, pattern: ApplyDefPattern, expr: Expr) -> bool:
  153. return is_apply_def(expr, pattern.target)
  154. @register_visiter_func(ConstantPattern)
  155. def _visit_const_pattern(self, pattern: ConstantPattern, expr: Expr) -> bool:
  156. return is_constant(expr)
  157. @register_visiter_func(VarPattern)
  158. def _visit_var_pattern(self, pattern: VarPattern, expr: Expr) -> bool:
  159. return not is_constant(expr)
  160. @register_visiter_func(AnyPattern)
  161. def _visit_any_pattern(self, pattern: AnyPattern, expr: Expr) -> bool:
  162. return True