You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

pass_base.py 5.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. import copy
  2. from abc import abstractmethod
  3. from collections import OrderedDict, namedtuple
  4. from functools import partial
  5. from re import T
  6. from typing import Any, Callable, Dict, Iterable, List, Union
  7. from ...logger import get_logger
  8. from ..expr import Expr
  9. from ..traced_module import InternalGraph, TracedModule
  10. from .utils import register_obj
  11. logger = get_logger(__name__)
  12. class PassContext:
  13. def __init__(
  14. self, disabled_pass: Iterable[str] = None, pass_config: Dict[str, Any] = None
  15. ):
  16. self._disabled_pass = set()
  17. self._config = pass_config
  18. self._handle = None
  19. if disabled_pass:
  20. self.add_diabled_pass(disabled_pass)
  21. def add_diabled_pass(self, passes: Iterable[str]):
  22. if isinstance(passes, str):
  23. passes = [passes]
  24. for pas in passes:
  25. self._disabled_pass.add(pas)
  26. def pass_enabled(self, pas: Union["BasePass", str]):
  27. pass_name = pas.name if isinstance(pas, BasePass) else pas
  28. return pass_name not in self._disabled_pass
  29. _default_context = PassContext()
  30. def get_default_pass_context():
  31. return _default_context
  32. _pass_dict = OrderedDict()
  33. register_pass = partial(register_obj, _dict=_pass_dict)
  34. def get_registered_pass(pass_name: str):
  35. pas = _pass_dict.get(pass_name, None)
  36. assert (
  37. pas is not None
  38. ), "{} is not found, please call `register_pass` to register it".format(pass_name)
  39. return pas
  40. class BasePass:
  41. run_once = True # bool
  42. required_pass = [] # Iterable[str]
  43. name = "" # str
  44. def __init__(self):
  45. super().__init__()
  46. def __call__(
  47. self, mod: TracedModule, pass_ctx: PassContext = get_default_pass_context()
  48. ) -> TracedModule:
  49. assert isinstance(pass_ctx, PassContext)
  50. return self.apply_optimization(mod, pass_ctx)
  51. def apply_optimization(
  52. self, mod: TracedModule, pass_ctx: PassContext
  53. ) -> TracedModule:
  54. new_mod = mod
  55. for pass_name in self.required_pass + [self.name]:
  56. if not pass_ctx.pass_enabled(pass_name):
  57. logger.warning(
  58. "Since {} is disabled, {} will skipped".format(pass_name, self.name)
  59. )
  60. return mod
  61. for pass_name in self.required_pass:
  62. pass_func = get_registered_pass(pass_name)()
  63. new_mod = pass_func(new_mod, pass_ctx)
  64. iter_num = 1
  65. graph_changed = self.visit_graph(new_mod.graph)
  66. while not self.run_once and graph_changed:
  67. graph_changed = self.visit_graph(new_mod.graph)
  68. iter_num += 1
  69. if iter_num == 100:
  70. break
  71. assert iter_num < 100, "{} was run 100 times, plase check for pass conflict."
  72. return new_mod
  73. @abstractmethod
  74. def visit_graph(self, graph: InternalGraph):
  75. raise NotImplementedError
  76. def before_visit_graph(self, graph: InternalGraph):
  77. pass
  78. def run_transform(self, expr: Expr) -> Expr:
  79. return expr
  80. def __repr__(self) -> str:
  81. return self.name
  82. class ForwardPass(BasePass):
  83. def visit_graph(self, graph: InternalGraph):
  84. class Item:
  85. def __init__(self, expr: Expr, child_expanded: bool = False):
  86. self.expr = expr
  87. self.child_expanded = child_expanded
  88. self.before_visit_graph(graph)
  89. graph_changed = False
  90. queue = [Item(n.expr) for n in graph.outputs]
  91. visited_expr, visited_graph = set(), set()
  92. while queue:
  93. item = queue[-1]
  94. if item.expr in visited_expr:
  95. queue.pop()
  96. elif item.child_expanded:
  97. if item.expr not in graph._exprs:
  98. queue.pop()
  99. continue
  100. new_expr = self.run_transform(item.expr)
  101. if new_expr is not item.expr:
  102. graph_changed = True
  103. assert new_expr not in visited_expr
  104. queue.append(Item(new_expr))
  105. continue
  106. if (
  107. hasattr(item.expr, "graph")
  108. and item.expr.graph is not None
  109. and item.expr.graph not in visited_graph
  110. ):
  111. graph_changed |= self.visit_graph(item.expr.graph)
  112. visited_graph.add(item.expr.graph)
  113. visited_expr.add(item.expr)
  114. else:
  115. item.child_expanded = True
  116. for i in item.expr.inputs:
  117. expr = i.expr
  118. if expr not in queue and expr not in visited_expr:
  119. queue.append(Item(expr))
  120. return graph_changed
  121. class BackwardPass(BasePass):
  122. def visit_graph(self, graph: InternalGraph):
  123. self.before_visit_graph(graph)
  124. graph_changed = False
  125. queue = [n.expr for n in graph.outputs]
  126. visited_expr, visited_graph = set(), set()
  127. while queue:
  128. expr = queue.pop()
  129. if expr not in graph._exprs:
  130. continue
  131. new_expr = self.run_transform(expr)
  132. if new_expr is not expr:
  133. graph_changed = True
  134. queue.append(new_expr)
  135. continue
  136. else:
  137. visited_expr.add(expr)
  138. if (
  139. hasattr(expr, "graph")
  140. and expr.graph is not None
  141. and expr.graph not in visited_graph
  142. ):
  143. graph_changed |= self.visit_graph(expr.graph)
  144. visited_graph.add(expr.graph)
  145. for i in expr.inputs:
  146. expr = i.expr
  147. if expr not in queue and expr not in visited_expr:
  148. queue.append(expr)
  149. return graph_changed