GitOrigin-RevId: 383458acbf
release-1.1
@@ -62,6 +62,21 @@ def get_zero_point(dtype): | |||||
return metadata["zero_point"] | return metadata["zero_point"] | ||||
def is_equal(dt0, dt1): | |||||
def _get_zero_point(dtype): | |||||
assert is_quantize(dtype) | |||||
metadata = dtype.metadata["mgb_dtype"] | |||||
return metadata.get("zero_point") | |||||
if is_quantize(dt0) and is_quantize(dt1): | |||||
return get_scale(dt0) == get_scale(dt1) and _get_zero_point( | |||||
dt0 | |||||
) == _get_zero_point(dt1) | |||||
if not (is_quantize(dt0) or is_quantize(dt1)): | |||||
return dt0 == dt1 | |||||
return False | |||||
def _check_zero_point(zp: int, dtype_str: str): | def _check_zero_point(zp: int, dtype_str: str): | ||||
qmin = _metadata_dict[dtype_str].qmin | qmin = _metadata_dict[dtype_str].qmin | ||||
qmax = _metadata_dict[dtype_str].qmax | qmax = _metadata_dict[dtype_str].qmax | ||||
@@ -14,6 +14,7 @@ import numpy as np | |||||
from ..ops import builtin | from ..ops import builtin | ||||
from ..ops.special import Const | from ..ops.special import Const | ||||
from ..tensor.core import OpBase, TensorBase, TensorWrapperBase, apply | from ..tensor.core import OpBase, TensorBase, TensorWrapperBase, apply | ||||
from .dtype import is_equal | |||||
def dtype_promotion(inputs): | def dtype_promotion(inputs): | ||||
@@ -112,7 +113,7 @@ def concatenate(inputs, axis=0, *, device=None): | |||||
def astype(x, dtype): | def astype(x, dtype): | ||||
dtype = np.dtype(dtype) | dtype = np.dtype(dtype) | ||||
if x.dtype != dtype: | |||||
if not is_equal(x.dtype, dtype): | |||||
(x,) = apply(builtin.TypeCvt(param=dtype), x) | (x,) = apply(builtin.TypeCvt(param=dtype), x) | ||||
return x | return x | ||||
@@ -8,6 +8,7 @@ | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
import numpy as np | import numpy as np | ||||
from megengine.core.tensor.dtype import get_scale, get_zero_point, qint8, quint8 | |||||
from megengine.core.tensor.tensor_wrapper import TensorWrapper | from megengine.core.tensor.tensor_wrapper import TensorWrapper | ||||
@@ -71,3 +72,17 @@ def test_transpose(): | |||||
x = np.random.rand(2, 5).astype("float32") | x = np.random.rand(2, 5).astype("float32") | ||||
xx = TensorWrapper(x) | xx = TensorWrapper(x) | ||||
np.testing.assert_almost_equal(xx.T.numpy(), x.T) | np.testing.assert_almost_equal(xx.T.numpy(), x.T) | ||||
def test_as_type(): | |||||
x = TensorWrapper([1, 2, 3], dtype=np.float32) | |||||
y = x.astype(qint8(0.1)) | |||||
np.testing.assert_almost_equal(get_scale(y.dtype), 0.1) | |||||
z = y.astype(qint8(0.2)) | |||||
np.testing.assert_almost_equal(get_scale(z.dtype), 0.2) | |||||
a = z.astype(quint8(0.3, 127)) | |||||
np.testing.assert_almost_equal(get_scale(a.dtype), 0.3) | |||||
np.testing.assert_equal(get_zero_point(a.dtype), 127) | |||||
b = a.astype(quint8(0.3, 128)) | |||||
np.testing.assert_almost_equal(get_scale(b.dtype), 0.3) | |||||
np.testing.assert_equal(get_zero_point(b.dtype), 128) |