Browse Source

feat(mge/module): deconv fuse bn and relu

GitOrigin-RevId: 5619b397a4
master
Megvii Engine Team 2 years ago
parent
commit
da7f250c48
11 changed files with 463 additions and 23 deletions
  1. +2
    -0
      imperative/python/megengine/module/__init__.py
  2. +9
    -0
      imperative/python/megengine/module/conv.py
  3. +62
    -0
      imperative/python/megengine/module/conv_transpose_bn.py
  4. +2
    -1
      imperative/python/megengine/module/qat/__init__.py
  5. +11
    -2
      imperative/python/megengine/module/qat/conv.py
  6. +163
    -0
      imperative/python/megengine/module/qat/conv_transpose_bn.py
  7. +2
    -1
      imperative/python/megengine/module/quantized/__init__.py
  8. +17
    -6
      imperative/python/megengine/module/quantized/conv.py
  9. +53
    -0
      imperative/python/megengine/module/quantized/conv_transpose_bn.py
  10. +84
    -10
      imperative/python/megengine/utils/bn_fusion.py
  11. +58
    -3
      imperative/python/test/unit/module/test_qat.py

+ 2
- 0
imperative/python/megengine/module/__init__.py View File

@@ -12,11 +12,13 @@ from .conv import (
ConvRelu2d, ConvRelu2d,
ConvTranspose2d, ConvTranspose2d,
ConvTranspose3d, ConvTranspose3d,
ConvTransposeRelu2d,
DeformableConv2d, DeformableConv2d,
LocalConv2d, LocalConv2d,
RegionRestrictedConv, RegionRestrictedConv,
) )
from .conv_bn import ConvBn2d, ConvBnRelu2d from .conv_bn import ConvBn2d, ConvBnRelu2d
from .conv_transpose_bn import ConvTransposeBn2d, ConvTransposeBnRelu2d
from .deformable_psroi_pooling import DeformablePSROIPooling from .deformable_psroi_pooling import DeformablePSROIPooling
from .dropout import Dropout from .dropout import Dropout
from .elemwise import Elemwise from .elemwise import Elemwise


+ 9
- 0
imperative/python/megengine/module/conv.py View File

@@ -773,6 +773,15 @@ class ConvRelu2d(Conv2d):
return relu(self.calc_conv(inp, self.weight, self.bias)) return relu(self.calc_conv(inp, self.weight, self.bias))




class ConvTransposeRelu2d(ConvTranspose2d):
r"""A fused :class:`~.Module` including :class:`~.module.ConvTranspose2d` and :func:`~.relu`.
Could be replaced with :class:`~.QATModule` version :class:`~.qat.ConvTransposeRelu2d` using :func:`~.quantize.quantize_qat`.
"""

def forward(self, inp):
return relu(self.calc_conv_transpose2d(inp, self.weight, self.bias))


class DeformableConv2d(_ConvNd): class DeformableConv2d(_ConvNd):
r"""Deformable Convolution. r"""Deformable Convolution.




+ 62
- 0
imperative/python/megengine/module/conv_transpose_bn.py View File

@@ -0,0 +1,62 @@
from typing import Tuple, Union

from ..functional import relu
from .batchnorm import BatchNorm2d
from .conv import ConvTranspose2d
from .module import Module


class _ConvTransposeBnActivation2d(Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, int]],
stride: Union[int, Tuple[int, int]] = 1,
padding: Union[int, Tuple[int, int]] = 0,
output_padding: Union[int, Tuple[int, int]] = 0,
dilation: Union[int, Tuple[int, int]] = 1,
groups: int = 1,
bias: bool = True,
conv_mode: str = "cross_correlation",
compute_mode: str = "default",
eps=1e-5,
momentum=0.9,
affine=True,
track_running_stats=True,
**kwargs
):
super().__init__(**kwargs)
self.conv_transpose2d = ConvTranspose2d(
in_channels,
out_channels,
kernel_size,
stride,
padding,
output_padding,
dilation,
groups,
bias,
conv_mode,
compute_mode,
**kwargs,
)
self.bn = BatchNorm2d(out_channels, eps, momentum, affine, track_running_stats)


