GitOrigin-RevId: 0af7b076e6
release-1.7
@@ -0,0 +1,183 @@ | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from collections import OrderedDict, defaultdict | |||
from functools import partial | |||
from ...logger import get_logger | |||
from ..expr import ( | |||
Expr, | |||
is_apply_def, | |||
is_call_function, | |||
is_call_module, | |||
is_call_tensor_method, | |||
is_constant, | |||
) | |||
from .pattern import ( | |||
AnyPattern, | |||
ApplyDefPattern, | |||
CallPattern, | |||
ConstantPattern, | |||
ExprPattern, | |||
FunctionPattern, | |||
ModulePattern, | |||
OrPattern, | |||
TensorMethodPattern, | |||
VarPattern, | |||
) | |||
from .utils import register_obj | |||
logger = get_logger(__name__) | |||
class PatternMatcher: | |||
method_dict = {} | |||
register_visiter_func = partial(register_obj, _dict=method_dict) | |||
def __init__(self) -> None: | |||
self.matched_patterns = [] | |||
self.matched_exprs = OrderedDict() | |||
def match(self, pattern: ExprPattern, expr: Expr) -> bool: | |||
self.matched_exprs.clear() | |||
self.matched_patterns.clear() | |||
pattern.check_users(False) | |||
res = self.visit_pattern(pattern, expr) | |||
if res and not self._check_users(): | |||
self.clear_map(0) | |||
res = False | |||
self._clear_pattern_users() | |||
return res | |||
def clear_map(self, mark): | |||
for _ in range(len(self.matched_patterns) - mark): | |||
p = self.matched_patterns.pop() | |||
self.matched_exprs.pop(p) | |||
p._clear_users() | |||
def _clear_pattern_users(self): | |||
for p in self.matched_patterns: | |||
p._clear_users() | |||
def _check_users(self) -> bool: | |||
for pat, expr in self.matched_exprs.items(): | |||
if pat._check_users: | |||
pattern_users = pat._users | |||
if len(expr.outputs) != 1: | |||
logger.warning( | |||
"only support single output, and the matching " | |||
"result may be wrong" | |||
) | |||
continue | |||
expr_users = expr.outputs[0].users | |||
if len(pattern_users) != len(expr_users): | |||
return False | |||
for pat, expr in zip(pattern_users, expr_users): | |||
if self.matched_exprs[pat] != expr: | |||
return False | |||
return True | |||
def visit_pattern(self, pattern: ExprPattern, expr: Expr) -> bool: | |||
if pattern in self.matched_exprs: | |||
if self.matched_exprs[pattern] is expr: | |||
if isinstance(pattern, (OrPattern)): | |||
assert self._visit_or_pattern(pattern, expr) == True | |||
return True | |||
else: | |||
return False | |||
else: | |||
mark = len(self.matched_patterns) | |||
visiter = self.method_dict.get(type(pattern)) | |||
matched = visiter(self, pattern, expr) | |||
if matched: | |||
self.matched_patterns.append(pattern) | |||
self.matched_exprs[pattern] = expr | |||
else: | |||
self.clear_map(mark) | |||
return matched | |||
@register_visiter_func(OrPattern) | |||
def _visit_or_pattern(self, pattern: OrPattern, expr: Expr) -> bool: | |||
if self.visit_pattern(pattern.left, expr): | |||
if pattern._users: | |||
pattern.left._add_users(pattern._users[-1]) | |||
return True | |||
if self.visit_pattern(pattern.right, expr): | |||
if pattern._users: | |||
pattern.right._add_users(pattern._users[-1]) | |||
return True | |||
return False | |||
@register_visiter_func(CallPattern) | |||
def _visit_call_pattern(self, pattern: CallPattern, expr: Expr) -> bool: | |||
mark = len(self.matched_patterns) | |||
match_res = self.visit_pattern(pattern.op, expr) | |||
if not match_res: | |||
self.clear_map(mark) | |||
return False | |||
inputs = expr.inputs | |||
if isinstance(pattern.op, ModulePattern): | |||
inputs = inputs[1:] | |||
if (pattern._match_all_args and len(pattern.args) != len(inputs)) or ( | |||
not pattern._match_all_args and len(pattern.args) > len(inputs) | |||
): | |||
self.clear_map(mark) | |||
return False | |||
for i, pat in enumerate(pattern.args): | |||
pat._add_users(pattern) | |||
match_res = self.visit_pattern(pat, inputs[i].expr) | |||
if not match_res: | |||
pat._clear_users() | |||
self.clear_map(mark) | |||
return False | |||
return True | |||
@register_visiter_func(ModulePattern) | |||
def _visit_module_pattern(self, pattern: ModulePattern, expr: Expr) -> bool: | |||
if not is_call_module(expr, pattern.target): | |||
return False | |||
module = expr.inputs[0].owner | |||
for key, target in pattern.attrs.items(): | |||
value = getattr(module, key, None) | |||
if target != value: | |||
return False | |||
return True | |||
@register_visiter_func(FunctionPattern) | |||
def _visit_function_pattern(self, pattern: FunctionPattern, expr: Expr) -> bool: | |||
if not is_call_function(expr, pattern.target): | |||
return False | |||
kwargs = expr.kwargs | |||
for key, target in pattern.params.items(): | |||
value = kwargs.get(key, None) | |||
if target != value: | |||
return False | |||
return True | |||
@register_visiter_func(TensorMethodPattern) | |||
def _visit_tensor_method_pattern( | |||
self, pattern: TensorMethodPattern, expr: Expr | |||
) -> bool: | |||
return is_call_tensor_method(expr, pattern.target) | |||
@register_visiter_func(ApplyDefPattern) | |||
def _visit_apply_pattern(self, pattern: ApplyDefPattern, expr: Expr) -> bool: | |||
return is_apply_def(expr, pattern.target) | |||
@register_visiter_func(ConstantPattern) | |||
def _visit_const_pattern(self, pattern: ConstantPattern, expr: Expr) -> bool: | |||
return is_constant(expr) | |||
@register_visiter_func(VarPattern) | |||
def _visit_var_pattern(self, pattern: VarPattern, expr: Expr) -> bool: | |||
return not is_constant(expr) | |||
@register_visiter_func(AnyPattern) | |||
def _visit_any_pattern(self, pattern: AnyPattern, expr: Expr) -> bool: | |||
return True |
@@ -0,0 +1,252 @@ | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from abc import abstractmethod | |||
from typing import Any, Callable, Dict, List | |||
from ...core._imperative_rt import OpDef | |||
from ...logger import get_logger | |||
from ...module import Module | |||
from ..expr import Expr | |||
from ..node import Node | |||
logger = get_logger(__name__) | |||
class ExprPattern: | |||
def __init__(self): | |||
self._check_users = True | |||
self._users = [] | |||
def __call__(self, *args): | |||
args = list(args) | |||
if len(args) == 1 and args[0] is None: | |||
args = None | |||
return CallPattern(self, *args) | |||
def __add__(self, other): | |||
return is_op("__add__")(self, other) | |||
def __iadd__(self, other): | |||
return is_op("__iadd__")(self, other) | |||
def __radd__(self, other): | |||
return is_op("__radd__")(self, other) | |||
def __sub__(self, other): | |||
return is_op("__sub__")(self, other) | |||
def __isub__(self, other): | |||
return is_op("__isub__")(self, other) | |||
def __rsub__(self, other): | |||
return is_op("__rsub__")(self, other) | |||
def __mul__(self, other): | |||
return is_op("__mul__")(self, other) | |||
def __imul__(self, other): | |||
return is_op("__imul__")(self, other) | |||
def __rmul__(self, other): | |||
return is_op("__rmul__")(self, other) | |||
def __truediv__(self, other): | |||
return is_op("__truediv__")(self, other) | |||
def __itruediv__(self, other): | |||
return is_op("__itruediv__")(self, other) | |||
def __rtruediv__(self, other): | |||
return is_op("__rtruediv__")(self, other) | |||
def __or__(self, other): | |||
assert isinstance(other, ExprPattern) | |||
return OrPattern(self, other) | |||
def get_output(self, index): | |||
raise NotImplementedError | |||
def check_users(self, check: bool = True): | |||
self._check_users = check | |||
return self | |||
def _add_users(self, pattern: "ExprPattern"): | |||
self._users.append(pattern) | |||
def _clear_users(self,): | |||
self._users.clear() | |||
def __getitem__(self, index): | |||
return is_op("__getitem__")(self, index) | |||
def has_attr(self, **attrs): | |||
logger.warning("has_param only support ModulePattern") | |||
return self | |||
def has_param(self, **params): | |||
logger.warning("has_param only support FunctionPattern") | |||
return self | |||
@abstractmethod | |||
def __repr__(self) -> str: | |||
raise NotImplementedError | |||
class CallPattern(ExprPattern): | |||
def __init__(self, op: ExprPattern, *args: List[ExprPattern]): | |||
super().__init__() | |||
self.op = op | |||
self.args = list(filter(lambda x: isinstance(x, ExprPattern), args)) | |||
self._match_all_args = True | |||
def __repr__(self) -> str: | |||
return "{}({})".format(self.op, ",".join(str(x) for x in self.args)) | |||
def not_all_args(self): | |||
self._match_all_args = False | |||
def check_users(self, check: bool = True): | |||
self._check_users = check | |||
self.op.check_users(check) | |||
return self | |||
def _add_users(self, pattern: "ExprPattern"): | |||
self._users.append(pattern) | |||
self.op._add_users(pattern) | |||
def _clear_users(self): | |||
self._users.clear() | |||
self.op._clear_users() | |||
class OrPattern(ExprPattern): | |||
def __init__(self, left: ExprPattern, right: ExprPattern): | |||
super().__init__() | |||
self.left = left | |||
self.right = right | |||
def __repr__(self) -> str: | |||
return "({}|{})".format(self.left, self.right) | |||
def check_users(self, check: bool = True): | |||
self._check_users = check | |||
self.left.check_users(check) | |||
self.right.check_users(check) | |||
return self | |||
def _clear_users(self): | |||
self._users.clear() | |||
self.left._clear_users() | |||
self.right._clear_users() | |||
class GetOutputPaterrn(ExprPattern): | |||
def __init__(self, op, index): | |||
super().__init__() | |||
self.op = op | |||
self.index = index | |||
def __repr__(self) -> str: | |||
return "{}[{}]".format(self.op, self.index) | |||
class ModulePattern(ExprPattern): | |||
def __init__(self, module_cls: Module) -> None: | |||
super().__init__() | |||
self.attrs = {} | |||
self.target = module_cls | |||
def has_attr(self, **attrs): | |||
self.attrs.update(attrs) | |||
return self | |||
def __repr__(self) -> str: | |||
return "{}".format(self.target.__name__) | |||
class FunctionPattern(ExprPattern): | |||
def __init__(self, func: Callable): | |||
super().__init__() | |||
self.params = {} | |||
self.target = func | |||
def has_params(self, **params): | |||
self.params.update(params) | |||
return self | |||
def __repr__(self) -> str: | |||
return "{}".format(self.target.__name__) | |||
class TensorMethodPattern(ExprPattern): | |||
def __init__(self, method: str): | |||
super().__init__() | |||
self.target = method | |||
def __repr__(self) -> str: | |||
return self.target | |||
class ApplyDefPattern(ExprPattern): | |||
def __init__(self, opdef: OpDef): | |||
super().__init__() | |||
self.target = opdef | |||
def __repr__(self) -> str: | |||
return "{}".format(self.target.__name__) | |||
class VarPattern(ExprPattern): | |||
def __init__(self): | |||
super().__init__() | |||
def __repr__(self) -> str: | |||
return "var" | |||
class ConstantPattern(ExprPattern): | |||
def __init__(self): | |||
super().__init__() | |||
def __repr__(self) -> str: | |||
return "const" | |||
class AnyPattern(ExprPattern): | |||
def __init__(self): | |||
super().__init__() | |||
def __repr__(self) -> str: | |||
return "any" | |||
def is_op(target): | |||
if isinstance(target, type): | |||
if issubclass(target, Module): | |||
return ModulePattern(target) | |||
if issubclass(target, OpDef): | |||
return ApplyDefPattern(target) | |||
elif callable(target): | |||
return FunctionPattern(target) | |||
elif isinstance(target, str): | |||
return TensorMethodPattern(target) | |||
else: | |||
raise ValueError("not support") | |||
def is_const(): | |||
return ConstantPattern().check_users(False) | |||
def any_node(): | |||
return AnyPattern() | |||
def is_var(): | |||
return VarPattern() |
@@ -0,0 +1,38 @@ | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import copy | |||
from typing import Any, Dict, List | |||
from ..expr import Expr, is_constant, is_getattr | |||
from ..node import Node, TensorNode | |||
def register_obj(objs: List[Any], _dict: Dict): | |||
if not isinstance(objs, List): | |||
objs = [objs] | |||
def _register(any_obj: Any): | |||
for obj in objs: | |||
_dict[obj] = any_obj | |||
return any_obj | |||
return _register | |||
def get_const_value(expr: Expr, fall_back: Any = None): | |||
value = fall_back | |||
if isinstance(expr, Node): | |||
expr = expr.expr | |||
if is_getattr(expr) and isinstance(expr.outputs[0], TensorNode): | |||
module = expr.inputs[0].owner | |||
assert module is not None | |||
value = copy.deepcopy(expr.interpret(module)[0]) | |||
elif is_constant(expr): | |||
value = copy.deepcopy(expr.interpret()[0]) | |||
return value |