@@ -0,0 +1,183 @@ | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from ... import functional as F | |||
from ... import module as M | |||
from ...core.ops.builtin import GetVarShape | |||
from ...logger import get_logger | |||
from ...tensor import Tensor | |||
from ..expr import Constant, Expr, is_apply_def, is_constant, is_getattr | |||
from ..node import Node, TensorNode | |||
from .matcher import PatternMatcher | |||
from .pass_base import BackwardPass, ForwardPass, register_pass | |||
from .pattern import is_op | |||
from .utils import get_const_value | |||
logger = get_logger(__name__) | |||
@register_pass("AttrToConstant") | |||
class AttrToConstant(BackwardPass): | |||
r"""Convert :class:`~.GetAttr` to :class:`~.Constant` expr.""" | |||
name = "AttrToConstant" | |||
run_once = True | |||
def run_transform(self, expr: Expr): | |||
if not (is_getattr(expr) and isinstance(expr.outputs[0], TensorNode)): | |||
return expr | |||
graph = expr.top_graph | |||
value = get_const_value(expr) | |||
orig_node = expr.outputs[0] | |||
name = orig_node.name | |||
with graph.insert_exprs(expr): | |||
const_node = Constant.make(value, name=name) | |||
graph.replace_node({orig_node: const_node}) | |||
graph.compile() | |||
name = orig_node.name | |||
return const_node.expr | |||
@register_pass("FixInputShape") | |||
class FixInputShape(BackwardPass): | |||
name = "FixInputShape" | |||
run_once = True | |||
def run_transform(self, expr: Expr): | |||
if not is_apply_def(expr, GetVarShape): | |||
return expr | |||
shape = Tensor(expr.inputs[0].shape, dtype="int32") | |||
graph = expr.top_graph | |||
with graph.insert_exprs(expr): | |||
const_shape = Constant.make(shape) | |||
graph.replace_node({expr.outputs[0]: const_shape}) | |||
graph.compile() | |||
const_shape.name = expr.outputs[0].name | |||
return const_shape.expr | |||
@register_pass("FlodConstant") | |||
class FlodConstant(ForwardPass): | |||
r"""Constant folding.""" | |||
name = "FlodConstant" | |||
required_pass = ["AttrToConstant"] | |||
run_once = False | |||
def run_transform(self, expr: Expr): | |||
if len(expr.inputs) == 0 or any(not is_constant(n.expr) for n in expr.inputs): | |||
return expr | |||
const_var = expr.interpret(*[get_const_value(n.expr) for n in expr.inputs])[0] | |||
graph = expr.top_graph | |||
with graph.insert_exprs(expr): | |||
const_node = Constant.make(const_var) | |||
graph.replace_node({expr.outputs[0]: const_node}) | |||
graph.compile() | |||
const_node.name = expr.outputs[0].name | |||
return const_node.expr | |||
@register_pass("NormElemWise") | |||
class NormElemWise(BackwardPass): | |||
r"""Transform add/sub or mul/div expr to add-only or mul-only chains. | |||
For example, the following code | |||
.. code-block:: | |||
b = 1 - a | |||
c = 2 * b | |||
d = 1 / c | |||
will be changed to | |||
.. code-block:: | |||
a1 = F.neg(a) | |||
b = a1 + 1 | |||
c = b * 2 | |||
d = F.pow(d, -1) | |||
""" | |||
name = "NormElemWise" | |||
required_pass = ["FlodConstant"] | |||
run_once = False | |||
def __init__(self,): | |||
super().__init__() | |||
self.pattern = is_op(F.add) | |||
for op in [F.sub, F.mul, F.div]: | |||
self.pattern |= is_op(op) | |||
for op in ["__add__", "__iadd__", "__radd__"]: | |||
self.pattern |= is_op(op) | |||
for op in ["__sub__", "__isub__", "__rsub__"]: | |||
self.pattern |= is_op(op) | |||
for op in ["__mul__", "__imul__", "__rmul__"]: | |||
self.pattern |= is_op(op) | |||
for op in ["__truediv__", "__itruediv__", "__rtruediv__"]: | |||
self.pattern |= is_op(op) | |||
def run_transform(self, expr: Expr): | |||
matcher = PatternMatcher() | |||
if not matcher.match(self.pattern, expr): | |||
return expr | |||
pattern = matcher.matched_patterns[0] | |||
target = pattern.target | |||
cofee, left_node, right_node = 1, None, None | |||
if len(expr.inputs) == 1 and target not in ["__add__", "__mul__"]: | |||
left_node = expr.inputs[0] | |||
right_node = expr.const_val[0][-1] | |||
if target in ["__rsub__", "__rtruediv__"]: | |||
cofee = -1 | |||
if target in [F.sub, F.div] and left_node is not expr.kwargs["x"]: | |||
cofee = -1 | |||
elif len(expr.inputs) == 2 and ( | |||
target not in ["__add__", "__mul__"] or is_constant(expr.inputs[0].expr) | |||
): | |||
left_node, right_node = expr.inputs | |||
if target in ["__rsub__", "__rtruediv__"]: | |||
left_node, right_node = right_node, left_node | |||
if target in [F.sub, F.div] and left_node is not expr.kwargs["x"]: | |||
left_node, right_node = right_node, left_node | |||
if is_constant(left_node.expr): | |||
left_node, right_node = right_node, left_node | |||
cofee = -1 | |||
if left_node is None: | |||
return expr | |||
if isinstance(right_node, TensorNode): | |||
right_node = get_const_value(right_node.expr, right_node) | |||
graph = expr.top_graph | |||
with graph.insert_exprs(): | |||
if target in ["__mul__", "__imul__", "__rmul__", F.mul]: | |||
out_node = left_node * right_node | |||
elif target in ["__add__", "__iadd__", "__radd__", F.add]: | |||
out_node = left_node + right_node | |||
elif target in ["__sub__", "__isub__", "__rsub__", F.sub]: | |||
if cofee == -1: | |||
left_node = F.neg(left_node) | |||
else: | |||
if isinstance(right_node, TensorNode): | |||
right_node = F.neg(right_node) | |||
else: | |||
right_node = -1 * right_node | |||
out_node = left_node + right_node | |||
elif target in ["__truediv__", "__itruediv__", "__rtruediv__", F.div]: | |||
if cofee == -1: | |||
left_node = F.pow(left_node, -1) | |||
else: | |||
if isinstance(right_node, TensorNode): | |||
right_node = F.pow(right_node, -1) | |||
else: | |||
right_node = 1 / right_node | |||
out_node = left_node * right_node | |||
graph.replace_node({expr.outputs[0]: out_node}) | |||
graph.compile() | |||
return out_node.expr |
@@ -0,0 +1,298 @@ | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from collections import OrderedDict, defaultdict | |||
from copy import deepcopy | |||
from typing import Any, Dict, List, Set | |||
from ... import functional as F | |||
from ... import module as M | |||
from ...core.ops.builtin import GetVarShape | |||
from ...logger import get_logger | |||
from ...tensor import Parameter, Tensor | |||
from ..expr import ( | |||
Expr, | |||
is_apply_def, | |||
is_call_function, | |||
is_call_module, | |||
is_call_tensor_method, | |||
is_constant, | |||
is_getattr, | |||
) | |||
from ..traced_module import InternalGraph | |||
from ..utils import assign_attr, get_subattr | |||
from .matcher import PatternMatcher | |||
from .pass_base import BackwardPass, register_pass | |||
from .pattern import is_const, is_op, is_var | |||
from .utils import get_const_value | |||
logger = get_logger(__name__) | |||
@register_pass("BackwardFoldScale") | |||
class BackwardFoldScale(BackwardPass): | |||
r"""Backward fold const scaling into weights of conv2d. | |||
For example, the following code | |||
.. code-block:: | |||
x = conv(x, w, b) | |||
x = relu(x) | |||
x1 = x + 3 | |||
x2 = x + 4 | |||
y = (x1 + x2) * 3 | |||
will be changed to | |||
.. code-block:: | |||
x = conv(x, w * 3, b * 3) | |||
x = relu(x) | |||
x1 = x + 9 | |||
x2 = x + 12 | |||
y = x1 + x2 | |||
""" | |||
name = "BackwardFoldScale" | |||
required_pass = ["AttrToConstant", "NormElemWise"] | |||
run_once = True | |||
def __init__(self): | |||
super().__init__() | |||
# todo : supoort more axis | |||
self.scale_message = OrderedDict() | |||
self.used_names = defaultdict(int) | |||
def run_transform(self, expr: Expr) -> Expr: | |||
if expr not in self.scale_message: | |||
return expr | |||
var = is_var().check_users(False) | |||
mul_const_pattern = var * is_const() | var * "*" | is_op(F.neg) | |||
add_const_pattern = var + is_const() | var + "*" | |||
conv_pattern = is_op(F.conv2d) | is_op(M.Conv2d) | |||
pattern = conv_pattern | add_const_pattern | mul_const_pattern | |||
macther = PatternMatcher() | |||
if not macther.match(pattern, expr): | |||
return expr | |||
macther_exprs = macther.matched_exprs | |||
if conv_pattern in macther_exprs: | |||
return self.fold_conv_mul(expr) | |||
if mul_const_pattern in macther_exprs: | |||
return self.fold_mul(expr) | |||
if add_const_pattern in macther_exprs: | |||
return self.fold_add_mul(expr) | |||
return expr | |||
def fold_add_mul(self, expr: Expr): | |||
if self.scale_message[expr] is None: | |||
return expr | |||
scale = self.scale_message[expr] | |||
if len(expr.inputs) == 1: | |||
const = expr.const_val[0][-1] | |||
else: | |||
const = get_const_value(expr.inputs[1]) | |||
const = const * scale | |||
inp_node = expr.inputs[0] | |||
graph = expr.top_graph | |||
with graph.insert_exprs(): | |||
add_node = inp_node + const | |||
graph.replace_node({expr.outputs[0]: add_node}) | |||
graph.compile() | |||
add_node.name = expr.outputs[0].name | |||
return add_node.expr | |||
def fold_mul(self, expr: Expr): | |||
if self.scale_message[expr] is None: | |||
return expr | |||
graph = expr.top_graph | |||
graph.replace_node({expr.outputs[0]: expr.inputs[0]}) | |||
graph.compile() | |||
return expr | |||
def fold_conv_mul(self, expr: Expr): | |||
graph = expr.top_graph | |||
scale = self.scale_message[expr] | |||
if scale is None: | |||
return expr | |||
if is_call_function(expr, F.conv2d): | |||
named_args = expr.named_args | |||
weight = get_const_value(named_args["weight"], named_args["weight"]) * scale | |||
bias = get_const_value(named_args["bias"], named_args["bias"]) * scale | |||
named_args["weight"] = weight | |||
named_args["bias"] = bias | |||
with graph.insert_exprs(): | |||
out_node = F.conv2d(**named_args) | |||
graph.replace_node({expr.outputs[0]: out_node}) | |||
graph.compile() | |||
out_node.name = expr.outputs[0].name | |||
return out_node.expr | |||
else: | |||
mnode = expr.inputs[0] | |||
attr_name = expr.inputs[0].expr.name | |||
graph = expr.top_graph | |||
if len(mnode.users) > 1: | |||
self.used_names[mnode.qualname] += 1 | |||
attr_name = "{}_{}".format(attr_name, self.used_names[mnode.qualname]) | |||
logger.warning( | |||
"{} is used {} times and its name will be reset to {}.{}".format( | |||
mnode.qualname, len(mnode.users), graph.qualname, attr_name | |||
) | |||
) | |||
conv_module = mnode.owner | |||
if len(mnode.users) > 1: | |||
conv_module = deepcopy(conv_module) | |||
conv_module._name = None | |||
conv_module.weight = Parameter(conv_module.weight * scale) | |||
if conv_module.bias is not None: | |||
conv_module.bias = Parameter(conv_module.bias * scale) | |||
if len(mnode.users) > 1: | |||
self_node = mnode.expr.inputs[0] | |||
assign_attr(conv_module, self_node.owner, attr_name) | |||
with graph.insert_exprs(mnode.expr): | |||
new_conv_node = get_subattr(self_node, attr_name) | |||
expr.replace_inputs({mnode: new_conv_node}) | |||
return expr | |||
def reset_expr_message_to_none( | |||
self, expr: Expr, scale_message: Dict[Expr, Any], skip_exprs: Set[Expr], | |||
): | |||
if expr in skip_exprs: | |||
return | |||
scale_message[expr] = None | |||
if is_call_function(expr, F.conv2d) or is_call_module(expr, M.Conv2d): | |||
return | |||
for out_node in expr.outputs: | |||
for user in out_node.users: | |||
if user in scale_message: | |||
self.reset_expr_message_to_none(user, scale_message, skip_exprs) | |||
def before_visit_graph(self, graph: InternalGraph): | |||
var = is_var().check_users(False) | |||
mul_const_pattern = var * is_const() | var * "*" | is_op(F.neg) | |||
relu_pattern = ( | |||
is_op(F.relu) | is_op(M.ReLU) | is_op(F.leaky_relu) | is_op(M.LeakyReLU) | |||
) | |||
# The param of conv must be const, not support dynamic conv | |||
conv_pattern = ( | |||
is_op(F.conv2d)(var, is_const(), is_const()) | |||
| is_op(F.conv2d)(var, is_const()) | |||
| is_op(M.Conv2d) | |||
) | |||
pattern = mul_const_pattern | relu_pattern | conv_pattern | |||
for op in [ | |||
"__add__", | |||
F.reshape, | |||
"reshape", | |||
F.transpose, | |||
"tranpose", | |||
F.min, | |||
"min", | |||
F.max, | |||
"max", | |||
F.max_pool2d, | |||
M.MaxPool2d, | |||
F.avg_pool2d, | |||
M.AvgPool2d, | |||
F.adaptive_avg_pool2d, | |||
M.AdaptiveAvgPool2d, | |||
F.adaptive_max_pool2d, | |||
M.AdaptiveMaxPool2d, | |||
F.expand_dims, | |||
F.concat, | |||
"__getitem__", | |||
]: | |||
pattern |= is_op(op) | |||
matcher = PatternMatcher() | |||
scale_message = OrderedDict() | |||
mem_conv_scale_message = OrderedDict() | |||
skip_exprs = self.init_skip_exprs(graph) | |||
for expr in reversed(graph._exprs): | |||
if expr in skip_exprs: | |||
continue | |||
if len(expr.outputs) > 1 or not matcher.match(pattern, expr): | |||
self.reset_expr_message_to_none(expr, scale_message, skip_exprs) | |||
if is_call_function(expr, F.conv2d): | |||
for user in expr.outputs[0].users: | |||
self.reset_expr_message_to_none(user, scale_message, skip_exprs) | |||
continue | |||
matched_exprs = matcher.matched_exprs | |||
const = None | |||
if mul_const_pattern in matched_exprs: | |||
if is_call_function(expr, F.neg): | |||
const = -1 | |||
elif len(expr.inputs) == 1: | |||
const = expr.const_val[0][-1] | |||
else: | |||
const = get_const_value(expr.inputs[1]) | |||
if isinstance(const, Tensor) and const._tuple_shape not in [(1,), tuple()]: | |||
self.reset_expr_message_to_none(expr, scale_message, skip_exprs) | |||
continue | |||
users_const = [ | |||
scale_message[e] for e in expr.outputs[0].users if e not in skip_exprs | |||
] | |||
if len(users_const) == 0: | |||
scale_message[expr] = const | |||
continue | |||
if any(c is None or c != users_const[0] for c in users_const): | |||
self.reset_expr_message_to_none(expr, scale_message, skip_exprs) | |||
scale_message[expr] = const | |||
continue | |||
const = 1 if const is None else const | |||
const = const * users_const[0] | |||
if relu_pattern in matched_exprs and const < 0: | |||
self.reset_expr_message_to_none(expr, scale_message, skip_exprs) | |||
continue | |||
if conv_pattern in matched_exprs: | |||
self.reset_expr_message_to_none(expr, scale_message, skip_exprs) | |||
mem_conv_scale_message[expr] = const | |||
continue | |||
scale_message[expr] = const | |||
self.scale_message.update(scale_message) | |||
self.scale_message.update(mem_conv_scale_message) | |||
def init_skip_exprs(self, graph: InternalGraph): | |||
skip_exprs = set() | |||
for expr in graph._exprs: | |||
if is_apply_def(expr, GetVarShape): | |||
skip_exprs.add(expr) | |||
elif is_call_tensor_method(expr, "__getitem__") and expr in skip_exprs: | |||
skip_exprs.add(expr) | |||
elif is_getattr(expr): | |||
skip_exprs.add(expr) | |||
elif is_constant(expr): | |||
skip_exprs.add(expr) | |||
elif all(n.expr in skip_exprs for n in expr.inputs): | |||
skip_exprs.add(expr) | |||
return skip_exprs |
@@ -0,0 +1,248 @@ | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import operator | |||
from collections import defaultdict | |||
from typing import Any, Callable, List | |||
from ... import functional as F | |||
from ... import module as M | |||
from ...logger import get_logger | |||
from ...tensor import Parameter, Tensor | |||
from ...utils.bn_fusion import fold_weight_bias | |||
from ..expr import Expr, is_call_function | |||
from ..utils import assign_attr, get_subattr | |||
from .matcher import PatternMatcher | |||
from .pass_base import BackwardPass, register_pass | |||
from .pattern import ExprPattern, any_node, is_const, is_op, is_var | |||
from .utils import get_const_value, register_obj | |||
logger = get_logger(__name__) | |||
@register_pass("FuseAddMul") | |||
class FuseAddMul(BackwardPass): | |||
"""Fold adjacent const add or mul binary operations. | |||
For example, the following code | |||
.. code-block:: | |||
x = x + 1 | |||
x = 2 + x | |||
x = x * 4 | |||
x = x * 0.25 | |||
will be changed to | |||
.. code-block:: | |||
x = x + 3 | |||
""" | |||
name = "FuseAddMul" | |||
required_pass = ["NormElemWise"] | |||
run_once = False | |||
def __init__(self,): | |||
super().__init__() | |||
def _make_pattern(op_0, op_1) -> ExprPattern: | |||
x = is_var().check_users(False) | |||
if op_0 not in [operator.add, operator.mul]: | |||
op_0 = is_op(op_0) | |||
if op_1 not in [operator.add, operator.mul]: | |||
op_1 = is_op(op_1) | |||
pattern = op_0(x, is_const()) | op_0(x, "*") | |||
pattern = op_1(pattern, is_const()) | op_1(pattern, "*") | |||
return pattern | |||
self.pattern_dict = {} | |||
for op, func in zip([operator.add, F.pow], [self.fold_add, self.fold_pow],): | |||
self.pattern_dict[_make_pattern(op, op)] = func | |||
for op_0 in [F.neg, operator.mul]: | |||
for op_1 in [F.neg, operator.mul]: | |||
self.pattern_dict[_make_pattern(op_0, op_1)] = self.fold_mul | |||
def run_transform(self, expr: Expr): | |||
matcher = PatternMatcher() | |||
for pattern, func in self.pattern_dict.items(): | |||
res = matcher.match(pattern, expr) | |||
if res: | |||
break | |||
if not res: | |||
return expr | |||
return func(expr) | |||
def _fold_helper(self, expr: Expr, op_c: Callable, op_t: Callable): | |||
const_0 = self.get_const_value(expr) | |||
# todo: support more shape | |||
if isinstance(const_0, Tensor) and const_0._tuple_shape not in [(1,), tuple()]: | |||
return expr | |||
const_1 = self.get_const_value(expr.inputs[0].expr) | |||
if isinstance(const_1, Tensor) and const_1._tuple_shape not in [(1,), tuple()]: | |||
return expr | |||
inp_node = expr.inputs[0].expr.inputs[0] | |||
const = op_c(const_0, const_1) | |||
graph = expr.top_graph | |||
if (const == 1 and op_t in [operator.pow, operator.mul]) or ( | |||
const == 0 and op_t in [operator.add] | |||
): | |||
graph.replace_node({expr.outputs[0]: inp_node}) | |||
graph.compile() | |||
return expr | |||
with expr.top_graph.insert_exprs(): | |||
out_node = op_t(inp_node, const) | |||
graph.replace_node({expr.outputs[0]: out_node}) | |||
graph.compile() | |||
return out_node.expr | |||
def fold_add(self, expr: Expr): | |||
return self._fold_helper(expr, operator.add, operator.add) | |||
def fold_mul(self, expr): | |||
return self._fold_helper(expr, operator.mul, operator.mul) | |||
def fold_pow(self, expr): | |||
return self._fold_helper(expr, operator.mul, F.pow) | |||
def get_const_value(self, expr: Expr): | |||
if is_call_function(expr, F.neg): | |||
return -1 | |||
if len(expr.inputs) == 2: | |||
value = get_const_value(expr.inputs[1].expr, None) | |||
assert value is not None, " " | |||
return value | |||
value = expr.const_val[0][-1] | |||
return value | |||
@register_pass("FuseConvBn") | |||
class FuseConvBn(BackwardPass): | |||
r"""Fuse BN layers into conv2d.""" | |||
name = "FuseConvBn" | |||
required_pass = ["AttrToConstant"] | |||
run_once = True | |||
def __init__(self): | |||
super().__init__() | |||
self.used_name = defaultdict(int) | |||
def run_transform(self, expr: Expr): | |||
conv_pat_0 = is_op(M.Conv2d) | |||
conv_pat_1 = is_op(F.conv2d) | |||
bn_pat_0 = is_op(M.BatchNorm2d)(conv_pat_0 | conv_pat_1) | |||
bn_pat_1 = is_op(F.batch_norm) | |||
# inp, running_mean, running_var, weight, bias | |||
bn_inps = ( | |||
conv_pat_0 | conv_pat_1, | |||
is_const(), | |||
is_const(), | |||
is_const(), | |||
is_const(), | |||
) | |||
bn_pat = ( | |||
(bn_pat_1(*bn_inps[:3])) | |||
| (bn_pat_1(*bn_inps[:4])) | |||
| (bn_pat_1(*bn_inps)) | |||
| bn_pat_0 | |||
) | |||
matcher = PatternMatcher() | |||
if not matcher.match(bn_pat, expr): | |||
return expr | |||
matched_exprs = matcher.matched_exprs | |||
if conv_pat_0 in matched_exprs: | |||
return self.fold_convm_bn(matched_exprs[conv_pat_0], matched_exprs[bn_pat]) | |||
else: | |||
return self.fold_convf_bn(matched_exprs[conv_pat_1], matched_exprs[bn_pat]) | |||
def fold_convm_bn(self, conv: Expr, bn: Expr): | |||
mnode, inp_node = conv.inputs[:2] | |||
self_node = mnode.expr.inputs[0] | |||
attr_name = conv.inputs[0].expr.name | |||
graph = conv.top_graph | |||
if len(mnode.users) > 1: | |||
self.used_name[mnode.qualname] += 1 | |||
attr_name = "{}_{}".format(attr_name, self.used_name[mnode.qualname]) | |||
logger.warning( | |||
"{} is used {} times and its name will be reset to {}.{}".format( | |||
mnode.qualname, len(mnode.users), graph.qualname, attr_name | |||
) | |||
) | |||
conv_module = mnode.owner | |||
weight, bias = conv_module.weight, conv_module.bias | |||
mean, var, gamma, beta, eps = self.get_bn_params(bn) | |||
weight, bias = fold_weight_bias(weight, bias, gamma, beta, mean, var, eps) | |||
new_conv = M.Conv2d( | |||
in_channels=conv_module.in_channels, | |||
out_channels=conv_module.out_channels, | |||
kernel_size=conv_module.kernel_size, | |||
stride=conv_module.stride, | |||
padding=conv_module.padding, | |||
dilation=conv_module.dilation, | |||
groups=conv_module.groups, | |||
bias=conv_module.bias is not None, | |||
conv_mode=conv_module.conv_mode, | |||
compute_mode=conv_module.compute_mode, | |||
name=conv_module.name, | |||
) | |||
new_conv.weight = Parameter(weight) | |||
new_conv.bias = Parameter(bias) | |||
new_conv.training = conv_module.training | |||
assign_attr(new_conv, self_node.owner, attr_name) | |||
with graph.insert_exprs(mnode.expr): | |||
out_node = get_subattr(self_node, attr_name)(inp_node) | |||
graph.replace_node({bn.outputs[0]: out_node}) | |||
graph.compile() | |||
out_node.name = conv.outputs[0].name | |||
return out_node.expr | |||
def fold_convf_bn(self, conv: Expr, bn: Expr): | |||
named_args = conv.named_args | |||
weight = get_const_value(named_args["weight"], named_args["weight"]) | |||
bias = get_const_value(named_args["bias"], named_args["bias"]) | |||
mean, var, gamma, beta, eps = self.get_bn_params(bn) | |||
weight, bias = fold_weight_bias(weight, bias, gamma, beta, mean, var, eps) | |||
named_args["weight"] = weight | |||
named_args["bias"] = bias | |||
graph = conv.top_graph | |||
with graph.insert_exprs(): | |||
out_node = F.conv2d(**named_args) | |||
graph.replace_node({bn.outputs[0]: out_node}) | |||
graph.compile() | |||
out_node.name = conv.outputs[0].name | |||
return out_node.expr | |||
def get_bn_params(self, bn: Expr): | |||
if is_call_function(bn): | |||
named_args = bn.named_args | |||
mean = get_const_value( | |||
named_args["running_mean"], named_args["running_mean"] | |||
) | |||
var = get_const_value(named_args["running_var"], named_args["running_var"]) | |||
gamma = get_const_value(named_args["weight"], named_args["weight"]) | |||
beta = get_const_value(named_args["bias"], named_args["bias"]) | |||
eps = named_args["eps"] | |||
return mean, var, gamma, beta, eps | |||
else: | |||
bn_module = bn.inputs[0].owner | |||
mean = bn_module.running_mean | |||
var = bn_module.running_var | |||
gamma = bn_module.weight | |||
beta = bn_module.bias | |||
eps = bn_module.eps | |||
return mean, var, gamma, beta, eps |
@@ -0,0 +1,190 @@ | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import copy | |||
from abc import abstractmethod | |||
from collections import OrderedDict, namedtuple | |||
from functools import partial | |||
from re import T | |||
from typing import Any, Callable, Dict, Iterable, List, Union | |||
from ...logger import get_logger | |||
from ..expr import Expr | |||
from ..traced_module import InternalGraph, TracedModule | |||
from .utils import register_obj | |||
logger = get_logger(__name__) | |||
class PassContext: | |||
def __init__( | |||
self, disabled_pass: Iterable[str] = None, pass_config: Dict[str, Any] = None | |||
): | |||
self._disabled_pass = set() | |||
self._config = pass_config | |||
self._handle = None | |||
if disabled_pass: | |||
self.add_diabled_pass(disabled_pass) | |||
def add_diabled_pass(self, passes: Iterable[str]): | |||
if isinstance(passes, str): | |||
passes = [passes] | |||
for pas in passes: | |||
self._disabled_pass.add(pas) | |||
def pass_enabled(self, pas: Union["BasePass", str]): | |||
pass_name = pas.name if isinstance(pas, BasePass) else pas | |||
return pass_name not in self._disabled_pass | |||
_default_context = PassContext() | |||
def get_default_pass_context(): | |||
return _default_context | |||
_pass_dict = OrderedDict() | |||
register_pass = partial(register_obj, _dict=_pass_dict) | |||
def get_registered_pass(pass_name: str): | |||
pas = _pass_dict.get(pass_name, None) | |||
assert ( | |||
pas is not None | |||
), "{} is not found, please call `register_pass` to register it".format(pass_name) | |||
return pas | |||
class BasePass: | |||
run_once = True # bool | |||
required_pass = [] # Iterable[str] | |||
name = "" # str | |||
def __init__(self): | |||
super().__init__() | |||
def __call__( | |||
self, mod: TracedModule, pass_ctx: PassContext = get_default_pass_context() | |||
) -> TracedModule: | |||
assert isinstance(pass_ctx, PassContext) | |||
return self.apply_optimization(mod, pass_ctx) | |||
def apply_optimization( | |||
self, mod: TracedModule, pass_ctx: PassContext | |||
) -> TracedModule: | |||
new_mod = mod | |||
for pass_name in self.required_pass + [self.name]: | |||
if not pass_ctx.pass_enabled(pass_name): | |||
logger.warning( | |||
"Since {} is disabled, {} will skipped".format(pass_name, self.name) | |||
) | |||
return mod | |||
for pass_name in self.required_pass: | |||
pass_func = get_registered_pass(pass_name)() | |||
new_mod = pass_func(new_mod, pass_ctx) | |||
iter_num = 1 | |||
graph_changed = self.visit_graph(new_mod.graph) | |||
while not self.run_once and graph_changed: | |||
graph_changed = self.visit_graph(new_mod.graph) | |||
iter_num += 1 | |||
if iter_num == 100: | |||
break | |||
assert iter_num < 100, "{} was run 100 times, plase check for pass conflict." | |||
return new_mod | |||
@abstractmethod | |||
def visit_graph(self, graph: InternalGraph): | |||
raise NotImplementedError | |||
def before_visit_graph(self, graph: InternalGraph): | |||
pass | |||
def run_transform(self, expr: Expr) -> Expr: | |||
return expr | |||
def __repr__(self) -> str: | |||
return self.name | |||
class ForwardPass(BasePass): | |||
def visit_graph(self, graph: InternalGraph): | |||
class Item: | |||
def __init__(self, expr: Expr, child_expanded: bool = False): | |||
self.expr = expr | |||
self.child_expanded = child_expanded | |||
self.before_visit_graph(graph) | |||
graph_changed = False | |||
queue = [Item(n.expr) for n in graph.outputs] | |||
visited_expr, visited_graph = set(), set() | |||
while queue: | |||
item = queue[-1] | |||
if item.expr in visited_expr: | |||
queue.pop() | |||
elif item.child_expanded: | |||
if item.expr not in graph._exprs: | |||
queue.pop() | |||
continue | |||
new_expr = self.run_transform(item.expr) | |||
if new_expr is not item.expr: | |||
graph_changed = True | |||
assert new_expr not in visited_expr | |||
queue.append(Item(new_expr)) | |||
continue | |||
if ( | |||
hasattr(item.expr, "graph") | |||
and item.expr.graph is not None | |||
and item.expr.graph not in visited_graph | |||
): | |||
graph_changed |= self.visit_graph(item.expr.graph) | |||
visited_graph.add(item.expr.graph) | |||
visited_expr.add(item.expr) | |||
else: | |||
item.child_expanded = True | |||
for i in item.expr.inputs: | |||
expr = i.expr | |||
if expr not in queue and expr not in visited_expr: | |||
queue.append(Item(expr)) | |||
return graph_changed | |||
class BackwardPass(BasePass): | |||
def visit_graph(self, graph: InternalGraph): | |||
self.before_visit_graph(graph) | |||
graph_changed = False | |||
queue = [n.expr for n in graph.outputs] | |||
visited_expr, visited_graph = set(), set() | |||
while queue: | |||
expr = queue.pop() | |||
if expr not in graph._exprs: | |||
continue | |||
new_expr = self.run_transform(expr) | |||
if new_expr is not expr: | |||
graph_changed = True | |||
queue.append(new_expr) | |||
continue | |||
else: | |||
visited_expr.add(expr) | |||
if ( | |||
hasattr(expr, "graph") | |||
and expr.graph is not None | |||
and expr.graph not in visited_graph | |||
): | |||
graph_changed |= self.visit_graph(expr.graph) | |||
visited_graph.add(expr.graph) | |||
for i in expr.inputs: | |||
expr = i.expr | |||
if expr not in queue and expr not in visited_expr: | |||
queue.append(expr) | |||
return graph_changed |
@@ -13,7 +13,7 @@ import inspect | |||
import re | |||
import weakref | |||
from importlib import import_module | |||
from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union | |||
from typing import Callable, Dict, Iterable, List, Optional, Sequence, Union | |||
from ..core._imperative_rt import OpDef | |||
from ..core._imperative_rt.core2 import Tensor as RawTensor | |||
@@ -50,20 +50,30 @@ def get_suffix_name(prefix: str, name: str): | |||
return matchd.group(1) | |||
def is_call_module(expr): | |||
def is_call_module(expr, module_cls: Module = None): | |||
return ( | |||
isinstance(expr, CallMethod) | |||
and isinstance(expr.inputs[0], ModuleNode) | |||
and expr.method == "__call__" | |||
) | |||
) and (module_cls is None or isinstance(expr.inputs[0].owner, module_cls)) | |||
def is_call_tensor_method(expr): | |||
return isinstance(expr, CallMethod) and not is_call_module(expr) | |||
def is_call_tensor_method(expr, method: Iterable[str] = None): | |||
if method and isinstance(method, str): | |||
method = (method,) | |||
return ( | |||
isinstance(expr, CallMethod) | |||
and not is_call_module(expr) | |||
and (method is None or any(expr.method == f for f in method)) | |||
) | |||
def is_call_function(expr): | |||
return isinstance(expr, CallFunction) | |||
def is_call_function(expr, func: Iterable[Callable] = None): | |||
if func and not isinstance(func, Iterable): | |||
func = (func,) | |||
return isinstance(expr, CallFunction) and ( | |||
func is None or any(expr.func == f for f in func) | |||
) | |||
def is_constant(expr): | |||
@@ -74,8 +84,8 @@ def is_getattr(expr): | |||
return isinstance(expr, GetAttr) | |||
def is_apply_def(expr): | |||
return isinstance(expr, Apply) | |||
def is_apply_def(expr, opdef=None): | |||
return isinstance(expr, Apply) and (opdef is None or isinstance(expr.opdef, opdef)) | |||
def is_input(expr): | |||
@@ -78,6 +78,7 @@ class Node: | |||
"The name(%s) is already in use. Please try a different one again." | |||
% (new_name) | |||
) | |||
graph._namespace.unassociate_name_with_obj(self) | |||
self._name = graph._namespace.create_unique_name(new_name, self) | |||
@property | |||
@@ -14,6 +14,7 @@ from typing import Callable, Dict, Iterable, List, Optional, Sequence, Type, Uni | |||
from .. import get_logger | |||
from ..module import Module | |||
from ..tensor import Parameter, Tensor | |||
logger = get_logger(__name__) | |||
@@ -301,3 +302,26 @@ class _ModuleDict(Module, MutableMapping): | |||
def forward(self): | |||
raise RuntimeError("ModuleList is not callable") | |||
def assign_attr(obj: Union[Module, Tensor], module: Module, target: str): | |||
*prefix, name = target.split(".") | |||
for item in prefix: | |||
module = getattr(module, item) | |||
if not isinstance(module, Module): | |||
raise AttributeError("`{}` is not an Module".format(item)) | |||
setattr(module, name, obj) | |||
def get_subattr(module: Module, target: str): | |||
# todo : remove this import | |||
from .node import ModuleNode | |||
if target == "": | |||
return module | |||
*prefix, name = target.split(".") | |||
for item in prefix: | |||
module = getattr(module, item) | |||
if not isinstance(module, (Module, ModuleNode)): | |||
raise AttributeError("`{}` is not an Module".format(item)) | |||
return getattr(module, name) |
@@ -0,0 +1,86 @@ | |||
from copy import deepcopy | |||
from ..functional import ones, sqrt, zeros | |||
from ..module import BatchNorm2d, Conv2d, ConvBn2d, ConvBnRelu2d, ConvRelu2d, ReLU | |||
from ..tensor import Parameter | |||
_MAP_TO_FUSED_MODULE = { | |||
(Conv2d, BatchNorm2d, ReLU, False): ConvRelu2d, | |||
(Conv2d, BatchNorm2d, ReLU, True): ConvBnRelu2d, | |||
(Conv2d, BatchNorm2d, False): Conv2d, | |||
(Conv2d, BatchNorm2d, True): ConvBn2d, | |||
(Conv2d, ReLU): ConvRelu2d, | |||
} | |||
def fold_weight_bias(weight, bias, gamma, beta, bn_mean, bn_var, eps=1e-5): | |||
# get fold bn conv param | |||
kernel_shape = weight.shape | |||
if len(kernel_shape) == 5: | |||
groups, num_features = kernel_shape[0], kernel_shape[1] | |||
else: | |||
groups, num_features = 1, kernel_shape[0] | |||
if gamma is None: | |||
gamma = ones((num_features), dtype="float32") | |||
gamma = gamma.reshape(1, -1, 1, 1) | |||
if beta is None: | |||
beta = zeros((num_features), dtype="float32") | |||
beta = beta.reshape(1, -1, 1, 1) | |||
if bn_mean is None: | |||
bn_mean = zeros((1, num_features, 1, 1), dtype="float32") | |||
if bn_var is None: | |||
bn_var = ones((1, num_features, 1, 1), dtype="float32") | |||
if bias is None: | |||
bias = zeros((1, num_features, 1, 1), dtype="float32") | |||
bn_istd = 1.0 / sqrt(bn_var + eps) | |||
scale_factor = gamma * bn_istd | |||
if groups == 1: | |||
w_fold = weight * scale_factor.reshape(-1, 1, 1, 1) | |||
else: | |||
w_fold = weight * scale_factor.reshape(groups, -1, 1, 1, 1) | |||
b_fold = beta + gamma * (bias - bn_mean) * bn_istd | |||
return w_fold, b_fold | |||
def fuse_conv_bn_relu_module(conv: Conv2d, bn: BatchNorm2d, relu: ReLU): | |||
module_key = tuple([type(m) for m in [conv, bn, relu] if m]) | |||
if bn: | |||
assert ( | |||
conv.training == bn.training | |||
), "Conv and BN both must be in the same mode (train or eval)." | |||
assert ( | |||
bn.num_features == conv.out_channels | |||
), "Output channel of Conv2d must match num_features of BatchNorm2d" | |||
module_key = module_key + (conv.training,) | |||
module = _MAP_TO_FUSED_MODULE[module_key]( | |||
in_channels=conv.in_channels, | |||
out_channels=conv.out_channels, | |||
kernel_size=conv.kernel_size, | |||
stride=conv.stride, | |||
padding=conv.padding, | |||
dilation=conv.dilation, | |||
groups=conv.groups, | |||
bias=conv.bias is not None, | |||
conv_mode=conv.conv_mode, | |||
compute_mode=conv.compute_mode, | |||
name=conv.name, | |||
) | |||
new_conv = module if bn is None or not conv.training else module.conv | |||
weight, bias = conv.weight, conv.bias | |||
if not conv.training and bn is not None: | |||
weight, bias = fold_weight_bias( | |||
weight, bias, bn.weight, bn.bias, bn.running_mean, bn.running_var, bn.eps, | |||
) | |||
new_conv.weight = Parameter(weight) | |||
if bias is not None: | |||
new_conv.bias = Parameter(bias) | |||
if bn is not None and conv.training: | |||
module.bn = deepcopy(bn) | |||
new_conv.training = conv.training | |||
return module |
@@ -13,20 +13,7 @@ import megengine.quantization as Q | |||
from megengine import Tensor | |||
from megengine.module.qat.module import QATModule | |||
from megengine.traced_module import TracedModule, trace_module | |||
def get_subattr(self: M.Module, name: str): | |||
if name == "": | |||
return self | |||
module_path, _, name = name.rpartition(".") | |||
if module_path == "": | |||
return getattr(self, name) | |||
module_names = module_path.split(".") | |||
for item in module_names: | |||
self = getattr(self, item) | |||
if not isinstance(self, M.Module): | |||
raise AttributeError("`{}` is not an Module".format(item)) | |||
return getattr(self, name) | |||
from megengine.traced_module.utils import get_subattr | |||
class MyConvBnRelu2d(M.ConvBnRelu2d): | |||