|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168 |
- # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
- #
- # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
- #
- # Unless required by applicable law or agreed to in writing,
- # software distributed under the License is distributed on an
- # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- from typing import Tuple, Union
-
- from ..core import ones, zeros
- from ..functional import flatten, relu, sqrt, sum
- from .batchnorm import BatchNorm2d
- from .conv import Conv2d
- from .module import QATModule
-
-
- class _ConvBn2d(QATModule):
- def __init__(
- self,
- in_channels: int,
- out_channels: int,
- kernel_size: Union[int, Tuple[int, int]],
- stride: Union[int, Tuple[int, int]] = 1,
- padding: Union[int, Tuple[int, int]] = 0,
- dilation: Union[int, Tuple[int, int]] = 1,
- groups: int = 1,
- bias: bool = True,
- conv_mode: str = "CROSS_CORRELATION",
- compute_mode: str = "DEFAULT",
- eps=1e-5,
- momentum=0.9,
- affine=True,
- track_running_stats=True,
- freeze_bn=False,
- ):
- super().__init__()
- self.conv = Conv2d(
- in_channels,
- out_channels,
- kernel_size,
- stride,
- padding,
- dilation,
- groups,
- bias,
- conv_mode,
- compute_mode,
- )
- self.bn = BatchNorm2d(out_channels, eps, momentum, affine, track_running_stats)
- self.freeze_bn = freeze_bn
-
- def update_bn_stats(self):
- self.freeze_bn = False
- return self
-
- def freeze_bn_stats(self):
- self.freeze_bn = True
- return self
-
- def get_bn_gamma_beta(self):
- if self.bn.weight is None:
- gamma = ones((self.bn.num_features), dtype="float32")
- else:
- gamma = self.bn.weight
-
- if self.bn.bias is None:
- beta = zeros((self.bn.num_features), dtype="float32")
- else:
- beta = self.bn.bias
-
- return gamma, beta
-
- def get_batch_mean_var(self, inp):
- def _sum_channel(inp, axis=0, keepdims=True):
- if isinstance(axis, int):
- out = sum(inp, axis=axis, keepdims=keepdims)
- elif isinstance(axis, tuple):
- for idx, elem in enumerate(axis):
- out = sum(inp if idx == 0 else out, axis=elem, keepdims=keepdims)
- return out
-
- sum1 = _sum_channel(inp, (0, 2, 3))
- sum2 = _sum_channel(inp ** 2, (0, 2, 3))
- reduce_size = inp.shapeof().prod() / inp.shapeof(1)
- batch_mean = sum1 / reduce_size
- batch_var = (sum2 - sum1 ** 2 / reduce_size) / (reduce_size - 1)
-
- return batch_mean, batch_var
-
- def fold_weight_bias(self, bn_mean, bn_var):
- # get fold bn conv param
- # bn_istd = 1 / bn_std
- # w_fold = gamma / bn_std * W
- # b_fold = gamma * (b - bn_mean) / bn_std + beta
- gamma, beta = self.get_bn_gamma_beta()
- b = self.conv.bias
- if b is None:
- b = zeros(self.conv._infer_bias_shape(), dtype="float32")
- if bn_mean is None:
- bn_mean = zeros((1, self.bn.num_features, 1, 1), dtype="float32")
- if bn_var is None:
- bn_var = ones((1, self.bn.num_features, 1, 1), dtype="float32")
-
- bn_istd = 1.0 / sqrt(bn_var + self.bn.eps)
- if self.conv.groups == 1:
- w_fold = (
- self.conv.weight
- * gamma.reshape(-1, 1, 1, 1)
- * bn_istd.reshape(-1, 1, 1, 1)
- )
- else:
- w_fold = (
- self.conv.weight
- * gamma.reshape(self.conv.groups, -1, 1, 1, 1)
- * bn_istd.reshape(self.conv.groups, -1, 1, 1, 1)
- )
- b_fold = flatten(beta) + (
- flatten(gamma) * (flatten(b) - flatten(bn_mean)) * flatten(bn_istd)
- )
- b_fold = b_fold.reshape(self.conv._infer_bias_shape())
-
- return w_fold, b_fold
-
- def calc_conv_bn_qat(self, inp):
- # TODO: use pytorch method as
- conv = self.conv(inp)
- self.bn(conv)
-
- if self.training:
- bn_mean, bn_var = self.get_batch_mean_var(conv)
- else:
- bn_mean, bn_var = self.bn.running_mean, self.bn.running_var
-
- w_fold, b_fold = self.fold_weight_bias(bn_mean, bn_var)
- w_qat = self.apply_fakequant_with_observer(
- w_fold, self.weight_fake_quant, self.weight_observer
- )
- return self.conv.calc_conv(inp, w_qat, b_fold)
-
-
- class ConvBn2d(_ConvBn2d):
- r"""
- A fused :class:`~.QATModule` including Conv2d and BatchNorm2d, supporting ``qat`` mode
- and ``normal`` mode.
- """
-
- def forward_qat(self, inp):
- return self.apply_fakequant_with_observer(
- self.calc_conv_bn_qat(inp), self.act_fake_quant, self.act_observer
- )
-
- def forward(self, inp):
- return self.bn(self.conv(inp))
-
-
- class ConvBnRelu2d(_ConvBn2d):
- r"""
- A fused :class:`~.QATModule` including Conv2d, BatchNorm2d and relu, supporting ``qat``
- mode and ``normal`` mode.
- """
-
- def forward_qat(self, inp):
- return self.apply_fakequant_with_observer(
- relu(self.calc_conv_bn_qat(inp)), self.act_fake_quant, self.act_observer
- )
-
- def forward(self, inp):
- return relu(self.bn(self.conv(inp)))
|