GitOrigin-RevId: 17b28060cc
release-1.2
@@ -29,11 +29,14 @@ class Linear(QuantizedModule): | |||||
inp_scale = dtype.get_scale(inp.dtype) | inp_scale = dtype.get_scale(inp.dtype) | ||||
w_scale = dtype.get_scale(self.weight.dtype) | w_scale = dtype.get_scale(self.weight.dtype) | ||||
bias_dtype = dtype.qint32(inp_scale * w_scale) | bias_dtype = dtype.qint32(inp_scale * w_scale) | ||||
return F.nn.linear( | |||||
ret = F.nn.linear( | |||||
inp, | inp, | ||||
self.weight, | self.weight, | ||||
None if self.bias is None else self.bias.astype(bias_dtype), | None if self.bias is None else self.bias.astype(bias_dtype), | ||||
).astype(self.output_dtype) | |||||
) | |||||
ret = ret if self.output_dtype is None else ret.astype(self.output_dtype) | |||||
return ret | |||||
@classmethod | @classmethod | ||||
def from_qat_module(cls, qat_module: QAT.Linear): | def from_qat_module(cls, qat_module: QAT.Linear): | ||||
@@ -12,6 +12,7 @@ from .observer import HistogramObserver, Observer | |||||
from .qconfig import ( | from .qconfig import ( | ||||
QConfig, | QConfig, | ||||
calibration_qconfig, | calibration_qconfig, | ||||
easyquant_qconfig, | |||||
ema_fakequant_qconfig, | ema_fakequant_qconfig, | ||||
ema_lowbit_fakequant_qconfig, | ema_lowbit_fakequant_qconfig, | ||||
min_max_fakequant_qconfig, | min_max_fakequant_qconfig, | ||||
@@ -138,3 +138,5 @@ passive_qconfig = QConfig( | |||||
weight_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=True), | weight_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=True), | ||||
act_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=False), | act_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=False), | ||||
) | ) | ||||
easyquant_qconfig = passive_qconfig |
@@ -223,11 +223,11 @@ def apply_easy_quant(module, data, start=0.8, stop=1.2, num=40): | |||||
mod._forward_hooks.clear() | mod._forward_hooks.clear() | ||||
fp32_in = [_[:batch_size] for _ in inputs] | |||||
int8_in = [_[batch_size:] for _ in inputs] | |||||
normal_in = [_[:batch_size] for _ in inputs] | |||||
fakequant_in = [_[batch_size:] for _ in inputs] | |||||
disable_fake_quant(mod) | disable_fake_quant(mod) | ||||
fp32_out = mod(*fp32_in) | |||||
normal_out = mod(*normal_in) | |||||
enable_fake_quant(mod) | enable_fake_quant(mod) | ||||
ob = getattr(mod, where) | ob = getattr(mod, where) | ||||
@@ -239,19 +239,15 @@ def apply_easy_quant(module, data, start=0.8, stop=1.2, num=40): | |||||
best_scale = 0 | best_scale = 0 | ||||
for scale in np.linspace(start * orig_scale, stop * orig_scale, num): | for scale in np.linspace(start * orig_scale, stop * orig_scale, num): | ||||
ob.scale = scale | ob.scale = scale | ||||
int8_out = mod(*int8_in) | |||||
dis = get_cosine(fp32_out, int8_out) | |||||
fakequant_out = mod(*fakequant_in) | |||||
dis = get_cosine(normal_out, fakequant_out) | |||||
if dis > distance: | if dis > distance: | ||||
distance = dis | distance = dis | ||||
best_scale = scale | best_scale = scale | ||||
ob.scale = best_scale | ob.scale = best_scale | ||||
if where == "act_observer": | |||||
int8_out = mod(*int8_in) | |||||
return concat([fp32_out, int8_out]) | |||||
else: | |||||
int8_out = outputs[batch_size:] | |||||
return concat([fp32_out, int8_out]) | |||||
fakequant_out = outputs[batch_size:] | |||||
return concat([normal_out, fakequant_out]) | |||||
data = concat([data, data]) | data = concat([data, data]) | ||||
@@ -0,0 +1,203 @@ | |||||
import numpy as np | |||||
import pytest | |||||
import megengine as mge | |||||
import megengine.functional as F | |||||
import megengine.module as Float | |||||
import megengine.module.qat as QAT | |||||
import megengine.module.quantized as Q | |||||
from megengine.core.tensor import dtype | |||||
from megengine.quantization import min_max_fakequant_qconfig | |||||
from megengine.quantization.quantize import disable_observer, propagate_qconfig | |||||
""" | |||||
Calculate testing scales based on ``min_max_fakequant_qconfig`` | |||||
""" | |||||
inp_scale = np.float32(np.random.rand() + 1) | |||||
min_val = np.random.randint(-127, 0, size=(2,)).astype("float32") | |||||
max_val = np.random.randint(1, 127, size=(2,)).astype("float32") | |||||
weight_scale = np.float32(np.max([-min_val[0], max_val[0]]) / 254 * 2) | |||||
act_scale = np.float32(np.max([-min_val[1], max_val[1]]) / 255 * 2) | |||||
def quant(x, scale): | |||||
inp_dtype = dtype.qint8(scale) | |||||
return x.astype(inp_dtype) | |||||
def fake_quant(x, scale): | |||||
x = x / scale | |||||
x = F.round(x) | |||||
x = F.clip(x, -128, 127) | |||||
x = x * scale | |||||
return x | |||||
def init_qat_net(net): | |||||
if net.with_weight: | |||||
net.weight_observer.min_val.set_value(min_val[0]) | |||||
net.weight_observer.max_val.set_value(max_val[0]) | |||||
if net.with_act: | |||||
net.act_observer.min_val.set_value(min_val[1]) | |||||
net.act_observer.max_val.set_value(max_val[1]) | |||||
def test_quant_stub(): | |||||
normal_net = Float.QuantStub() | |||||
normal_net.eval() | |||||
qat_net = QAT.QuantStub() | |||||
qat_net.eval() | |||||
disable_observer(qat_net) | |||||
propagate_qconfig(qat_net, min_max_fakequant_qconfig) | |||||
init_qat_net(qat_net) | |||||
q_net = Q.QuantStub.from_qat_module(qat_net) | |||||
q_net.eval() | |||||
x = mge.tensor(np.random.normal(size=(3, 3)).astype("float32")) | |||||
normal_out = fake_quant(normal_net(x), act_scale) | |||||
qat_out = qat_net(x) | |||||
q_out = q_net(x).numpy() * act_scale | |||||
np.testing.assert_allclose(qat_out, normal_out) | |||||
np.testing.assert_allclose(q_out, normal_out.numpy()) | |||||
def test_dequant_stub(): | |||||
normal_net = Float.DequantStub() | |||||
normal_net.eval() | |||||
qat_net = QAT.DequantStub() | |||||
qat_net.eval() | |||||
disable_observer(qat_net) | |||||
propagate_qconfig(qat_net, min_max_fakequant_qconfig) | |||||
init_qat_net(qat_net) | |||||
q_net = Q.DequantStub.from_qat_module(qat_net) | |||||
q_net.eval() | |||||
x = mge.tensor(np.random.normal(size=(3, 3)).astype("float32")) | |||||
x = fake_quant(x, inp_scale) | |||||
x.q_dict["scale"] = inp_scale | |||||
normal_out = normal_net(x) | |||||
qat_out = qat_net(x) | |||||
q_out = q_net(quant(x, inp_scale)).numpy() | |||||
np.testing.assert_allclose(qat_out, normal_out) | |||||
np.testing.assert_allclose(q_out, normal_out.numpy()) | |||||
@pytest.mark.parametrize("kind", ["COS", "RELU", "ADD", "MUL", "FUSE_ADD_RELU"]) | |||||
def test_elemwise(kind): | |||||
normal_net = Float.Elemwise(kind) | |||||
normal_net.eval() | |||||
qat_net = QAT.Elemwise(kind) | |||||
qat_net.eval() | |||||
disable_observer(qat_net) | |||||
propagate_qconfig(qat_net, min_max_fakequant_qconfig) | |||||
init_qat_net(qat_net) | |||||
q_net = Q.Elemwise.from_qat_module(qat_net) | |||||
q_net.eval() | |||||
x1_scale = np.float32(np.random.rand() + 1) | |||||
x1 = mge.tensor(np.random.normal(size=(3, 3)).astype("float32")) | |||||
x1 = fake_quant(x1, x1_scale) | |||||
x1.q_dict["scale"] = x1_scale | |||||
x2_scale = np.float32(np.random.rand() + 1) | |||||
x2 = mge.tensor(np.random.normal(size=(3, 3)).astype("float32")) | |||||
x2 = fake_quant(x2, x2_scale) | |||||
x2.q_dict["scale"] = x2_scale | |||||
x1_int8 = quant(x1, x1_scale) | |||||
x2_int8 = quant(x2, x2_scale) | |||||
if kind in ("ADD", "MUL", "FUSE_ADD_RELU"): | |||||
normal_out = fake_quant(normal_net(x1, x2), act_scale) | |||||
qat_out = qat_net(x1, x2) | |||||
q_out = q_net(x1_int8, x2_int8).numpy() * act_scale | |||||
else: | |||||
normal_out = fake_quant(normal_net(x1), act_scale) | |||||
qat_out = qat_net(x1) | |||||
q_out = q_net(x1_int8).numpy() * act_scale | |||||
np.testing.assert_allclose(qat_out, normal_out) | |||||
np.testing.assert_allclose(q_out, normal_out.numpy()) | |||||
def test_linear(): | |||||
normal_net = Float.Linear(3, 3, bias=True) | |||||
normal_net.eval() | |||||
qat_net = QAT.Linear(3, 3, bias=True) | |||||
qat_net.eval() | |||||
disable_observer(qat_net) | |||||
propagate_qconfig(qat_net, min_max_fakequant_qconfig) | |||||
init_qat_net(qat_net) | |||||
x = mge.tensor(np.random.normal(size=(3, 3)).astype("float32")) | |||||
x = fake_quant(x, inp_scale) | |||||
x.q_dict["scale"] = inp_scale | |||||
x_int8 = quant(x, inp_scale) | |||||
weight = np.random.normal(size=(3, 3)).astype("float32") | |||||
bias = np.random.normal(size=(3,)).astype("float32") | |||||
normal_net.weight.set_value(fake_quant(weight, weight_scale)) | |||||
normal_net.bias.set_value(fake_quant(bias, inp_scale * weight_scale)) | |||||
qat_net.weight.set_value(weight) | |||||
qat_net.bias.set_value(bias) | |||||
q_net = Q.Linear.from_qat_module(qat_net) | |||||
q_net.eval() | |||||
normal_out = fake_quant(normal_net(x), act_scale) | |||||
qat_out = qat_net(x) | |||||
q_out = q_net(x_int8).numpy() * act_scale | |||||
np.testing.assert_allclose(qat_out, normal_out) | |||||
np.testing.assert_allclose(q_out, normal_out.numpy()) | |||||
@pytest.mark.parametrize("module", ["Conv2d", "ConvBn2d", "ConvBnRelu2d"]) | |||||
def test_conv(module): | |||||
normal_net = getattr(Float, module)(3, 3, 3, 1, 1, 1, bias=True) | |||||
normal_net.eval() | |||||
qat_net = getattr(QAT, module)(3, 3, 3, 1, 1, 1, bias=True) | |||||
qat_net.eval() | |||||
disable_observer(qat_net) | |||||
propagate_qconfig(qat_net, min_max_fakequant_qconfig) | |||||
init_qat_net(qat_net) | |||||
x = mge.tensor(np.random.normal(size=(1, 3, 3, 3)).astype("float32")) | |||||
x = fake_quant(x, inp_scale) | |||||
x.q_dict["scale"] = inp_scale | |||||
x_int8 = quant(x, inp_scale) | |||||
weight = np.random.normal(size=(3, 3, 3, 3)).astype("float32") | |||||
bias = np.random.normal(size=(1, 3, 1, 1)).astype("float32") | |||||
if module in ("ConvBn2d", "ConvBnRelu2d"): | |||||
normal_net.conv.weight.set_value(fake_quant(weight, weight_scale)) | |||||
normal_net.conv.bias.set_value(fake_quant(bias, inp_scale * weight_scale)) | |||||
qat_net.conv.weight.set_value(weight) | |||||
qat_net.conv.bias.set_value(bias) | |||||
else: | |||||
normal_net.weight.set_value(fake_quant(weight, weight_scale)) | |||||
normal_net.bias.set_value(fake_quant(bias, inp_scale * weight_scale)) | |||||
qat_net.weight.set_value(weight) | |||||
qat_net.bias.set_value(bias) | |||||
q_net = getattr(Q, module).from_qat_module(qat_net) | |||||
q_net.eval() | |||||
normal_out = fake_quant(normal_net(x), act_scale) | |||||
qat_out = qat_net(x) | |||||
q_out = q_net(x_int8).numpy() * act_scale | |||||
np.testing.assert_allclose(qat_out, normal_out) | |||||
np.testing.assert_allclose(q_out, normal_out.numpy()) |
@@ -103,7 +103,7 @@ def test_sync_exponential_moving_average_observer(): | |||||
y2 = mge.tensor(x2[rank * 3 : (rank + 1) * 3]) | y2 = mge.tensor(x2[rank * 3 : (rank + 1) * 3]) | ||||
m(y1) | m(y1) | ||||
m(y2) | m(y2) | ||||
np.testing.assert_allclose(m.min_val.numpy(), expected_min) | |||||
np.testing.assert_allclose(m.max_val.numpy(), expected_max) | |||||
np.testing.assert_allclose(m.min_val.numpy(), expected_min, atol=1e-6) | |||||
np.testing.assert_allclose(m.max_val.numpy(), expected_max, atol=1e-6) | |||||
worker() | worker() |
@@ -0,0 +1,161 @@ | |||||
import numpy as np | |||||
import pytest | |||||
import megengine as mge | |||||
import megengine.functional as F | |||||
from megengine.core.tensor import dtype | |||||
from megengine.distributed.helper import get_device_count_by_fork | |||||
from megengine.functional.elemwise import _elemwise_multi_type, _elwise | |||||
def quant(x, scale): | |||||
x_dtype = dtype.qint8(scale) | |||||
return x.astype(x_dtype) | |||||
def fake_quant(x, scale): | |||||
x = x / scale | |||||
x = F.round(x) | |||||
x = F.clip(x, -128, 127) | |||||
x = x * scale | |||||
return x | |||||
@pytest.mark.parametrize("kind", ["ABS", "SIN", "SUB", "MUL", "FUSE_ADD_TANH"]) | |||||
def test_elemwise(kind): | |||||
x1 = mge.tensor(np.random.normal(size=(3, 3)).astype("float32")) | |||||
x1_scale = np.float32(np.random.rand() + 1) | |||||
x1 = fake_quant(x1, x1_scale) | |||||
x1.q_dict["scale"] = x1_scale | |||||
x1_int8 = quant(x1, x1_scale) | |||||
x2 = mge.tensor(np.random.normal(size=(3, 3)).astype("float32")) | |||||
x2_scale = np.float32(np.random.rand() + 1) | |||||
x2 = fake_quant(x2, x2_scale) | |||||
x2.q_dict["scale"] = x2_scale | |||||
x2_int8 = quant(x2, x2_scale) | |||||
output_scale = np.float32(np.random.rand() + 1) | |||||
output_dtype = dtype.qint8(output_scale) | |||||
quantized_kind = "Q" + kind | |||||
if kind in ("ABS", "SIN"): | |||||
desired_out = fake_quant(_elwise(x1, mode=kind), output_scale) | |||||
actual_out = ( | |||||
_elemwise_multi_type( | |||||
x1_int8, mode=quantized_kind, dtype=output_dtype | |||||
).numpy() | |||||
* output_scale | |||||
) | |||||
else: | |||||
desired_out = fake_quant(_elwise(x1, x2, mode=kind), output_scale) | |||||
actual_out = ( | |||||
_elemwise_multi_type( | |||||
x1_int8, x2_int8, mode=quantized_kind, dtype=output_dtype | |||||
).numpy() | |||||
* output_scale | |||||
) | |||||
np.testing.assert_allclose(actual_out, desired_out.numpy()) | |||||
@pytest.mark.skipif( | |||||
get_device_count_by_fork("gpu") > 0, reason="cuda does not support nchw int8" | |||||
) | |||||
def test_conv_bias(): | |||||
inp_scale = np.float32(np.random.rand() + 1) | |||||
w_scale = np.float32(np.random.rand() + 1) | |||||
outp_scale = np.float32(np.random.rand() + 1) | |||||
inp_dtype = dtype.qint8(inp_scale) | |||||
w_dtype = dtype.qint8(w_scale) | |||||
b_dtype = dtype.qint32(inp_scale * w_scale) | |||||
out_dtype = dtype.qint8(outp_scale) | |||||
def run( | |||||
N, | |||||
IC, | |||||
OC, | |||||
IH, | |||||
IW, | |||||
KH, | |||||
KW, | |||||
PH, | |||||
PW, | |||||
SH, | |||||
SW, | |||||
has_bias=True, | |||||
nonlinear_mode="IDENTITY", | |||||
): | |||||
inp_v = np.random.normal(size=(N, IC, IH, IW)) | |||||
w_v = np.random.normal(size=(OC, IC, KH, KW)) | |||||
b_v = np.random.normal(size=(1, OC, 1, 1)) | |||||
inp_scale = dtype.get_scale(inp_dtype) | |||||
w_scale = dtype.get_scale(w_dtype) | |||||
b_scale = dtype.get_scale(b_dtype) | |||||
inpv = dtype.convert_to_qint8(inp_v * inp_scale, inp_dtype) | |||||
wv = dtype.convert_to_qint8(w_v * w_scale, w_dtype) | |||||
bv = dtype.convert_to_qint32(b_v * b_scale, b_dtype) | |||||
inp_int8 = mge.tensor(inpv, dtype=inp_dtype) | |||||
w_int8 = mge.Parameter(wv, dtype=w_dtype) | |||||
b_int32 = mge.Parameter(bv, dtype=b_dtype) | |||||
inp_fp32 = inp_int8.astype("float32") | |||||
w_fp32 = w_int8.astype("float32") | |||||
b_fp32 = b_int32.astype("float32") | |||||
def convert_to_nchw4(var): | |||||
var = F.reshape( | |||||
var, (var.shape[0], var.shape[1] // 4, 4, var.shape[2], var.shape[3]) | |||||
) | |||||
var = F.transpose(var, (0, 1, 3, 4, 2)) | |||||
return var | |||||
def run_conv2d(inp, w, b): | |||||
O = F.conv2d( | |||||
inp, w, b if has_bias else None, stride=(SH, SW), padding=(PH, PW), | |||||
) | |||||
if nonlinear_mode == "RELU": | |||||
return F.relu(O) | |||||
else: | |||||
return O | |||||
def run_conv_bias(inp, w, b, format="NCHW"): | |||||
b = b if has_bias else mge.Parameter(np.zeros_like(b.numpy())) | |||||
if format == "NCHW4": | |||||
inp = convert_to_nchw4(inp) | |||||
w = convert_to_nchw4(w) | |||||
b = convert_to_nchw4(b) | |||||
return F.quantized.conv_bias_activation( | |||||
inp, | |||||
w, | |||||
b, | |||||
stride=(SH, SW), | |||||
padding=(PH, PW), | |||||
dtype=out_dtype, | |||||
nonlinear_mode=nonlinear_mode, | |||||
) | |||||
format = "NCHW4" if mge.is_cuda_available() else "NCHW" | |||||
expected = run_conv2d(inp_fp32, w_fp32, b_fp32) | |||||
expected = expected.astype(out_dtype).astype("float32") | |||||
result = run_conv_bias(inp_int8, w_int8, b_int32, format=format).astype( | |||||
"float32" | |||||
) | |||||
if format == "NCHW4": | |||||
result = F.transpose(result, (0, 1, 4, 2, 3)) | |||||
expected = F.flatten(expected) | |||||
result = F.flatten(result) | |||||
np.testing.assert_allclose(result.numpy(), expected.numpy(), atol=outp_scale) | |||||
run(1, 4, 4, 24, 33, 1, 1, 2, 3, 1, 1, False) | |||||
run(10, 12, 24, 46, 46, 1, 1, 2, 1, 3, 1, False) | |||||
run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, False) | |||||
run(1, 4, 4, 24, 33, 1, 1, 2, 3, 1, 1) | |||||
run(10, 12, 24, 46, 46, 1, 1, 2, 1, 3, 1) | |||||
run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2) | |||||
run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, False, "RELU") | |||||
run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, True, "RELU") |
@@ -88,12 +88,14 @@ def test_propagate_qconfig(): | |||||
def init_qat_net(): | def init_qat_net(): | ||||
net = QATNet() | net = QATNet() | ||||
propagate_qconfig(net, min_max_fakequant_qconfig) | propagate_qconfig(net, min_max_fakequant_qconfig) | ||||
min_val = np.random.randint(-127, 0, size=(2,)) | |||||
max_val = np.random.randint(1, 127, size=(2,)) | |||||
net.linear.weight_observer.min_val.set_value(min_val[0]) | |||||
net.linear.weight_observer.max_val.set_value(max_val[0]) | |||||
net.linear.act_observer.min_val.set_value(min_val[1]) | |||||
net.linear.act_observer.max_val.set_value(max_val[1]) | |||||
min_val = np.random.randint(-127, 0, size=(3,)) | |||||
max_val = np.random.randint(1, 127, size=(3,)) | |||||
net.quant.act_observer.min_val.set_value(min_val[0]) | |||||
net.quant.act_observer.max_val.set_value(max_val[0]) | |||||
net.linear.weight_observer.min_val.set_value(min_val[1]) | |||||
net.linear.weight_observer.max_val.set_value(max_val[1]) | |||||
net.linear.act_observer.min_val.set_value(min_val[2]) | |||||
net.linear.act_observer.max_val.set_value(max_val[2]) | |||||
return net | return net | ||||