You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

pattern.py 6.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244
  1. from abc import abstractmethod
  2. from typing import Any, Callable, Dict, List
  3. from ...core._imperative_rt import OpDef
  4. from ...logger import get_logger
  5. from ...module import Module
  6. from ..expr import Expr
  7. from ..node import Node
  8. logger = get_logger(__name__)
  9. class ExprPattern:
  10. def __init__(self):
  11. self._check_users = True
  12. self._users = []
  13. def __call__(self, *args):
  14. args = list(args)
  15. if len(args) == 1 and args[0] is None:
  16. args = None
  17. return CallPattern(self, *args)
  18. def __add__(self, other):
  19. return is_op("__add__")(self, other)
  20. def __iadd__(self, other):
  21. return is_op("__iadd__")(self, other)
  22. def __radd__(self, other):
  23. return is_op("__radd__")(self, other)
  24. def __sub__(self, other):
  25. return is_op("__sub__")(self, other)
  26. def __isub__(self, other):
  27. return is_op("__isub__")(self, other)
  28. def __rsub__(self, other):
  29. return is_op("__rsub__")(self, other)
  30. def __mul__(self, other):
  31. return is_op("__mul__")(self, other)
  32. def __imul__(self, other):
  33. return is_op("__imul__")(self, other)
  34. def __rmul__(self, other):
  35. return is_op("__rmul__")(self, other)
  36. def __truediv__(self, other):
  37. return is_op("__truediv__")(self, other)
  38. def __itruediv__(self, other):
  39. return is_op("__itruediv__")(self, other)
  40. def __rtruediv__(self, other):
  41. return is_op("__rtruediv__")(self, other)
  42. def __or__(self, other):
  43. assert isinstance(other, ExprPattern)
  44. return OrPattern(self, other)
  45. def get_output(self, index):
  46. raise NotImplementedError
  47. def check_users(self, check: bool = True):
  48. self._check_users = check
  49. return self
  50. def _add_users(self, pattern: "ExprPattern"):
  51. self._users.append(pattern)
  52. def _clear_users(self,):
  53. self._users.clear()
  54. def __getitem__(self, index):
  55. return is_op("__getitem__")(self, index)
  56. def has_attr(self, **attrs):
  57. logger.warning("has_param only support ModulePattern")
  58. return self
  59. def has_param(self, **params):
  60. logger.warning("has_param only support FunctionPattern")
  61. return self
  62. @abstractmethod
  63. def __repr__(self) -> str:
  64. raise NotImplementedError
  65. class CallPattern(ExprPattern):
  66. def __init__(self, op: ExprPattern, *args: List[ExprPattern]):
  67. super().__init__()
  68. self.op = op
  69. self.args = list(filter(lambda x: isinstance(x, ExprPattern), args))
  70. self._match_all_args = True
  71. def __repr__(self) -> str:
  72. return "{}({})".format(self.op, ",".join(str(x) for x in self.args))
  73. def not_all_args(self):
  74. self._match_all_args = False
  75. def check_users(self, check: bool = True):
  76. self._check_users = check
  77. self.op.check_users(check)
  78. return self
  79. def _add_users(self, pattern: "ExprPattern"):
  80. self._users.append(pattern)
  81. self.op._add_users(pattern)
  82. def _clear_users(self):
  83. self._users.clear()
  84. self.op._clear_users()
  85. class OrPattern(ExprPattern):
  86. def __init__(self, left: ExprPattern, right: ExprPattern):
  87. super().__init__()
  88. self.left = left
  89. self.right = right
  90. def __repr__(self) -> str:
  91. return "({}|{})".format(self.left, self.right)
  92. def check_users(self, check: bool = True):
  93. self._check_users = check
  94. self.left.check_users(check)
  95. self.right.check_users(check)
  96. return self
  97. def _clear_users(self):
  98. self._users.clear()
  99. self.left._clear_users()
  100. self.right._clear_users()
  101. class GetOutputPaterrn(ExprPattern):
  102. def __init__(self, op, index):
  103. super().__init__()
  104. self.op = op
  105. self.index = index
  106. def __repr__(self) -> str:
  107. return "{}[{}]".format(self.op, self.index)
  108. class ModulePattern(ExprPattern):
  109. def __init__(self, module_cls: Module) -> None:
  110. super().__init__()
  111. self.attrs = {}
  112. self.target = module_cls
  113. def has_attr(self, **attrs):
  114. self.attrs.update(attrs)
  115. return self
  116. def __repr__(self) -> str:
  117. return "{}".format(self.target.__name__)
  118. class FunctionPattern(ExprPattern):
  119. def __init__(self, func: Callable):
  120. super().__init__()
  121. self.params = {}
  122. self.target = func
  123. def has_params(self, **params):
  124. self.params.update(params)
  125. return self
  126. def __repr__(self) -> str:
  127. return "{}".format(self.target.__name__)
  128. class TensorMethodPattern(ExprPattern):
  129. def __init__(self, method: str):
  130. super().__init__()
  131. self.target = method
  132. def __repr__(self) -> str:
  133. return self.target
  134. class ApplyDefPattern(ExprPattern):
  135. def __init__(self, opdef: OpDef):
  136. super().__init__()
  137. self.target = opdef
  138. def __repr__(self) -> str:
  139. return "{}".format(self.target.__name__)
  140. class VarPattern(ExprPattern):
  141. def __init__(self):
  142. super().__init__()
  143. def __repr__(self) -> str:
  144. return "var"
  145. class ConstantPattern(ExprPattern):
  146. def __init__(self):
  147. super().__init__()
  148. def __repr__(self) -> str:
  149. return "const"
  150. class AnyPattern(ExprPattern):
  151. def __init__(self):
  152. super().__init__()
  153. def __repr__(self) -> str:
  154. return "any"
  155. def is_op(target):
  156. if isinstance(target, type):
  157. if issubclass(target, Module):
  158. return ModulePattern(target)
  159. if issubclass(target, OpDef):
  160. return ApplyDefPattern(target)
  161. elif callable(target):
  162. return FunctionPattern(target)
  163. elif isinstance(target, str):
  164. return TensorMethodPattern(target)
  165. else:
  166. raise ValueError("not support")
  167. def is_const():
  168. return ConstantPattern().check_users(False)
  169. def any_node():
  170. return AnyPattern()
  171. def is_var():
  172. return VarPattern()