You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

pattern.py 6.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. from abc import abstractmethod
  9. from typing import Any, Callable, Dict, List
  10. from ...core._imperative_rt import OpDef
  11. from ...logger import get_logger
  12. from ...module import Module
  13. from ..expr import Expr
  14. from ..node import Node
  15. logger = get_logger(__name__)
  16. class ExprPattern:
  17. def __init__(self):
  18. self._check_users = True
  19. self._users = []
  20. def __call__(self, *args):
  21. args = list(args)
  22. if len(args) == 1 and args[0] is None:
  23. args = None
  24. return CallPattern(self, *args)
  25. def __add__(self, other):
  26. return is_op("__add__")(self, other)
  27. def __iadd__(self, other):
  28. return is_op("__iadd__")(self, other)
  29. def __radd__(self, other):
  30. return is_op("__radd__")(self, other)
  31. def __sub__(self, other):
  32. return is_op("__sub__")(self, other)
  33. def __isub__(self, other):
  34. return is_op("__isub__")(self, other)
  35. def __rsub__(self, other):
  36. return is_op("__rsub__")(self, other)
  37. def __mul__(self, other):
  38. return is_op("__mul__")(self, other)
  39. def __imul__(self, other):
  40. return is_op("__imul__")(self, other)
  41. def __rmul__(self, other):
  42. return is_op("__rmul__")(self, other)
  43. def __truediv__(self, other):
  44. return is_op("__truediv__")(self, other)
  45. def __itruediv__(self, other):
  46. return is_op("__itruediv__")(self, other)
  47. def __rtruediv__(self, other):
  48. return is_op("__rtruediv__")(self, other)
  49. def __or__(self, other):
  50. assert isinstance(other, ExprPattern)
  51. return OrPattern(self, other)
  52. def get_output(self, index):
  53. raise NotImplementedError
  54. def check_users(self, check: bool = True):
  55. self._check_users = check
  56. return self
  57. def _add_users(self, pattern: "ExprPattern"):
  58. self._users.append(pattern)
  59. def _clear_users(self,):
  60. self._users.clear()
  61. def __getitem__(self, index):
  62. return is_op("__getitem__")(self, index)
  63. def has_attr(self, **attrs):
  64. logger.warning("has_param only support ModulePattern")
  65. return self
  66. def has_param(self, **params):
  67. logger.warning("has_param only support FunctionPattern")
  68. return self
  69. @abstractmethod
  70. def __repr__(self) -> str:
  71. raise NotImplementedError
  72. class CallPattern(ExprPattern):
  73. def __init__(self, op: ExprPattern, *args: List[ExprPattern]):
  74. super().__init__()
  75. self.op = op
  76. self.args = list(filter(lambda x: isinstance(x, ExprPattern), args))
  77. self._match_all_args = True
  78. def __repr__(self) -> str:
  79. return "{}({})".format(self.op, ",".join(str(x) for x in self.args))
  80. def not_all_args(self):
  81. self._match_all_args = False
  82. def check_users(self, check: bool = True):
  83. self._check_users = check
  84. self.op.check_users(check)
  85. return self
  86. def _add_users(self, pattern: "ExprPattern"):
  87. self._users.append(pattern)
  88. self.op._add_users(pattern)
  89. def _clear_users(self):
  90. self._users.clear()
  91. self.op._clear_users()
  92. class OrPattern(ExprPattern):
  93. def __init__(self, left: ExprPattern, right: ExprPattern):
  94. super().__init__()
  95. self.left = left
  96. self.right = right
  97. def __repr__(self) -> str:
  98. return "({}|{})".format(self.left, self.right)
  99. def check_users(self, check: bool = True):
  100. self._check_users = check
  101. self.left.check_users(check)
  102. self.right.check_users(check)
  103. return self
  104. def _clear_users(self):
  105. self._users.clear()
  106. self.left._clear_users()
  107. self.right._clear_users()
  108. class GetOutputPaterrn(ExprPattern):
  109. def __init__(self, op, index):
  110. super().__init__()
  111. self.op = op
  112. self.index = index
  113. def __repr__(self) -> str:
  114. return "{}[{}]".format(self.op, self.index)
  115. class ModulePattern(ExprPattern):
  116. def __init__(self, module_cls: Module) -> None:
  117. super().__init__()
  118. self.attrs = {}
  119. self.target = module_cls
  120. def has_attr(self, **attrs):
  121. self.attrs.update(attrs)
  122. return self
  123. def __repr__(self) -> str:
  124. return "{}".format(self.target.__name__)
  125. class FunctionPattern(ExprPattern):
  126. def __init__(self, func: Callable):
  127. super().__init__()
  128. self.params = {}
  129. self.target = func
  130. def has_params(self, **params):
  131. self.params.update(params)
  132. return self
  133. def __repr__(self) -> str:
  134. return "{}".format(self.target.__name__)
  135. class TensorMethodPattern(ExprPattern):
  136. def __init__(self, method: str):
  137. super().__init__()
  138. self.target = method
  139. def __repr__(self) -> str:
  140. return self.target
  141. class ApplyDefPattern(ExprPattern):
  142. def __init__(self, opdef: OpDef):
  143. super().__init__()
  144. self.target = opdef
  145. def __repr__(self) -> str:
  146. return "{}".format(self.target.__name__)
  147. class VarPattern(ExprPattern):
  148. def __init__(self):
  149. super().__init__()
  150. def __repr__(self) -> str:
  151. return "var"
  152. class ConstantPattern(ExprPattern):
  153. def __init__(self):
  154. super().__init__()
  155. def __repr__(self) -> str:
  156. return "const"
  157. class AnyPattern(ExprPattern):
  158. def __init__(self):
  159. super().__init__()
  160. def __repr__(self) -> str:
  161. return "any"
  162. def is_op(target):
  163. if isinstance(target, type):
  164. if issubclass(target, Module):
  165. return ModulePattern(target)
  166. if issubclass(target, OpDef):
  167. return ApplyDefPattern(target)
  168. elif callable(target):
  169. return FunctionPattern(target)
  170. elif isinstance(target, str):
  171. return TensorMethodPattern(target)
  172. else:
  173. raise ValueError("not support")
  174. def is_const():
  175. return ConstantPattern().check_users(False)
  176. def any_node():
  177. return AnyPattern()
  178. def is_var():
  179. return VarPattern()

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台