class ConvTransposeBn2d(_ConvTransposeBnActivation2d):
r"""A fused :class:`~.Module` including :class:`~.module.ConvTranspose2d` and :class:`~.module.BatchNorm2d`.
Could be replaced with :class:`~.QATModule` version :class:`~.qat.ConvTransposeBn2d` using:func:`~.quantize.quantize_qat`.
"""

def forward(self, inp):
return self.bn(self.conv_transpose2d(inp))


class ConvTransposeBnRelu2d(_ConvTransposeBnActivation2d):
r"""A fused :class:`~.Module` including :class:`~.module.ConvTranspose2d`, :class:`~.module.BatchNorm2d` and :func:`~.relu`.
Could be replaced with :class:`~.QATModule` version :class:`~.qat.ConvTransposeBnRelu2d` using :func:`~.quantize.quantize_qat`.
"""

def forward(self, inp):
return relu(self.bn(self.conv_transpose2d(inp)))

+ 2
- 1
imperative/python/megengine/module/qat/__init__.py View File

@@ -1,7 +1,8 @@
from .batch_matmul_activation import BatchMatMulActivation from .batch_matmul_activation import BatchMatMulActivation
from .concat import Concat from .concat import Concat
from .conv import Conv2d, ConvRelu2d, ConvTranspose2d
from .conv import Conv2d, ConvRelu2d, ConvTranspose2d, ConvTransposeRelu2d
from .conv_bn import ConvBn2d, ConvBnRelu2d from .conv_bn import ConvBn2d, ConvBnRelu2d
from .conv_transpose_bn import ConvTransposeBn2d, ConvTransposeBnRelu2d
from .elemwise import Elemwise from .elemwise import Elemwise
from .linear import Linear from .linear import Linear
from .module import QATModule from .module import QATModule


+ 11
- 2
imperative/python/megengine/module/qat/conv.py View File

@@ -59,8 +59,8 @@ class ConvTranspose2d(Float.ConvTranspose2d, QATModule):
def calc_conv_transpose2d_qat(self, inp): def calc_conv_transpose2d_qat(self, inp):
w_qat = self.apply_quant_weight(self.weight) w_qat = self.apply_quant_weight(self.weight)
b_qat = self.apply_quant_bias(self.bias, inp, w_qat) b_qat = self.apply_quant_bias(self.bias, inp, w_qat)
conv = self.calc_conv_transpose2d(inp, w_qat, b_qat)
return conv
conv_transpose2d = self.calc_conv_transpose2d(inp, w_qat, b_qat)
return conv_transpose2d


@classmethod @classmethod
def from_float_module(cls, float_module: Float.ConvTranspose2d): def from_float_module(cls, float_module: Float.ConvTranspose2d):
@@ -88,3 +88,12 @@ class ConvTranspose2d(Float.ConvTranspose2d, QATModule):


def forward(self, inp): def forward(self, inp):
return self.apply_quant_activation(self.calc_conv_transpose2d_qat(inp)) return self.apply_quant_activation(self.calc_conv_transpose2d_qat(inp))


class ConvTransposeRelu2d(ConvTranspose2d):
r"""A :class:`~.QATModule` include :class:`~.module.ConvTranspose2d` and :func:`~.relu` with QAT support.
Could be applied with :class:`~.Observer` and :class:`~.quantization.fake_quant.FakeQuantize`.
"""

def forward(self, inp):
return self.apply_quant_activation(F.relu(self.calc_conv_transpose2d_qat(inp)))

+ 163
- 0
imperative/python/megengine/module/qat/conv_transpose_bn.py View File

@@ -0,0 +1,163 @@
from ...functional import ones, relu, sqrt, sum, zeros
from .. import conv_transpose_bn as Float
from .module import QATModule


