From 6c692b26a7834fe832aeb07cd6414b3cac49d040 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 26 Nov 2021 15:44:11 +0800 Subject: [PATCH] feat(mge/traced_module): add some fuse passes GitOrigin-RevId: 065f9df32eaead53544989c826910f8c326ba738 --- .../megengine/traced_module/_passes/const_pass.py | 183 +++++++++++++ .../traced_module/_passes/fold_scale_pass.py | 298 +++++++++++++++++++++ .../megengine/traced_module/_passes/fuse_pass.py | 248 +++++++++++++++++ .../megengine/traced_module/_passes/pass_base.py | 190 +++++++++++++ imperative/python/megengine/traced_module/expr.py | 28 +- imperative/python/megengine/traced_module/node.py | 1 + imperative/python/megengine/traced_module/utils.py | 24 ++ imperative/python/megengine/utils/bn_fusion.py | 86 ++++++ .../test/unit/traced_module/test_qat_module.py | 15 +- 9 files changed, 1050 insertions(+), 23 deletions(-) create mode 100644 imperative/python/megengine/traced_module/_passes/const_pass.py create mode 100644 imperative/python/megengine/traced_module/_passes/fold_scale_pass.py create mode 100644 imperative/python/megengine/traced_module/_passes/fuse_pass.py create mode 100644 imperative/python/megengine/traced_module/_passes/pass_base.py create mode 100644 imperative/python/megengine/utils/bn_fusion.py diff --git a/imperative/python/megengine/traced_module/_passes/const_pass.py b/imperative/python/megengine/traced_module/_passes/const_pass.py new file mode 100644 index 00000000..143a704c --- /dev/null +++ b/imperative/python/megengine/traced_module/_passes/const_pass.py @@ -0,0 +1,183 @@ +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + +from ... import functional as F +from ... import module as M +from ...core.ops.builtin import GetVarShape +from ...logger import get_logger +from ...tensor import Tensor +from ..expr import Constant, Expr, is_apply_def, is_constant, is_getattr +from ..node import Node, TensorNode +from .matcher import PatternMatcher +from .pass_base import BackwardPass, ForwardPass, register_pass +from .pattern import is_op +from .utils import get_const_value + +logger = get_logger(__name__) + + +@register_pass("AttrToConstant") +class AttrToConstant(BackwardPass): + r"""Convert :class:`~.GetAttr` to :class:`~.Constant` expr.""" + name = "AttrToConstant" + run_once = True + + def run_transform(self, expr: Expr): + if not (is_getattr(expr) and isinstance(expr.outputs[0], TensorNode)): + return expr + graph = expr.top_graph + value = get_const_value(expr) + orig_node = expr.outputs[0] + name = orig_node.name + with graph.insert_exprs(expr): + const_node = Constant.make(value, name=name) + graph.replace_node({orig_node: const_node}) + graph.compile() + name = orig_node.name + return const_node.expr + + +@register_pass("FixInputShape") +class FixInputShape(BackwardPass): + name = "FixInputShape" + run_once = True + + def run_transform(self, expr: Expr): + if not is_apply_def(expr, GetVarShape): + return expr + shape = Tensor(expr.inputs[0].shape, dtype="int32") + graph = expr.top_graph + with graph.insert_exprs(expr): + const_shape = Constant.make(shape) + graph.replace_node({expr.outputs[0]: const_shape}) + graph.compile() + const_shape.name = expr.outputs[0].name + return const_shape.expr + + +@register_pass("FlodConstant") +class FlodConstant(ForwardPass): + r"""Constant folding.""" + name = "FlodConstant" + required_pass = ["AttrToConstant"] + run_once = False + + def run_transform(self, expr: Expr): + if len(expr.inputs) == 0 or any(not is_constant(n.expr) for n in expr.inputs): + return expr + const_var = expr.interpret(*[get_const_value(n.expr) for n in expr.inputs])[0] + graph = expr.top_graph + with graph.insert_exprs(expr): + const_node = Constant.make(const_var) + graph.replace_node({expr.outputs[0]: const_node}) + graph.compile() + const_node.name = expr.outputs[0].name + return const_node.expr + + +@register_pass("NormElemWise") +class NormElemWise(BackwardPass): + r"""Transform add/sub or mul/div expr to add-only or mul-only chains. + + For example, the following code + + .. code-block:: + + b = 1 - a + c = 2 * b + d = 1 / c + + will be changed to + + .. code-block:: + + a1 = F.neg(a) + b = a1 + 1 + c = b * 2 + d = F.pow(d, -1) + """ + name = "NormElemWise" + required_pass = ["FlodConstant"] + run_once = False + + def __init__(self,): + super().__init__() + self.pattern = is_op(F.add) + for op in [F.sub, F.mul, F.div]: + self.pattern |= is_op(op) + for op in ["__add__", "__iadd__", "__radd__"]: + self.pattern |= is_op(op) + for op in ["__sub__", "__isub__", "__rsub__"]: + self.pattern |= is_op(op) + for op in ["__mul__", "__imul__", "__rmul__"]: + self.pattern |= is_op(op) + for op in ["__truediv__", "__itruediv__", "__rtruediv__"]: + self.pattern |= is_op(op) + + def run_transform(self, expr: Expr): + + matcher = PatternMatcher() + if not matcher.match(self.pattern, expr): + return expr + + pattern = matcher.matched_patterns[0] + target = pattern.target + cofee, left_node, right_node = 1, None, None + if len(expr.inputs) == 1 and target not in ["__add__", "__mul__"]: + left_node = expr.inputs[0] + right_node = expr.const_val[0][-1] + if target in ["__rsub__", "__rtruediv__"]: + cofee = -1 + if target in [F.sub, F.div] and left_node is not expr.kwargs["x"]: + cofee = -1 + elif len(expr.inputs) == 2 and ( + target not in ["__add__", "__mul__"] or is_constant(expr.inputs[0].expr) + ): + left_node, right_node = expr.inputs + if target in ["__rsub__", "__rtruediv__"]: + left_node, right_node = right_node, left_node + if target in [F.sub, F.div] and left_node is not expr.kwargs["x"]: + left_node, right_node = right_node, left_node + if is_constant(left_node.expr): + left_node, right_node = right_node, left_node + cofee = -1 + + if left_node is None: + return expr + + if isinstance(right_node, TensorNode): + right_node = get_const_value(right_node.expr, right_node) + + graph = expr.top_graph + with graph.insert_exprs(): + if target in ["__mul__", "__imul__", "__rmul__", F.mul]: + out_node = left_node * right_node + elif target in ["__add__", "__iadd__", "__radd__", F.add]: + out_node = left_node + right_node + elif target in ["__sub__", "__isub__", "__rsub__", F.sub]: + if cofee == -1: + left_node = F.neg(left_node) + else: + if isinstance(right_node, TensorNode): + right_node = F.neg(right_node) + else: + right_node = -1 * right_node + out_node = left_node + right_node + elif target in ["__truediv__", "__itruediv__", "__rtruediv__", F.div]: + if cofee == -1: + left_node = F.pow(left_node, -1) + else: + if isinstance(right_node, TensorNode): + right_node = F.pow(right_node, -1) + else: + right_node = 1 / right_node + out_node = left_node * right_node + + graph.replace_node({expr.outputs[0]: out_node}) + graph.compile() + return out_node.expr diff --git a/imperative/python/megengine/traced_module/_passes/fold_scale_pass.py b/imperative/python/megengine/traced_module/_passes/fold_scale_pass.py new file mode 100644 index 00000000..ab3e31a2 --- /dev/null +++ b/imperative/python/megengine/traced_module/_passes/fold_scale_pass.py @@ -0,0 +1,298 @@ +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + +from collections import OrderedDict, defaultdict +from copy import deepcopy +from typing import Any, Dict, List, Set + +from ... import functional as F +from ... import module as M +from ...core.ops.builtin import GetVarShape +from ...logger import get_logger +from ...tensor import Parameter, Tensor +from ..expr import ( + Expr, + is_apply_def, + is_call_function, + is_call_module, + is_call_tensor_method, + is_constant, + is_getattr, +) +from ..traced_module import InternalGraph +from ..utils import assign_attr, get_subattr +from .matcher import PatternMatcher +from .pass_base import BackwardPass, register_pass +from .pattern import is_const, is_op, is_var +from .utils import get_const_value + +logger = get_logger(__name__) + + +@register_pass("BackwardFoldScale") +class BackwardFoldScale(BackwardPass): + r"""Backward fold const scaling into weights of conv2d. + + For example, the following code + + .. code-block:: + + x = conv(x, w, b) + x = relu(x) + x1 = x + 3 + x2 = x + 4 + y = (x1 + x2) * 3 + + will be changed to + + .. code-block:: + + x = conv(x, w * 3, b * 3) + x = relu(x) + x1 = x + 9 + x2 = x + 12 + y = x1 + x2 + + """ + name = "BackwardFoldScale" + required_pass = ["AttrToConstant", "NormElemWise"] + run_once = True + + def __init__(self): + super().__init__() + # todo : supoort more axis + self.scale_message = OrderedDict() + self.used_names = defaultdict(int) + + def run_transform(self, expr: Expr) -> Expr: + if expr not in self.scale_message: + return expr + + var = is_var().check_users(False) + mul_const_pattern = var * is_const() | var * "*" | is_op(F.neg) + add_const_pattern = var + is_const() | var + "*" + conv_pattern = is_op(F.conv2d) | is_op(M.Conv2d) + pattern = conv_pattern | add_const_pattern | mul_const_pattern + macther = PatternMatcher() + + if not macther.match(pattern, expr): + return expr + macther_exprs = macther.matched_exprs + + if conv_pattern in macther_exprs: + return self.fold_conv_mul(expr) + + if mul_const_pattern in macther_exprs: + return self.fold_mul(expr) + + if add_const_pattern in macther_exprs: + return self.fold_add_mul(expr) + + return expr + + def fold_add_mul(self, expr: Expr): + if self.scale_message[expr] is None: + return expr + scale = self.scale_message[expr] + if len(expr.inputs) == 1: + const = expr.const_val[0][-1] + else: + const = get_const_value(expr.inputs[1]) + + const = const * scale + inp_node = expr.inputs[0] + graph = expr.top_graph + with graph.insert_exprs(): + add_node = inp_node + const + + graph.replace_node({expr.outputs[0]: add_node}) + graph.compile() + add_node.name = expr.outputs[0].name + return add_node.expr + + def fold_mul(self, expr: Expr): + if self.scale_message[expr] is None: + return expr + graph = expr.top_graph + graph.replace_node({expr.outputs[0]: expr.inputs[0]}) + graph.compile() + return expr + + def fold_conv_mul(self, expr: Expr): + graph = expr.top_graph + scale = self.scale_message[expr] + + if scale is None: + return expr + + if is_call_function(expr, F.conv2d): + named_args = expr.named_args + weight = get_const_value(named_args["weight"], named_args["weight"]) * scale + bias = get_const_value(named_args["bias"], named_args["bias"]) * scale + named_args["weight"] = weight + named_args["bias"] = bias + with graph.insert_exprs(): + out_node = F.conv2d(**named_args) + graph.replace_node({expr.outputs[0]: out_node}) + graph.compile() + out_node.name = expr.outputs[0].name + return out_node.expr + else: + mnode = expr.inputs[0] + attr_name = expr.inputs[0].expr.name + graph = expr.top_graph + if len(mnode.users) > 1: + self.used_names[mnode.qualname] += 1 + attr_name = "{}_{}".format(attr_name, self.used_names[mnode.qualname]) + logger.warning( + "{} is used {} times and its name will be reset to {}.{}".format( + mnode.qualname, len(mnode.users), graph.qualname, attr_name + ) + ) + conv_module = mnode.owner + if len(mnode.users) > 1: + conv_module = deepcopy(conv_module) + conv_module._name = None + conv_module.weight = Parameter(conv_module.weight * scale) + if conv_module.bias is not None: + conv_module.bias = Parameter(conv_module.bias * scale) + + if len(mnode.users) > 1: + self_node = mnode.expr.inputs[0] + assign_attr(conv_module, self_node.owner, attr_name) + with graph.insert_exprs(mnode.expr): + new_conv_node = get_subattr(self_node, attr_name) + expr.replace_inputs({mnode: new_conv_node}) + return expr + + def reset_expr_message_to_none( + self, expr: Expr, scale_message: Dict[Expr, Any], skip_exprs: Set[Expr], + ): + if expr in skip_exprs: + return + scale_message[expr] = None + if is_call_function(expr, F.conv2d) or is_call_module(expr, M.Conv2d): + return + for out_node in expr.outputs: + for user in out_node.users: + if user in scale_message: + self.reset_expr_message_to_none(user, scale_message, skip_exprs) + + def before_visit_graph(self, graph: InternalGraph): + var = is_var().check_users(False) + mul_const_pattern = var * is_const() | var * "*" | is_op(F.neg) + relu_pattern = ( + is_op(F.relu) | is_op(M.ReLU) | is_op(F.leaky_relu) | is_op(M.LeakyReLU) + ) + + # The param of conv must be const, not support dynamic conv + conv_pattern = ( + is_op(F.conv2d)(var, is_const(), is_const()) + | is_op(F.conv2d)(var, is_const()) + | is_op(M.Conv2d) + ) + + pattern = mul_const_pattern | relu_pattern | conv_pattern + for op in [ + "__add__", + F.reshape, + "reshape", + F.transpose, + "tranpose", + F.min, + "min", + F.max, + "max", + F.max_pool2d, + M.MaxPool2d, + F.avg_pool2d, + M.AvgPool2d, + F.adaptive_avg_pool2d, + M.AdaptiveAvgPool2d, + F.adaptive_max_pool2d, + M.AdaptiveMaxPool2d, + F.expand_dims, + F.concat, + "__getitem__", + ]: + pattern |= is_op(op) + + matcher = PatternMatcher() + + scale_message = OrderedDict() + mem_conv_scale_message = OrderedDict() + skip_exprs = self.init_skip_exprs(graph) + for expr in reversed(graph._exprs): + if expr in skip_exprs: + continue + + if len(expr.outputs) > 1 or not matcher.match(pattern, expr): + self.reset_expr_message_to_none(expr, scale_message, skip_exprs) + if is_call_function(expr, F.conv2d): + for user in expr.outputs[0].users: + self.reset_expr_message_to_none(user, scale_message, skip_exprs) + continue + + matched_exprs = matcher.matched_exprs + + const = None + if mul_const_pattern in matched_exprs: + if is_call_function(expr, F.neg): + const = -1 + elif len(expr.inputs) == 1: + const = expr.const_val[0][-1] + else: + const = get_const_value(expr.inputs[1]) + + if isinstance(const, Tensor) and const._tuple_shape not in [(1,), tuple()]: + self.reset_expr_message_to_none(expr, scale_message, skip_exprs) + continue + + users_const = [ + scale_message[e] for e in expr.outputs[0].users if e not in skip_exprs + ] + + if len(users_const) == 0: + scale_message[expr] = const + continue + + if any(c is None or c != users_const[0] for c in users_const): + self.reset_expr_message_to_none(expr, scale_message, skip_exprs) + scale_message[expr] = const + continue + + const = 1 if const is None else const + const = const * users_const[0] + if relu_pattern in matched_exprs and const < 0: + self.reset_expr_message_to_none(expr, scale_message, skip_exprs) + continue + + if conv_pattern in matched_exprs: + self.reset_expr_message_to_none(expr, scale_message, skip_exprs) + mem_conv_scale_message[expr] = const + continue + + scale_message[expr] = const + + self.scale_message.update(scale_message) + self.scale_message.update(mem_conv_scale_message) + + def init_skip_exprs(self, graph: InternalGraph): + skip_exprs = set() + for expr in graph._exprs: + if is_apply_def(expr, GetVarShape): + skip_exprs.add(expr) + elif is_call_tensor_method(expr, "__getitem__") and expr in skip_exprs: + skip_exprs.add(expr) + elif is_getattr(expr): + skip_exprs.add(expr) + elif is_constant(expr): + skip_exprs.add(expr) + elif all(n.expr in skip_exprs for n in expr.inputs): + skip_exprs.add(expr) + return skip_exprs diff --git a/imperative/python/megengine/traced_module/_passes/fuse_pass.py b/imperative/python/megengine/traced_module/_passes/fuse_pass.py new file mode 100644 index 00000000..59822572 --- /dev/null +++ b/imperative/python/megengine/traced_module/_passes/fuse_pass.py @@ -0,0 +1,248 @@ +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + +import operator +from collections import defaultdict +from typing import Any, Callable, List + +from ... import functional as F +from ... import module as M +from ...logger import get_logger +from ...tensor import Parameter, Tensor +from ...utils.bn_fusion import fold_weight_bias +from ..expr import Expr, is_call_function +from ..utils import assign_attr, get_subattr +from .matcher import PatternMatcher +from .pass_base import BackwardPass, register_pass +from .pattern import ExprPattern, any_node, is_const, is_op, is_var +from .utils import get_const_value, register_obj + +logger = get_logger(__name__) + + +@register_pass("FuseAddMul") +class FuseAddMul(BackwardPass): + """Fold adjacent const add or mul binary operations. + + For example, the following code + + .. code-block:: + + x = x + 1 + x = 2 + x + x = x * 4 + x = x * 0.25 + + will be changed to + + .. code-block:: + + x = x + 3 + """ + + name = "FuseAddMul" + required_pass = ["NormElemWise"] + run_once = False + + def __init__(self,): + super().__init__() + + def _make_pattern(op_0, op_1) -> ExprPattern: + x = is_var().check_users(False) + if op_0 not in [operator.add, operator.mul]: + op_0 = is_op(op_0) + if op_1 not in [operator.add, operator.mul]: + op_1 = is_op(op_1) + pattern = op_0(x, is_const()) | op_0(x, "*") + pattern = op_1(pattern, is_const()) | op_1(pattern, "*") + return pattern + + self.pattern_dict = {} + + for op, func in zip([operator.add, F.pow], [self.fold_add, self.fold_pow],): + self.pattern_dict[_make_pattern(op, op)] = func + + for op_0 in [F.neg, operator.mul]: + for op_1 in [F.neg, operator.mul]: + self.pattern_dict[_make_pattern(op_0, op_1)] = self.fold_mul + + def run_transform(self, expr: Expr): + matcher = PatternMatcher() + for pattern, func in self.pattern_dict.items(): + res = matcher.match(pattern, expr) + if res: + break + if not res: + return expr + return func(expr) + + def _fold_helper(self, expr: Expr, op_c: Callable, op_t: Callable): + const_0 = self.get_const_value(expr) + # todo: support more shape + if isinstance(const_0, Tensor) and const_0._tuple_shape not in [(1,), tuple()]: + return expr + + const_1 = self.get_const_value(expr.inputs[0].expr) + if isinstance(const_1, Tensor) and const_1._tuple_shape not in [(1,), tuple()]: + return expr + + inp_node = expr.inputs[0].expr.inputs[0] + const = op_c(const_0, const_1) + graph = expr.top_graph + + if (const == 1 and op_t in [operator.pow, operator.mul]) or ( + const == 0 and op_t in [operator.add] + ): + graph.replace_node({expr.outputs[0]: inp_node}) + graph.compile() + return expr + + with expr.top_graph.insert_exprs(): + out_node = op_t(inp_node, const) + graph.replace_node({expr.outputs[0]: out_node}) + graph.compile() + return out_node.expr + + def fold_add(self, expr: Expr): + return self._fold_helper(expr, operator.add, operator.add) + + def fold_mul(self, expr): + return self._fold_helper(expr, operator.mul, operator.mul) + + def fold_pow(self, expr): + return self._fold_helper(expr, operator.mul, F.pow) + + def get_const_value(self, expr: Expr): + if is_call_function(expr, F.neg): + return -1 + if len(expr.inputs) == 2: + value = get_const_value(expr.inputs[1].expr, None) + assert value is not None, " " + return value + value = expr.const_val[0][-1] + return value + + +@register_pass("FuseConvBn") +class FuseConvBn(BackwardPass): + r"""Fuse BN layers into conv2d.""" + name = "FuseConvBn" + required_pass = ["AttrToConstant"] + run_once = True + + def __init__(self): + super().__init__() + self.used_name = defaultdict(int) + + def run_transform(self, expr: Expr): + conv_pat_0 = is_op(M.Conv2d) + conv_pat_1 = is_op(F.conv2d) + bn_pat_0 = is_op(M.BatchNorm2d)(conv_pat_0 | conv_pat_1) + bn_pat_1 = is_op(F.batch_norm) + # inp, running_mean, running_var, weight, bias + bn_inps = ( + conv_pat_0 | conv_pat_1, + is_const(), + is_const(), + is_const(), + is_const(), + ) + bn_pat = ( + (bn_pat_1(*bn_inps[:3])) + | (bn_pat_1(*bn_inps[:4])) + | (bn_pat_1(*bn_inps)) + | bn_pat_0 + ) + + matcher = PatternMatcher() + if not matcher.match(bn_pat, expr): + return expr + + matched_exprs = matcher.matched_exprs + if conv_pat_0 in matched_exprs: + return self.fold_convm_bn(matched_exprs[conv_pat_0], matched_exprs[bn_pat]) + else: + return self.fold_convf_bn(matched_exprs[conv_pat_1], matched_exprs[bn_pat]) + + def fold_convm_bn(self, conv: Expr, bn: Expr): + mnode, inp_node = conv.inputs[:2] + self_node = mnode.expr.inputs[0] + attr_name = conv.inputs[0].expr.name + graph = conv.top_graph + if len(mnode.users) > 1: + self.used_name[mnode.qualname] += 1 + attr_name = "{}_{}".format(attr_name, self.used_name[mnode.qualname]) + logger.warning( + "{} is used {} times and its name will be reset to {}.{}".format( + mnode.qualname, len(mnode.users), graph.qualname, attr_name + ) + ) + + conv_module = mnode.owner + weight, bias = conv_module.weight, conv_module.bias + mean, var, gamma, beta, eps = self.get_bn_params(bn) + weight, bias = fold_weight_bias(weight, bias, gamma, beta, mean, var, eps) + new_conv = M.Conv2d( + in_channels=conv_module.in_channels, + out_channels=conv_module.out_channels, + kernel_size=conv_module.kernel_size, + stride=conv_module.stride, + padding=conv_module.padding, + dilation=conv_module.dilation, + groups=conv_module.groups, + bias=conv_module.bias is not None, + conv_mode=conv_module.conv_mode, + compute_mode=conv_module.compute_mode, + name=conv_module.name, + ) + new_conv.weight = Parameter(weight) + new_conv.bias = Parameter(bias) + new_conv.training = conv_module.training + assign_attr(new_conv, self_node.owner, attr_name) + with graph.insert_exprs(mnode.expr): + out_node = get_subattr(self_node, attr_name)(inp_node) + graph.replace_node({bn.outputs[0]: out_node}) + graph.compile() + out_node.name = conv.outputs[0].name + return out_node.expr + + def fold_convf_bn(self, conv: Expr, bn: Expr): + named_args = conv.named_args + weight = get_const_value(named_args["weight"], named_args["weight"]) + bias = get_const_value(named_args["bias"], named_args["bias"]) + mean, var, gamma, beta, eps = self.get_bn_params(bn) + weight, bias = fold_weight_bias(weight, bias, gamma, beta, mean, var, eps) + named_args["weight"] = weight + named_args["bias"] = bias + graph = conv.top_graph + with graph.insert_exprs(): + out_node = F.conv2d(**named_args) + graph.replace_node({bn.outputs[0]: out_node}) + graph.compile() + out_node.name = conv.outputs[0].name + return out_node.expr + + def get_bn_params(self, bn: Expr): + if is_call_function(bn): + named_args = bn.named_args + mean = get_const_value( + named_args["running_mean"], named_args["running_mean"] + ) + var = get_const_value(named_args["running_var"], named_args["running_var"]) + gamma = get_const_value(named_args["weight"], named_args["weight"]) + beta = get_const_value(named_args["bias"], named_args["bias"]) + eps = named_args["eps"] + return mean, var, gamma, beta, eps + else: + bn_module = bn.inputs[0].owner + mean = bn_module.running_mean + var = bn_module.running_var + gamma = bn_module.weight + beta = bn_module.bias + eps = bn_module.eps + return mean, var, gamma, beta, eps diff --git a/imperative/python/megengine/traced_module/_passes/pass_base.py b/imperative/python/megengine/traced_module/_passes/pass_base.py new file mode 100644 index 00000000..735d7fae --- /dev/null +++ b/imperative/python/megengine/traced_module/_passes/pass_base.py @@ -0,0 +1,190 @@ +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + +import copy +from abc import abstractmethod +from collections import OrderedDict, namedtuple +from functools import partial +from re import T +from typing import Any, Callable, Dict, Iterable, List, Union + +from ...logger import get_logger +from ..expr import Expr +from ..traced_module import InternalGraph, TracedModule +from .utils import register_obj + +logger = get_logger(__name__) + + +class PassContext: + def __init__( + self, disabled_pass: Iterable[str] = None, pass_config: Dict[str, Any] = None + ): + self._disabled_pass = set() + self._config = pass_config + self._handle = None + if disabled_pass: + self.add_diabled_pass(disabled_pass) + + def add_diabled_pass(self, passes: Iterable[str]): + if isinstance(passes, str): + passes = [passes] + for pas in passes: + self._disabled_pass.add(pas) + + def pass_enabled(self, pas: Union["BasePass", str]): + pass_name = pas.name if isinstance(pas, BasePass) else pas + return pass_name not in self._disabled_pass + + +_default_context = PassContext() + + +def get_default_pass_context(): + return _default_context + + +_pass_dict = OrderedDict() +register_pass = partial(register_obj, _dict=_pass_dict) + + +def get_registered_pass(pass_name: str): + pas = _pass_dict.get(pass_name, None) + assert ( + pas is not None + ), "{} is not found, please call `register_pass` to register it".format(pass_name) + return pas + + +class BasePass: + run_once = True # bool + required_pass = [] # Iterable[str] + name = "" # str + + def __init__(self): + super().__init__() + + def __call__( + self, mod: TracedModule, pass_ctx: PassContext = get_default_pass_context() + ) -> TracedModule: + assert isinstance(pass_ctx, PassContext) + return self.apply_optimization(mod, pass_ctx) + + def apply_optimization( + self, mod: TracedModule, pass_ctx: PassContext + ) -> TracedModule: + new_mod = mod + for pass_name in self.required_pass + [self.name]: + if not pass_ctx.pass_enabled(pass_name): + logger.warning( + "Since {} is disabled, {} will skipped".format(pass_name, self.name) + ) + return mod + + for pass_name in self.required_pass: + pass_func = get_registered_pass(pass_name)() + new_mod = pass_func(new_mod, pass_ctx) + + iter_num = 1 + graph_changed = self.visit_graph(new_mod.graph) + while not self.run_once and graph_changed: + graph_changed = self.visit_graph(new_mod.graph) + iter_num += 1 + if iter_num == 100: + break + assert iter_num < 100, "{} was run 100 times, plase check for pass conflict." + + return new_mod + + @abstractmethod + def visit_graph(self, graph: InternalGraph): + raise NotImplementedError + + def before_visit_graph(self, graph: InternalGraph): + pass + + def run_transform(self, expr: Expr) -> Expr: + return expr + + def __repr__(self) -> str: + return self.name + + +class ForwardPass(BasePass): + def visit_graph(self, graph: InternalGraph): + class Item: + def __init__(self, expr: Expr, child_expanded: bool = False): + self.expr = expr + self.child_expanded = child_expanded + + self.before_visit_graph(graph) + graph_changed = False + queue = [Item(n.expr) for n in graph.outputs] + visited_expr, visited_graph = set(), set() + while queue: + item = queue[-1] + if item.expr in visited_expr: + queue.pop() + elif item.child_expanded: + if item.expr not in graph._exprs: + queue.pop() + continue + new_expr = self.run_transform(item.expr) + if new_expr is not item.expr: + graph_changed = True + assert new_expr not in visited_expr + queue.append(Item(new_expr)) + continue + if ( + hasattr(item.expr, "graph") + and item.expr.graph is not None + and item.expr.graph not in visited_graph + ): + graph_changed |= self.visit_graph(item.expr.graph) + visited_graph.add(item.expr.graph) + visited_expr.add(item.expr) + else: + item.child_expanded = True + for i in item.expr.inputs: + expr = i.expr + if expr not in queue and expr not in visited_expr: + queue.append(Item(expr)) + return graph_changed + + +class BackwardPass(BasePass): + def visit_graph(self, graph: InternalGraph): + self.before_visit_graph(graph) + graph_changed = False + queue = [n.expr for n in graph.outputs] + visited_expr, visited_graph = set(), set() + while queue: + expr = queue.pop() + if expr not in graph._exprs: + continue + new_expr = self.run_transform(expr) + if new_expr is not expr: + graph_changed = True + queue.append(new_expr) + continue + else: + visited_expr.add(expr) + + if ( + hasattr(expr, "graph") + and expr.graph is not None + and expr.graph not in visited_graph + ): + graph_changed |= self.visit_graph(expr.graph) + visited_graph.add(expr.graph) + + for i in expr.inputs: + expr = i.expr + if expr not in queue and expr not in visited_expr: + queue.append(expr) + return graph_changed diff --git a/imperative/python/megengine/traced_module/expr.py b/imperative/python/megengine/traced_module/expr.py index fcd49dfe..b7e7c077 100644 --- a/imperative/python/megengine/traced_module/expr.py +++ b/imperative/python/megengine/traced_module/expr.py @@ -13,7 +13,7 @@ import inspect import re import weakref from importlib import import_module -from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union +from typing import Callable, Dict, Iterable, List, Optional, Sequence, Union from ..core._imperative_rt import OpDef from ..core._imperative_rt.core2 import Tensor as RawTensor @@ -50,20 +50,30 @@ def get_suffix_name(prefix: str, name: str): return matchd.group(1) -def is_call_module(expr): +def is_call_module(expr, module_cls: Module = None): return ( isinstance(expr, CallMethod) and isinstance(expr.inputs[0], ModuleNode) and expr.method == "__call__" - ) + ) and (module_cls is None or isinstance(expr.inputs[0].owner, module_cls)) -def is_call_tensor_method(expr): - return isinstance(expr, CallMethod) and not is_call_module(expr) +def is_call_tensor_method(expr, method: Iterable[str] = None): + if method and isinstance(method, str): + method = (method,) + return ( + isinstance(expr, CallMethod) + and not is_call_module(expr) + and (method is None or any(expr.method == f for f in method)) + ) -def is_call_function(expr): - return isinstance(expr, CallFunction) +def is_call_function(expr, func: Iterable[Callable] = None): + if func and not isinstance(func, Iterable): + func = (func,) + return isinstance(expr, CallFunction) and ( + func is None or any(expr.func == f for f in func) + ) def is_constant(expr): @@ -74,8 +84,8 @@ def is_getattr(expr): return isinstance(expr, GetAttr) -def is_apply_def(expr): - return isinstance(expr, Apply) +def is_apply_def(expr, opdef=None): + return isinstance(expr, Apply) and (opdef is None or isinstance(expr.opdef, opdef)) def is_input(expr): diff --git a/imperative/python/megengine/traced_module/node.py b/imperative/python/megengine/traced_module/node.py index e6786406..7bc15705 100644 --- a/imperative/python/megengine/traced_module/node.py +++ b/imperative/python/megengine/traced_module/node.py @@ -78,6 +78,7 @@ class Node: "The name(%s) is already in use. Please try a different one again." % (new_name) ) + graph._namespace.unassociate_name_with_obj(self) self._name = graph._namespace.create_unique_name(new_name, self) @property diff --git a/imperative/python/megengine/traced_module/utils.py b/imperative/python/megengine/traced_module/utils.py index 48094d5a..21ccb35c 100644 --- a/imperative/python/megengine/traced_module/utils.py +++ b/imperative/python/megengine/traced_module/utils.py @@ -14,6 +14,7 @@ from typing import Callable, Dict, Iterable, List, Optional, Sequence, Type, Uni from .. import get_logger from ..module import Module +from ..tensor import Parameter, Tensor logger = get_logger(__name__) @@ -301,3 +302,26 @@ class _ModuleDict(Module, MutableMapping): def forward(self): raise RuntimeError("ModuleList is not callable") + + +def assign_attr(obj: Union[Module, Tensor], module: Module, target: str): + *prefix, name = target.split(".") + for item in prefix: + module = getattr(module, item) + if not isinstance(module, Module): + raise AttributeError("`{}` is not an Module".format(item)) + setattr(module, name, obj) + + +def get_subattr(module: Module, target: str): + # todo : remove this import + from .node import ModuleNode + + if target == "": + return module + *prefix, name = target.split(".") + for item in prefix: + module = getattr(module, item) + if not isinstance(module, (Module, ModuleNode)): + raise AttributeError("`{}` is not an Module".format(item)) + return getattr(module, name) diff --git a/imperative/python/megengine/utils/bn_fusion.py b/imperative/python/megengine/utils/bn_fusion.py new file mode 100644 index 00000000..41f08055 --- /dev/null +++ b/imperative/python/megengine/utils/bn_fusion.py @@ -0,0 +1,86 @@ +from copy import deepcopy + +from ..functional import ones, sqrt, zeros +from ..module import BatchNorm2d, Conv2d, ConvBn2d, ConvBnRelu2d, ConvRelu2d, ReLU +from ..tensor import Parameter + +_MAP_TO_FUSED_MODULE = { + (Conv2d, BatchNorm2d, ReLU, False): ConvRelu2d, + (Conv2d, BatchNorm2d, ReLU, True): ConvBnRelu2d, + (Conv2d, BatchNorm2d, False): Conv2d, + (Conv2d, BatchNorm2d, True): ConvBn2d, + (Conv2d, ReLU): ConvRelu2d, +} + + +def fold_weight_bias(weight, bias, gamma, beta, bn_mean, bn_var, eps=1e-5): + # get fold bn conv param + kernel_shape = weight.shape + if len(kernel_shape) == 5: + groups, num_features = kernel_shape[0], kernel_shape[1] + else: + groups, num_features = 1, kernel_shape[0] + + if gamma is None: + gamma = ones((num_features), dtype="float32") + gamma = gamma.reshape(1, -1, 1, 1) + if beta is None: + beta = zeros((num_features), dtype="float32") + beta = beta.reshape(1, -1, 1, 1) + + if bn_mean is None: + bn_mean = zeros((1, num_features, 1, 1), dtype="float32") + if bn_var is None: + bn_var = ones((1, num_features, 1, 1), dtype="float32") + + if bias is None: + bias = zeros((1, num_features, 1, 1), dtype="float32") + + bn_istd = 1.0 / sqrt(bn_var + eps) + scale_factor = gamma * bn_istd + + if groups == 1: + w_fold = weight * scale_factor.reshape(-1, 1, 1, 1) + else: + w_fold = weight * scale_factor.reshape(groups, -1, 1, 1, 1) + + b_fold = beta + gamma * (bias - bn_mean) * bn_istd + return w_fold, b_fold + + +def fuse_conv_bn_relu_module(conv: Conv2d, bn: BatchNorm2d, relu: ReLU): + module_key = tuple([type(m) for m in [conv, bn, relu] if m]) + if bn: + assert ( + conv.training == bn.training + ), "Conv and BN both must be in the same mode (train or eval)." + assert ( + bn.num_features == conv.out_channels + ), "Output channel of Conv2d must match num_features of BatchNorm2d" + module_key = module_key + (conv.training,) + module = _MAP_TO_FUSED_MODULE[module_key]( + in_channels=conv.in_channels, + out_channels=conv.out_channels, + kernel_size=conv.kernel_size, + stride=conv.stride, + padding=conv.padding, + dilation=conv.dilation, + groups=conv.groups, + bias=conv.bias is not None, + conv_mode=conv.conv_mode, + compute_mode=conv.compute_mode, + name=conv.name, + ) + new_conv = module if bn is None or not conv.training else module.conv + weight, bias = conv.weight, conv.bias + if not conv.training and bn is not None: + weight, bias = fold_weight_bias( + weight, bias, bn.weight, bn.bias, bn.running_mean, bn.running_var, bn.eps, + ) + new_conv.weight = Parameter(weight) + if bias is not None: + new_conv.bias = Parameter(bias) + if bn is not None and conv.training: + module.bn = deepcopy(bn) + new_conv.training = conv.training + return module diff --git a/imperative/python/test/unit/traced_module/test_qat_module.py b/imperative/python/test/unit/traced_module/test_qat_module.py index d4011163..6ef8764b 100644 --- a/imperative/python/test/unit/traced_module/test_qat_module.py +++ b/imperative/python/test/unit/traced_module/test_qat_module.py @@ -13,20 +13,7 @@ import megengine.quantization as Q from megengine import Tensor from megengine.module.qat.module import QATModule from megengine.traced_module import TracedModule, trace_module - - -def get_subattr(self: M.Module, name: str): - if name == "": - return self - module_path, _, name = name.rpartition(".") - if module_path == "": - return getattr(self, name) - module_names = module_path.split(".") - for item in module_names: - self = getattr(self, item) - if not isinstance(self, M.Module): - raise AttributeError("`{}` is not an Module".format(item)) - return getattr(self, name) +from megengine.traced_module.utils import get_subattr class MyConvBnRelu2d(M.ConvBnRelu2d):