@@ -8,6 +8,7 @@ | |||||
from ..core._imperative_rt.core2 import set_cpp_apply_module_trace | from ..core._imperative_rt.core2 import set_cpp_apply_module_trace | ||||
from . import compat | from . import compat | ||||
from ._passes import optimize | |||||
from .traced_module import ( | from .traced_module import ( | ||||
TracedModule, | TracedModule, | ||||
_register_all_builtin_module, | _register_all_builtin_module, | ||||
@@ -19,3 +20,11 @@ from .traced_module import ( | |||||
_register_all_builtin_module() | _register_all_builtin_module() | ||||
set_cpp_apply_module_trace(cpp_apply_module_trace) | set_cpp_apply_module_trace(cpp_apply_module_trace) | ||||
__all__ = { | |||||
"register_as_builtin", | |||||
"trace_module", | |||||
"wrap", | |||||
"TracedModule", | |||||
"optimize", | |||||
} |
@@ -0,0 +1,12 @@ | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
from . import const_pass, fold_scale_pass, fuse_pass | |||||
from .optimization import optimize | |||||
__all__ = ["optimize"] |
@@ -0,0 +1,70 @@ | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
from copy import deepcopy | |||||
from typing import List, Set | |||||
from ...logger import get_logger | |||||
from ..traced_module import TracedModule | |||||
from .pass_base import get_default_pass_context, get_registered_pass | |||||
logger = get_logger(__name__) | |||||
def optimize( | |||||
module: TracedModule, enabled_pass: List[str] = ["FuseConvBn"], | |||||
) -> TracedModule: | |||||
r"""Performs a set of optimization passes to optimize a `TracedModule` for inference. | |||||
The following passes are currently supported: | |||||
* FuseConvBn: fuse BN layers into to conv2d | |||||
* FuseAddMul: fold adjacent const add or mul binary operations | |||||
* BackwardFoldScale: backward fold const scaling into weights of conv2d | |||||
Args: | |||||
module: the :class:`TracedModule` to be optimized. | |||||
enabled_pass: optimization passes to be enabled during optimization. | |||||
Default: ["FuseConvBn"] | |||||
Returns: | |||||
the optimized :class:`TracedModule`. | |||||
""" | |||||
defalut_passes_list = [ | |||||
"FuseConvBn", | |||||
"FuseAddMul", | |||||
] | |||||
if isinstance(enabled_pass, str): | |||||
enabled_pass = [enabled_pass] | |||||
if "BackwardFoldScale" in enabled_pass: | |||||
if "FuseConvBn" not in enabled_pass: | |||||
logger.warning( | |||||
"Since BackwardFoldScale requires FuseConvBn" | |||||
", FuseConvBn will be enabled." | |||||
) | |||||
enabled_pass.append("FuseConvBn") | |||||
defalut_passes_list.extend( | |||||
["BackwardFoldScale", "FuseAddMul",] | |||||
) | |||||
pass_ctx = get_default_pass_context() | |||||
def run_pass(mod: TracedModule): | |||||
for pass_name in defalut_passes_list: | |||||
if pass_name in enabled_pass: | |||||
pass_func = get_registered_pass(pass_name)() | |||||
mod = pass_func(mod, pass_ctx) | |||||
return mod | |||||
module = deepcopy(module) | |||||
module = run_pass(module) | |||||
return module |
@@ -0,0 +1,106 @@ | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
import types | |||||
import numpy as np | |||||
import pytest | |||||
import megengine as mge | |||||
import megengine.functional as F | |||||
import megengine.module as M | |||||
import megengine.traced_module as tm | |||||
class myconv(M.Conv2d): | |||||
pass | |||||
class mybn(M.BatchNorm2d): | |||||
pass | |||||
class MyBlock(M.Module): | |||||
def __init__(self, conv_cls, bn_cls): | |||||
super().__init__() | |||||
self.conv = conv_cls(3, 3, 1, 1, 0) | |||||
self.bn = bn_cls(3) | |||||
self.conv2 = conv_cls(3, 3, 1, 1, 0) | |||||
self.bn2 = bn_cls(3) | |||||
self.scale = mge.Tensor([3, 4]) | |||||
def forward(self, x): | |||||
x1 = self.conv(x) | |||||
x1 = self.bn(x1) | |||||
x1 = F.relu(x1) | |||||
x1 = x1 * self.scale[0] | |||||
x2 = self.conv2(x) | |||||
x2 = self.bn2(x2) | |||||
x2 = F.relu(x2) | |||||
x2 = x2 * self.scale[1] | |||||
y = x1 + x2 | |||||
y = y + 4 | |||||
y = self.scale[0] + y | |||||
y = F.relu(y) * 3 | |||||
return y | |||||
class MyModule(M.Module): | |||||
def __init__(self, conv_cls, bn_cls): | |||||
super().__init__() | |||||
self.block_0 = MyBlock(conv_cls, bn_cls) | |||||
self.block_1 = MyBlock(conv_cls, bn_cls) | |||||
def forward(self, x): | |||||
x1 = self.block_0(x) | |||||
x2 = self.block_1(x) | |||||
y = x1 + x2 | |||||
y = F.reshape(y, (-1)) | |||||
y = y * 3 | |||||
return y | |||||
@pytest.mark.parametrize("conv_cls", [M.Conv2d, myconv]) | |||||
@pytest.mark.parametrize("bn_cls", [M.BatchNorm2d, mybn]) | |||||
def test_backward_fold_scale(conv_cls, bn_cls): | |||||
module = MyModule(conv_cls, bn_cls) | |||||
module.eval() | |||||
inp = mge.Tensor(np.random.random((1, 3, 32, 32))) | |||||
desired = module(inp) | |||||
traced_net = tm.trace_module(module, inp) | |||||
traced_net = traced_net.flatten() | |||||
optimized_net = tm.optimize(traced_net, "BackwardFoldScale") | |||||
actual = optimized_net(inp) | |||||
np.testing.assert_allclose(desired=desired, actual=actual, atol=1e-4) | |||||
# fuse all mul to conv | |||||
mul_list = optimized_net.graph.get_method_by_type("__mul__").as_list() | |||||
assert len(mul_list) == 0 | |||||
@pytest.mark.parametrize("conv_cls", [M.Conv2d, myconv]) | |||||
@pytest.mark.parametrize("bn_cls", [M.BatchNorm2d, mybn]) | |||||
def test_fuse_bn(conv_cls, bn_cls): | |||||
module = MyModule(conv_cls, bn_cls) | |||||
module.eval() | |||||
inp = mge.Tensor(np.random.random((1, 3, 32, 32))) | |||||
desired = module(inp) | |||||
traced_net = tm.trace_module(module, inp) | |||||
traced_net = traced_net.flatten() | |||||
optimized_net = tm.optimize(traced_net, "FuseConvBn") | |||||
actual = optimized_net(inp) | |||||
np.testing.assert_allclose(desired=desired, actual=actual, atol=1e-4) | |||||
# fuse all mul to conv | |||||
bn_list = optimized_net.graph.get_function_by_type(F.batch_norm).as_list() | |||||
assert len(bn_list) == 0 | |||||
bn_list = optimized_net.graph.get_module_by_type(M.BatchNorm2d).as_list() | |||||
assert len(bn_list) == 0 |