class _ConvTransposeBnActivation2d(Float._ConvTransposeBnActivation2d, QATModule):
def get_batch_mean_var(self, inp):
def _sum_channel(inp, axis=0, keepdims=True):
if isinstance(axis, int):
out = sum(inp, axis=axis, keepdims=keepdims)
elif isinstance(axis, tuple):
for idx, elem in enumerate(axis):
out = sum(inp if idx == 0 else out, axis=elem, keepdims=keepdims)
return out

sum1 = _sum_channel(inp, (0, 2, 3))
sum2 = _sum_channel(inp ** 2, (0, 2, 3))
reduce_size = inp.size / inp.shape[1]
batch_mean = sum1 / reduce_size
batch_var = (sum2 - sum1 ** 2 / reduce_size) / reduce_size
return batch_mean, batch_var

def fold_weight_bias(self, bn_mean, bn_var):
# get fold bn conv_transpose2d param
gamma = self.bn.weight
if gamma is None:
gamma = ones((self.bn.num_features), dtype="float32")
gamma = gamma.reshape(1, -1, 1, 1)
beta = self.bn.bias
if beta is None:
beta = zeros((1, self.bn.num_features, 1, 1), dtype="float32")

if bn_mean is None:
bn_mean = zeros((1, self.bn.num_features, 1, 1), dtype="float32")
if bn_var is None:
bn_var = ones((1, self.bn.num_features, 1, 1), dtype="float32")

conv_transpose2d_bias = self.conv_transpose2d.bias
if conv_transpose2d_bias is None:
conv_transpose2d_bias = zeros(
self.conv_transpose2d._infer_bias_shape(), dtype="float32"
)

bn_istd = 1.0 / sqrt(bn_var + self.bn.eps)
scale_factor = gamma * bn_istd
if self.conv_transpose2d.groups == 1:
w_fold = self.conv_transpose2d.weight * scale_factor.reshape(-1, 1, 1, 1)
else:
w_fold = self.conv_transpose2d.weight * scale_factor.reshape(
self.conv_transpose2d.groups, -1, 1, 1, 1
)

w_fold = self.apply_quant_weight(w_fold)
b_fold = beta + gamma * (conv_transpose2d_bias - bn_mean) * bn_istd
return w_fold, b_fold

def update_running_mean_and_running_var(
self, bn_mean, bn_var, num_elements_per_channel
):
# update running mean and running var. no grad, use unbiased bn var
bn_mean = bn_mean.detach()
bn_var = (
bn_var.detach() * num_elements_per_channel / (num_elements_per_channel - 1)
)
exponential_average_factor = 1 - self.bn.momentum
self.bn.running_mean *= self.bn.momentum
self.bn.running_mean += exponential_average_factor * bn_mean
self.bn.running_var *= self.bn.momentum
self.bn.running_var += exponential_average_factor * bn_var

def calc_conv_transpose2d_bn_qat(self, inp, approx=True):
if self.training and not approx:
conv_transpose2d = self.conv_transpose2d(inp)
bn_mean, bn_var = self.get_batch_mean_var(conv_transpose2d)
num_elements_per_channel = conv_transpose2d.size / conv_transpose2d.shape[1]
self.update_running_mean_and_running_var(
bn_mean, bn_var, num_elements_per_channel
)
else:
bn_mean, bn_var = self.bn.running_mean, self.bn.running_var

# get gamma and beta in BatchNorm
gamma = self.bn.weight
if gamma is None:
gamma = ones((self.bn.num_features), dtype="float32")
gamma = gamma.reshape(1, -1, 1, 1)
beta = self.bn.bias
if beta is None:
beta = zeros((self.bn.num_features), dtype="float32")
beta = beta.reshape(1, -1, 1, 1)
# conv_transpose2d_bias
conv_transpose2d_bias = self.conv_transpose2d.bias
if conv_transpose2d_bias is None:
conv_transpose2d_bias = zeros(
self.conv_transpose2d._infer_bias_shape(), dtype="float32"
)

