|
- # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
- #
- # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
- #
- # Unless required by applicable law or agreed to in writing,
- # software distributed under the License is distributed on an
- # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-
- import operator
- from collections import defaultdict
- from typing import Any, Callable, List
-
- from ... import functional as F
- from ... import module as M
- from ...logger import get_logger
- from ...tensor import Parameter, Tensor
- from ...utils.bn_fusion import fold_weight_bias
- from ..expr import Expr, is_call_function
- from ..utils import assign_attr, get_subattr
- from .matcher import PatternMatcher
- from .pass_base import BackwardPass, register_pass
- from .pattern import ExprPattern, any_node, is_const, is_op, is_var
- from .utils import get_const_value, register_obj
-
- logger = get_logger(__name__)
-
-
- @register_pass("FuseAddMul")
- class FuseAddMul(BackwardPass):
- """Fold adjacent const add or mul binary operations.
-
- For example, the following code
-
- .. code-block::
-
- x = x + 1
- x = 2 + x
- x = x * 4
- x = x * 0.25
-
- will be changed to
-
- .. code-block::
-
- x = x + 3
- """
-
- name = "FuseAddMul"
- required_pass = ["NormElemWise"]
- run_once = False
-
- def __init__(self,):
- super().__init__()
-
- def _make_pattern(op_0, op_1) -> ExprPattern:
- x = is_var().check_users(False)
- if op_0 not in [operator.add, operator.mul]:
- op_0 = is_op(op_0)
- if op_1 not in [operator.add, operator.mul]:
- op_1 = is_op(op_1)
- pattern = op_0(x, is_const()) | op_0(x, "*")
- pattern = op_1(pattern, is_const()) | op_1(pattern, "*")
- return pattern
-
- self.pattern_dict = {}
-
- for op, func in zip([operator.add, F.pow], [self.fold_add, self.fold_pow],):
- self.pattern_dict[_make_pattern(op, op)] = func
-
- for op_0 in [F.neg, operator.mul]:
- for op_1 in [F.neg, operator.mul]:
- self.pattern_dict[_make_pattern(op_0, op_1)] = self.fold_mul
-
- def run_transform(self, expr: Expr):
- matcher = PatternMatcher()
- for pattern, func in self.pattern_dict.items():
- res = matcher.match(pattern, expr)
- if res:
- break
- if not res:
- return expr
- return func(expr)
-
- def _fold_helper(self, expr: Expr, op_c: Callable, op_t: Callable):
- const_0 = self.get_const_value(expr)
- # todo: support more shape
- if isinstance(const_0, Tensor) and const_0._tuple_shape not in [(1,), tuple()]:
- return expr
-
- const_1 = self.get_const_value(expr.inputs[0].expr)
- if isinstance(const_1, Tensor) and const_1._tuple_shape not in [(1,), tuple()]:
- return expr
-
- inp_node = expr.inputs[0].expr.inputs[0]
- const = op_c(const_0, const_1)
- graph = expr.top_graph
-
- if (const == 1 and op_t in [operator.pow, operator.mul]) or (
- const == 0 and op_t in [operator.add]
- ):
- graph.replace_node({expr.outputs[0]: inp_node})
- graph.compile()
- return expr
-
- with expr.top_graph.insert_exprs():
- out_node = op_t(inp_node, const)
- graph.replace_node({expr.outputs[0]: out_node})
- graph.compile()
- return out_node.expr
-
- def fold_add(self, expr: Expr):
- return self._fold_helper(expr, operator.add, operator.add)
-
- def fold_mul(self, expr):
- return self._fold_helper(expr, operator.mul, operator.mul)
-
- def fold_pow(self, expr):
- return self._fold_helper(expr, operator.mul, F.pow)
-
- def get_const_value(self, expr: Expr):
- if is_call_function(expr, F.neg):
- return -1
- if len(expr.inputs) == 2:
- value = get_const_value(expr.inputs[1].expr, None)
- assert value is not None, " "
- return value
- value = expr.const_val[0][-1]
- return value
-
-
- @register_pass("FuseConvBn")
- class FuseConvBn(BackwardPass):
- r"""Fuse BN layers into conv2d."""
- name = "FuseConvBn"
- required_pass = ["AttrToConstant"]
- run_once = True
-
- def __init__(self):
- super().__init__()
- self.used_name = defaultdict(int)
-
- def run_transform(self, expr: Expr):
- conv_pat_0 = is_op(M.Conv2d)
- conv_pat_1 = is_op(F.conv2d)
- bn_pat_0 = is_op(M.BatchNorm2d)(conv_pat_0 | conv_pat_1)
- bn_pat_1 = is_op(F.batch_norm)
- # inp, running_mean, running_var, weight, bias
- bn_inps = (
- conv_pat_0 | conv_pat_1,
- is_const(),
- is_const(),
- is_const(),
- is_const(),
- )
- bn_pat = (
- (bn_pat_1(*bn_inps[:3]))
- | (bn_pat_1(*bn_inps[:4]))
- | (bn_pat_1(*bn_inps))
- | bn_pat_0
- )
-
- matcher = PatternMatcher()
- if not matcher.match(bn_pat, expr):
- return expr
-
- matched_exprs = matcher.matched_exprs
- if conv_pat_0 in matched_exprs:
- return self.fold_convm_bn(matched_exprs[conv_pat_0], matched_exprs[bn_pat])
- else:
- return self.fold_convf_bn(matched_exprs[conv_pat_1], matched_exprs[bn_pat])
-
- def fold_convm_bn(self, conv: Expr, bn: Expr):
- mnode, inp_node = conv.inputs[:2]
- self_node = mnode.expr.inputs[0]
- attr_name = conv.inputs[0].expr.name
- graph = conv.top_graph
- if len(mnode.users) > 1:
- self.used_name[mnode.qualname] += 1
- attr_name = "{}_{}".format(attr_name, self.used_name[mnode.qualname])
- logger.warning(
- "{} is used {} times and its name will be reset to {}.{}".format(
- mnode.qualname, len(mnode.users), graph.qualname, attr_name
- )
- )
-
- conv_module = mnode.owner
- weight, bias = conv_module.weight, conv_module.bias
- mean, var, gamma, beta, eps = self.get_bn_params(bn)
- weight, bias = fold_weight_bias(weight, bias, gamma, beta, mean, var, eps)
- new_conv = M.Conv2d(
- in_channels=conv_module.in_channels,
- out_channels=conv_module.out_channels,
- kernel_size=conv_module.kernel_size,
- stride=conv_module.stride,
- padding=conv_module.padding,
- dilation=conv_module.dilation,
- groups=conv_module.groups,
- bias=conv_module.bias is not None,
- conv_mode=conv_module.conv_mode,
- compute_mode=conv_module.compute_mode,
- name=conv_module.name,
- )
- new_conv.weight = Parameter(weight)
- new_conv.bias = Parameter(bias)
- new_conv.training = conv_module.training
- assign_attr(new_conv, self_node.owner, attr_name)
- with graph.insert_exprs(mnode.expr):
- out_node = get_subattr(self_node, attr_name)(inp_node)
- graph.replace_node({bn.outputs[0]: out_node})
- graph.compile()
- out_node.name = conv.outputs[0].name
- return out_node.expr
-
- def fold_convf_bn(self, conv: Expr, bn: Expr):
- named_args = conv.named_args
- weight = get_const_value(named_args["weight"], named_args["weight"])
- bias = get_const_value(named_args["bias"], named_args["bias"])
- mean, var, gamma, beta, eps = self.get_bn_params(bn)
- weight, bias = fold_weight_bias(weight, bias, gamma, beta, mean, var, eps)
- named_args["weight"] = weight
- named_args["bias"] = bias
- graph = conv.top_graph
- with graph.insert_exprs():
- out_node = F.conv2d(**named_args)
- graph.replace_node({bn.outputs[0]: out_node})
- graph.compile()
- out_node.name = conv.outputs[0].name
- return out_node.expr
-
- def get_bn_params(self, bn: Expr):
- if is_call_function(bn):
- named_args = bn.named_args
- mean = get_const_value(
- named_args["running_mean"], named_args["running_mean"]
- )
- var = get_const_value(named_args["running_var"], named_args["running_var"])
- gamma = get_const_value(named_args["weight"], named_args["weight"])
- beta = get_const_value(named_args["bias"], named_args["bias"])
- eps = named_args["eps"]
- return mean, var, gamma, beta, eps
- else:
- bn_module = bn.inputs[0].owner
- mean = bn_module.running_mean
- var = bn_module.running_var
- gamma = bn_module.weight
- beta = bn_module.bias
- eps = bn_module.eps
- return mean, var, gamma, beta, eps
|