Browse Source

fix(imperative/quantization): fix zero scale bug of easy quant

GitOrigin-RevId: f45e19b3e4
tags/v1.3.1
Megvii Engine Team 4 years ago
parent
commit
44bafd3f58
1 changed files with 11 additions and 6 deletions
  1. +11
    -6
      imperative/python/megengine/quantization/quantize.py

+ 11
- 6
imperative/python/megengine/quantization/quantize.py View File

@@ -13,6 +13,7 @@ import numpy as np


from .. import module as Float from .. import module as Float
from ..functional import concat, norm from ..functional import concat, norm
from ..logger import get_logger
from ..module import Module from ..module import Module
from ..module import qat as QAT from ..module import qat as QAT
from ..module import quantized as Quantized from ..module import quantized as Quantized
@@ -22,6 +23,8 @@ from ..tensor import Tensor
from ..utils.module_utils import set_expand_structure from ..utils.module_utils import set_expand_structure
from .qconfig import QConfig, ema_fakequant_qconfig from .qconfig import QConfig, ema_fakequant_qconfig


logger = get_logger(__name__)



def _get_quantable_module_names(): def _get_quantable_module_names():
def is_quantable(key: str): def is_quantable(key: str):
@@ -236,16 +239,18 @@ def apply_easy_quant(
return return


orig_scale = ob.orig_scale orig_scale = ob.orig_scale
distance = 0
best_scale = 0
cosine = optimal = 0
for scale in np.linspace(start * orig_scale, stop * orig_scale, num): for scale in np.linspace(start * orig_scale, stop * orig_scale, num):
ob.scale = scale ob.scale = scale
fakequant_out = mod(*fakequant_in) fakequant_out = mod(*fakequant_in)
dis = get_cosine(normal_out, fakequant_out) dis = get_cosine(normal_out, fakequant_out)
if dis > distance:
distance = dis
best_scale = scale
ob.scale = best_scale
if dis > cosine:
cosine = dis
optimal = scale
if optimal == 0:
logger.warning("EasyQuant finds no better scale")
else:
ob.scale = optimal


fakequant_out = outputs[batch_size:] fakequant_out = outputs[batch_size:]
return concat([normal_out, fakequant_out]) return concat([normal_out, fakequant_out])


Loading…
Cancel
Save