bn_istd = 1.0 / sqrt(bn_var + self.bn.eps)
scale_factor = gamma * bn_istd
if self.conv_transpose2d.groups == 1:
w_fold = self.conv_transpose2d.weight * scale_factor.reshape(1, -1, 1, 1)
else:
w_fold = self.conv_transpose2d.weight * scale_factor.reshape(
self.conv_transpose2d.groups, 1, -1, 1, 1
)
b_fold = None
if not (self.training and approx):
b_fold = beta + gamma * (conv_transpose2d_bias - bn_mean) * bn_istd

w_qat = self.apply_quant_weight(w_fold)
b_qat = self.apply_quant_bias(b_fold, inp, w_qat)
conv_transpose2d = self.conv_transpose2d.calc_conv_transpose2d(
inp, w_qat, b_qat
)
if not (self.training and approx):
return conv_transpose2d

# rescale conv_transpose2d to get original conv_transpose2d output
orig_conv_transpose2d = conv_transpose2d / scale_factor.reshape(1, -1, 1, 1)
if self.conv_transpose2d.bias is not None:
orig_conv_transpose2d = orig_conv_transpose2d + self.conv_transpose2d.bias
# calculate batch norm
conv_transpose2d = self.bn(orig_conv_transpose2d)
return conv_transpose2d

@classmethod
def from_float_module(cls, float_module: Float._ConvTransposeBnActivation2d):
qat_module = cls(
float_module.conv_transpose2d.in_channels,
float_module.conv_transpose2d.out_channels,
float_module.conv_transpose2d.kernel_size,
float_module.conv_transpose2d.stride,
float_module.conv_transpose2d.padding,
float_module.conv_transpose2d.output_padding,
float_module.conv_transpose2d.dilation,
float_module.conv_transpose2d.groups,
float_module.conv_transpose2d.bias is not None,
float_module.conv_transpose2d.conv_mode,
float_module.conv_transpose2d.compute_mode,
name=float_module.name,
)
qat_module.conv_transpose2d.weight = float_module.conv_transpose2d.weight
qat_module.conv_transpose2d.bias = float_module.conv_transpose2d.bias
qat_module.bn = float_module.bn
return qat_module


class ConvTransposeBn2d(_ConvTransposeBnActivation2d):
r"""A fused :class:`~.QATModule` including :class:`~.module.ConvTranspose2d` and :class:`~.module.BatchNorm2d` with QAT support.
Could be applied with :class:`~.Observer` and :class:`~.quantization.fake_quant.FakeQuantize`.
"""

def forward(self, inp):
return self.apply_quant_activation(self.calc_conv_transpose2d_bn_qat(inp))


class ConvTransposeBnRelu2d(_ConvTransposeBnActivation2d):
r"""A fused :class:`~.QATModule` including :class:`~.module.ConvTranspose2d`, :class:`~.module.BatchNorm2d` and :func:`~.relu` with QAT support.
Could be applied with :class:`~.Observer` and :class:`~.quantization.fake_quant.FakeQuantize`.
"""

def forward(self, inp):
return self.apply_quant_activation(relu(self.calc_conv_transpose2d_bn_qat(inp)))

+ 2
- 1
imperative/python/megengine/module/quantized/__init__.py View File

@@ -1,7 +1,8 @@
from .batch_matmul_activation import BatchMatMulActivation from .batch_matmul_activation import BatchMatMulActivation
from .concat import Concat from .concat import Concat
from .conv import Conv2d, ConvRelu2d, ConvTranspose2d
from .conv import Conv2d, ConvRelu2d, ConvTranspose2d, ConvTransposeRelu2d
from .conv_bn import ConvBn2d, ConvBnRelu2d from .conv_bn import ConvBn2d, ConvBnRelu2d
from .conv_transpose_bn import ConvTransposeBn2d, ConvTransposeBnRelu2d
from .elemwise import Elemwise from .elemwise import Elemwise
from .linear import Linear from .linear import Linear
from .module import QuantizedModule from .module import QuantizedModule


