|
|
@@ -13,6 +13,7 @@ import numpy as np |
|
|
|
|
|
|
|
from .. import module as Float |
|
|
|
from ..functional import concat, norm |
|
|
|
from ..logger import get_logger |
|
|
|
from ..module import Module |
|
|
|
from ..module import qat as QAT |
|
|
|
from ..module import quantized as Quantized |
|
|
@@ -22,6 +23,8 @@ from ..tensor import Tensor |
|
|
|
from ..utils.module_utils import set_expand_structure |
|
|
|
from .qconfig import QConfig, ema_fakequant_qconfig |
|
|
|
|
|
|
|
logger = get_logger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
def _get_quantable_module_names(): |
|
|
|
def is_quantable(key: str): |
|
|
@@ -236,16 +239,18 @@ def apply_easy_quant( |
|
|
|
return |
|
|
|
|
|
|
|
orig_scale = ob.orig_scale |
|
|
|
distance = 0 |
|
|
|
best_scale = 0 |
|
|
|
cosine = optimal = 0 |
|
|
|
for scale in np.linspace(start * orig_scale, stop * orig_scale, num): |
|
|
|
ob.scale = scale |
|
|
|
fakequant_out = mod(*fakequant_in) |
|
|
|
dis = get_cosine(normal_out, fakequant_out) |
|
|
|
if dis > distance: |
|
|
|
distance = dis |
|
|
|
best_scale = scale |
|
|
|
ob.scale = best_scale |
|
|
|
if dis > cosine: |
|
|
|
cosine = dis |
|
|
|
optimal = scale |
|
|
|
if optimal == 0: |
|
|
|
logger.warning("EasyQuant finds no better scale") |
|
|
|
else: |
|
|
|
ob.scale = optimal |
|
|
|
|
|
|
|
fakequant_out = outputs[batch_size:] |
|
|
|
return concat([normal_out, fakequant_out]) |
|
|
|