You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

pass_base.py 6.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. import copy
  9. from abc import abstractmethod
  10. from collections import OrderedDict, namedtuple
  11. from functools import partial
  12. from re import T
  13. from typing import Any, Callable, Dict, Iterable, List, Union
  14. from ...logger import get_logger
  15. from ..expr import Expr
  16. from ..traced_module import InternalGraph, TracedModule
  17. from .utils import register_obj
  18. logger = get_logger(__name__)
  19. class PassContext:
  20. def __init__(
  21. self, disabled_pass: Iterable[str] = None, pass_config: Dict[str, Any] = None
  22. ):
  23. self._disabled_pass = set()
  24. self._config = pass_config
  25. self._handle = None
  26. if disabled_pass:
  27. self.add_diabled_pass(disabled_pass)
  28. def add_diabled_pass(self, passes: Iterable[str]):
  29. if isinstance(passes, str):
  30. passes = [passes]
  31. for pas in passes:
  32. self._disabled_pass.add(pas)
  33. def pass_enabled(self, pas: Union["BasePass", str]):
  34. pass_name = pas.name if isinstance(pas, BasePass) else pas
  35. return pass_name not in self._disabled_pass
  36. _default_context = PassContext()
  37. def get_default_pass_context():
  38. return _default_context
  39. _pass_dict = OrderedDict()
  40. register_pass = partial(register_obj, _dict=_pass_dict)
  41. def get_registered_pass(pass_name: str):
  42. pas = _pass_dict.get(pass_name, None)
  43. assert (
  44. pas is not None
  45. ), "{} is not found, please call `register_pass` to register it".format(pass_name)
  46. return pas
  47. class BasePass:
  48. run_once = True # bool
  49. required_pass = [] # Iterable[str]
  50. name = "" # str
  51. def __init__(self):
  52. super().__init__()
  53. def __call__(
  54. self, mod: TracedModule, pass_ctx: PassContext = get_default_pass_context()
  55. ) -> TracedModule:
  56. assert isinstance(pass_ctx, PassContext)
  57. return self.apply_optimization(mod, pass_ctx)
  58. def apply_optimization(
  59. self, mod: TracedModule, pass_ctx: PassContext
  60. ) -> TracedModule:
  61. new_mod = mod
  62. for pass_name in self.required_pass + [self.name]:
  63. if not pass_ctx.pass_enabled(pass_name):
  64. logger.warning(
  65. "Since {} is disabled, {} will skipped".format(pass_name, self.name)
  66. )
  67. return mod
  68. for pass_name in self.required_pass:
  69. pass_func = get_registered_pass(pass_name)()
  70. new_mod = pass_func(new_mod, pass_ctx)
  71. iter_num = 1
  72. graph_changed = self.visit_graph(new_mod.graph)
  73. while not self.run_once and graph_changed:
  74. graph_changed = self.visit_graph(new_mod.graph)
  75. iter_num += 1
  76. if iter_num == 100:
  77. break
  78. assert iter_num < 100, "{} was run 100 times, plase check for pass conflict."
  79. return new_mod
  80. @abstractmethod
  81. def visit_graph(self, graph: InternalGraph):
  82. raise NotImplementedError
  83. def before_visit_graph(self, graph: InternalGraph):
  84. pass
  85. def run_transform(self, expr: Expr) -> Expr:
  86. return expr
  87. def __repr__(self) -> str:
  88. return self.name
  89. class ForwardPass(BasePass):
  90. def visit_graph(self, graph: InternalGraph):
  91. class Item:
  92. def __init__(self, expr: Expr, child_expanded: bool = False):
  93. self.expr = expr
  94. self.child_expanded = child_expanded
  95. self.before_visit_graph(graph)
  96. graph_changed = False
  97. queue = [Item(n.expr) for n in graph.outputs]
  98. visited_expr, visited_graph = set(), set()
  99. while queue:
  100. item = queue[-1]
  101. if item.expr in visited_expr:
  102. queue.pop()
  103. elif item.child_expanded:
  104. if item.expr not in graph._exprs:
  105. queue.pop()
  106. continue
  107. new_expr = self.run_transform(item.expr)
  108. if new_expr is not item.expr:
  109. graph_changed = True
  110. assert new_expr not in visited_expr
  111. queue.append(Item(new_expr))
  112. continue
  113. if (
  114. hasattr(item.expr, "graph")
  115. and item.expr.graph is not None
  116. and item.expr.graph not in visited_graph
  117. ):
  118. graph_changed |= self.visit_graph(item.expr.graph)
  119. visited_graph.add(item.expr.graph)
  120. visited_expr.add(item.expr)
  121. else:
  122. item.child_expanded = True
  123. for i in item.expr.inputs:
  124. expr = i.expr
  125. if expr not in queue and expr not in visited_expr:
  126. queue.append(Item(expr))
  127. return graph_changed
  128. class BackwardPass(BasePass):
  129. def visit_graph(self, graph: InternalGraph):
  130. self.before_visit_graph(graph)
  131. graph_changed = False
  132. queue = [n.expr for n in graph.outputs]
  133. visited_expr, visited_graph = set(), set()
  134. while queue:
  135. expr = queue.pop()
  136. if expr not in graph._exprs:
  137. continue
  138. new_expr = self.run_transform(expr)
  139. if new_expr is not expr:
  140. graph_changed = True
  141. queue.append(new_expr)
  142. continue
  143. else:
  144. visited_expr.add(expr)
  145. if (
  146. hasattr(expr, "graph")
  147. and expr.graph is not None
  148. and expr.graph not in visited_graph
  149. ):
  150. graph_changed |= self.visit_graph(expr.graph)
  151. visited_graph.add(expr.graph)
  152. for i in expr.inputs:
  153. expr = i.expr
  154. if expr not in queue and expr not in visited_expr:
  155. queue.append(expr)
  156. return graph_changed

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台