Browse Source

feat(mge/quantization): make `q_dict` a kwarg rather than an arg

GitOrigin-RevId: 38e3b2bfaf
tags/v1.0.0-rc1
Megvii Engine Team Xinran Xu 4 years ago
parent
commit
ab9fa48ee7
1 changed files with 11 additions and 11 deletions
  1. +11
    -11
      python_module/megengine/quantization/fake_quant.py

+ 11
- 11
python_module/megengine/quantization/fake_quant.py View File

@@ -50,17 +50,17 @@ class _FakeQuantize(Module):
def disable(self):
self.enabled = False

def fake_quant_forward(self, inp, q_dict):
def fake_quant_forward(self, inp, q_dict=None):
return inp

def normal_foward(self, inp, q_dict):
def normal_foward(self, inp, q_dict=None):
return inp

def forward(self, inp, q_dict):
def forward(self, inp, q_dict=None):
if self.enabled:
return self.fake_quant_forward(inp, q_dict)
return self.fake_quant_forward(inp, q_dict=q_dict)
else:
return self.normal_foward(inp, q_dict)
return self.normal_foward(inp, q_dict=q_dict)


class TQT_Function(Function):
@@ -110,11 +110,11 @@ class TQT(_FakeQuantize):
super().__init__(dtype, narrow_range, enable)
self.scale = Parameter(0.0, dtype=np.float32)

def fake_quant_forward(self, inp, q_dict):
def fake_quant_forward(self, inp, q_dict=None):
# when enable, TQT will do fakequant forward, finetune the scale
return TQT_Function(self.qmin, self.qmax)(inp, self.scale)

def normal_foward(self, inp, q_dict):
def normal_foward(self, inp, q_dict=None):
if q_dict["enable_observer"]:
# when disable, TQT will do normal forward, initialize scale weight
tmp_scale = F.maximum(F.abs(q_dict["min_val"]), F.abs(q_dict["max_val"]))
@@ -123,9 +123,9 @@ class TQT(_FakeQuantize):
return inp

def get_qparams(self):
qdict = get_qparam_dict(QuantMode.TQT)
qdict["scale"] = 2 ** self.scale
return qdict
q_dict = get_qparam_dict(QuantMode.TQT)
q_dict["scale"] = 2 ** self.scale
return q_dict

def get_dtype(self):
q_dict = self.get_qparams()
@@ -141,5 +141,5 @@ class FakeQuantize(_FakeQuantize):
A module to do quant and dequant according to observer's scale and zero_point.
"""

def fake_quant_forward(self, inp, q_dict):
def fake_quant_forward(self, inp, q_dict=None):
return fake_quant_tensor(inp, self.qmin, self.qmax, q_dict)

Loading…
Cancel
Save