You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

fuse_pass.py 8.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240
  1. import operator
  2. from collections import defaultdict
  3. from typing import Any, Callable, List
  4. from ... import functional as F
  5. from ... import module as M
  6. from ...logger import get_logger
  7. from ...tensor import Parameter, Tensor
  8. from ...utils.bn_fusion import fold_weight_bias
  9. from ..expr import Expr, is_call_function
  10. from ..utils import assign_attr, get_subattr
  11. from .matcher import PatternMatcher
  12. from .pass_base import BackwardPass, register_pass
  13. from .pattern import ExprPattern, any_node, is_const, is_op, is_var
  14. from .utils import get_const_value, register_obj
  15. logger = get_logger(__name__)
  16. @register_pass("FuseAddMul")
  17. class FuseAddMul(BackwardPass):
  18. """Fold adjacent const add or mul binary operations.
  19. For example, the following code
  20. .. code-block::
  21. x = x + 1
  22. x = 2 + x
  23. x = x * 4
  24. x = x * 0.25
  25. will be changed to
  26. .. code-block::
  27. x = x + 3
  28. """
  29. name = "FuseAddMul"
  30. required_pass = ["NormElemWise"]
  31. run_once = False
  32. def __init__(self,):
  33. super().__init__()
  34. def _make_pattern(op_0, op_1) -> ExprPattern:
  35. x = is_var().check_users(False)
  36. if op_0 not in [operator.add, operator.mul]:
  37. op_0 = is_op(op_0)
  38. if op_1 not in [operator.add, operator.mul]:
  39. op_1 = is_op(op_1)
  40. pattern = op_0(x, is_const()) | op_0(x, "*")
  41. pattern = op_1(pattern, is_const()) | op_1(pattern, "*")
  42. return pattern
  43. self.pattern_dict = {}
  44. for op, func in zip([operator.add, F.pow], [self.fold_add, self.fold_pow],):
  45. self.pattern_dict[_make_pattern(op, op)] = func
  46. for op_0 in [F.neg, operator.mul]:
  47. for op_1 in [F.neg, operator.mul]:
  48. self.pattern_dict[_make_pattern(op_0, op_1)] = self.fold_mul
  49. def run_transform(self, expr: Expr):
  50. matcher = PatternMatcher()
  51. for pattern, func in self.pattern_dict.items():
  52. res = matcher.match(pattern, expr)
  53. if res:
  54. break
  55. if not res:
  56. return expr
  57. return func(expr)
  58. def _fold_helper(self, expr: Expr, op_c: Callable, op_t: Callable):
  59. const_0 = self.get_const_value(expr)
  60. # todo: support more shape
  61. if isinstance(const_0, Tensor) and const_0._tuple_shape not in [(1,), tuple()]:
  62. return expr
  63. const_1 = self.get_const_value(expr.inputs[0].expr)
  64. if isinstance(const_1, Tensor) and const_1._tuple_shape not in [(1,), tuple()]:
  65. return expr
  66. inp_node = expr.inputs[0].expr.inputs[0]
  67. const = op_c(const_0, const_1)
  68. graph = expr.top_graph
  69. if (const == 1 and op_t in [operator.pow, operator.mul]) or (
  70. const == 0 and op_t in [operator.add]
  71. ):
  72. graph.replace_node({expr.outputs[0]: inp_node})
  73. graph.compile()
  74. return expr
  75. with expr.top_graph.insert_exprs():
  76. out_node = op_t(inp_node, const)
  77. graph.replace_node({expr.outputs[0]: out_node})
  78. graph.compile()
  79. return out_node.expr
  80. def fold_add(self, expr: Expr):
  81. return self._fold_helper(expr, operator.add, operator.add)
  82. def fold_mul(self, expr):
  83. return self._fold_helper(expr, operator.mul, operator.mul)
  84. def fold_pow(self, expr):
  85. return self._fold_helper(expr, operator.mul, F.pow)
  86. def get_const_value(self, expr: Expr):
  87. if is_call_function(expr, F.neg):
  88. return -1
  89. if len(expr.inputs) == 2:
  90. value = get_const_value(expr.inputs[1].expr, None)
  91. assert value is not None, " "
  92. return value
  93. value = expr.const_val[0][-1]
  94. return value
  95. @register_pass("FuseConvBn")
  96. class FuseConvBn(BackwardPass):
  97. r"""Fuse BN layers into conv2d."""
  98. name = "FuseConvBn"
  99. required_pass = ["AttrToConstant"]
  100. run_once = True
  101. def __init__(self):
  102. super().__init__()
  103. self.used_name = defaultdict(int)
  104. def run_transform(self, expr: Expr):
  105. conv_pat_0 = is_op(M.Conv2d)
  106. conv_pat_1 = is_op(F.conv2d)
  107. bn_pat_0 = is_op(M.BatchNorm2d)(conv_pat_0 | conv_pat_1)
  108. bn_pat_1 = is_op(F.batch_norm)
  109. # inp, running_mean, running_var, weight, bias
  110. bn_inps = (
  111. conv_pat_0 | conv_pat_1,
  112. is_const(),
  113. is_const(),
  114. is_const(),
  115. is_const(),
  116. )
  117. bn_pat = (
  118. (bn_pat_1(*bn_inps[:3]))
  119. | (bn_pat_1(*bn_inps[:4]))
  120. | (bn_pat_1(*bn_inps))
  121. | bn_pat_0
  122. )
  123. matcher = PatternMatcher()
  124. if not matcher.match(bn_pat, expr):
  125. return expr
  126. matched_exprs = matcher.matched_exprs
  127. if conv_pat_0 in matched_exprs:
  128. return self.fold_convm_bn(matched_exprs[conv_pat_0], matched_exprs[bn_pat])
  129. else:
  130. return self.fold_convf_bn(matched_exprs[conv_pat_1], matched_exprs[bn_pat])
  131. def fold_convm_bn(self, conv: Expr, bn: Expr):
  132. mnode, inp_node = conv.inputs[:2]
  133. self_node = mnode.expr.inputs[0]
  134. attr_name = conv.inputs[0].expr.name
  135. graph = conv.top_graph
  136. if len(mnode.users) > 1:
  137. self.used_name[mnode.qualname] += 1
  138. attr_name = "{}_{}".format(attr_name, self.used_name[mnode.qualname])
  139. logger.warning(
  140. "{} is used {} times and its name will be reset to {}.{}".format(
  141. mnode.qualname, len(mnode.users), graph.qualname, attr_name
  142. )
  143. )
  144. conv_module = mnode.owner
  145. weight, bias = conv_module.weight, conv_module.bias
  146. mean, var, gamma, beta, eps = self.get_bn_params(bn)
  147. weight, bias = fold_weight_bias(weight, bias, gamma, beta, mean, var, eps)
  148. new_conv = M.Conv2d(
  149. in_channels=conv_module.in_channels,
  150. out_channels=conv_module.out_channels,
  151. kernel_size=conv_module.kernel_size,
  152. stride=conv_module.stride,
  153. padding=conv_module.padding,
  154. dilation=conv_module.dilation,
  155. groups=conv_module.groups,
  156. bias=conv_module.bias is not None,
  157. conv_mode=conv_module.conv_mode,
  158. compute_mode=conv_module.compute_mode,
  159. name=conv_module.name,
  160. )
  161. new_conv.weight = Parameter(weight)
  162. new_conv.bias = Parameter(bias)
  163. new_conv.training = conv_module.training
  164. assign_attr(new_conv, self_node.owner, attr_name)
  165. with graph.insert_exprs(mnode.expr):
  166. out_node = get_subattr(self_node, attr_name)(inp_node)
  167. graph.replace_node({bn.outputs[0]: out_node})
  168. graph.compile()
  169. out_node.name = conv.outputs[0].name
  170. return out_node.expr
  171. def fold_convf_bn(self, conv: Expr, bn: Expr):
  172. named_args = conv.named_args
  173. weight = get_const_value(named_args["weight"], named_args["weight"])
  174. bias = get_const_value(named_args["bias"], named_args["bias"])
  175. mean, var, gamma, beta, eps = self.get_bn_params(bn)
  176. weight, bias = fold_weight_bias(weight, bias, gamma, beta, mean, var, eps)
  177. named_args["weight"] = weight
  178. named_args["bias"] = bias
  179. graph = conv.top_graph
  180. with graph.insert_exprs():
  181. out_node = F.conv2d(**named_args)
  182. graph.replace_node({bn.outputs[0]: out_node})
  183. graph.compile()
  184. out_node.name = conv.outputs[0].name
  185. return out_node.expr
  186. def get_bn_params(self, bn: Expr):
  187. if is_call_function(bn):
  188. named_args = bn.named_args
  189. mean = get_const_value(
  190. named_args["running_mean"], named_args["running_mean"]
  191. )
  192. var = get_const_value(named_args["running_var"], named_args["running_var"])
  193. gamma = get_const_value(named_args["weight"], named_args["weight"])
  194. beta = get_const_value(named_args["bias"], named_args["bias"])
  195. eps = named_args["eps"]
  196. return mean, var, gamma, beta, eps
  197. else:
  198. bn_module = bn.inputs[0].owner
  199. mean = bn_module.running_mean
  200. var = bn_module.running_var
  201. gamma = bn_module.weight
  202. beta = bn_module.bias
  203. eps = bn_module.eps
  204. return mean, var, gamma, beta, eps