|
|
@@ -50,17 +50,17 @@ class _FakeQuantize(Module): |
|
|
|
def disable(self): |
|
|
|
self.enabled = False |
|
|
|
|
|
|
|
def fake_quant_forward(self, inp, q_dict): |
|
|
|
def fake_quant_forward(self, inp, q_dict=None): |
|
|
|
return inp |
|
|
|
|
|
|
|
def normal_foward(self, inp, q_dict): |
|
|
|
def normal_foward(self, inp, q_dict=None): |
|
|
|
return inp |
|
|
|
|
|
|
|
def forward(self, inp, q_dict): |
|
|
|
def forward(self, inp, q_dict=None): |
|
|
|
if self.enabled: |
|
|
|
return self.fake_quant_forward(inp, q_dict) |
|
|
|
return self.fake_quant_forward(inp, q_dict=q_dict) |
|
|
|
else: |
|
|
|
return self.normal_foward(inp, q_dict) |
|
|
|
return self.normal_foward(inp, q_dict=q_dict) |
|
|
|
|
|
|
|
|
|
|
|
class TQT_Function(Function): |
|
|
@@ -110,11 +110,11 @@ class TQT(_FakeQuantize): |
|
|
|
super().__init__(dtype, narrow_range, enable) |
|
|
|
self.scale = Parameter(0.0, dtype=np.float32) |
|
|
|
|
|
|
|
def fake_quant_forward(self, inp, q_dict): |
|
|
|
def fake_quant_forward(self, inp, q_dict=None): |
|
|
|
# when enable, TQT will do fakequant forward, finetune the scale |
|
|
|
return TQT_Function(self.qmin, self.qmax)(inp, self.scale) |
|
|
|
|
|
|
|
def normal_foward(self, inp, q_dict): |
|
|
|
def normal_foward(self, inp, q_dict=None): |
|
|
|
if q_dict["enable_observer"]: |
|
|
|
# when disable, TQT will do normal forward, initialize scale weight |
|
|
|
tmp_scale = F.maximum(F.abs(q_dict["min_val"]), F.abs(q_dict["max_val"])) |
|
|
@@ -123,9 +123,9 @@ class TQT(_FakeQuantize): |
|
|
|
return inp |
|
|
|
|
|
|
|
def get_qparams(self): |
|
|
|
qdict = get_qparam_dict(QuantMode.TQT) |
|
|
|
qdict["scale"] = 2 ** self.scale |
|
|
|
return qdict |
|
|
|
q_dict = get_qparam_dict(QuantMode.TQT) |
|
|
|
q_dict["scale"] = 2 ** self.scale |
|
|
|
return q_dict |
|
|
|
|
|
|
|
def get_dtype(self): |
|
|
|
q_dict = self.get_qparams() |
|
|
@@ -141,5 +141,5 @@ class FakeQuantize(_FakeQuantize): |
|
|
|
A module to do quant and dequant according to observer's scale and zero_point. |
|
|
|
""" |
|
|
|
|
|
|
|
def fake_quant_forward(self, inp, q_dict): |
|
|
|
def fake_quant_forward(self, inp, q_dict=None): |
|
|
|
return fake_quant_tensor(inp, self.qmin, self.qmax, q_dict) |