You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

const_pass.py 6.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. from ... import functional as F
  9. from ... import module as M
  10. from ...core.ops.builtin import GetVarShape
  11. from ...logger import get_logger
  12. from ...tensor import Tensor
  13. from ..expr import Constant, Expr, is_apply_def, is_constant, is_getattr
  14. from ..node import Node, NodeMixin, TensorNode
  15. from .matcher import PatternMatcher
  16. from .pass_base import BackwardPass, ForwardPass, register_pass
  17. from .pattern import is_op
  18. from .utils import get_const_value
  19. logger = get_logger(__name__)
  20. def _as_const_node(x):
  21. node = Constant.make(x)
  22. NodeMixin.wrap(x, node)
  23. return node
  24. @register_pass("AttrToConstant")
  25. class AttrToConstant(BackwardPass):
  26. r"""Convert :class:`~.GetAttr` to :class:`~.Constant` expr."""
  27. name = "AttrToConstant"
  28. run_once = True
  29. def run_transform(self, expr: Expr):
  30. if not (is_getattr(expr) and isinstance(expr.outputs[0], TensorNode)):
  31. return expr
  32. graph = expr.top_graph
  33. value = get_const_value(expr)
  34. orig_node = expr.outputs[0]
  35. name = orig_node.name
  36. with graph.insert_exprs(expr):
  37. const_node = _as_const_node(value)
  38. graph.replace_node({orig_node: const_node})
  39. graph.compile()
  40. const_node.name = name
  41. return const_node.expr
  42. @register_pass("FixInputShape")
  43. class FixInputShape(BackwardPass):
  44. name = "FixInputShape"
  45. run_once = True
  46. def run_transform(self, expr: Expr):
  47. if not is_apply_def(expr, GetVarShape):
  48. return expr
  49. shape = Tensor(expr.inputs[0].shape, dtype="int32")
  50. graph = expr.top_graph
  51. with graph.insert_exprs(expr):
  52. const_shape = _as_const_node(shape)
  53. graph.replace_node({expr.outputs[0]: const_shape})
  54. graph.compile()
  55. const_shape.name = expr.outputs[0].name
  56. return const_shape.expr
  57. @register_pass("FlodConstant")
  58. class FlodConstant(ForwardPass):
  59. r"""Constant folding."""
  60. name = "FlodConstant"
  61. required_pass = ["AttrToConstant"]
  62. run_once = False
  63. def run_transform(self, expr: Expr):
  64. if len(expr.inputs) == 0 or any(not is_constant(n.expr) for n in expr.inputs):
  65. return expr
  66. const_var = expr.interpret(*[get_const_value(n.expr) for n in expr.inputs])[0]
  67. graph = expr.top_graph
  68. with graph.insert_exprs(expr):
  69. const_node = _as_const_node(const_var)
  70. graph.replace_node({expr.outputs[0]: const_node})
  71. graph.compile()
  72. const_node.name = expr.outputs[0].name
  73. return const_node.expr
  74. @register_pass("NormElemWise")
  75. class NormElemWise(BackwardPass):
  76. r"""Transform add/sub or mul/div expr to add-only or mul-only chains.
  77. For example, the following code
  78. .. code-block::
  79. b = 1 - a
  80. c = 2 * b
  81. d = 1 / c
  82. will be changed to
  83. .. code-block::
  84. a1 = F.neg(a)
  85. b = a1 + 1
  86. c = b * 2
  87. d = F.pow(d, -1)
  88. """
  89. name = "NormElemWise"
  90. required_pass = ["FlodConstant"]
  91. run_once = False
  92. def __init__(self,):
  93. super().__init__()
  94. self.pattern = is_op(F.add)
  95. for op in [F.sub, F.mul, F.div]:
  96. self.pattern |= is_op(op)
  97. for op in ["__add__", "__iadd__", "__radd__"]:
  98. self.pattern |= is_op(op)
  99. for op in ["__sub__", "__isub__", "__rsub__"]:
  100. self.pattern |= is_op(op)
  101. for op in ["__mul__", "__imul__", "__rmul__"]:
  102. self.pattern |= is_op(op)
  103. for op in ["__truediv__", "__itruediv__", "__rtruediv__"]:
  104. self.pattern |= is_op(op)
  105. def run_transform(self, expr: Expr):
  106. matcher = PatternMatcher()
  107. if not matcher.match(self.pattern, expr):
  108. return expr
  109. pattern = matcher.matched_patterns[0]
  110. target = pattern.target
  111. cofee, left_node, right_node = 1, None, None
  112. if len(expr.inputs) == 1 and target not in ["__add__", "__mul__"]:
  113. left_node = expr.inputs[0]
  114. right_node = expr.const_val[0][-1]
  115. if target in ["__rsub__", "__rtruediv__"]:
  116. cofee = -1
  117. if target in [F.sub, F.div] and left_node is not expr.kwargs["x"]:
  118. cofee = -1
  119. elif len(expr.inputs) == 2 and (
  120. target not in ["__add__", "__mul__"] or is_constant(expr.inputs[0].expr)
  121. ):
  122. left_node, right_node = expr.inputs
  123. if target in ["__rsub__", "__rtruediv__"]:
  124. left_node, right_node = right_node, left_node
  125. if target in [F.sub, F.div] and left_node is not expr.kwargs["x"]:
  126. left_node, right_node = right_node, left_node
  127. if is_constant(left_node.expr):
  128. left_node, right_node = right_node, left_node
  129. cofee = -1
  130. if left_node is None:
  131. return expr
  132. if isinstance(right_node, TensorNode):
  133. right_node = get_const_value(right_node.expr, right_node)
  134. graph = expr.top_graph
  135. with graph.insert_exprs():
  136. if target in ["__mul__", "__imul__", "__rmul__", F.mul]:
  137. out_node = left_node * right_node
  138. elif target in ["__add__", "__iadd__", "__radd__", F.add]:
  139. out_node = left_node + right_node
  140. elif target in ["__sub__", "__isub__", "__rsub__", F.sub]:
  141. if cofee == -1:
  142. left_node = F.neg(left_node)
  143. else:
  144. if isinstance(right_node, TensorNode):
  145. right_node = F.neg(right_node)
  146. else:
  147. right_node = -1 * right_node
  148. out_node = left_node + right_node
  149. elif target in ["__truediv__", "__itruediv__", "__rtruediv__", F.div]:
  150. if cofee == -1:
  151. left_node = F.pow(left_node, -1)
  152. else:
  153. if isinstance(right_node, TensorNode):
  154. right_node = F.pow(right_node, -1)
  155. else:
  156. right_node = 1 / right_node
  157. out_node = left_node * right_node
  158. graph.replace_node({expr.outputs[0]: out_node})
  159. graph.compile()
  160. return out_node.expr