Browse Source

fix(imperative/quantization): fix zero scale bug of easy quant

GitOrigin-RevId: f45e19b3e4
tags/v1.3.1
Megvii Engine Team 4 years ago
parent
commit
44bafd3f58
1 changed files with 11 additions and 6 deletions
  1. +11
    -6
      imperative/python/megengine/quantization/quantize.py

+ 11
- 6
imperative/python/megengine/quantization/quantize.py View File

@@ -13,6 +13,7 @@ import numpy as np

from .. import module as Float
from ..functional import concat, norm
from ..logger import get_logger
from ..module import Module
from ..module import qat as QAT
from ..module import quantized as Quantized
@@ -22,6 +23,8 @@ from ..tensor import Tensor
from ..utils.module_utils import set_expand_structure
from .qconfig import QConfig, ema_fakequant_qconfig

logger = get_logger(__name__)


def _get_quantable_module_names():
def is_quantable(key: str):
@@ -236,16 +239,18 @@ def apply_easy_quant(
return

orig_scale = ob.orig_scale
distance = 0
best_scale = 0
cosine = optimal = 0
for scale in np.linspace(start * orig_scale, stop * orig_scale, num):
ob.scale = scale
fakequant_out = mod(*fakequant_in)
dis = get_cosine(normal_out, fakequant_out)
if dis > distance:
distance = dis
best_scale = scale
ob.scale = best_scale
if dis > cosine:
cosine = dis
optimal = scale
if optimal == 0:
logger.warning("EasyQuant finds no better scale")
else:
ob.scale = optimal

fakequant_out = outputs[batch_size:]
return concat([normal_out, fakequant_out])


Loading…
Cancel
Save