You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

const_pass.py 6.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. from ... import functional as F
  9. from ... import module as M
  10. from ...core.ops.builtin import GetVarShape
  11. from ...logger import get_logger
  12. from ...tensor import Tensor
  13. from ..expr import Constant, Expr, is_apply_def, is_constant, is_getattr
  14. from ..node import Node, TensorNode
  15. from .matcher import PatternMatcher
  16. from .pass_base import BackwardPass, ForwardPass, register_pass
  17. from .pattern import is_op
  18. from .utils import get_const_value
  19. logger = get_logger(__name__)
  20. @register_pass("AttrToConstant")
  21. class AttrToConstant(BackwardPass):
  22. r"""Convert :class:`~.GetAttr` to :class:`~.Constant` expr."""
  23. name = "AttrToConstant"
  24. run_once = True
  25. def run_transform(self, expr: Expr):
  26. if not (is_getattr(expr) and isinstance(expr.outputs[0], TensorNode)):
  27. return expr
  28. graph = expr.top_graph
  29. value = get_const_value(expr)
  30. orig_node = expr.outputs[0]
  31. name = orig_node.name
  32. with graph.insert_exprs(expr):
  33. const_node = Constant.make(value, name=name)
  34. graph.replace_node({orig_node: const_node})
  35. graph.compile()
  36. name = orig_node.name
  37. return const_node.expr
  38. @register_pass("FixInputShape")
  39. class FixInputShape(BackwardPass):
  40. name = "FixInputShape"
  41. run_once = True
  42. def run_transform(self, expr: Expr):
  43. if not is_apply_def(expr, GetVarShape):
  44. return expr
  45. shape = Tensor(expr.inputs[0].shape, dtype="int32")
  46. graph = expr.top_graph
  47. with graph.insert_exprs(expr):
  48. const_shape = Constant.make(shape)
  49. graph.replace_node({expr.outputs[0]: const_shape})
  50. graph.compile()
  51. const_shape.name = expr.outputs[0].name
  52. return const_shape.expr
  53. @register_pass("FlodConstant")
  54. class FlodConstant(ForwardPass):
  55. r"""Constant folding."""
  56. name = "FlodConstant"
  57. required_pass = ["AttrToConstant"]
  58. run_once = False
  59. def run_transform(self, expr: Expr):
  60. if len(expr.inputs) == 0 or any(not is_constant(n.expr) for n in expr.inputs):
  61. return expr
  62. const_var = expr.interpret(*[get_const_value(n.expr) for n in expr.inputs])[0]
  63. graph = expr.top_graph
  64. with graph.insert_exprs(expr):
  65. const_node = Constant.make(const_var)
  66. graph.replace_node({expr.outputs[0]: const_node})
  67. graph.compile()
  68. const_node.name = expr.outputs[0].name
  69. return const_node.expr
  70. @register_pass("NormElemWise")
  71. class NormElemWise(BackwardPass):
  72. r"""Transform add/sub or mul/div expr to add-only or mul-only chains.
  73. For example, the following code
  74. .. code-block::
  75. b = 1 - a
  76. c = 2 * b
  77. d = 1 / c
  78. will be changed to
  79. .. code-block::
  80. a1 = F.neg(a)
  81. b = a1 + 1
  82. c = b * 2
  83. d = F.pow(d, -1)
  84. """
  85. name = "NormElemWise"
  86. required_pass = ["FlodConstant"]
  87. run_once = False
  88. def __init__(self,):
  89. super().__init__()
  90. self.pattern = is_op(F.add)
  91. for op in [F.sub, F.mul, F.div]:
  92. self.pattern |= is_op(op)
  93. for op in ["__add__", "__iadd__", "__radd__"]:
  94. self.pattern |= is_op(op)
  95. for op in ["__sub__", "__isub__", "__rsub__"]:
  96. self.pattern |= is_op(op)
  97. for op in ["__mul__", "__imul__", "__rmul__"]:
  98. self.pattern |= is_op(op)
  99. for op in ["__truediv__", "__itruediv__", "__rtruediv__"]:
  100. self.pattern |= is_op(op)
  101. def run_transform(self, expr: Expr):
  102. matcher = PatternMatcher()
  103. if not matcher.match(self.pattern, expr):
  104. return expr
  105. pattern = matcher.matched_patterns[0]
  106. target = pattern.target
  107. cofee, left_node, right_node = 1, None, None
  108. if len(expr.inputs) == 1 and target not in ["__add__", "__mul__"]:
  109. left_node = expr.inputs[0]
  110. right_node = expr.const_val[0][-1]
  111. if target in ["__rsub__", "__rtruediv__"]:
  112. cofee = -1
  113. if target in [F.sub, F.div] and left_node is not expr.kwargs["x"]:
  114. cofee = -1
  115. elif len(expr.inputs) == 2 and (
  116. target not in ["__add__", "__mul__"] or is_constant(expr.inputs[0].expr)
  117. ):
  118. left_node, right_node = expr.inputs
  119. if target in ["__rsub__", "__rtruediv__"]:
  120. left_node, right_node = right_node, left_node
  121. if target in [F.sub, F.div] and left_node is not expr.kwargs["x"]:
  122. left_node, right_node = right_node, left_node
  123. if is_constant(left_node.expr):
  124. left_node, right_node = right_node, left_node
  125. cofee = -1
  126. if left_node is None:
  127. return expr
  128. if isinstance(right_node, TensorNode):
  129. right_node = get_const_value(right_node.expr, right_node)
  130. graph = expr.top_graph
  131. with graph.insert_exprs():
  132. if target in ["__mul__", "__imul__", "__rmul__", F.mul]:
  133. out_node = left_node * right_node
  134. elif target in ["__add__", "__iadd__", "__radd__", F.add]:
  135. out_node = left_node + right_node
  136. elif target in ["__sub__", "__isub__", "__rsub__", F.sub]:
  137. if cofee == -1:
  138. left_node = F.neg(left_node)
  139. else:
  140. if isinstance(right_node, TensorNode):
  141. right_node = F.neg(right_node)
  142. else:
  143. right_node = -1 * right_node
  144. out_node = left_node + right_node
  145. elif target in ["__truediv__", "__itruediv__", "__rtruediv__", F.div]:
  146. if cofee == -1:
  147. left_node = F.pow(left_node, -1)
  148. else:
  149. if isinstance(right_node, TensorNode):
  150. right_node = F.pow(right_node, -1)
  151. else:
  152. right_node = 1 / right_node
  153. out_node = left_node * right_node
  154. graph.replace_node({expr.outputs[0]: out_node})
  155. graph.compile()
  156. return out_node.expr

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台