# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") # # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. from typing import Tuple, Union from ..functional import relu from .batchnorm import BatchNorm2d from .conv import Conv2d from .module import Module class _ConvBnActivation2d(Module): def __init__( self, in_channels: int, out_channels: int, kernel_size: Union[int, Tuple[int, int]], stride: Union[int, Tuple[int, int]] = 1, padding: Union[int, Tuple[int, int]] = 0, dilation: Union[int, Tuple[int, int]] = 1, groups: int = 1, bias: bool = True, conv_mode: str = "cross_correlation", compute_mode: str = "default", eps=1e-5, momentum=0.9, affine=True, track_running_stats=True, padding_mode: str = "zeros", **kwargs ): super().__init__(**kwargs) self.conv = Conv2d( in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, conv_mode, compute_mode, padding_mode, **kwargs, ) self.bn = BatchNorm2d(out_channels, eps, momentum, affine, track_running_stats) class ConvBn2d(_ConvBnActivation2d): r"""A fused :class:`~.Module` including :class:`~.module.Conv2d` and :class:`~.module.BatchNorm2d`. Could be replaced with :class:`~.QATModule` version :class:`~.qat.ConvBn2d` using :func:`~.quantize.quantize_qat`. """ def forward(self, inp): return self.bn(self.conv(inp)) class ConvBnRelu2d(_ConvBnActivation2d): r"""A fused :class:`~.Module` including :class:`~.module.Conv2d`, :class:`~.module.BatchNorm2d` and :func:`~.relu`. Could be replaced with :class:`~.QATModule` version :class:`~.qat.ConvBnRelu2d` using :func:`~.quantize.quantize_qat`. """ def forward(self, inp): return relu(self.bn(self.conv(inp)))