@@ -8,6 +8,7 @@ | |||
from ..core._imperative_rt.core2 import set_cpp_apply_module_trace | |||
from . import compat | |||
from ._passes import optimize | |||
from .traced_module import ( | |||
TracedModule, | |||
_register_all_builtin_module, | |||
@@ -19,3 +20,11 @@ from .traced_module import ( | |||
_register_all_builtin_module() | |||
set_cpp_apply_module_trace(cpp_apply_module_trace) | |||
__all__ = { | |||
"register_as_builtin", | |||
"trace_module", | |||
"wrap", | |||
"TracedModule", | |||
"optimize", | |||
} |
@@ -0,0 +1,12 @@ | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from . import const_pass, fold_scale_pass, fuse_pass | |||
from .optimization import optimize | |||
__all__ = ["optimize"] |
@@ -0,0 +1,70 @@ | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from copy import deepcopy | |||
from typing import List, Set | |||
from ...logger import get_logger | |||
from ..traced_module import TracedModule | |||
from .pass_base import get_default_pass_context, get_registered_pass | |||
logger = get_logger(__name__) | |||
def optimize( | |||
module: TracedModule, enabled_pass: List[str] = ["FuseConvBn"], | |||
) -> TracedModule: | |||
r"""Performs a set of optimization passes to optimize a `TracedModule` for inference. | |||
The following passes are currently supported: | |||
* FuseConvBn: fuse BN layers into to conv2d | |||
* FuseAddMul: fold adjacent const add or mul binary operations | |||
* BackwardFoldScale: backward fold const scaling into weights of conv2d | |||
Args: | |||
module: the :class:`TracedModule` to be optimized. | |||
enabled_pass: optimization passes to be enabled during optimization. | |||
Default: ["FuseConvBn"] | |||
Returns: | |||
the optimized :class:`TracedModule`. | |||
""" | |||
defalut_passes_list = [ | |||
"FuseConvBn", | |||
"FuseAddMul", | |||
] | |||
if isinstance(enabled_pass, str): | |||
enabled_pass = [enabled_pass] | |||
if "BackwardFoldScale" in enabled_pass: | |||
if "FuseConvBn" not in enabled_pass: | |||
logger.warning( | |||
"Since BackwardFoldScale requires FuseConvBn" | |||
", FuseConvBn will be enabled." | |||
) | |||
enabled_pass.append("FuseConvBn") | |||
defalut_passes_list.extend( | |||
["BackwardFoldScale", "FuseAddMul",] | |||
) | |||
pass_ctx = get_default_pass_context() | |||
def run_pass(mod: TracedModule): | |||
for pass_name in defalut_passes_list: | |||
if pass_name in enabled_pass: | |||
pass_func = get_registered_pass(pass_name)() | |||
mod = pass_func(mod, pass_ctx) | |||
return mod | |||
module = deepcopy(module) | |||
module = run_pass(module) | |||
return module |
@@ -0,0 +1,106 @@ | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import types | |||
import numpy as np | |||
import pytest | |||
import megengine as mge | |||
import megengine.functional as F | |||
import megengine.module as M | |||
import megengine.traced_module as tm | |||
class myconv(M.Conv2d): | |||
pass | |||
class mybn(M.BatchNorm2d): | |||
pass | |||
class MyBlock(M.Module): | |||
def __init__(self, conv_cls, bn_cls): | |||
super().__init__() | |||
self.conv = conv_cls(3, 3, 1, 1, 0) | |||
self.bn = bn_cls(3) | |||
self.conv2 = conv_cls(3, 3, 1, 1, 0) | |||
self.bn2 = bn_cls(3) | |||
self.scale = mge.Tensor([3, 4]) | |||
def forward(self, x): | |||
x1 = self.conv(x) | |||
x1 = self.bn(x1) | |||
x1 = F.relu(x1) | |||
x1 = x1 * self.scale[0] | |||
x2 = self.conv2(x) | |||
x2 = self.bn2(x2) | |||
x2 = F.relu(x2) | |||
x2 = x2 * self.scale[1] | |||
y = x1 + x2 | |||
y = y + 4 | |||
y = self.scale[0] + y | |||
y = F.relu(y) * 3 | |||
return y | |||
class MyModule(M.Module): | |||
def __init__(self, conv_cls, bn_cls): | |||
super().__init__() | |||
self.block_0 = MyBlock(conv_cls, bn_cls) | |||
self.block_1 = MyBlock(conv_cls, bn_cls) | |||
def forward(self, x): | |||
x1 = self.block_0(x) | |||
x2 = self.block_1(x) | |||
y = x1 + x2 | |||
y = F.reshape(y, (-1)) | |||
y = y * 3 | |||
return y | |||
@pytest.mark.parametrize("conv_cls", [M.Conv2d, myconv]) | |||
@pytest.mark.parametrize("bn_cls", [M.BatchNorm2d, mybn]) | |||
def test_backward_fold_scale(conv_cls, bn_cls): | |||
module = MyModule(conv_cls, bn_cls) | |||
module.eval() | |||
inp = mge.Tensor(np.random.random((1, 3, 32, 32))) | |||
desired = module(inp) | |||
traced_net = tm.trace_module(module, inp) | |||
traced_net = traced_net.flatten() | |||
optimized_net = tm.optimize(traced_net, "BackwardFoldScale") | |||
actual = optimized_net(inp) | |||
np.testing.assert_allclose(desired=desired, actual=actual, atol=1e-4) | |||
# fuse all mul to conv | |||
mul_list = optimized_net.graph.get_method_by_type("__mul__").as_list() | |||
assert len(mul_list) == 0 | |||
@pytest.mark.parametrize("conv_cls", [M.Conv2d, myconv]) | |||
@pytest.mark.parametrize("bn_cls", [M.BatchNorm2d, mybn]) | |||
def test_fuse_bn(conv_cls, bn_cls): | |||
module = MyModule(conv_cls, bn_cls) | |||
module.eval() | |||
inp = mge.Tensor(np.random.random((1, 3, 32, 32))) | |||
desired = module(inp) | |||
traced_net = tm.trace_module(module, inp) | |||
traced_net = traced_net.flatten() | |||
optimized_net = tm.optimize(traced_net, "FuseConvBn") | |||
actual = optimized_net(inp) | |||
np.testing.assert_allclose(desired=desired, actual=actual, atol=1e-4) | |||
# fuse all mul to conv | |||
bn_list = optimized_net.graph.get_function_by_type(F.batch_norm).as_list() | |||
assert len(bn_list) == 0 | |||
bn_list = optimized_net.graph.get_module_by_type(M.BatchNorm2d).as_list() | |||
assert len(bn_list) == 0 |