+ 17
- 6
imperative/python/megengine/module/quantized/conv.py View File

@@ -178,7 +178,7 @@ class ConvTranspose2d(Float.ConvTranspose2d, QuantizedModule):
:class:`~.QATModule` instance. :class:`~.QATModule` instance.
""" """
output_dtype = qat_module.get_activation_dtype() output_dtype = qat_module.get_activation_dtype()
qconv = cls(
qconv_transpose2d = cls(
qat_module.in_channels, qat_module.in_channels,
qat_module.out_channels, qat_module.out_channels,
qat_module.kernel_size, qat_module.kernel_size,
@@ -194,15 +194,19 @@ class ConvTranspose2d(Float.ConvTranspose2d, QuantizedModule):
name=qat_module.name, name=qat_module.name,
) )
weight = qat_module.weight.astype(qat_module.get_weight_dtype()) weight = qat_module.weight.astype(qat_module.get_weight_dtype())
qconv.weight = Parameter(weight.numpy(), name=qat_module.weight.name)
qconv.bias = (
qconv_transpose2d.weight = Parameter(
weight.numpy(), name=qat_module.weight.name
)
qconv_transpose2d.bias = (
Parameter(qat_module.bias.numpy(), name=qat_module.bias.name) Parameter(qat_module.bias.numpy(), name=qat_module.bias.name)
if qat_module.bias is not None if qat_module.bias is not None
else None else None
) )
return qconv
return qconv_transpose2d

def calc_conv_transpose2d_quantized(self, inp, nonlinear_mode):
assert nonlinear_mode == "identity", "nonlinear_mode shoule be 'identity'"


def calc_conv_transpose2d_quantized(self, inp):
if self.bias is not None: if self.bias is not None:
inp_scale = dtype.get_scale(inp.dtype) inp_scale = dtype.get_scale(inp.dtype)
w_scale = dtype.get_scale(self.weight.dtype) w_scale = dtype.get_scale(self.weight.dtype)
@@ -225,4 +229,11 @@ class ConvTranspose2d(Float.ConvTranspose2d, QuantizedModule):
) )


def forward(self, inp): def forward(self, inp):
return self.calc_conv_transpose2d_quantized(inp)
return self.calc_conv_transpose2d_quantized(inp, nonlinear_mode="identity")


class ConvTransposeRelu2d(ConvTranspose2d):
r"""Quantized version of :class:`~.qat.ConvTransposeRelu2d`."""

def forward(self, inp):
return self.calc_conv_transpose2d_quantized(inp, nonlinear_mode="relu")

+ 53
- 0
imperative/python/megengine/module/quantized/conv_transpose_bn.py View File

@@ -0,0 +1,53 @@
from ...tensor import Parameter
from ..qat import conv_transpose_bn as QAT
from .conv import ConvTranspose2d


class _ConvTransposeBnActivation2d(ConvTranspose2d):
r"""Applies a 2D deconvolution over a quantized input tensor, used for inference only.
"""

@classmethod
def from_qat_module(cls, qat_module: QAT._ConvTransposeBnActivation2d):
r"""
Return a :class:`~.QuantizedModule` instance converted from a
:class:`~.QATModule` instance.
"""
output_dtype = qat_module.get_activation_dtype()
qconv_transpose2d = cls(
qat_module.conv_transpose2d.in_channels,
qat_module.conv_transpose2d.out_channels,
qat_module.conv_transpose2d.kernel_size,
qat_module.conv_transpose2d.stride,
qat_module.conv_transpose2d.padding,
qat_module.conv_transpose2d.output_padding,
qat_module.conv_transpose2d.dilation,
qat_module.conv_transpose2d.groups,
dtype=output_dtype,
name=qat_module.name,
)
w_fold, b_fold = qat_module.fold_weight_bias(
qat_module.bn.running_mean, qat_module.bn.running_var
)
weight = w_fold.astype(qat_module.get_weight_dtype())
qconv_transpose2d.weight = Parameter(
weight.numpy(), name=qat_module.conv_transpose2d.weight.name
)
qconv_transpose2d.bias = Parameter(b_fold.numpy())
if qat_module.conv_transpose2d.bias is not None:
qconv_transpose2d.bias.name = qat_module.conv_transpose2d.bias.name
return qconv_transpose2d


