|
|
@@ -8,7 +8,7 @@ |
|
|
|
from typing import Tuple, Union |
|
|
|
|
|
|
|
from ..core import ones, zeros |
|
|
|
from ..functional import flatten, relu, sqrt, sum |
|
|
|
from ..functional import add_update, flatten, relu, sqrt, sum, zero_grad |
|
|
|
from .batchnorm import BatchNorm2d |
|
|
|
from .conv import Conv2d |
|
|
|
from .module import QATModule |
|
|
@@ -31,7 +31,6 @@ class _ConvBn2d(QATModule): |
|
|
|
momentum=0.9, |
|
|
|
affine=True, |
|
|
|
track_running_stats=True, |
|
|
|
freeze_bn=False, |
|
|
|
): |
|
|
|
super().__init__() |
|
|
|
self.conv = Conv2d( |
|
|
@@ -47,28 +46,6 @@ class _ConvBn2d(QATModule): |
|
|
|
compute_mode, |
|
|
|
) |
|
|
|
self.bn = BatchNorm2d(out_channels, eps, momentum, affine, track_running_stats) |
|
|
|
self.freeze_bn = freeze_bn |
|
|
|
|
|
|
|
def update_bn_stats(self): |
|
|
|
self.freeze_bn = False |
|
|
|
return self |
|
|
|
|
|
|
|
def freeze_bn_stats(self): |
|
|
|
self.freeze_bn = True |
|
|
|
return self |
|
|
|
|
|
|
|
def get_bn_gamma_beta(self): |
|
|
|
if self.bn.weight is None: |
|
|
|
gamma = ones((self.bn.num_features), dtype="float32") |
|
|
|
else: |
|
|
|
gamma = self.bn.weight |
|
|
|
|
|
|
|
if self.bn.bias is None: |
|
|
|
beta = zeros((self.bn.num_features), dtype="float32") |
|
|
|
else: |
|
|
|
beta = self.bn.bias |
|
|
|
|
|
|
|
return gamma, beta |
|
|
|
|
|
|
|
def get_batch_mean_var(self, inp): |
|
|
|
def _sum_channel(inp, axis=0, keepdims=True): |
|
|
@@ -83,8 +60,7 @@ class _ConvBn2d(QATModule): |
|
|
|
sum2 = _sum_channel(inp ** 2, (0, 2, 3)) |
|
|
|
reduce_size = inp.shapeof().prod() / inp.shapeof(1) |
|
|
|
batch_mean = sum1 / reduce_size |
|
|
|
batch_var = (sum2 - sum1 ** 2 / reduce_size) / (reduce_size - 1) |
|
|
|
|
|
|
|
batch_var = (sum2 - sum1 ** 2 / reduce_size) / reduce_size |
|
|
|
return batch_mean, batch_var |
|
|
|
|
|
|
|
def fold_weight_bias(self, bn_mean, bn_var): |
|
|
@@ -92,50 +68,123 @@ class _ConvBn2d(QATModule): |
|
|
|
# bn_istd = 1 / bn_std |
|
|
|
# w_fold = gamma / bn_std * W |
|
|
|
# b_fold = gamma * (b - bn_mean) / bn_std + beta |
|
|
|
gamma, beta = self.get_bn_gamma_beta() |
|
|
|
b = self.conv.bias |
|
|
|
if b is None: |
|
|
|
b = zeros(self.conv._infer_bias_shape(), dtype="float32") |
|
|
|
gamma = self.bn.weight |
|
|
|
if gamma is None: |
|
|
|
gamma = ones((self.bn.num_features), dtype="float32") |
|
|
|
gamma = gamma.reshape(1, -1, 1, 1) |
|
|
|
beta = self.bn.bias |
|
|
|
if beta is None: |
|
|
|
beta = zeros((self.bn.num_features), dtype="float32") |
|
|
|
beta = beta.reshape(1, -1, 1, 1) |
|
|
|
|
|
|
|
if bn_mean is None: |
|
|
|
bn_mean = zeros((1, self.bn.num_features, 1, 1), dtype="float32") |
|
|
|
if bn_var is None: |
|
|
|
bn_var = ones((1, self.bn.num_features, 1, 1), dtype="float32") |
|
|
|
|
|
|
|
conv_bias = self.conv.bias |
|
|
|
if conv_bias is None: |
|
|
|
conv_bias = zeros(self.conv._infer_bias_shape(), dtype="float32") |
|
|
|
|
|
|
|
bn_istd = 1.0 / sqrt(bn_var + self.bn.eps) |
|
|
|
# bn_istd = 1 / bn_std |
|
|
|
# w_fold = gamma / bn_std * W |
|
|
|
scale_factor = gamma * bn_istd |
|
|
|
if self.conv.groups == 1: |
|
|
|
w_fold = ( |
|
|
|
self.conv.weight |
|
|
|
* gamma.reshape(-1, 1, 1, 1) |
|
|
|
* bn_istd.reshape(-1, 1, 1, 1) |
|
|
|
) |
|
|
|
w_fold = self.conv.weight * scale_factor.reshape(-1, 1, 1, 1) |
|
|
|
else: |
|
|
|
w_fold = ( |
|
|
|
self.conv.weight |
|
|
|
* gamma.reshape(self.conv.groups, -1, 1, 1, 1) |
|
|
|
* bn_istd.reshape(self.conv.groups, -1, 1, 1, 1) |
|
|
|
w_fold = self.conv.weight * scale_factor.reshape( |
|
|
|
self.conv.groups, -1, 1, 1, 1 |
|
|
|
) |
|
|
|
b_fold = flatten(beta) + ( |
|
|
|
flatten(gamma) * (flatten(b) - flatten(bn_mean)) * flatten(bn_istd) |
|
|
|
) |
|
|
|
b_fold = b_fold.reshape(self.conv._infer_bias_shape()) |
|
|
|
|
|
|
|
# b_fold = gamma * (b - bn_mean) / bn_std + beta |
|
|
|
b_fold = beta + gamma * (conv_bias - bn_mean) * bn_istd |
|
|
|
return w_fold, b_fold |
|
|
|
|
|
|
|
def calc_conv_bn_qat(self, inp): |
|
|
|
# TODO: use pytorch method as |
|
|
|
conv = self.conv(inp) |
|
|
|
self.bn(conv) |
|
|
|
def update_running_mean_and_running_var( |
|
|
|
self, bn_mean, bn_var, num_elements_per_channel |
|
|
|
): |
|
|
|
# update running mean and running var. no grad, use unbiased bn var |
|
|
|
bn_mean = zero_grad(bn_mean) |
|
|
|
bn_var = ( |
|
|
|
zero_grad(bn_var) |
|
|
|
* num_elements_per_channel |
|
|
|
/ (num_elements_per_channel - 1) |
|
|
|
) |
|
|
|
exponential_average_factor = 1 - self.bn.momentum |
|
|
|
add_update( |
|
|
|
self.bn.running_mean, |
|
|
|
delta=bn_mean, |
|
|
|
alpha=1 - exponential_average_factor, |
|
|
|
beta=exponential_average_factor, |
|
|
|
) |
|
|
|
add_update( |
|
|
|
self.bn.running_var, |
|
|
|
delta=bn_var, |
|
|
|
alpha=1 - exponential_average_factor, |
|
|
|
beta=exponential_average_factor, |
|
|
|
) |
|
|
|
|
|
|
|
if self.training: |
|
|
|
def calc_conv_bn_qat(self, inp, approx=True): |
|
|
|
if self.training and not approx: |
|
|
|
conv = self.conv(inp) |
|
|
|
bn_mean, bn_var = self.get_batch_mean_var(conv) |
|
|
|
num_elements_per_channel = conv.shapeof().prod() / conv.shapeof(1) |
|
|
|
self.update_running_mean_and_running_var( |
|
|
|
bn_mean, bn_var, num_elements_per_channel |
|
|
|
) |
|
|
|
else: |
|
|
|
bn_mean, bn_var = self.bn.running_mean, self.bn.running_var |
|
|
|
|
|
|
|
w_fold, b_fold = self.fold_weight_bias(bn_mean, bn_var) |
|
|
|
# get gamma and beta in BatchNorm |
|
|
|
gamma = self.bn.weight |
|
|
|
if gamma is None: |
|
|
|
gamma = ones((self.bn.num_features), dtype="float32") |
|
|
|
gamma = gamma.reshape(1, -1, 1, 1) |
|
|
|
beta = self.bn.bias |
|
|
|
if beta is None: |
|
|
|
beta = zeros((self.bn.num_features), dtype="float32") |
|
|
|
beta = beta.reshape(1, -1, 1, 1) |
|
|
|
# conv_bias |
|
|
|
conv_bias = self.conv.bias |
|
|
|
if conv_bias is None: |
|
|
|
conv_bias = zeros(self.conv._infer_bias_shape(), dtype="float32") |
|
|
|
|
|
|
|
bn_istd = 1.0 / sqrt(bn_var + self.bn.eps) |
|
|
|
# bn_istd = 1 / bn_std |
|
|
|
# w_fold = gamma / bn_std * W |
|
|
|
scale_factor = gamma * bn_istd |
|
|
|
if self.conv.groups == 1: |
|
|
|
w_fold = self.conv.weight * scale_factor.reshape(-1, 1, 1, 1) |
|
|
|
else: |
|
|
|
w_fold = self.conv.weight * scale_factor.reshape( |
|
|
|
self.conv.groups, -1, 1, 1, 1 |
|
|
|
) |
|
|
|
b_fold = None |
|
|
|
if not (self.training and approx): |
|
|
|
# b_fold = gamma * (conv_bias - bn_mean) / bn_std + beta |
|
|
|
b_fold = beta + gamma * (conv_bias - bn_mean) * bn_istd |
|
|
|
|
|
|
|
w_qat = self.apply_fakequant_with_observer( |
|
|
|
w_fold, self.weight_fake_quant, self.weight_observer |
|
|
|
) |
|
|
|
return self.conv.calc_conv(inp, w_qat, b_fold) |
|
|
|
conv = self.conv.calc_conv(inp, w_qat, b_fold) |
|
|
|
if not (self.training and approx): |
|
|
|
return conv |
|
|
|
|
|
|
|
# rescale conv to get original conv output |
|
|
|
orig_conv = conv / scale_factor.reshape(1, -1, 1, 1) |
|
|
|
if self.conv.bias is not None: |
|
|
|
orig_conv = orig_conv + self.conv.bias |
|
|
|
# calculate batch norm |
|
|
|
bn_mean, bn_var = self.get_batch_mean_var(orig_conv) |
|
|
|
bn_istd = 1.0 / sqrt(bn_var + self.bn.eps) |
|
|
|
conv = gamma * bn_istd * (orig_conv - bn_mean) + beta |
|
|
|
num_elements_per_channel = conv.shapeof().prod() / conv.shapeof(1) |
|
|
|
self.update_running_mean_and_running_var( |
|
|
|
bn_mean, bn_var, num_elements_per_channel |
|
|
|
) |
|
|
|
return conv |
|
|
|
|
|
|
|
|
|
|
|
class ConvBn2d(_ConvBn2d): |
|
|
|