Browse Source

fix(mge/quantization): fix quantized concat forward problem

GitOrigin-RevId: dc21b340d1
release-1.3
Megvii Engine Team 4 years ago
parent
commit
0ed3699895
3 changed files with 56 additions and 4 deletions
  1. +1
    -1
      imperative/python/megengine/module/quantized/concat.py
  2. +10
    -2
      imperative/python/megengine/quantization/__init__.py
  3. +45
    -1
      imperative/python/test/unit/quantization/test_module.py

+ 1
- 1
imperative/python/megengine/module/quantized/concat.py View File

@@ -23,7 +23,7 @@ class Concat(QuantizedModule):
self.output_dtype = dtype

def forward(self, inps: Iterable[Tensor], axis: int = 0):
new_inps = (x.astype(self.output_dtype) for x in inps)
new_inps = tuple(x.astype(self.output_dtype) for x in inps)
return F.concat(new_inps, axis)

@classmethod


+ 10
- 2
imperative/python/megengine/quantization/__init__.py View File

@@ -6,8 +6,16 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

from .fake_quant import FakeQuantize
from .observer import Observer
from .fake_quant import TQT, FakeQuantize
from .observer import (
ExponentialMovingAverageObserver,
HistogramObserver,
MinMaxObserver,
Observer,
PassiveObserver,
SyncExponentialMovingAverageObserver,
SyncMinMaxObserver,
)
from .qconfig import (
QConfig,
calibration_qconfig,


+ 45
- 1
imperative/python/test/unit/quantization/test_module.py View File

@@ -30,7 +30,10 @@ min_max_fakequant_qconfig = QConfig(
act_fake_quant=partial(FakeQuantize, dtype="qint8"),
)

inp_scale = np.float32(np.random.rand() + 1)

def gen_inp_scale():
return np.float32(np.random.rand() + 1)


min_val = np.random.randint(-127, 0, size=(2,)).astype("float32")
max_val = np.random.randint(1, 127, size=(2,)).astype("float32")
@@ -116,6 +119,7 @@ def test_dequant_stub():
q_net.eval()

x = mge.tensor(np.random.normal(size=(3, 3)).astype("float32"))
inp_scale = gen_inp_scale()
x = fake_quant_act(x, inp_scale)
x.qparams.scale = inp_scale

@@ -192,6 +196,7 @@ def test_linear():
init_qat_net(qat_net)

x = mge.tensor(np.random.normal(size=(3, 3)).astype("float32"))
inp_scale = gen_inp_scale()
x = fake_quant_act(x, inp_scale)
x.qparams.update(create_qparams(QuantMode.SYMMERTIC, "qint8", inp_scale))

@@ -235,6 +240,7 @@ def test_conv(module):
init_qat_net(qat_net)

x = mge.tensor(np.random.normal(size=(1, 3, 3, 3)).astype("float32"))
inp_scale = gen_inp_scale()
x = fake_quant_act(x, inp_scale)
x.qparams.update(create_qparams(QuantMode.SYMMERTIC, "qint8", inp_scale))

@@ -269,3 +275,41 @@ def test_conv(module):
np.testing.assert_allclose(qat_without_fakequant, normal, atol=1e-5)
np.testing.assert_allclose(qat, fake_quant_normal, atol=act_scale)
np.testing.assert_allclose(q, fake_quant_normal.numpy(), atol=act_scale)


def test_concat():
normal_net = Float.Concat()
normal_net.eval()

qat_net = QAT.Concat()
qat_net.eval()
disable_observer(qat_net)

propagate_qconfig(qat_net, min_max_fakequant_qconfig)
init_qat_net(qat_net)

inps = []
inps_int8 = []
for i in range(3):
inp_scale = gen_inp_scale()
inps.append(mge.tensor(np.random.normal(size=(3, 3)).astype("float32")))
inps[i] = fake_quant_act(inps[i], inp_scale)
inps[i].qparams.update(create_qparams(QuantMode.SYMMERTIC, "qint8", inp_scale))
inps_int8.append(quant(inps[i], inp_scale))

qat_from_float = QAT.Concat.from_float_module(normal_net)
qat_from_float.eval()
disable_fake_quant(qat_from_float)
disable_observer(qat_from_float)

q_net = Q.Concat.from_qat_module(qat_net)
q_net.eval()

normal = normal_net(inps)
qat_without_fakequant = qat_from_float(inps)
fake_quant_normal = fake_quant_act(normal_net(inps), act_scale)
qat = qat_net(inps)
q = q_net(inps_int8).numpy() * act_scale
np.testing.assert_allclose(qat_without_fakequant, normal)
np.testing.assert_allclose(qat, fake_quant_normal.numpy())
np.testing.assert_allclose(q, fake_quant_normal.numpy())

Loading…
Cancel
Save