class ConvTransposeBn2d(_ConvTransposeBnActivation2d):
r"""Quantized version of :class:`~.qat.ConvTransposeBn2d`."""

def forward(self, inp):
return self.calc_conv_transpose2d_quantized(inp, nonlinear_mode="identity")


class ConvTransposeBnRelu2d(_ConvTransposeBnActivation2d):
r"""Quantized version of :class:`~.qat.ConvTransposeBnRelu2d`."""

def forward(self, inp):
return self.calc_conv_transpose2d_quantized(inp, nonlinear_mode="relu")

+ 84
- 10
imperative/python/megengine/utils/bn_fusion.py View File

@@ -1,48 +1,70 @@
from copy import deepcopy from copy import deepcopy


from ..functional import ones, sqrt, zeros from ..functional import ones, sqrt, zeros
from ..module import BatchNorm2d, Conv2d, ConvBn2d, ConvBnRelu2d, ConvRelu2d, ReLU
from ..module import (
BatchNorm2d,
Conv2d,
ConvBn2d,
ConvBnRelu2d,
ConvRelu2d,
ConvTranspose2d,
ConvTransposeBn2d,
ConvTransposeBnRelu2d,
ConvTransposeRelu2d,
ReLU,
)
from ..tensor import Parameter from ..tensor import Parameter


_MAP_TO_FUSED_MODULE = { _MAP_TO_FUSED_MODULE = {
(Conv2d, BatchNorm2d, ReLU, False): ConvRelu2d, (Conv2d, BatchNorm2d, ReLU, False): ConvRelu2d,
(Conv2d, BatchNorm2d, ReLU, True): ConvBnRelu2d, (Conv2d, BatchNorm2d, ReLU, True): ConvBnRelu2d,
(ConvTranspose2d, BatchNorm2d, ReLU, False): ConvTransposeRelu2d,
(ConvTranspose2d, BatchNorm2d, ReLU, True): ConvTransposeBnRelu2d,
(Conv2d, BatchNorm2d, False): Conv2d, (Conv2d, BatchNorm2d, False): Conv2d,
(Conv2d, BatchNorm2d, True): ConvBn2d, (Conv2d, BatchNorm2d, True): ConvBn2d,
(Conv2d, ReLU): ConvRelu2d, (Conv2d, ReLU): ConvRelu2d,
(ConvTranspose2d, BatchNorm2d, False): ConvTranspose2d,
(ConvTranspose2d, BatchNorm2d, True): ConvTransposeBn2d,
(ConvTranspose2d, ReLU): ConvTransposeRelu2d,
} }




def fold_weight_bias(weight, bias, gamma, beta, bn_mean, bn_var, eps=1e-5):
# get fold bn conv param
def fold_weight_bias(
weight, bias, gamma, beta, bn_mean, bn_var, eps=1e-5, transpose=False
):
shape = (1, -1, 1, 1)
if transpose:
shape = (-1, 1, 1, 1)

kernel_shape = weight.shape kernel_shape = weight.shape
if len(kernel_shape) == 5: if len(kernel_shape) == 5:
groups, num_features = kernel_shape[0], kernel_shape[1] groups, num_features = kernel_shape[0], kernel_shape[1]
else: else:
groups, num_features = 1, kernel_shape[0] groups, num_features = 1, kernel_shape[0]


