GitOrigin-RevId: 1c5195176d
release-1.7
@@ -85,54 +85,6 @@ def _check_result_attr(oup, dtype, dtype_str, is_unsigned=True): | |||
np.testing.assert_equal(get_zero_point(oup.dtype), get_zero_point(dtype)) | |||
def test_dtype_int8_ffi_handle(): | |||
device = "xpux" | |||
shape = (3, 3, 3) | |||
data = np.random.random(shape).astype(np.float32) * 5 - 1 | |||
def identity(x): | |||
return x | |||
dtype = quint8(0.01, 127) | |||
inp = convert_to_quint8(data, dtype) | |||
oup = _get_compiled_result(inp, dtype, shape, device, calc_func=identity) | |||
_check_result_attr(oup, dtype, "quint8") | |||
np.testing.assert_allclose(convert_from_quint8(oup), convert_from_quint8(inp)) | |||
dtype = qint8(0.01) | |||
inp = convert_to_qint8(data, dtype) | |||
oup = _get_compiled_result(inp, dtype, shape, device, calc_func=identity) | |||
_check_result_attr(oup, dtype, "qint8", is_unsigned=False) | |||
np.testing.assert_allclose(convert_from_qint8(oup), convert_from_qint8(inp)) | |||
def test_quint8_typecvt(): | |||
device = "xpux" | |||
shape = (3, 3, 3) | |||
data = np.random.random(shape).astype(np.float32) * 5 - 1 | |||
def typecvt(x, dt=None): | |||
(y,) = G.apply_normal_varnode(ops.TypeCvt(dtype=dt), x) | |||
return y | |||
# convert to quint8 | |||
dtype = quint8(0.01, 135) | |||
oup = _get_compiled_result( | |||
data, np.float32, shape, device, calc_func=partial(typecvt, dt=dtype) | |||
) | |||
_check_result_attr(oup, dtype, "quint8") | |||
np.testing.assert_equal(oup, convert_to_quint8(data, dtype)) | |||
# convert from quint8 to float32 | |||
oup_float = _get_compiled_result( | |||
oup, dtype, shape, device, calc_func=partial(typecvt, dt=np.float32) | |||
) | |||
assert oup_float.dtype == np.float32 | |||
np.testing.assert_equal( | |||
oup_float, convert_from_quint8(convert_to_quint8(data, dtype)) | |||
) | |||
def test_dtype_quint4(): | |||
with pytest.raises(ValueError): | |||
blah = quint4(0.05, 0.233) | |||
@@ -161,31 +113,43 @@ def test_dtype_qint4(): | |||
np.testing.assert_allclose(get_scale(dt), 0.01) | |||
def test_dtype_int4_ffi_handle(): | |||
device = "xpux" | |||
shape = (3, 3, 3) | |||
data = np.random.random(shape).astype(np.float32) * 5 - 1 | |||
@pytest.mark.parametrize( | |||
"dtype, dtype_name", | |||
[ | |||
(quint4(0.01, 5), "quint4"), | |||
(qint4(0.01), "qint4"), | |||
(quint8(0.01, 135), "quint8"), | |||
(qint8(0.01), "qint8"), | |||
], | |||
) | |||
def test_dtype_qint_mgb_ffi_handle(dtype, dtype_name): | |||
def identity(x): | |||
return x | |||
dtype = quint4(0.01, 7) | |||
inp = convert_to_quint4(data, dtype) | |||
oup = _get_compiled_result(inp, dtype, shape, device, calc_func=identity) | |||
_check_result_attr(oup, dtype, "quint4") | |||
np.testing.assert_allclose(convert_from_quint4(oup), convert_from_quint4(inp)) | |||
convert_to_dtype = eval("convert_to_%s" % dtype_name) | |||
convert_from_dtype = eval("convert_from_%s" % dtype_name) | |||
device = "xpux" | |||
shape = (3, 3, 3) | |||
data = np.random.random(shape).astype(np.float32) * 5 - 1 | |||
dtype = qint4(0.01) | |||
inp = convert_to_qint4(data, dtype) | |||
inp = convert_to_dtype(data, dtype) | |||
oup = _get_compiled_result(inp, dtype, shape, device, calc_func=identity) | |||
_check_result_attr(oup, dtype, "qint4", is_unsigned=False) | |||
np.testing.assert_allclose(convert_from_qint4(oup), convert_from_qint4(inp)) | |||
@pytest.mark.skipif( | |||
get_device_count("gpu") != 0, reason="TypeCvt to quint4 is not supported on GPU", | |||
_check_result_attr(oup, dtype, dtype_name, dtype_name.startswith("qu")) | |||
np.testing.assert_allclose(convert_from_dtype(oup), convert_from_dtype(inp)) | |||
@pytest.mark.parametrize( | |||
"dtype, dtype_name", | |||
[ | |||
(quint4(0.01, 5), "quint4"), | |||
(qint4(0.01), "qint4"), | |||
(quint8(0.01, 135), "quint8"), | |||
(qint8(0.01), "qint8"), | |||
], | |||
) | |||
def test_quint4_typecvt(): | |||
def test_qint_typecvt(dtype, dtype_name): | |||
convert_to_dtype = eval("convert_to_%s" % dtype_name) | |||
convert_from_dtype = eval("convert_from_%s" % dtype_name) | |||
device = "xpux" | |||
shape = (3, 3, 3) | |||
data = np.random.random(shape).astype(np.float32) * 5 - 1 | |||
@@ -195,12 +159,11 @@ def test_quint4_typecvt(): | |||
return y | |||
# convert to quint4 | |||
dtype = quint4(0.01, 5) | |||
oup = _get_compiled_result( | |||
data, np.float32, shape, device, calc_func=partial(typecvt, dt=dtype) | |||
) | |||
_check_result_attr(oup, dtype, "quint4") | |||
np.testing.assert_equal(oup, convert_to_quint4(data, dtype)) | |||
_check_result_attr(oup, dtype, dtype_name, dtype_name.startswith("qu")) | |||
np.testing.assert_equal(oup, convert_to_dtype(data, dtype)) | |||
# convert from quint4 to float32 | |||
oup_float = _get_compiled_result( | |||
@@ -208,5 +171,62 @@ def test_quint4_typecvt(): | |||
) | |||
assert oup_float.dtype == np.float32 | |||
np.testing.assert_equal( | |||
oup_float, convert_from_quint4(convert_to_quint4(data, dtype)) | |||
oup_float, convert_from_dtype(convert_to_dtype(data, dtype)) | |||
) | |||
@pytest.mark.parametrize( | |||
"dtype, dtype_name", | |||
[ | |||
(quint4(0.01, 5), "quint4"), | |||
(qint4(0.01), "qint4"), | |||
(quint8(0.01, 135), "quint8"), | |||
(qint8(0.01), "qint8"), | |||
], | |||
) | |||
def test_qint_astype(dtype, dtype_name): | |||
convert_to_dtype = eval("convert_to_%s" % dtype_name) | |||
convert_from_dtype = eval("convert_from_%s" % dtype_name) | |||
shape = (3, 3, 3) | |||
data = np.random.random(shape).astype(np.float32) * 5 - 1 | |||
inp = Tensor(data, dtype="float32") | |||
# convert to quint4 | |||
oup = inp.astype(dtype) | |||
_check_result_attr(oup, dtype, dtype_name, dtype_name.startswith("qu")) | |||
np.testing.assert_equal(oup.numpy(), convert_to_dtype(data, dtype)) | |||
# convert from quint4 to float32 | |||
oup_float = oup.astype("float32") | |||
assert oup_float.dtype == np.float32 | |||
np.testing.assert_equal( | |||
oup_float.numpy(), convert_from_dtype(convert_to_dtype(data, dtype)) | |||
) | |||
@pytest.mark.parametrize( | |||
"dtype, dtype_name", | |||
[ | |||
(quint4(0.01, 5), "quint4"), | |||
(qint4(0.01), "qint4"), | |||
(quint8(0.01, 135), "quint8"), | |||
(qint8(0.01), "qint8"), | |||
], | |||
) | |||
def test_qint_new_tensor(dtype, dtype_name): | |||
convert_to_dtype = eval("convert_to_%s" % dtype_name) | |||
convert_from_dtype = eval("convert_from_%s" % dtype_name) | |||
shape = (3, 3, 3) | |||
data = np.random.random(shape).astype(np.float32) * 5 - 1 | |||
# create a new Tensor with quint8 dtype | |||
inp = Tensor(convert_to_dtype(data, dtype), dtype=dtype) | |||
_check_result_attr(inp, dtype, dtype_name, dtype_name.startswith("qu")) | |||
np.testing.assert_equal(inp.numpy(), convert_to_dtype(data, dtype)) | |||
# convert from quint8 to float32 | |||
inp_float = inp.astype("float32") | |||
assert inp_float.dtype == np.float32 | |||
np.testing.assert_equal( | |||
inp_float.numpy(), convert_from_dtype(convert_to_dtype(data, dtype)) | |||
) |
@@ -12,7 +12,7 @@ import pytest | |||
import megengine as mge | |||
import megengine.functional as F | |||
from megengine.core.tensor import dtype | |||
from megengine.device import get_device_count | |||
from megengine.device import get_cuda_compute_capability, get_device_count | |||
from megengine.functional.elemwise import _elemwise_multi_type, _elwise | |||
from megengine.module.quantized.conv import ConvTranspose2d | |||
from megengine.quantization import QuantMode, create_qparams | |||
@@ -171,8 +171,92 @@ def test_conv_bias(): | |||
run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, True, "relu") | |||
@pytest.mark.skip(reason="does not support int4 when cuda version is lower than 10.2") | |||
def test_conv_bias_int4(): | |||
inp_scale = 1.5 | |||
w_scale = 2.5 | |||
outp_scale = 1.5 | |||
inp_dtype = dtype.quint4(inp_scale, 0) | |||
w_dtype = dtype.qint4(w_scale) | |||
b_dtype = dtype.qint32(inp_scale * w_scale) | |||
out_dtype = dtype.quint4(outp_scale, 0) | |||
def run( | |||
N, | |||
IC, | |||
OC, | |||
IH, | |||
IW, | |||
KH, | |||
KW, | |||
PH, | |||
PW, | |||
SH, | |||
SW, | |||
has_bias=True, | |||
nonlinear_mode="identity", | |||
): | |||
inp_v = np.random.normal(size=(N, IC, IH, IW)) | |||
w_v = np.random.normal(size=(OC, IC, KH, KW)) | |||
b_v = np.random.normal(size=(1, OC, 1, 1)) | |||
inp_scale = dtype.get_scale(inp_dtype) | |||
w_scale = dtype.get_scale(w_dtype) | |||
b_scale = dtype.get_scale(b_dtype) | |||
inpv = dtype.convert_to_quint4(inp_v * inp_scale, inp_dtype) | |||
wv = dtype.convert_to_qint4(w_v * w_scale, w_dtype) | |||
bv = dtype.convert_to_qint32(b_v * b_scale, b_dtype) | |||
inp_uint4 = mge.Tensor(inpv, dtype=inp_dtype) | |||
w_int4 = mge.Parameter(wv, dtype=w_dtype) | |||
b_int32 = mge.Parameter(bv, dtype=b_dtype) | |||
inp_fp32 = inp_uint4.astype("float32") | |||
w_fp32 = w_int4.astype("float32") | |||
b_fp32 = b_int32.astype("float32") | |||
def run_conv2d(inp, w, b): | |||
O = F.conv2d( | |||
inp, w, b if has_bias else None, stride=(SH, SW), padding=(PH, PW), | |||
) | |||
if nonlinear_mode == "relu": | |||
return F.relu(O) | |||
else: | |||
return O | |||
def run_conv_bias(inp, w, b): | |||
b = b if has_bias else mge.Parameter(np.zeros_like(b.numpy())) | |||
return F.quantized.conv_bias_activation( | |||
inp, | |||
w, | |||
b, | |||
stride=(SH, SW), | |||
padding=(PH, PW), | |||
dtype=out_dtype, | |||
nonlinear_mode=nonlinear_mode, | |||
) | |||
expected = run_conv2d(inp_fp32, w_fp32, b_fp32) | |||
expected = expected.astype(out_dtype).astype("float32") | |||
result = run_conv_bias(inp_uint4, w_int4, b_int32).astype("float32") | |||
expected = F.flatten(expected) | |||
result = F.flatten(result) | |||
np.testing.assert_allclose(result.numpy(), expected.numpy(), atol=outp_scale) | |||
run(1, 4, 4, 24, 33, 1, 1, 2, 3, 1, 1, False) | |||
run(10, 12, 24, 46, 46, 1, 1, 2, 1, 3, 1, False) | |||
run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, False) | |||
run(1, 4, 4, 24, 33, 1, 1, 2, 3, 1, 1) | |||
run(10, 12, 24, 46, 46, 1, 1, 2, 1, 3, 1) | |||
run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2) | |||
run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, False, "relu") | |||
run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, True, "relu") | |||
@pytest.mark.skipif( | |||
get_device_count("gpu") > 0, | |||
get_cuda_compute_capability(0) < 61, | |||
reason="does not support int8 when gpu compute capability less than 6.1", | |||
) | |||
def test_conv_transpose2d(): | |||
@@ -13,7 +13,11 @@ import megengine.random as rand | |||
from megengine.core._imperative_rt.core2 import apply | |||
from megengine.core._wrap import Device | |||
from megengine.core.ops import builtin | |||
from megengine.device import get_device_count, is_cuda_available | |||
from megengine.device import ( | |||
get_cuda_compute_capability, | |||
get_device_count, | |||
is_cuda_available, | |||
) | |||
from megengine.functional.debug_param import ( | |||
get_execution_strategy, | |||
set_execution_strategy, | |||
@@ -287,7 +291,7 @@ def test_deformable_ps_roi_pooling(): | |||
@pytest.mark.skipif( | |||
get_device_count("gpu") > 0, | |||
get_cuda_compute_capability(0) < 61, | |||
reason="does not support int8 when gpu compute capability less than 6.1", | |||
) | |||
def test_convbias(): | |||
@@ -304,6 +308,27 @@ def test_convbias(): | |||
check_pygraph_dump(fwd, [inp, weight, bias], [result]) | |||
@pytest.mark.skip(reason="does not support int4 when cuda version is lower than 10.2") | |||
def test_conv_bias_int4(): | |||
@trace(symbolic=True, capture_as_const=True) | |||
def fwd(inp, weight, bias): | |||
return F.quantized.conv_bias_activation( | |||
inp, | |||
weight, | |||
bias, | |||
dtype=dtype.quint4(scale=1.0, zero_point=0), | |||
nonlinear_mode="relu", | |||
) | |||
inp = Tensor( | |||
np.random.random((1, 3, 64, 64)), dtype=dtype.quint4(scale=1.0, zero_point=0) | |||
) | |||
weight = Tensor(np.random.random((32, 3, 3, 3)), dtype=dtype.qint4(scale=1.0)) | |||
bias = Tensor(np.random.random((1, 32, 1, 1)), dtype=dtype.qint32(scale=1.0)) | |||
result = fwd(inp, weight, bias) | |||
check_pygraph_dump(fwd, [inp, weight, bias], [result]) | |||
def test_batch_convbias(): | |||
if is_cuda_available(): | |||
return | |||
@@ -616,6 +616,7 @@ public: | |||
descs.emplace_back(); | |||
auto& desc = descs.back(); | |||
desc.layout.dtype = minigraph.output_var(i)->dtype(); | |||
desc.layout.format = minigraph.output_var(i)->format(); | |||
desc.comp_node = minigraph.output_var(i)->comp_node(); | |||
if (auto* shape = sess.infer_shape(i, false)) { | |||
desc.layout.init_contiguous_stride(*shape); | |||
@@ -435,6 +435,11 @@ DEF(DType dtype) : m_layout{dtype} {} | |||
DEF(CompNode node, DType dtype) : m_storage{node}, m_layout{dtype} {} | |||
//! allocate contiguous from given comp node, shape and dtype | |||
DEF(CompNode node, const TensorShape& shape, DType dtype) | |||
: m_storage{node}, m_layout{dtype} { | |||
resize(shape); | |||
} | |||
DEF(CompNode node, const TensorShape& shape, DType dtype, TensorFormat format) | |||
: m_storage{node}, m_layout{dtype, format} { | |||
resize(shape); | |||
@@ -339,8 +339,10 @@ public: | |||
//! allocate contiguous tensor | |||
MGE_WIN_DECLSPEC_FUC TensorND( | |||
CompNode node, const TensorShape& shape, DType dtype = dtype::Float32{}, | |||
TensorFormat format = {}); | |||
CompNode node, const TensorShape& shape, DType dtype = dtype::Float32{}); | |||
MGE_WIN_DECLSPEC_FUC TensorND( | |||
CompNode node, const TensorShape& shape, DType dtype, TensorFormat format); | |||
//! allocate contiguous tensor from given comp node and layout; layout | |||
//! is required to be contiguous, and its dtype and format would be used | |||