GitOrigin-RevId: bf54012155
tags/v1.3.0
@@ -5,8 +5,6 @@ | |||||
# Unless required by applicable law or agreed to in writing, | # Unless required by applicable law or agreed to in writing, | ||||
# software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
from ...quantization.utils import fake_quant_bias | |||||
from .. import batch_matmul_activation as Float | from .. import batch_matmul_activation as Float | ||||
from .module import QATModule | from .module import QATModule | ||||
@@ -18,7 +16,7 @@ class BatchMatMulActivation(Float.BatchMatMulActivation, QATModule): | |||||
def forward(self, inp): | def forward(self, inp): | ||||
w_qat = self.apply_quant_weight(self.weight) | w_qat = self.apply_quant_weight(self.weight) | ||||
b_qat = fake_quant_bias(self.bias, inp, w_qat) | |||||
b_qat = self.apply_quant_bias(self.bias, inp, w_qat) | |||||
return self.apply_quant_activation(self._calc_linear(inp, w_qat, b_qat)) | return self.apply_quant_activation(self._calc_linear(inp, w_qat, b_qat)) | ||||
@classmethod | @classmethod | ||||
@@ -6,7 +6,6 @@ | |||||
# software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
from ... import functional as F | from ... import functional as F | ||||
from ...quantization.utils import fake_quant_bias | |||||
from .. import conv as Float | from .. import conv as Float | ||||
from .module import QATModule | from .module import QATModule | ||||
@@ -19,10 +18,7 @@ class Conv2d(Float.Conv2d, QATModule): | |||||
def calc_conv_qat(self, inp): | def calc_conv_qat(self, inp): | ||||
w_qat = self.apply_quant_weight(self.weight) | w_qat = self.apply_quant_weight(self.weight) | ||||
if self.weight_fake_quant and self.weight_fake_quant.enabled: | |||||
b_qat = fake_quant_bias(self.bias, inp, w_qat) | |||||
else: | |||||
b_qat = self.bias | |||||
b_qat = self.apply_quant_bias(self.bias, inp, w_qat) | |||||
conv = self.calc_conv(inp, w_qat, b_qat) | conv = self.calc_conv(inp, w_qat, b_qat) | ||||
return conv | return conv | ||||
@@ -6,7 +6,6 @@ | |||||
# software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
from ...functional import ones, relu, sqrt, sum, zeros | from ...functional import ones, relu, sqrt, sum, zeros | ||||
from ...quantization.utils import fake_quant_bias | |||||
from .. import conv_bn as Float | from .. import conv_bn as Float | ||||
from .module import QATModule | from .module import QATModule | ||||
@@ -122,10 +121,7 @@ class _ConvBnActivation2d(Float._ConvBnActivation2d, QATModule): | |||||
b_fold = beta + gamma * (conv_bias - bn_mean) * bn_istd | b_fold = beta + gamma * (conv_bias - bn_mean) * bn_istd | ||||
w_qat = self.apply_quant_weight(w_fold) | w_qat = self.apply_quant_weight(w_fold) | ||||
if self.weight_fake_quant and self.weight_fake_quant.enabled: | |||||
b_qat = fake_quant_bias(b_fold, inp, w_qat) | |||||
else: | |||||
b_qat = b_fold | |||||
b_qat = self.apply_quant_bias(b_fold, inp, w_qat) | |||||
conv = self.conv.calc_conv(inp, w_qat, b_qat) | conv = self.conv.calc_conv(inp, w_qat, b_qat) | ||||
if not (self.training and approx): | if not (self.training and approx): | ||||
return conv | return conv | ||||
@@ -5,7 +5,6 @@ | |||||
# Unless required by applicable law or agreed to in writing, | # Unless required by applicable law or agreed to in writing, | ||||
# software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
from ...quantization.utils import fake_quant_bias | |||||
from .. import linear as Float | from .. import linear as Float | ||||
from .module import QATModule | from .module import QATModule | ||||
@@ -22,13 +21,10 @@ class Linear(Float.Linear, QATModule): | |||||
""" | """ | ||||
def forward(self, x): | |||||
def forward(self, inp): | |||||
w_qat = self.apply_quant_weight(self.weight) | w_qat = self.apply_quant_weight(self.weight) | ||||
if self.weight_fake_quant and self.weight_fake_quant.enabled: | |||||
b_qat = fake_quant_bias(self.bias, x, w_qat) | |||||
else: | |||||
b_qat = self.bias | |||||
return self.apply_quant_activation(self._calc_linear(x, w_qat, b_qat)) | |||||
b_qat = self.apply_quant_bias(self.bias, inp, w_qat) | |||||
return self.apply_quant_activation(self._calc_linear(inp, w_qat, b_qat)) | |||||
@classmethod | @classmethod | ||||
def from_float_module(cls, float_module: Float.Linear): | def from_float_module(cls, float_module: Float.Linear): | ||||
@@ -11,6 +11,7 @@ from abc import abstractmethod | |||||
from ...quantization.fake_quant import FakeQuantize | from ...quantization.fake_quant import FakeQuantize | ||||
from ...quantization.observer import Observer | from ...quantization.observer import Observer | ||||
from ...quantization.qconfig import QConfig | from ...quantization.qconfig import QConfig | ||||
from ...quantization.utils import fake_quant_bias | |||||
from ...tensor import Tensor | from ...tensor import Tensor | ||||
from ..module import Module | from ..module import Module | ||||
@@ -107,6 +108,24 @@ class QATModule(Module): | |||||
target, self.act_fake_quant, self.act_observer | target, self.act_fake_quant, self.act_observer | ||||
) | ) | ||||
def apply_quant_bias(self, target: Tensor, inp: Tensor, w_qat: Tensor): | |||||
r""" | |||||
Use :func:`~.fake_quant_bias` to process ``target``. Only valid when | |||||
``act_fake_quant`` and ``weight_fake_quant`` are both enabled. | |||||
""" | |||||
# bias should have the same dtype as activation, so act_fake_quant can also | |||||
# decide whether to do bias fakequant | |||||
if ( | |||||
self.act_fake_quant | |||||
and self.act_fake_quant.enabled | |||||
and self.weight_fake_quant | |||||
and self.weight_fake_quant.enabled | |||||
): | |||||
b_qat = fake_quant_bias(target, inp, w_qat) | |||||
else: | |||||
b_qat = target | |||||
return b_qat | |||||
def _get_method_result( | def _get_method_result( | ||||
self, method: str, fake_quant: FakeQuantize, observer: Observer | self, method: str, fake_quant: FakeQuantize, observer: Observer | ||||
): | ): | ||||
@@ -30,4 +30,10 @@ from .quantize import ( | |||||
quantize_qat, | quantize_qat, | ||||
reset_qconfig, | reset_qconfig, | ||||
) | ) | ||||
from .utils import QParams, QuantMode, create_qparams | |||||
from .utils import ( | |||||
QParams, | |||||
QuantMode, | |||||
create_qparams, | |||||
fake_quant_bias, | |||||
fake_quant_tensor, | |||||
) |