diff --git a/python_module/megengine/quantization/fake_quant.py b/python_module/megengine/quantization/fake_quant.py index da365f16..7260b9db 100644 --- a/python_module/megengine/quantization/fake_quant.py +++ b/python_module/megengine/quantization/fake_quant.py @@ -50,17 +50,17 @@ class _FakeQuantize(Module): def disable(self): self.enabled = False - def fake_quant_forward(self, inp, q_dict): + def fake_quant_forward(self, inp, q_dict=None): return inp - def normal_foward(self, inp, q_dict): + def normal_foward(self, inp, q_dict=None): return inp - def forward(self, inp, q_dict): + def forward(self, inp, q_dict=None): if self.enabled: - return self.fake_quant_forward(inp, q_dict) + return self.fake_quant_forward(inp, q_dict=q_dict) else: - return self.normal_foward(inp, q_dict) + return self.normal_foward(inp, q_dict=q_dict) class TQT_Function(Function): @@ -110,11 +110,11 @@ class TQT(_FakeQuantize): super().__init__(dtype, narrow_range, enable) self.scale = Parameter(0.0, dtype=np.float32) - def fake_quant_forward(self, inp, q_dict): + def fake_quant_forward(self, inp, q_dict=None): # when enable, TQT will do fakequant forward, finetune the scale return TQT_Function(self.qmin, self.qmax)(inp, self.scale) - def normal_foward(self, inp, q_dict): + def normal_foward(self, inp, q_dict=None): if q_dict["enable_observer"]: # when disable, TQT will do normal forward, initialize scale weight tmp_scale = F.maximum(F.abs(q_dict["min_val"]), F.abs(q_dict["max_val"])) @@ -123,9 +123,9 @@ class TQT(_FakeQuantize): return inp def get_qparams(self): - qdict = get_qparam_dict(QuantMode.TQT) - qdict["scale"] = 2 ** self.scale - return qdict + q_dict = get_qparam_dict(QuantMode.TQT) + q_dict["scale"] = 2 ** self.scale + return q_dict def get_dtype(self): q_dict = self.get_qparams() @@ -141,5 +141,5 @@ class FakeQuantize(_FakeQuantize): A module to do quant and dequant according to observer's scale and zero_point. """ - def fake_quant_forward(self, inp, q_dict): + def fake_quant_forward(self, inp, q_dict=None): return fake_quant_tensor(inp, self.qmin, self.qmax, q_dict)