diff --git a/imperative/python/megengine/traced_module/_passes/matcher.py b/imperative/python/megengine/traced_module/_passes/matcher.py new file mode 100644 index 00000000..db4765ef --- /dev/null +++ b/imperative/python/megengine/traced_module/_passes/matcher.py @@ -0,0 +1,183 @@ +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + +from collections import OrderedDict, defaultdict +from functools import partial + +from ...logger import get_logger +from ..expr import ( + Expr, + is_apply_def, + is_call_function, + is_call_module, + is_call_tensor_method, + is_constant, +) +from .pattern import ( + AnyPattern, + ApplyDefPattern, + CallPattern, + ConstantPattern, + ExprPattern, + FunctionPattern, + ModulePattern, + OrPattern, + TensorMethodPattern, + VarPattern, +) +from .utils import register_obj + +logger = get_logger(__name__) + + +class PatternMatcher: + + method_dict = {} + register_visiter_func = partial(register_obj, _dict=method_dict) + + def __init__(self) -> None: + self.matched_patterns = [] + self.matched_exprs = OrderedDict() + + def match(self, pattern: ExprPattern, expr: Expr) -> bool: + self.matched_exprs.clear() + self.matched_patterns.clear() + pattern.check_users(False) + res = self.visit_pattern(pattern, expr) + if res and not self._check_users(): + self.clear_map(0) + res = False + self._clear_pattern_users() + return res + + def clear_map(self, mark): + for _ in range(len(self.matched_patterns) - mark): + p = self.matched_patterns.pop() + self.matched_exprs.pop(p) + p._clear_users() + + def _clear_pattern_users(self): + for p in self.matched_patterns: + p._clear_users() + + def _check_users(self) -> bool: + for pat, expr in self.matched_exprs.items(): + if pat._check_users: + pattern_users = pat._users + if len(expr.outputs) != 1: + logger.warning( + "only support single output, and the matching " + "result may be wrong" + ) + continue + expr_users = expr.outputs[0].users + if len(pattern_users) != len(expr_users): + return False + for pat, expr in zip(pattern_users, expr_users): + if self.matched_exprs[pat] != expr: + return False + return True + + def visit_pattern(self, pattern: ExprPattern, expr: Expr) -> bool: + if pattern in self.matched_exprs: + if self.matched_exprs[pattern] is expr: + if isinstance(pattern, (OrPattern)): + assert self._visit_or_pattern(pattern, expr) == True + return True + else: + return False + else: + mark = len(self.matched_patterns) + visiter = self.method_dict.get(type(pattern)) + matched = visiter(self, pattern, expr) + if matched: + self.matched_patterns.append(pattern) + self.matched_exprs[pattern] = expr + else: + self.clear_map(mark) + return matched + + @register_visiter_func(OrPattern) + def _visit_or_pattern(self, pattern: OrPattern, expr: Expr) -> bool: + if self.visit_pattern(pattern.left, expr): + if pattern._users: + pattern.left._add_users(pattern._users[-1]) + return True + if self.visit_pattern(pattern.right, expr): + if pattern._users: + pattern.right._add_users(pattern._users[-1]) + return True + return False + + @register_visiter_func(CallPattern) + def _visit_call_pattern(self, pattern: CallPattern, expr: Expr) -> bool: + mark = len(self.matched_patterns) + match_res = self.visit_pattern(pattern.op, expr) + if not match_res: + self.clear_map(mark) + return False + inputs = expr.inputs + if isinstance(pattern.op, ModulePattern): + inputs = inputs[1:] + if (pattern._match_all_args and len(pattern.args) != len(inputs)) or ( + not pattern._match_all_args and len(pattern.args) > len(inputs) + ): + self.clear_map(mark) + return False + for i, pat in enumerate(pattern.args): + pat._add_users(pattern) + match_res = self.visit_pattern(pat, inputs[i].expr) + if not match_res: + pat._clear_users() + self.clear_map(mark) + return False + return True + + @register_visiter_func(ModulePattern) + def _visit_module_pattern(self, pattern: ModulePattern, expr: Expr) -> bool: + if not is_call_module(expr, pattern.target): + return False + module = expr.inputs[0].owner + for key, target in pattern.attrs.items(): + value = getattr(module, key, None) + if target != value: + return False + return True + + @register_visiter_func(FunctionPattern) + def _visit_function_pattern(self, pattern: FunctionPattern, expr: Expr) -> bool: + if not is_call_function(expr, pattern.target): + return False + kwargs = expr.kwargs + for key, target in pattern.params.items(): + value = kwargs.get(key, None) + if target != value: + return False + return True + + @register_visiter_func(TensorMethodPattern) + def _visit_tensor_method_pattern( + self, pattern: TensorMethodPattern, expr: Expr + ) -> bool: + return is_call_tensor_method(expr, pattern.target) + + @register_visiter_func(ApplyDefPattern) + def _visit_apply_pattern(self, pattern: ApplyDefPattern, expr: Expr) -> bool: + return is_apply_def(expr, pattern.target) + + @register_visiter_func(ConstantPattern) + def _visit_const_pattern(self, pattern: ConstantPattern, expr: Expr) -> bool: + return is_constant(expr) + + @register_visiter_func(VarPattern) + def _visit_var_pattern(self, pattern: VarPattern, expr: Expr) -> bool: + return not is_constant(expr) + + @register_visiter_func(AnyPattern) + def _visit_any_pattern(self, pattern: AnyPattern, expr: Expr) -> bool: + return True diff --git a/imperative/python/megengine/traced_module/_passes/pattern.py b/imperative/python/megengine/traced_module/_passes/pattern.py new file mode 100644 index 00000000..1f559f4b --- /dev/null +++ b/imperative/python/megengine/traced_module/_passes/pattern.py @@ -0,0 +1,252 @@ +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + +from abc import abstractmethod +from typing import Any, Callable, Dict, List + +from ...core._imperative_rt import OpDef +from ...logger import get_logger +from ...module import Module +from ..expr import Expr +from ..node import Node + +logger = get_logger(__name__) + + +class ExprPattern: + def __init__(self): + self._check_users = True + self._users = [] + + def __call__(self, *args): + args = list(args) + if len(args) == 1 and args[0] is None: + args = None + return CallPattern(self, *args) + + def __add__(self, other): + return is_op("__add__")(self, other) + + def __iadd__(self, other): + return is_op("__iadd__")(self, other) + + def __radd__(self, other): + return is_op("__radd__")(self, other) + + def __sub__(self, other): + return is_op("__sub__")(self, other) + + def __isub__(self, other): + return is_op("__isub__")(self, other) + + def __rsub__(self, other): + return is_op("__rsub__")(self, other) + + def __mul__(self, other): + return is_op("__mul__")(self, other) + + def __imul__(self, other): + return is_op("__imul__")(self, other) + + def __rmul__(self, other): + return is_op("__rmul__")(self, other) + + def __truediv__(self, other): + return is_op("__truediv__")(self, other) + + def __itruediv__(self, other): + return is_op("__itruediv__")(self, other) + + def __rtruediv__(self, other): + return is_op("__rtruediv__")(self, other) + + def __or__(self, other): + assert isinstance(other, ExprPattern) + return OrPattern(self, other) + + def get_output(self, index): + raise NotImplementedError + + def check_users(self, check: bool = True): + self._check_users = check + return self + + def _add_users(self, pattern: "ExprPattern"): + self._users.append(pattern) + + def _clear_users(self,): + self._users.clear() + + def __getitem__(self, index): + return is_op("__getitem__")(self, index) + + def has_attr(self, **attrs): + logger.warning("has_param only support ModulePattern") + return self + + def has_param(self, **params): + logger.warning("has_param only support FunctionPattern") + return self + + @abstractmethod + def __repr__(self) -> str: + raise NotImplementedError + + +class CallPattern(ExprPattern): + def __init__(self, op: ExprPattern, *args: List[ExprPattern]): + super().__init__() + self.op = op + self.args = list(filter(lambda x: isinstance(x, ExprPattern), args)) + self._match_all_args = True + + def __repr__(self) -> str: + return "{}({})".format(self.op, ",".join(str(x) for x in self.args)) + + def not_all_args(self): + self._match_all_args = False + + def check_users(self, check: bool = True): + self._check_users = check + self.op.check_users(check) + return self + + def _add_users(self, pattern: "ExprPattern"): + self._users.append(pattern) + self.op._add_users(pattern) + + def _clear_users(self): + self._users.clear() + self.op._clear_users() + + +class OrPattern(ExprPattern): + def __init__(self, left: ExprPattern, right: ExprPattern): + super().__init__() + self.left = left + self.right = right + + def __repr__(self) -> str: + return "({}|{})".format(self.left, self.right) + + def check_users(self, check: bool = True): + self._check_users = check + self.left.check_users(check) + self.right.check_users(check) + return self + + def _clear_users(self): + self._users.clear() + self.left._clear_users() + self.right._clear_users() + + +class GetOutputPaterrn(ExprPattern): + def __init__(self, op, index): + super().__init__() + self.op = op + self.index = index + + def __repr__(self) -> str: + return "{}[{}]".format(self.op, self.index) + + +class ModulePattern(ExprPattern): + def __init__(self, module_cls: Module) -> None: + super().__init__() + self.attrs = {} + self.target = module_cls + + def has_attr(self, **attrs): + self.attrs.update(attrs) + return self + + def __repr__(self) -> str: + return "{}".format(self.target.__name__) + + +class FunctionPattern(ExprPattern): + def __init__(self, func: Callable): + super().__init__() + self.params = {} + self.target = func + + def has_params(self, **params): + self.params.update(params) + return self + + def __repr__(self) -> str: + return "{}".format(self.target.__name__) + + +class TensorMethodPattern(ExprPattern): + def __init__(self, method: str): + super().__init__() + self.target = method + + def __repr__(self) -> str: + return self.target + + +class ApplyDefPattern(ExprPattern): + def __init__(self, opdef: OpDef): + super().__init__() + self.target = opdef + + def __repr__(self) -> str: + return "{}".format(self.target.__name__) + + +class VarPattern(ExprPattern): + def __init__(self): + super().__init__() + + def __repr__(self) -> str: + return "var" + + +class ConstantPattern(ExprPattern): + def __init__(self): + super().__init__() + + def __repr__(self) -> str: + return "const" + + +class AnyPattern(ExprPattern): + def __init__(self): + super().__init__() + + def __repr__(self) -> str: + return "any" + + +def is_op(target): + if isinstance(target, type): + if issubclass(target, Module): + return ModulePattern(target) + if issubclass(target, OpDef): + return ApplyDefPattern(target) + elif callable(target): + return FunctionPattern(target) + elif isinstance(target, str): + return TensorMethodPattern(target) + else: + raise ValueError("not support") + + +def is_const(): + return ConstantPattern().check_users(False) + + +def any_node(): + return AnyPattern() + + +def is_var(): + return VarPattern() diff --git a/imperative/python/megengine/traced_module/_passes/utils.py b/imperative/python/megengine/traced_module/_passes/utils.py new file mode 100644 index 00000000..ed25072e --- /dev/null +++ b/imperative/python/megengine/traced_module/_passes/utils.py @@ -0,0 +1,38 @@ +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + +import copy +from typing import Any, Dict, List + +from ..expr import Expr, is_constant, is_getattr +from ..node import Node, TensorNode + + +def register_obj(objs: List[Any], _dict: Dict): + if not isinstance(objs, List): + objs = [objs] + + def _register(any_obj: Any): + for obj in objs: + _dict[obj] = any_obj + return any_obj + + return _register + + +def get_const_value(expr: Expr, fall_back: Any = None): + value = fall_back + if isinstance(expr, Node): + expr = expr.expr + if is_getattr(expr) and isinstance(expr.outputs[0], TensorNode): + module = expr.inputs[0].owner + assert module is not None + value = copy.deepcopy(expr.interpret(module)[0]) + elif is_constant(expr): + value = copy.deepcopy(expr.interpret()[0]) + return value