GitOrigin-RevId: 0af7b076e6
release-1.7
@@ -0,0 +1,183 @@ | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
from collections import OrderedDict, defaultdict | |||||
from functools import partial | |||||
from ...logger import get_logger | |||||
from ..expr import ( | |||||
Expr, | |||||
is_apply_def, | |||||
is_call_function, | |||||
is_call_module, | |||||
is_call_tensor_method, | |||||
is_constant, | |||||
) | |||||
from .pattern import ( | |||||
AnyPattern, | |||||
ApplyDefPattern, | |||||
CallPattern, | |||||
ConstantPattern, | |||||
ExprPattern, | |||||
FunctionPattern, | |||||
ModulePattern, | |||||
OrPattern, | |||||
TensorMethodPattern, | |||||
VarPattern, | |||||
) | |||||
from .utils import register_obj | |||||
logger = get_logger(__name__) | |||||
class PatternMatcher: | |||||
method_dict = {} | |||||
register_visiter_func = partial(register_obj, _dict=method_dict) | |||||
def __init__(self) -> None: | |||||
self.matched_patterns = [] | |||||
self.matched_exprs = OrderedDict() | |||||
def match(self, pattern: ExprPattern, expr: Expr) -> bool: | |||||
self.matched_exprs.clear() | |||||
self.matched_patterns.clear() | |||||
pattern.check_users(False) | |||||
res = self.visit_pattern(pattern, expr) | |||||
if res and not self._check_users(): | |||||
self.clear_map(0) | |||||
res = False | |||||
self._clear_pattern_users() | |||||
return res | |||||
def clear_map(self, mark): | |||||
for _ in range(len(self.matched_patterns) - mark): | |||||
p = self.matched_patterns.pop() | |||||
self.matched_exprs.pop(p) | |||||
p._clear_users() | |||||
def _clear_pattern_users(self): | |||||
for p in self.matched_patterns: | |||||
p._clear_users() | |||||
def _check_users(self) -> bool: | |||||
for pat, expr in self.matched_exprs.items(): | |||||
if pat._check_users: | |||||
pattern_users = pat._users | |||||
if len(expr.outputs) != 1: | |||||
logger.warning( | |||||
"only support single output, and the matching " | |||||
"result may be wrong" | |||||
) | |||||
continue | |||||
expr_users = expr.outputs[0].users | |||||
if len(pattern_users) != len(expr_users): | |||||
return False | |||||
for pat, expr in zip(pattern_users, expr_users): | |||||
if self.matched_exprs[pat] != expr: | |||||
return False | |||||
return True | |||||
def visit_pattern(self, pattern: ExprPattern, expr: Expr) -> bool: | |||||
if pattern in self.matched_exprs: | |||||
if self.matched_exprs[pattern] is expr: | |||||
if isinstance(pattern, (OrPattern)): | |||||
assert self._visit_or_pattern(pattern, expr) == True | |||||
return True | |||||
else: | |||||
return False | |||||
else: | |||||
mark = len(self.matched_patterns) | |||||
visiter = self.method_dict.get(type(pattern)) | |||||
matched = visiter(self, pattern, expr) | |||||
if matched: | |||||
self.matched_patterns.append(pattern) | |||||
self.matched_exprs[pattern] = expr | |||||
else: | |||||
self.clear_map(mark) | |||||
return matched | |||||
@register_visiter_func(OrPattern) | |||||
def _visit_or_pattern(self, pattern: OrPattern, expr: Expr) -> bool: | |||||
if self.visit_pattern(pattern.left, expr): | |||||
if pattern._users: | |||||
pattern.left._add_users(pattern._users[-1]) | |||||
return True | |||||
if self.visit_pattern(pattern.right, expr): | |||||
if pattern._users: | |||||
pattern.right._add_users(pattern._users[-1]) | |||||
return True | |||||
return False | |||||
@register_visiter_func(CallPattern) | |||||
def _visit_call_pattern(self, pattern: CallPattern, expr: Expr) -> bool: | |||||
mark = len(self.matched_patterns) | |||||
match_res = self.visit_pattern(pattern.op, expr) | |||||
if not match_res: | |||||
self.clear_map(mark) | |||||
return False | |||||
inputs = expr.inputs | |||||
if isinstance(pattern.op, ModulePattern): | |||||
inputs = inputs[1:] | |||||
if (pattern._match_all_args and len(pattern.args) != len(inputs)) or ( | |||||
not pattern._match_all_args and len(pattern.args) > len(inputs) | |||||
): | |||||
self.clear_map(mark) | |||||
return False | |||||
for i, pat in enumerate(pattern.args): | |||||
pat._add_users(pattern) | |||||
match_res = self.visit_pattern(pat, inputs[i].expr) | |||||
if not match_res: | |||||
pat._clear_users() | |||||
self.clear_map(mark) | |||||
return False | |||||
return True | |||||
@register_visiter_func(ModulePattern) | |||||
def _visit_module_pattern(self, pattern: ModulePattern, expr: Expr) -> bool: | |||||
if not is_call_module(expr, pattern.target): | |||||
return False | |||||
module = expr.inputs[0].owner | |||||
for key, target in pattern.attrs.items(): | |||||
value = getattr(module, key, None) | |||||
if target != value: | |||||
return False | |||||
return True | |||||
@register_visiter_func(FunctionPattern) | |||||
def _visit_function_pattern(self, pattern: FunctionPattern, expr: Expr) -> bool: | |||||
if not is_call_function(expr, pattern.target): | |||||
return False | |||||
kwargs = expr.kwargs | |||||
for key, target in pattern.params.items(): | |||||
value = kwargs.get(key, None) | |||||
if target != value: | |||||
return False | |||||
return True | |||||
@register_visiter_func(TensorMethodPattern) | |||||
def _visit_tensor_method_pattern( | |||||
self, pattern: TensorMethodPattern, expr: Expr | |||||
) -> bool: | |||||
return is_call_tensor_method(expr, pattern.target) | |||||
@register_visiter_func(ApplyDefPattern) | |||||
def _visit_apply_pattern(self, pattern: ApplyDefPattern, expr: Expr) -> bool: | |||||
return is_apply_def(expr, pattern.target) | |||||
@register_visiter_func(ConstantPattern) | |||||
def _visit_const_pattern(self, pattern: ConstantPattern, expr: Expr) -> bool: | |||||
return is_constant(expr) | |||||
@register_visiter_func(VarPattern) | |||||
def _visit_var_pattern(self, pattern: VarPattern, expr: Expr) -> bool: | |||||
return not is_constant(expr) | |||||
@register_visiter_func(AnyPattern) | |||||
def _visit_any_pattern(self, pattern: AnyPattern, expr: Expr) -> bool: | |||||
return True |
@@ -0,0 +1,252 @@ | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
from abc import abstractmethod | |||||
from typing import Any, Callable, Dict, List | |||||
from ...core._imperative_rt import OpDef | |||||
from ...logger import get_logger | |||||
from ...module import Module | |||||
from ..expr import Expr | |||||
from ..node import Node | |||||
logger = get_logger(__name__) | |||||
class ExprPattern: | |||||
def __init__(self): | |||||
self._check_users = True | |||||
self._users = [] | |||||
def __call__(self, *args): | |||||
args = list(args) | |||||
if len(args) == 1 and args[0] is None: | |||||
args = None | |||||
return CallPattern(self, *args) | |||||
def __add__(self, other): | |||||
return is_op("__add__")(self, other) | |||||
def __iadd__(self, other): | |||||
return is_op("__iadd__")(self, other) | |||||
def __radd__(self, other): | |||||
return is_op("__radd__")(self, other) | |||||
def __sub__(self, other): | |||||
return is_op("__sub__")(self, other) | |||||
def __isub__(self, other): | |||||
return is_op("__isub__")(self, other) | |||||
def __rsub__(self, other): | |||||
return is_op("__rsub__")(self, other) | |||||
def __mul__(self, other): | |||||
return is_op("__mul__")(self, other) | |||||
def __imul__(self, other): | |||||
return is_op("__imul__")(self, other) | |||||
def __rmul__(self, other): | |||||
return is_op("__rmul__")(self, other) | |||||
def __truediv__(self, other): | |||||
return is_op("__truediv__")(self, other) | |||||
def __itruediv__(self, other): | |||||
return is_op("__itruediv__")(self, other) | |||||
def __rtruediv__(self, other): | |||||
return is_op("__rtruediv__")(self, other) | |||||
def __or__(self, other): | |||||
assert isinstance(other, ExprPattern) | |||||
return OrPattern(self, other) | |||||
def get_output(self, index): | |||||
raise NotImplementedError | |||||
def check_users(self, check: bool = True): | |||||
self._check_users = check | |||||
return self | |||||
def _add_users(self, pattern: "ExprPattern"): | |||||
self._users.append(pattern) | |||||
def _clear_users(self,): | |||||
self._users.clear() | |||||
def __getitem__(self, index): | |||||
return is_op("__getitem__")(self, index) | |||||
def has_attr(self, **attrs): | |||||
logger.warning("has_param only support ModulePattern") | |||||
return self | |||||
def has_param(self, **params): | |||||
logger.warning("has_param only support FunctionPattern") | |||||
return self | |||||
@abstractmethod | |||||
def __repr__(self) -> str: | |||||
raise NotImplementedError | |||||
class CallPattern(ExprPattern): | |||||
def __init__(self, op: ExprPattern, *args: List[ExprPattern]): | |||||
super().__init__() | |||||
self.op = op | |||||
self.args = list(filter(lambda x: isinstance(x, ExprPattern), args)) | |||||
self._match_all_args = True | |||||
def __repr__(self) -> str: | |||||
return "{}({})".format(self.op, ",".join(str(x) for x in self.args)) | |||||
def not_all_args(self): | |||||
self._match_all_args = False | |||||
def check_users(self, check: bool = True): | |||||
self._check_users = check | |||||
self.op.check_users(check) | |||||
return self | |||||
def _add_users(self, pattern: "ExprPattern"): | |||||
self._users.append(pattern) | |||||
self.op._add_users(pattern) | |||||
def _clear_users(self): | |||||
self._users.clear() | |||||
self.op._clear_users() | |||||
class OrPattern(ExprPattern): | |||||
def __init__(self, left: ExprPattern, right: ExprPattern): | |||||
super().__init__() | |||||
self.left = left | |||||
self.right = right | |||||
def __repr__(self) -> str: | |||||
return "({}|{})".format(self.left, self.right) | |||||
def check_users(self, check: bool = True): | |||||
self._check_users = check | |||||
self.left.check_users(check) | |||||
self.right.check_users(check) | |||||
return self | |||||
def _clear_users(self): | |||||
self._users.clear() | |||||
self.left._clear_users() | |||||
self.right._clear_users() | |||||
class GetOutputPaterrn(ExprPattern): | |||||
def __init__(self, op, index): | |||||
super().__init__() | |||||
self.op = op | |||||
self.index = index | |||||
def __repr__(self) -> str: | |||||
return "{}[{}]".format(self.op, self.index) | |||||
class ModulePattern(ExprPattern): | |||||
def __init__(self, module_cls: Module) -> None: | |||||
super().__init__() | |||||
self.attrs = {} | |||||
self.target = module_cls | |||||
def has_attr(self, **attrs): | |||||
self.attrs.update(attrs) | |||||
return self | |||||
def __repr__(self) -> str: | |||||
return "{}".format(self.target.__name__) | |||||
class FunctionPattern(ExprPattern): | |||||
def __init__(self, func: Callable): | |||||
super().__init__() | |||||
self.params = {} | |||||
self.target = func | |||||
def has_params(self, **params): | |||||
self.params.update(params) | |||||
return self | |||||
def __repr__(self) -> str: | |||||
return "{}".format(self.target.__name__) | |||||
class TensorMethodPattern(ExprPattern): | |||||
def __init__(self, method: str): | |||||
super().__init__() | |||||
self.target = method | |||||
def __repr__(self) -> str: | |||||
return self.target | |||||
class ApplyDefPattern(ExprPattern): | |||||
def __init__(self, opdef: OpDef): | |||||
super().__init__() | |||||
self.target = opdef | |||||
def __repr__(self) -> str: | |||||
return "{}".format(self.target.__name__) | |||||
class VarPattern(ExprPattern): | |||||
def __init__(self): | |||||
super().__init__() | |||||
def __repr__(self) -> str: | |||||
return "var" | |||||
class ConstantPattern(ExprPattern): | |||||
def __init__(self): | |||||
super().__init__() | |||||
def __repr__(self) -> str: | |||||
return "const" | |||||
class AnyPattern(ExprPattern): | |||||
def __init__(self): | |||||
super().__init__() | |||||
def __repr__(self) -> str: | |||||
return "any" | |||||
def is_op(target): | |||||
if isinstance(target, type): | |||||
if issubclass(target, Module): | |||||
return ModulePattern(target) | |||||
if issubclass(target, OpDef): | |||||
return ApplyDefPattern(target) | |||||
elif callable(target): | |||||
return FunctionPattern(target) | |||||
elif isinstance(target, str): | |||||
return TensorMethodPattern(target) | |||||
else: | |||||
raise ValueError("not support") | |||||
def is_const(): | |||||
return ConstantPattern().check_users(False) | |||||
def any_node(): | |||||
return AnyPattern() | |||||
def is_var(): | |||||
return VarPattern() |
@@ -0,0 +1,38 @@ | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
import copy | |||||
from typing import Any, Dict, List | |||||
from ..expr import Expr, is_constant, is_getattr | |||||
from ..node import Node, TensorNode | |||||
def register_obj(objs: List[Any], _dict: Dict): | |||||
if not isinstance(objs, List): | |||||
objs = [objs] | |||||
def _register(any_obj: Any): | |||||
for obj in objs: | |||||
_dict[obj] = any_obj | |||||
return any_obj | |||||
return _register | |||||
def get_const_value(expr: Expr, fall_back: Any = None): | |||||
value = fall_back | |||||
if isinstance(expr, Node): | |||||
expr = expr.expr | |||||
if is_getattr(expr) and isinstance(expr.outputs[0], TensorNode): | |||||
module = expr.inputs[0].owner | |||||
assert module is not None | |||||
value = copy.deepcopy(expr.interpret(module)[0]) | |||||
elif is_constant(expr): | |||||
value = copy.deepcopy(expr.interpret()[0]) | |||||
return value |