Browse Source

feat(mge/module): add fused conv_bn qat approximate version

GitOrigin-RevId: 1b7284a595
tags/v0.5.0
Megvii Engine Team 5 years ago
parent
commit
980ebf2c72
2 changed files with 138 additions and 50 deletions
  1. +99
    -50
      python_module/megengine/module/conv_bn_relu.py
  2. +39
    -0
      python_module/test/unit/module/test_conv_bn_relu.py

+ 99
- 50
python_module/megengine/module/conv_bn_relu.py View File

@@ -8,7 +8,7 @@
from typing import Tuple, Union from typing import Tuple, Union


from ..core import ones, zeros from ..core import ones, zeros
from ..functional import flatten, relu, sqrt, sum
from ..functional import add_update, flatten, relu, sqrt, sum, zero_grad
from .batchnorm import BatchNorm2d from .batchnorm import BatchNorm2d
from .conv import Conv2d from .conv import Conv2d
from .module import QATModule from .module import QATModule
@@ -31,7 +31,6 @@ class _ConvBn2d(QATModule):
momentum=0.9, momentum=0.9,
affine=True, affine=True,
track_running_stats=True, track_running_stats=True,
freeze_bn=False,
): ):
super().__init__() super().__init__()
self.conv = Conv2d( self.conv = Conv2d(
@@ -47,28 +46,6 @@ class _ConvBn2d(QATModule):
compute_mode, compute_mode,
) )
self.bn = BatchNorm2d(out_channels, eps, momentum, affine, track_running_stats) self.bn = BatchNorm2d(out_channels, eps, momentum, affine, track_running_stats)
self.freeze_bn = freeze_bn

def update_bn_stats(self):
self.freeze_bn = False
return self

def freeze_bn_stats(self):
self.freeze_bn = True
return self

def get_bn_gamma_beta(self):
if self.bn.weight is None:
gamma = ones((self.bn.num_features), dtype="float32")
else:
gamma = self.bn.weight

if self.bn.bias is None:
beta = zeros((self.bn.num_features), dtype="float32")
else:
beta = self.bn.bias

return gamma, beta


def get_batch_mean_var(self, inp): def get_batch_mean_var(self, inp):
def _sum_channel(inp, axis=0, keepdims=True): def _sum_channel(inp, axis=0, keepdims=True):
@@ -83,8 +60,7 @@ class _ConvBn2d(QATModule):
sum2 = _sum_channel(inp ** 2, (0, 2, 3)) sum2 = _sum_channel(inp ** 2, (0, 2, 3))
reduce_size = inp.shapeof().prod() / inp.shapeof(1) reduce_size = inp.shapeof().prod() / inp.shapeof(1)
batch_mean = sum1 / reduce_size batch_mean = sum1 / reduce_size
batch_var = (sum2 - sum1 ** 2 / reduce_size) / (reduce_size - 1)

batch_var = (sum2 - sum1 ** 2 / reduce_size) / reduce_size
return batch_mean, batch_var return batch_mean, batch_var


def fold_weight_bias(self, bn_mean, bn_var): def fold_weight_bias(self, bn_mean, bn_var):
@@ -92,50 +68,123 @@ class _ConvBn2d(QATModule):
# bn_istd = 1 / bn_std # bn_istd = 1 / bn_std
# w_fold = gamma / bn_std * W # w_fold = gamma / bn_std * W
# b_fold = gamma * (b - bn_mean) / bn_std + beta # b_fold = gamma * (b - bn_mean) / bn_std + beta
gamma, beta = self.get_bn_gamma_beta()
b = self.conv.bias
if b is None:
b = zeros(self.conv._infer_bias_shape(), dtype="float32")
gamma = self.bn.weight
if gamma is None:
gamma = ones((self.bn.num_features), dtype="float32")
gamma = gamma.reshape(1, -1, 1, 1)
beta = self.bn.bias
if beta is None:
beta = zeros((self.bn.num_features), dtype="float32")
beta = beta.reshape(1, -1, 1, 1)

if bn_mean is None: if bn_mean is None:
bn_mean = zeros((1, self.bn.num_features, 1, 1), dtype="float32") bn_mean = zeros((1, self.bn.num_features, 1, 1), dtype="float32")
if bn_var is None: if bn_var is None:
bn_var = ones((1, self.bn.num_features, 1, 1), dtype="float32") bn_var = ones((1, self.bn.num_features, 1, 1), dtype="float32")


conv_bias = self.conv.bias
if conv_bias is None:
conv_bias = zeros(self.conv._infer_bias_shape(), dtype="float32")

bn_istd = 1.0 / sqrt(bn_var + self.bn.eps) bn_istd = 1.0 / sqrt(bn_var + self.bn.eps)
# bn_istd = 1 / bn_std
# w_fold = gamma / bn_std * W
scale_factor = gamma * bn_istd
if self.conv.groups == 1: if self.conv.groups == 1:
w_fold = (
self.conv.weight
* gamma.reshape(-1, 1, 1, 1)
* bn_istd.reshape(-1, 1, 1, 1)
)
w_fold = self.conv.weight * scale_factor.reshape(-1, 1, 1, 1)
else: else:
w_fold = (
self.conv.weight
* gamma.reshape(self.conv.groups, -1, 1, 1, 1)
* bn_istd.reshape(self.conv.groups, -1, 1, 1, 1)
w_fold = self.conv.weight * scale_factor.reshape(
self.conv.groups, -1, 1, 1, 1
) )
b_fold = flatten(beta) + (
flatten(gamma) * (flatten(b) - flatten(bn_mean)) * flatten(bn_istd)
)
b_fold = b_fold.reshape(self.conv._infer_bias_shape())


