Browse Source

feat(mge/traced_module): add optimization api

GitOrigin-RevId: eaa7402640
release-1.7
Megvii Engine Team 3 years ago
parent
commit
7a023c059a
4 changed files with 197 additions and 0 deletions
  1. +9
    -0
      imperative/python/megengine/traced_module/__init__.py
  2. +12
    -0
      imperative/python/megengine/traced_module/_passes/__init__.py
  3. +70
    -0
      imperative/python/megengine/traced_module/_passes/optimization.py
  4. +106
    -0
      imperative/python/test/unit/traced_module/test_passes.py

+ 9
- 0
imperative/python/megengine/traced_module/__init__.py View File

@@ -8,6 +8,7 @@

from ..core._imperative_rt.core2 import set_cpp_apply_module_trace
from . import compat
from ._passes import optimize
from .traced_module import (
TracedModule,
_register_all_builtin_module,
@@ -19,3 +20,11 @@ from .traced_module import (

_register_all_builtin_module()
set_cpp_apply_module_trace(cpp_apply_module_trace)

__all__ = {
"register_as_builtin",
"trace_module",
"wrap",
"TracedModule",
"optimize",
}

+ 12
- 0
imperative/python/megengine/traced_module/_passes/__init__.py View File

@@ -0,0 +1,12 @@
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

from . import const_pass, fold_scale_pass, fuse_pass
from .optimization import optimize

__all__ = ["optimize"]

+ 70
- 0
imperative/python/megengine/traced_module/_passes/optimization.py View File

@@ -0,0 +1,70 @@
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

from copy import deepcopy
from typing import List, Set

from ...logger import get_logger
from ..traced_module import TracedModule
from .pass_base import get_default_pass_context, get_registered_pass

logger = get_logger(__name__)


def optimize(
module: TracedModule, enabled_pass: List[str] = ["FuseConvBn"],
) -> TracedModule:
r"""Performs a set of optimization passes to optimize a `TracedModule` for inference.

The following passes are currently supported:

* FuseConvBn: fuse BN layers into to conv2d
* FuseAddMul: fold adjacent const add or mul binary operations
* BackwardFoldScale: backward fold const scaling into weights of conv2d

Args:
module: the :class:`TracedModule` to be optimized.
enabled_pass: optimization passes to be enabled during optimization.
Default: ["FuseConvBn"]

Returns:
the optimized :class:`TracedModule`.
"""

defalut_passes_list = [
"FuseConvBn",
"FuseAddMul",
]

if isinstance(enabled_pass, str):
enabled_pass = [enabled_pass]

if "BackwardFoldScale" in enabled_pass:
if "FuseConvBn" not in enabled_pass:
logger.warning(
"Since BackwardFoldScale requires FuseConvBn"
", FuseConvBn will be enabled."
)
enabled_pass.append("FuseConvBn")
defalut_passes_list.extend(
["BackwardFoldScale", "FuseAddMul",]
)

pass_ctx = get_default_pass_context()

def run_pass(mod: TracedModule):
for pass_name in defalut_passes_list:
if pass_name in enabled_pass:
pass_func = get_registered_pass(pass_name)()
mod = pass_func(mod, pass_ctx)
return mod

module = deepcopy(module)
module = run_pass(module)

return module

+ 106
- 0
imperative/python/test/unit/traced_module/test_passes.py View File

@@ -0,0 +1,106 @@
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

import types

import numpy as np
import pytest

import megengine as mge
import megengine.functional as F
import megengine.module as M
import megengine.traced_module as tm


class myconv(M.Conv2d):
pass


class mybn(M.BatchNorm2d):
pass


class MyBlock(M.Module):
def __init__(self, conv_cls, bn_cls):
super().__init__()
self.conv = conv_cls(3, 3, 1, 1, 0)
self.bn = bn_cls(3)
self.conv2 = conv_cls(3, 3, 1, 1, 0)
self.bn2 = bn_cls(3)
self.scale = mge.Tensor([3, 4])

def forward(self, x):
x1 = self.conv(x)
x1 = self.bn(x1)
x1 = F.relu(x1)
x1 = x1 * self.scale[0]
x2 = self.conv2(x)
x2 = self.bn2(x2)
x2 = F.relu(x2)
x2 = x2 * self.scale[1]
y = x1 + x2
y = y + 4
y = self.scale[0] + y
y = F.relu(y) * 3
return y


class MyModule(M.Module):
def __init__(self, conv_cls, bn_cls):
super().__init__()
self.block_0 = MyBlock(conv_cls, bn_cls)
self.block_1 = MyBlock(conv_cls, bn_cls)

def forward(self, x):
x1 = self.block_0(x)
x2 = self.block_1(x)
y = x1 + x2
y = F.reshape(y, (-1))
y = y * 3
return y


@pytest.mark.parametrize("conv_cls", [M.Conv2d, myconv])
@pytest.mark.parametrize("bn_cls", [M.BatchNorm2d, mybn])
def test_backward_fold_scale(conv_cls, bn_cls):
module = MyModule(conv_cls, bn_cls)
module.eval()
inp = mge.Tensor(np.random.random((1, 3, 32, 32)))
desired = module(inp)
traced_net = tm.trace_module(module, inp)

traced_net = traced_net.flatten()
optimized_net = tm.optimize(traced_net, "BackwardFoldScale")

actual = optimized_net(inp)
np.testing.assert_allclose(desired=desired, actual=actual, atol=1e-4)
# fuse all mul to conv
mul_list = optimized_net.graph.get_method_by_type("__mul__").as_list()
assert len(mul_list) == 0


@pytest.mark.parametrize("conv_cls", [M.Conv2d, myconv])
@pytest.mark.parametrize("bn_cls", [M.BatchNorm2d, mybn])
def test_fuse_bn(conv_cls, bn_cls):
module = MyModule(conv_cls, bn_cls)
module.eval()
inp = mge.Tensor(np.random.random((1, 3, 32, 32)))
desired = module(inp)
traced_net = tm.trace_module(module, inp)

traced_net = traced_net.flatten()
optimized_net = tm.optimize(traced_net, "FuseConvBn")

actual = optimized_net(inp)
np.testing.assert_allclose(desired=desired, actual=actual, atol=1e-4)
# fuse all mul to conv
bn_list = optimized_net.graph.get_function_by_type(F.batch_norm).as_list()
assert len(bn_list) == 0

bn_list = optimized_net.graph.get_module_by_type(M.BatchNorm2d).as_list()
assert len(bn_list) == 0

Loading…
Cancel
Save