|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252 |
- # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
- #
- # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
- #
- # Unless required by applicable law or agreed to in writing,
- # software distributed under the License is distributed on an
- # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-
- from abc import abstractmethod
- from typing import Any, Callable, Dict, List
-
- from ...core._imperative_rt import OpDef
- from ...logger import get_logger
- from ...module import Module
- from ..expr import Expr
- from ..node import Node
-
- logger = get_logger(__name__)
-
-
- class ExprPattern:
- def __init__(self):
- self._check_users = True
- self._users = []
-
- def __call__(self, *args):
- args = list(args)
- if len(args) == 1 and args[0] is None:
- args = None
- return CallPattern(self, *args)
-
- def __add__(self, other):
- return is_op("__add__")(self, other)
-
- def __iadd__(self, other):
- return is_op("__iadd__")(self, other)
-
- def __radd__(self, other):
- return is_op("__radd__")(self, other)
-
- def __sub__(self, other):
- return is_op("__sub__")(self, other)
-
- def __isub__(self, other):
- return is_op("__isub__")(self, other)
-
- def __rsub__(self, other):
- return is_op("__rsub__")(self, other)
-
- def __mul__(self, other):
- return is_op("__mul__")(self, other)
-
- def __imul__(self, other):
- return is_op("__imul__")(self, other)
-
- def __rmul__(self, other):
- return is_op("__rmul__")(self, other)
-
- def __truediv__(self, other):
- return is_op("__truediv__")(self, other)
-
- def __itruediv__(self, other):
- return is_op("__itruediv__")(self, other)
-
- def __rtruediv__(self, other):
- return is_op("__rtruediv__")(self, other)
-
- def __or__(self, other):
- assert isinstance(other, ExprPattern)
- return OrPattern(self, other)
-
- def get_output(self, index):
- raise NotImplementedError
-
- def check_users(self, check: bool = True):
- self._check_users = check
- return self
-
- def _add_users(self, pattern: "ExprPattern"):
- self._users.append(pattern)
-
- def _clear_users(self,):
- self._users.clear()
-
- def __getitem__(self, index):
- return is_op("__getitem__")(self, index)
-
- def has_attr(self, **attrs):
- logger.warning("has_param only support ModulePattern")
- return self
-
- def has_param(self, **params):
- logger.warning("has_param only support FunctionPattern")
- return self
-
- @abstractmethod
- def __repr__(self) -> str:
- raise NotImplementedError
-
-
- class CallPattern(ExprPattern):
- def __init__(self, op: ExprPattern, *args: List[ExprPattern]):
- super().__init__()
- self.op = op
- self.args = list(filter(lambda x: isinstance(x, ExprPattern), args))
- self._match_all_args = True
-
- def __repr__(self) -> str:
- return "{}({})".format(self.op, ",".join(str(x) for x in self.args))
-
- def not_all_args(self):
- self._match_all_args = False
-
- def check_users(self, check: bool = True):
- self._check_users = check
- self.op.check_users(check)
- return self
-
- def _add_users(self, pattern: "ExprPattern"):
- self._users.append(pattern)
- self.op._add_users(pattern)
-
- def _clear_users(self):
- self._users.clear()
- self.op._clear_users()
-
-
- class OrPattern(ExprPattern):
- def __init__(self, left: ExprPattern, right: ExprPattern):
- super().__init__()
- self.left = left
- self.right = right
-
- def __repr__(self) -> str:
- return "({}|{})".format(self.left, self.right)
-
- def check_users(self, check: bool = True):
- self._check_users = check
- self.left.check_users(check)
- self.right.check_users(check)
- return self
-
- def _clear_users(self):
- self._users.clear()
- self.left._clear_users()
- self.right._clear_users()
-
-
- class GetOutputPaterrn(ExprPattern):
- def __init__(self, op, index):
- super().__init__()
- self.op = op
- self.index = index
-
- def __repr__(self) -> str:
- return "{}[{}]".format(self.op, self.index)
-
-
- class ModulePattern(ExprPattern):
- def __init__(self, module_cls: Module) -> None:
- super().__init__()
- self.attrs = {}
- self.target = module_cls
-
- def has_attr(self, **attrs):
- self.attrs.update(attrs)
- return self
-
- def __repr__(self) -> str:
- return "{}".format(self.target.__name__)
-
-
- class FunctionPattern(ExprPattern):
- def __init__(self, func: Callable):
- super().__init__()
- self.params = {}
- self.target = func
-
- def has_params(self, **params):
- self.params.update(params)
- return self
-
- def __repr__(self) -> str:
- return "{}".format(self.target.__name__)
-
-
- class TensorMethodPattern(ExprPattern):
- def __init__(self, method: str):
- super().__init__()
- self.target = method
-
- def __repr__(self) -> str:
- return self.target
-
-
- class ApplyDefPattern(ExprPattern):
- def __init__(self, opdef: OpDef):
- super().__init__()
- self.target = opdef
-
- def __repr__(self) -> str:
- return "{}".format(self.target.__name__)
-
-
- class VarPattern(ExprPattern):
- def __init__(self):
- super().__init__()
-
- def __repr__(self) -> str:
- return "var"
-
-
- class ConstantPattern(ExprPattern):
- def __init__(self):
- super().__init__()
-
- def __repr__(self) -> str:
- return "const"
-
-
- class AnyPattern(ExprPattern):
- def __init__(self):
- super().__init__()
-
- def __repr__(self) -> str:
- return "any"
-
-
- def is_op(target):
- if isinstance(target, type):
- if issubclass(target, Module):
- return ModulePattern(target)
- if issubclass(target, OpDef):
- return ApplyDefPattern(target)
- elif callable(target):
- return FunctionPattern(target)
- elif isinstance(target, str):
- return TensorMethodPattern(target)
- else:
- raise ValueError("not support")
-
-
- def is_const():
- return ConstantPattern().check_users(False)
-
-
- def any_node():
- return AnyPattern()
-
-
- def is_var():
- return VarPattern()
|