From 387867f8ae8373884b819273799d22c252bc1cba Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 28 Aug 2020 13:54:51 +0800 Subject: [PATCH] feat(mge/quantization): add cambricon-quantization-example GitOrigin-RevId: ed578ca92bccca84367a6bd5c4492ddcf759f50e --- python_module/megengine/module/qat/module.py | 2 +- python_module/megengine/module/quantized/linear.py | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/python_module/megengine/module/qat/module.py b/python_module/megengine/module/qat/module.py index 8546e4c6..c7cb80cb 100644 --- a/python_module/megengine/module/qat/module.py +++ b/python_module/megengine/module/qat/module.py @@ -52,7 +52,7 @@ class QATModule(Module): self.weight_fake_quant = safe_call(qconfig.weight_fake_quant) def _enable_exec(self, with_module, func, enable): - if not with_module: + if not with_module or not func: return if enable: func.enable() diff --git a/python_module/megengine/module/quantized/linear.py b/python_module/megengine/module/quantized/linear.py index 4c798929..a6e61a6e 100644 --- a/python_module/megengine/module/quantized/linear.py +++ b/python_module/megengine/module/quantized/linear.py @@ -32,11 +32,13 @@ class Linear(QuantizedModule): inp_scale = mgb.dtype.get_scale(inp.dtype) w_scale = mgb.dtype.get_scale(self.weight.dtype) bias_dtype = mgb.dtype.qint32(inp_scale * w_scale) - return F.linear( + ret = F.linear( inp, self.weight, None if self.bias is None else self.bias.astype(bias_dtype), - ).astype(self.output_dtype) + ) + ret = ret if self.output_dtype is None else ret.astype(self.output_dtype) + return ret @classmethod def from_qat_module(cls, qat_module: QAT.Linear):