out_channels = groups * num_features
if gamma is None: if gamma is None:
gamma = ones((num_features), dtype="float32")
gamma = ones((out_channels,), dtype="float32")
gamma = gamma.reshape(1, -1, 1, 1) gamma = gamma.reshape(1, -1, 1, 1)
if beta is None: if beta is None:
beta = zeros((num_features), dtype="float32")
beta = zeros((out_channels,), dtype="float32")
beta = beta.reshape(1, -1, 1, 1) beta = beta.reshape(1, -1, 1, 1)


if bn_mean is None: if bn_mean is None:
bn_mean = zeros((1, num_features, 1, 1), dtype="float32")
bn_mean = zeros((1, out_channels, 1, 1), dtype="float32")
if bn_var is None: if bn_var is None:
bn_var = ones((1, num_features, 1, 1), dtype="float32")
bn_var = ones((1, out_channels, 1, 1), dtype="float32")


if bias is None: if bias is None:
bias = zeros((1, num_features, 1, 1), dtype="float32")
bias = zeros((1, out_channels, 1, 1), dtype="float32")


bn_istd = 1.0 / sqrt(bn_var + eps) bn_istd = 1.0 / sqrt(bn_var + eps)
scale_factor = gamma * bn_istd scale_factor = gamma * bn_istd


if groups == 1: if groups == 1:
w_fold = weight * scale_factor.reshape(-1, 1, 1, 1)
w_fold = weight * scale_factor.reshape(*shape)
else: else:
w_fold = weight * scale_factor.reshape(groups, -1, 1, 1, 1)
w_fold = weight * scale_factor.reshape(groups, *shape)


b_fold = beta + gamma * (bias - bn_mean) * bn_istd b_fold = beta + gamma * (bias - bn_mean) * bn_istd
return w_fold, b_fold return w_fold, b_fold
@@ -84,3 +106,55 @@ def fuse_conv_bn_relu_module(conv: Conv2d, bn: BatchNorm2d, relu: ReLU):
module.bn = deepcopy(bn) module.bn = deepcopy(bn)
new_conv.training = conv.training new_conv.training = conv.training
return module return module


def fuse_conv_transpose2d_bn_relu_module(
conv_transpose2d: ConvTranspose2d, bn: BatchNorm2d, relu: ReLU
):
module_key = tuple([type(m) for m in [conv_transpose2d, bn, relu] if m])
if bn:
assert (
conv_transpose2d.training == bn.training
), "ConvTranspose2d and BN both must be in the same mode (train or eval)."
assert (
bn.num_features == conv_transpose2d.out_channels
), "Output channel of ConvTranspose2d must match num_features of BatchNorm2d"
module_key = module_key + (conv_transpose2d.training,)
module = _MAP_TO_FUSED_MODULE[module_key](
in_channels=conv_transpose2d.in_channels,
out_channels=conv_transpose2d.out_channels,
kernel_size=conv_transpose2d.kernel_size,
stride=conv_transpose2d.stride,
padding=conv_transpose2d.padding,
output_padding=conv_transpose2d.output_padding,
dilation=conv_transpose2d.dilation,
groups=conv_transpose2d.groups,
bias=conv_transpose2d.bias is not None,
conv_mode=conv_transpose2d.conv_mode,
compute_mode=conv_transpose2d.compute_mode,
name=conv_transpose2d.name,
)
new_conv_transpose2d = (
module
if bn is None or not conv_transpose2d.training
else module.conv_transpose2d
)
weight, bias = conv_transpose2d.weight, conv_transpose2d.bias
if not conv_transpose2d.training and bn is not None:
weight, bias = fold_weight_bias(
weight,
bias,
bn.weight,
bn.bias,
bn.running_mean,
bn.running_var,
bn.eps,
transpose=False,
)
new_conv_transpose2d.weight = Parameter(weight)
if bias is not None:
new_conv_transpose2d.bias = Parameter(bias)
if bn is not None and conv_transpose2d.training:
module.bn = deepcopy(bn)
new_conv_transpose2d.training = conv_transpose2d.training
return module

+ 58
- 3
imperative/python/test/unit/module/test_qat.py View File

@@ -5,7 +5,9 @@ import numpy as np
import pytest import pytest


