|
- # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
- #
- # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
- #
- # Unless required by applicable law or agreed to in writing,
- # software distributed under the License is distributed on an
- # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-
- from collections import OrderedDict, defaultdict
- from functools import partial
-
- from ...logger import get_logger
- from ..expr import (
- Expr,
- is_apply_def,
- is_call_function,
- is_call_module,
- is_call_tensor_method,
- is_constant,
- )
- from .pattern import (
- AnyPattern,
- ApplyDefPattern,
- CallPattern,
- ConstantPattern,
- ExprPattern,
- FunctionPattern,
- ModulePattern,
- OrPattern,
- TensorMethodPattern,
- VarPattern,
- )
- from .utils import register_obj
-
- logger = get_logger(__name__)
-
-
- class PatternMatcher:
-
- method_dict = {}
- register_visiter_func = partial(register_obj, _dict=method_dict)
-
- def __init__(self) -> None:
- self.matched_patterns = []
- self.matched_exprs = OrderedDict()
-
- def match(self, pattern: ExprPattern, expr: Expr) -> bool:
- self.matched_exprs.clear()
- self.matched_patterns.clear()
- pattern.check_users(False)
- res = self.visit_pattern(pattern, expr)
- if res and not self._check_users():
- self.clear_map(0)
- res = False
- self._clear_pattern_users()
- return res
-
- def clear_map(self, mark):
- for _ in range(len(self.matched_patterns) - mark):
- p = self.matched_patterns.pop()
- self.matched_exprs.pop(p)
- p._clear_users()
-
- def _clear_pattern_users(self):
- for p in self.matched_patterns:
- p._clear_users()
-
- def _check_users(self) -> bool:
- for pat, expr in self.matched_exprs.items():
- if pat._check_users:
- pattern_users = pat._users
- if len(expr.outputs) != 1:
- logger.warning(
- "only support single output, and the matching "
- "result may be wrong"
- )
- continue
- expr_users = expr.outputs[0].users
- if len(pattern_users) != len(expr_users):
- return False
- for pat, expr in zip(pattern_users, expr_users):
- if self.matched_exprs[pat] != expr:
- return False
- return True
-
- def visit_pattern(self, pattern: ExprPattern, expr: Expr) -> bool:
- if pattern in self.matched_exprs:
- if self.matched_exprs[pattern] is expr:
- if isinstance(pattern, (OrPattern)):
- assert self._visit_or_pattern(pattern, expr) == True
- return True
- else:
- return False
- else:
- mark = len(self.matched_patterns)
- visiter = self.method_dict.get(type(pattern))
- matched = visiter(self, pattern, expr)
- if matched:
- self.matched_patterns.append(pattern)
- self.matched_exprs[pattern] = expr
- else:
- self.clear_map(mark)
- return matched
-
- @register_visiter_func(OrPattern)
- def _visit_or_pattern(self, pattern: OrPattern, expr: Expr) -> bool:
- if self.visit_pattern(pattern.left, expr):
- if pattern._users:
- pattern.left._add_users(pattern._users[-1])
- return True
- if self.visit_pattern(pattern.right, expr):
- if pattern._users:
- pattern.right._add_users(pattern._users[-1])
- return True
- return False
-
- @register_visiter_func(CallPattern)
- def _visit_call_pattern(self, pattern: CallPattern, expr: Expr) -> bool:
- mark = len(self.matched_patterns)
- match_res = self.visit_pattern(pattern.op, expr)
- if not match_res:
- self.clear_map(mark)
- return False
- inputs = expr.inputs
- if isinstance(pattern.op, ModulePattern):
- inputs = inputs[1:]
- if (pattern._match_all_args and len(pattern.args) != len(inputs)) or (
- not pattern._match_all_args and len(pattern.args) > len(inputs)
- ):
- self.clear_map(mark)
- return False
- for i, pat in enumerate(pattern.args):
- pat._add_users(pattern)
- match_res = self.visit_pattern(pat, inputs[i].expr)
- if not match_res:
- pat._clear_users()
- self.clear_map(mark)
- return False
- return True
-
- @register_visiter_func(ModulePattern)
- def _visit_module_pattern(self, pattern: ModulePattern, expr: Expr) -> bool:
- if not is_call_module(expr, pattern.target):
- return False
- module = expr.inputs[0].owner
- for key, target in pattern.attrs.items():
- value = getattr(module, key, None)
- if target != value:
- return False
- return True
-
- @register_visiter_func(FunctionPattern)
- def _visit_function_pattern(self, pattern: FunctionPattern, expr: Expr) -> bool:
- if not is_call_function(expr, pattern.target):
- return False
- kwargs = expr.kwargs
- for key, target in pattern.params.items():
- value = kwargs.get(key, None)
- if target != value:
- return False
- return True
-
- @register_visiter_func(TensorMethodPattern)
- def _visit_tensor_method_pattern(
- self, pattern: TensorMethodPattern, expr: Expr
- ) -> bool:
- return is_call_tensor_method(expr, pattern.target)
-
- @register_visiter_func(ApplyDefPattern)
- def _visit_apply_pattern(self, pattern: ApplyDefPattern, expr: Expr) -> bool:
- return is_apply_def(expr, pattern.target)
-
- @register_visiter_func(ConstantPattern)
- def _visit_const_pattern(self, pattern: ConstantPattern, expr: Expr) -> bool:
- return is_constant(expr)
-
- @register_visiter_func(VarPattern)
- def _visit_var_pattern(self, pattern: VarPattern, expr: Expr) -> bool:
- return not is_constant(expr)
-
- @register_visiter_func(AnyPattern)
- def _visit_any_pattern(self, pattern: AnyPattern, expr: Expr) -> bool:
- return True
|