Browse Source

fix(mge/imperative): fix tensor astype failed for quantized type

GitOrigin-RevId: 383458acbf
release-1.1
Megvii Engine Team 4 years ago
parent
commit
be511a56f9
3 changed files with 32 additions and 1 deletions
  1. +15
    -0
      imperative/python/megengine/core/tensor/dtype.py
  2. +2
    -1
      imperative/python/megengine/core/tensor/utils.py
  3. +15
    -0
      imperative/python/test/unit/core/test_tensor_wrapper.py

+ 15
- 0
imperative/python/megengine/core/tensor/dtype.py View File

@@ -62,6 +62,21 @@ def get_zero_point(dtype):
return metadata["zero_point"] return metadata["zero_point"]




def is_equal(dt0, dt1):
def _get_zero_point(dtype):
assert is_quantize(dtype)
metadata = dtype.metadata["mgb_dtype"]
return metadata.get("zero_point")

if is_quantize(dt0) and is_quantize(dt1):
return get_scale(dt0) == get_scale(dt1) and _get_zero_point(
dt0
) == _get_zero_point(dt1)
if not (is_quantize(dt0) or is_quantize(dt1)):
return dt0 == dt1
return False


def _check_zero_point(zp: int, dtype_str: str): def _check_zero_point(zp: int, dtype_str: str):
qmin = _metadata_dict[dtype_str].qmin qmin = _metadata_dict[dtype_str].qmin
qmax = _metadata_dict[dtype_str].qmax qmax = _metadata_dict[dtype_str].qmax


+ 2
- 1
imperative/python/megengine/core/tensor/utils.py View File

@@ -14,6 +14,7 @@ import numpy as np
from ..ops import builtin from ..ops import builtin
from ..ops.special import Const from ..ops.special import Const
from ..tensor.core import OpBase, TensorBase, TensorWrapperBase, apply from ..tensor.core import OpBase, TensorBase, TensorWrapperBase, apply
from .dtype import is_equal




def dtype_promotion(inputs): def dtype_promotion(inputs):
@@ -112,7 +113,7 @@ def concatenate(inputs, axis=0, *, device=None):


def astype(x, dtype): def astype(x, dtype):
dtype = np.dtype(dtype) dtype = np.dtype(dtype)
if x.dtype != dtype:
if not is_equal(x.dtype, dtype):
(x,) = apply(builtin.TypeCvt(param=dtype), x) (x,) = apply(builtin.TypeCvt(param=dtype), x)
return x return x




+ 15
- 0
imperative/python/test/unit/core/test_tensor_wrapper.py View File

@@ -8,6 +8,7 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import numpy as np import numpy as np


from megengine.core.tensor.dtype import get_scale, get_zero_point, qint8, quint8
from megengine.core.tensor.tensor_wrapper import TensorWrapper from megengine.core.tensor.tensor_wrapper import TensorWrapper




@@ -71,3 +72,17 @@ def test_transpose():
x = np.random.rand(2, 5).astype("float32") x = np.random.rand(2, 5).astype("float32")
xx = TensorWrapper(x) xx = TensorWrapper(x)
np.testing.assert_almost_equal(xx.T.numpy(), x.T) np.testing.assert_almost_equal(xx.T.numpy(), x.T)


def test_as_type():
x = TensorWrapper([1, 2, 3], dtype=np.float32)
y = x.astype(qint8(0.1))
np.testing.assert_almost_equal(get_scale(y.dtype), 0.1)
z = y.astype(qint8(0.2))
np.testing.assert_almost_equal(get_scale(z.dtype), 0.2)
a = z.astype(quint8(0.3, 127))
np.testing.assert_almost_equal(get_scale(a.dtype), 0.3)
np.testing.assert_equal(get_zero_point(a.dtype), 127)
b = a.astype(quint8(0.3, 128))
np.testing.assert_almost_equal(get_scale(b.dtype), 0.3)
np.testing.assert_equal(get_zero_point(b.dtype), 128)

Loading…
Cancel
Save