import megengine.utils.comp_graph_tools as cgtools import megengine.utils.comp_graph_tools as cgtools
from megengine import jit, tensor
from megengine import jit
from megengine import module as M
from megengine import tensor
from megengine.device import get_device_count from megengine.device import get_device_count
from megengine.functional import expand_dims from megengine.functional import expand_dims
from megengine.module import ( from megengine.module import (
@@ -14,6 +16,8 @@ from megengine.module import (
ConvBn2d, ConvBn2d,
ConvRelu2d, ConvRelu2d,
ConvTranspose2d, ConvTranspose2d,
ConvTransposeBn2d,
ConvTransposeRelu2d,
DequantStub, DequantStub,
Module, Module,
QuantStub, QuantStub,
@@ -34,6 +38,49 @@ def test_qat_convbn2d():
module = ConvBn2d( module = ConvBn2d(
in_channels, out_channels, kernel_size, groups=groups, bias=bias in_channels, out_channels, kernel_size, groups=groups, bias=bias
) )
M.init.normal_(module.bn.weight)
M.init.normal_(module.bn.bias)
module.train()
qat_module = quantize_qat(module, inplace=False)
disable_fake_quant(qat_module)
inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32))
normal_outputs = module(inputs)
qat_outputs = qat_module(inputs)
np.testing.assert_allclose(
normal_outputs.numpy(), qat_outputs.numpy(), atol=5e-6
)
np.testing.assert_allclose(
module.bn.running_mean.numpy(),
qat_module.bn.running_mean.numpy(),
atol=5e-8,
)
np.testing.assert_allclose(
module.bn.running_var.numpy(), qat_module.bn.running_var.numpy(), atol=5e-7,
)
module.eval()
normal_outputs = module(inputs)
qat_module.eval()
qat_outputs = qat_module(inputs)
np.testing.assert_allclose(
normal_outputs.numpy(), qat_outputs.numpy(), atol=5e-6
)


def test_qat_convtransposebn2d():
in_channels = 32
out_channels = 64
kernel_size = 3
for groups, bias in product([1, 4], [True, False]):
module = ConvTransposeBn2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
output_padding=0,
groups=groups,
bias=bias,
)
M.init.normal_(module.bn.weight)
M.init.normal_(module.bn.bias)
module.train() module.train()
qat_module = quantize_qat(module, inplace=False) qat_module = quantize_qat(module, inplace=False)
disable_fake_quant(qat_module) disable_fake_quant(qat_module)
@@ -235,10 +282,14 @@ def test_qat_conv_transpose2d():
self.conv = ConvTranspose2d( self.conv = ConvTranspose2d(
in_channels, out_channels, kernel_size, bias=bias in_channels, out_channels, kernel_size, bias=bias
) )
self.conv_transpose2d_relu = ConvTransposeRelu2d(
out_channels, in_channels, kernel_size, bias=bias
)


def forward(self, inp): def forward(self, inp):
out = self.quant(inp) out = self.quant(inp)
out = self.conv(out) out = self.conv(out)
out = self.conv_transpose2d_relu(out)
out = self.dequant(out) out = self.dequant(out)
return out return out


@@ -250,10 +301,14 @@ def test_qat_conv_transpose2d():
disable_fake_quant(qat_net) disable_fake_quant(qat_net)
normal_outputs = net(inputs) normal_outputs = net(inputs)
qat_outputs = qat_net(inputs) qat_outputs = qat_net(inputs)
np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy())
np.testing.assert_allclose(
normal_outputs.numpy(), qat_outputs.numpy(), atol=1e-6
)


net.eval() net.eval()
normal_outputs = net(inputs) normal_outputs = net(inputs)
qat_net.eval() qat_net.eval()
qat_outputs = qat_net(inputs) qat_outputs = qat_net(inputs)
np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy())
np.testing.assert_allclose(
normal_outputs.numpy(), qat_outputs.numpy(), atol=1e-6
)

Loading…
Cancel
Save