You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

const_pass.py 6.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. from ... import functional as F
  2. from ... import module as M
  3. from ...core.ops.builtin import GetVarShape
  4. from ...logger import get_logger
  5. from ...tensor import Tensor
  6. from ..expr import Constant, Expr, is_apply_def, is_constant, is_getattr
  7. from ..node import Node, NodeMixin, TensorNode
  8. from .matcher import PatternMatcher
  9. from .pass_base import BackwardPass, ForwardPass, register_pass
  10. from .pattern import is_op
  11. from .utils import get_const_value
  12. logger = get_logger(__name__)
  13. def _as_const_node(x):
  14. node = Constant.make(x)
  15. NodeMixin.wrap(x, node)
  16. return node
  17. @register_pass("AttrToConstant")
  18. class AttrToConstant(BackwardPass):
  19. r"""Convert :class:`~.GetAttr` to :class:`~.Constant` expr."""
  20. name = "AttrToConstant"
  21. run_once = True
  22. def run_transform(self, expr: Expr):
  23. if not (is_getattr(expr) and isinstance(expr.outputs[0], TensorNode)):
  24. return expr
  25. graph = expr.top_graph
  26. value = get_const_value(expr)
  27. orig_node = expr.outputs[0]
  28. name = orig_node.name
  29. with graph.insert_exprs(expr):
  30. const_node = _as_const_node(value)
  31. graph.replace_node({orig_node: const_node})
  32. graph.compile()
  33. const_node.name = name
  34. return const_node.expr
  35. @register_pass("FixInputShape")
  36. class FixInputShape(BackwardPass):
  37. name = "FixInputShape"
  38. run_once = True
  39. def run_transform(self, expr: Expr):
  40. if not is_apply_def(expr, GetVarShape):
  41. return expr
  42. shape = Tensor(expr.inputs[0].shape, dtype="int32")
  43. graph = expr.top_graph
  44. with graph.insert_exprs(expr):
  45. const_shape = _as_const_node(shape)
  46. graph.replace_node({expr.outputs[0]: const_shape})
  47. graph.compile()
  48. const_shape.name = expr.outputs[0].name
  49. return const_shape.expr
  50. @register_pass("FlodConstant")
  51. class FlodConstant(ForwardPass):
  52. r"""Constant folding."""
  53. name = "FlodConstant"
  54. required_pass = ["AttrToConstant"]
  55. run_once = False
  56. def run_transform(self, expr: Expr):
  57. if len(expr.inputs) == 0 or any(not is_constant(n.expr) for n in expr.inputs):
  58. return expr
  59. const_var = expr.interpret(*[get_const_value(n.expr) for n in expr.inputs])[0]
  60. graph = expr.top_graph
  61. with graph.insert_exprs(expr):
  62. const_node = _as_const_node(const_var)
  63. graph.replace_node({expr.outputs[0]: const_node})
  64. graph.compile()
  65. const_node.name = expr.outputs[0].name
  66. return const_node.expr
  67. @register_pass("NormElemWise")
  68. class NormElemWise(BackwardPass):
  69. r"""Transform add/sub or mul/div expr to add-only or mul-only chains.
  70. For example, the following code
  71. .. code-block::
  72. b = 1 - a
  73. c = 2 * b
  74. d = 1 / c
  75. will be changed to
  76. .. code-block::
  77. a1 = F.neg(a)
  78. b = a1 + 1
  79. c = b * 2
  80. d = F.pow(d, -1)
  81. """
  82. name = "NormElemWise"
  83. required_pass = ["FlodConstant"]
  84. run_once = False
  85. def __init__(self,):
  86. super().__init__()
  87. self.pattern = is_op(F.add)
  88. for op in [F.sub, F.mul, F.div]:
  89. self.pattern |= is_op(op)
  90. for op in ["__add__", "__iadd__", "__radd__"]:
  91. self.pattern |= is_op(op)
  92. for op in ["__sub__", "__isub__", "__rsub__"]:
  93. self.pattern |= is_op(op)
  94. for op in ["__mul__", "__imul__", "__rmul__"]:
  95. self.pattern |= is_op(op)
  96. for op in ["__truediv__", "__itruediv__", "__rtruediv__"]:
  97. self.pattern |= is_op(op)
  98. def run_transform(self, expr: Expr):
  99. matcher = PatternMatcher()
  100. if not matcher.match(self.pattern, expr):
  101. return expr
  102. pattern = matcher.matched_patterns[0]
  103. target = pattern.target
  104. cofee, left_node, right_node = 1, None, None
  105. if len(expr.inputs) == 1 and target not in ["__add__", "__mul__"]:
  106. left_node = expr.inputs[0]
  107. right_node = expr.const_val[0][-1]
  108. if target in ["__rsub__", "__rtruediv__"]:
  109. cofee = -1
  110. if target in [F.sub, F.div] and left_node is not expr.kwargs["x"]:
  111. cofee = -1
  112. elif len(expr.inputs) == 2 and (
  113. target not in ["__add__", "__mul__"] or is_constant(expr.inputs[0].expr)
  114. ):
  115. left_node, right_node = expr.inputs
  116. if target in ["__rsub__", "__rtruediv__"]:
  117. left_node, right_node = right_node, left_node
  118. if target in [F.sub, F.div] and left_node is not expr.kwargs["x"]:
  119. left_node, right_node = right_node, left_node
  120. if is_constant(left_node.expr):
  121. left_node, right_node = right_node, left_node
  122. cofee = -1
  123. if left_node is None:
  124. return expr
  125. if isinstance(right_node, TensorNode):
  126. right_node = get_const_value(right_node.expr, right_node)
  127. graph = expr.top_graph
  128. with graph.insert_exprs():
  129. if target in ["__mul__", "__imul__", "__rmul__", F.mul]:
  130. out_node = left_node * right_node
  131. elif target in ["__add__", "__iadd__", "__radd__", F.add]:
  132. out_node = left_node + right_node
  133. elif target in ["__sub__", "__isub__", "__rsub__", F.sub]:
  134. if cofee == -1:
  135. left_node = F.neg(left_node)
  136. else:
  137. if isinstance(right_node, TensorNode):
  138. right_node = F.neg(right_node)
  139. else:
  140. right_node = -1 * right_node
  141. out_node = left_node + right_node
  142. elif target in ["__truediv__", "__itruediv__", "__rtruediv__", F.div]:
  143. if cofee == -1:
  144. left_node = F.pow(left_node, -1)
  145. else:
  146. if isinstance(right_node, TensorNode):
  147. right_node = F.pow(right_node, -1)
  148. else:
  149. right_node = 1 / right_node
  150. out_node = left_node * right_node
  151. graph.replace_node({expr.outputs[0]: out_node})
  152. graph.compile()
  153. return out_node.expr