# b_fold = gamma * (b - bn_mean) / bn_std + beta
b_fold = beta + gamma * (conv_bias - bn_mean) * bn_istd
return w_fold, b_fold return w_fold, b_fold


def calc_conv_bn_qat(self, inp):
# TODO: use pytorch method as
conv = self.conv(inp)
self.bn(conv)
def update_running_mean_and_running_var(
self, bn_mean, bn_var, num_elements_per_channel
):
# update running mean and running var. no grad, use unbiased bn var
bn_mean = zero_grad(bn_mean)
bn_var = (
zero_grad(bn_var)
* num_elements_per_channel
/ (num_elements_per_channel - 1)
)
exponential_average_factor = 1 - self.bn.momentum
add_update(
self.bn.running_mean,
delta=bn_mean,
alpha=1 - exponential_average_factor,
beta=exponential_average_factor,
)
add_update(
self.bn.running_var,
delta=bn_var,
alpha=1 - exponential_average_factor,
beta=exponential_average_factor,
)


if self.training:
def calc_conv_bn_qat(self, inp, approx=True):
if self.training and not approx:
conv = self.conv(inp)
bn_mean, bn_var = self.get_batch_mean_var(conv) bn_mean, bn_var = self.get_batch_mean_var(conv)
num_elements_per_channel = conv.shapeof().prod() / conv.shapeof(1)
self.update_running_mean_and_running_var(
bn_mean, bn_var, num_elements_per_channel
)
else: else:
bn_mean, bn_var = self.bn.running_mean, self.bn.running_var bn_mean, bn_var = self.bn.running_mean, self.bn.running_var


w_fold, b_fold = self.fold_weight_bias(bn_mean, bn_var)
# get gamma and beta in BatchNorm
gamma = self.bn.weight
if gamma is None:
gamma = ones((self.bn.num_features), dtype="float32")
gamma = gamma.reshape(1, -1, 1, 1)
beta = self.bn.bias
if beta is None:
beta = zeros((self.bn.num_features), dtype="float32")
beta = beta.reshape(1, -1, 1, 1)
# conv_bias
conv_bias = self.conv.bias
if conv_bias is None:
conv_bias = zeros(self.conv._infer_bias_shape(), dtype="float32")

bn_istd = 1.0 / sqrt(bn_var + self.bn.eps)
# bn_istd = 1 / bn_std
# w_fold = gamma / bn_std * W
scale_factor = gamma * bn_istd
if self.conv.groups == 1:
w_fold = self.conv.weight * scale_factor.reshape(-1, 1, 1, 1)
else:
w_fold = self.conv.weight * scale_factor.reshape(
self.conv.groups, -1, 1, 1, 1
)
b_fold = None
if not (self.training and approx):
# b_fold = gamma * (conv_bias - bn_mean) / bn_std + beta
b_fold = beta + gamma * (conv_bias - bn_mean) * bn_istd

w_qat = self.apply_fakequant_with_observer( w_qat = self.apply_fakequant_with_observer(
w_fold, self.weight_fake_quant, self.weight_observer w_fold, self.weight_fake_quant, self.weight_observer
) )
return self.conv.calc_conv(inp, w_qat, b_fold)
conv = self.conv.calc_conv(inp, w_qat, b_fold)
if not (self.training and approx):
return conv

# rescale conv to get original conv output
orig_conv = conv / scale_factor.reshape(1, -1, 1, 1)
if self.conv.bias is not None:
orig_conv = orig_conv + self.conv.bias
# calculate batch norm
bn_mean, bn_var = self.get_batch_mean_var(orig_conv)
bn_istd = 1.0 / sqrt(bn_var + self.bn.eps)
conv = gamma * bn_istd * (orig_conv - bn_mean) + beta
num_elements_per_channel = conv.shapeof().prod() / conv.shapeof(1)
self.update_running_mean_and_running_var(
bn_mean, bn_var, num_elements_per_channel
)
return conv




class ConvBn2d(_ConvBn2d): class ConvBn2d(_ConvBn2d):


+ 39
- 0
python_module/test/unit/module/test_conv_bn_relu.py View File

@@ -0,0 +1,39 @@
import copy
from itertools import product

import numpy as np

from megengine import tensor
from megengine.module import ConvBn2d
from megengine.quantization import quantize_qat
from megengine.quantization.quantize import disable_fake_quant
from megengine.test import assertTensorClose


def test_convbn2d():
in_channels = 32
out_channels = 64
kernel_size = 3
module = ConvBn2d(in_channels, out_channels, kernel_size)
quantize_qat(module)
for groups, bias in product([1, 4], [True, False]):
inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32))
module.train()
qat_module = copy.deepcopy(module)
disable_fake_quant(qat_module)
normal_outputs = module.forward(inputs)
qat_outputs = qat_module.forward_qat(inputs)
assertTensorClose(normal_outputs, qat_outputs, max_err=5e-6)
a = module.bn.running_mean.numpy()
b = qat_module.bn.running_mean.numpy()
assertTensorClose(
module.bn.running_mean, qat_module.bn.running_mean, max_err=5e-8
)
assertTensorClose(
module.bn.running_var, qat_module.bn.running_var, max_err=5e-7
)
module.eval()
normal_outputs = module.forward(inputs)
qat_module.eval()
qat_outputs = qat_module.forward_qat(inputs)
assertTensorClose(normal_outputs, qat_outputs, max_err=5e-6)

Loading…
Cancel
Save