You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

fuse_pass.py 8.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. import operator
  9. from collections import defaultdict
  10. from typing import Any, Callable, List
  11. from ... import functional as F
  12. from ... import module as M
  13. from ...logger import get_logger
  14. from ...tensor import Parameter, Tensor
  15. from ...utils.bn_fusion import fold_weight_bias
  16. from ..expr import Expr, is_call_function
  17. from ..utils import assign_attr, get_subattr
  18. from .matcher import PatternMatcher
  19. from .pass_base import BackwardPass, register_pass
  20. from .pattern import ExprPattern, any_node, is_const, is_op, is_var
  21. from .utils import get_const_value, register_obj
  22. logger = get_logger(__name__)
  23. @register_pass("FuseAddMul")
  24. class FuseAddMul(BackwardPass):
  25. """Fold adjacent const add or mul binary operations.
  26. For example, the following code
  27. .. code-block::
  28. x = x + 1
  29. x = 2 + x
  30. x = x * 4
  31. x = x * 0.25
  32. will be changed to
  33. .. code-block::
  34. x = x + 3
  35. """
  36. name = "FuseAddMul"
  37. required_pass = ["NormElemWise"]
  38. run_once = False
  39. def __init__(self,):
  40. super().__init__()
  41. def _make_pattern(op_0, op_1) -> ExprPattern:
  42. x = is_var().check_users(False)
  43. if op_0 not in [operator.add, operator.mul]:
  44. op_0 = is_op(op_0)
  45. if op_1 not in [operator.add, operator.mul]:
  46. op_1 = is_op(op_1)
  47. pattern = op_0(x, is_const()) | op_0(x, "*")
  48. pattern = op_1(pattern, is_const()) | op_1(pattern, "*")
  49. return pattern
  50. self.pattern_dict = {}
  51. for op, func in zip([operator.add, F.pow], [self.fold_add, self.fold_pow],):
  52. self.pattern_dict[_make_pattern(op, op)] = func
  53. for op_0 in [F.neg, operator.mul]:
  54. for op_1 in [F.neg, operator.mul]:
  55. self.pattern_dict[_make_pattern(op_0, op_1)] = self.fold_mul
  56. def run_transform(self, expr: Expr):
  57. matcher = PatternMatcher()
  58. for pattern, func in self.pattern_dict.items():
  59. res = matcher.match(pattern, expr)
  60. if res:
  61. break
  62. if not res:
  63. return expr
  64. return func(expr)
  65. def _fold_helper(self, expr: Expr, op_c: Callable, op_t: Callable):
  66. const_0 = self.get_const_value(expr)
  67. # todo: support more shape
  68. if isinstance(const_0, Tensor) and const_0._tuple_shape not in [(1,), tuple()]:
  69. return expr
  70. const_1 = self.get_const_value(expr.inputs[0].expr)
  71. if isinstance(const_1, Tensor) and const_1._tuple_shape not in [(1,), tuple()]:
  72. return expr
  73. inp_node = expr.inputs[0].expr.inputs[0]
  74. const = op_c(const_0, const_1)
  75. graph = expr.top_graph
  76. if (const == 1 and op_t in [operator.pow, operator.mul]) or (
  77. const == 0 and op_t in [operator.add]
  78. ):
  79. graph.replace_node({expr.outputs[0]: inp_node})
  80. graph.compile()
  81. return expr
  82. with expr.top_graph.insert_exprs():
  83. out_node = op_t(inp_node, const)
  84. graph.replace_node({expr.outputs[0]: out_node})
  85. graph.compile()
  86. return out_node.expr
  87. def fold_add(self, expr: Expr):
  88. return self._fold_helper(expr, operator.add, operator.add)
  89. def fold_mul(self, expr):
  90. return self._fold_helper(expr, operator.mul, operator.mul)
  91. def fold_pow(self, expr):
  92. return self._fold_helper(expr, operator.mul, F.pow)
  93. def get_const_value(self, expr: Expr):
  94. if is_call_function(expr, F.neg):
  95. return -1
  96. if len(expr.inputs) == 2:
  97. value = get_const_value(expr.inputs[1].expr, None)
  98. assert value is not None, " "
  99. return value
  100. value = expr.const_val[0][-1]
  101. return value
  102. @register_pass("FuseConvBn")
  103. class FuseConvBn(BackwardPass):
  104. r"""Fuse BN layers into conv2d."""
  105. name = "FuseConvBn"
  106. required_pass = ["AttrToConstant"]
  107. run_once = True
  108. def __init__(self):
  109. super().__init__()
  110. self.used_name = defaultdict(int)
  111. def run_transform(self, expr: Expr):
  112. conv_pat_0 = is_op(M.Conv2d)
  113. conv_pat_1 = is_op(F.conv2d)
  114. bn_pat_0 = is_op(M.BatchNorm2d)(conv_pat_0 | conv_pat_1)
  115. bn_pat_1 = is_op(F.batch_norm)
  116. # inp, running_mean, running_var, weight, bias
  117. bn_inps = (
  118. conv_pat_0 | conv_pat_1,
  119. is_const(),
  120. is_const(),
  121. is_const(),
  122. is_const(),
  123. )
  124. bn_pat = (
  125. (bn_pat_1(*bn_inps[:3]))
  126. | (bn_pat_1(*bn_inps[:4]))
  127. | (bn_pat_1(*bn_inps))
  128. | bn_pat_0
  129. )
  130. matcher = PatternMatcher()
  131. if not matcher.match(bn_pat, expr):
  132. return expr
  133. matched_exprs = matcher.matched_exprs
  134. if conv_pat_0 in matched_exprs:
  135. return self.fold_convm_bn(matched_exprs[conv_pat_0], matched_exprs[bn_pat])
  136. else:
  137. return self.fold_convf_bn(matched_exprs[conv_pat_1], matched_exprs[bn_pat])
  138. def fold_convm_bn(self, conv: Expr, bn: Expr):
  139. mnode, inp_node = conv.inputs[:2]
  140. self_node = mnode.expr.inputs[0]
  141. attr_name = conv.inputs[0].expr.name
  142. graph = conv.top_graph
  143. if len(mnode.users) > 1:
  144. self.used_name[mnode.qualname] += 1
  145. attr_name = "{}_{}".format(attr_name, self.used_name[mnode.qualname])
  146. logger.warning(
  147. "{} is used {} times and its name will be reset to {}.{}".format(
  148. mnode.qualname, len(mnode.users), graph.qualname, attr_name
  149. )
  150. )
  151. conv_module = mnode.owner
  152. weight, bias = conv_module.weight, conv_module.bias
  153. mean, var, gamma, beta, eps = self.get_bn_params(bn)
  154. weight, bias = fold_weight_bias(weight, bias, gamma, beta, mean, var, eps)
  155. new_conv = M.Conv2d(
  156. in_channels=conv_module.in_channels,
  157. out_channels=conv_module.out_channels,
  158. kernel_size=conv_module.kernel_size,
  159. stride=conv_module.stride,
  160. padding=conv_module.padding,
  161. dilation=conv_module.dilation,
  162. groups=conv_module.groups,
  163. bias=conv_module.bias is not None,
  164. conv_mode=conv_module.conv_mode,
  165. compute_mode=conv_module.compute_mode,
  166. name=conv_module.name,
  167. )
  168. new_conv.weight = Parameter(weight)
  169. new_conv.bias = Parameter(bias)
  170. new_conv.training = conv_module.training
  171. assign_attr(new_conv, self_node.owner, attr_name)
  172. with graph.insert_exprs(mnode.expr):
  173. out_node = get_subattr(self_node, attr_name)(inp_node)
  174. graph.replace_node({bn.outputs[0]: out_node})
  175. graph.compile()
  176. out_node.name = conv.outputs[0].name
  177. return out_node.expr
  178. def fold_convf_bn(self, conv: Expr, bn: Expr):
  179. named_args = conv.named_args
  180. weight = get_const_value(named_args["weight"], named_args["weight"])
  181. bias = get_const_value(named_args["bias"], named_args["bias"])
  182. mean, var, gamma, beta, eps = self.get_bn_params(bn)
  183. weight, bias = fold_weight_bias(weight, bias, gamma, beta, mean, var, eps)
  184. named_args["weight"] = weight
  185. named_args["bias"] = bias
  186. graph = conv.top_graph
  187. with graph.insert_exprs():
  188. out_node = F.conv2d(**named_args)
  189. graph.replace_node({bn.outputs[0]: out_node})
  190. graph.compile()
  191. out_node.name = conv.outputs[0].name
  192. return out_node.expr
  193. def get_bn_params(self, bn: Expr):
  194. if is_call_function(bn):
  195. named_args = bn.named_args
  196. mean = get_const_value(
  197. named_args["running_mean"], named_args["running_mean"]
  198. )
  199. var = get_const_value(named_args["running_var"], named_args["running_var"])
  200. gamma = get_const_value(named_args["weight"], named_args["weight"])
  201. beta = get_const_value(named_args["bias"], named_args["bias"])
  202. eps = named_args["eps"]
  203. return mean, var, gamma, beta, eps
  204. else:
  205. bn_module = bn.inputs[0].owner
  206. mean = bn_module.running_mean
  207. var = bn_module.running_var
  208. gamma = bn_module.weight
  209. beta = bn_module.bias
  210. eps = bn_module.eps
  211. return mean, var, gamma, beta, eps