Browse Source

fix(mge/core): fix dtype promotion issue for quantized dtype

GitOrigin-RevId: 9d09c8fa6f
release-1.1
Megvii Engine Team 4 years ago
parent
commit
82b0f67770
3 changed files with 19 additions and 8 deletions
  1. +2
    -2
      imperative/python/megengine/core/tensor/utils.py
  2. +1
    -6
      imperative/python/megengine/module/module.py
  3. +16
    -0
      imperative/python/test/unit/functional/test_elemwise.py

+ 2
- 2
imperative/python/megengine/core/tensor/utils.py View File

@@ -14,7 +14,7 @@ import numpy as np
from ..ops import builtin from ..ops import builtin
from ..ops.special import Const from ..ops.special import Const
from ..tensor.core import OpBase, TensorBase, TensorWrapperBase, apply from ..tensor.core import OpBase, TensorBase, TensorWrapperBase, apply
from .dtype import is_equal
from .dtype import is_equal, is_quantize




def dtype_promotion(inputs): def dtype_promotion(inputs):
@@ -122,7 +122,7 @@ def convert_single_value(v, inputs, *, dtype=None, device=None):
tensors = [i for i in inputs if isinstance(i, (TensorBase, TensorWrapperBase))] tensors = [i for i in inputs if isinstance(i, (TensorBase, TensorWrapperBase))]
assert len(tensors) > 0 assert len(tensors) > 0
if isinstance(v, (TensorWrapperBase, TensorBase)): if isinstance(v, (TensorWrapperBase, TensorBase)):
v = astype(v, dtype)
v = astype(v, v.dtype if is_quantize(v.dtype) else dtype)
else: else:
(v,) = Const(v, dtype=dtype, device=device)(*tensors) (v,) = Const(v, dtype=dtype, device=device)(*tensors)
return v return v


+ 1
- 6
imperative/python/megengine/module/module.py View File

@@ -12,7 +12,6 @@ from typing import Any, Callable, Iterable, Optional, Set, Tuple, Union


import numpy as np import numpy as np


from ..core.tensor.dtype import is_quantize
from ..core.tensor.utils import make_shape_tuple from ..core.tensor.utils import make_shape_tuple
from ..logger import get_logger from ..logger import get_logger
from ..tensor import Parameter, Tensor from ..tensor import Parameter, Tensor
@@ -529,11 +528,7 @@ class Module(metaclass=ABCMeta):
), "param `{}` shape mismatch, should be {}, get {}".format( ), "param `{}` shape mismatch, should be {}, get {}".format(
k, var.shape, to_be_load.shape k, var.shape, to_be_load.shape
) )
# For quantized dtype, the initialized dtype
# scale/zero_points maybe invalid, use pretrained dtype instead.
if is_quantize(to_be_load.dtype) and is_quantize(var.dtype):
var = var.astype(to_be_load.dtype)
var._reset(to_be_load)
var._reset(type(var)(to_be_load, dtype=to_be_load.dtype, device=var.device))
loaded.append(k) loaded.append(k)


return set(loaded), set(skipped) return set(loaded), set(skipped)


+ 16
- 0
imperative/python/test/unit/functional/test_elemwise.py View File

@@ -10,6 +10,7 @@ import numpy as np


import megengine.functional as F import megengine.functional as F
from megengine import tensor from megengine import tensor
from megengine.core.tensor import dtype
from megengine.functional.elemwise import _elwise from megengine.functional.elemwise import _elwise




@@ -150,3 +151,18 @@ def test_logical_oprs():
np.testing.assert_equal(x & y, F.logical_and(xx, yy).numpy()) np.testing.assert_equal(x & y, F.logical_and(xx, yy).numpy())
np.testing.assert_equal(x | y, F.logical_or(xx, yy).numpy()) np.testing.assert_equal(x | y, F.logical_or(xx, yy).numpy())
np.testing.assert_equal(x ^ y, F.logical_xor(xx, yy).numpy()) np.testing.assert_equal(x ^ y, F.logical_xor(xx, yy).numpy())


def test_qadd():
inp_scale = 0.5
outp_scale = 0.2
x = np.arange(6).reshape(2, 3).astype("float32")
y = np.arange(6).reshape(2, 3).astype("float32")
x = tensor(x, dtype=dtype.qint8(inp_scale))
y = tensor(y, dtype=dtype.qint8(inp_scale))
result_mge = F.elemwise._elemwise_multi_type(
x, y, mode="QADD", dtype=dtype.qint8(outp_scale)
)
result_mge = result_mge.astype("float32").numpy()
result_expect = x.astype("float32").numpy() + y.astype("float32").numpy()
np.testing.assert_almost_equal(result_mge, result_expect, decimal=6)

Loading…
Cancel
Save