GitOrigin-RevId: 9d09c8fa6f
release-1.1
@@ -14,7 +14,7 @@ import numpy as np | |||||
from ..ops import builtin | from ..ops import builtin | ||||
from ..ops.special import Const | from ..ops.special import Const | ||||
from ..tensor.core import OpBase, TensorBase, TensorWrapperBase, apply | from ..tensor.core import OpBase, TensorBase, TensorWrapperBase, apply | ||||
from .dtype import is_equal | |||||
from .dtype import is_equal, is_quantize | |||||
def dtype_promotion(inputs): | def dtype_promotion(inputs): | ||||
@@ -122,7 +122,7 @@ def convert_single_value(v, inputs, *, dtype=None, device=None): | |||||
tensors = [i for i in inputs if isinstance(i, (TensorBase, TensorWrapperBase))] | tensors = [i for i in inputs if isinstance(i, (TensorBase, TensorWrapperBase))] | ||||
assert len(tensors) > 0 | assert len(tensors) > 0 | ||||
if isinstance(v, (TensorWrapperBase, TensorBase)): | if isinstance(v, (TensorWrapperBase, TensorBase)): | ||||
v = astype(v, dtype) | |||||
v = astype(v, v.dtype if is_quantize(v.dtype) else dtype) | |||||
else: | else: | ||||
(v,) = Const(v, dtype=dtype, device=device)(*tensors) | (v,) = Const(v, dtype=dtype, device=device)(*tensors) | ||||
return v | return v | ||||
@@ -12,7 +12,6 @@ from typing import Any, Callable, Iterable, Optional, Set, Tuple, Union | |||||
import numpy as np | import numpy as np | ||||
from ..core.tensor.dtype import is_quantize | |||||
from ..core.tensor.utils import make_shape_tuple | from ..core.tensor.utils import make_shape_tuple | ||||
from ..logger import get_logger | from ..logger import get_logger | ||||
from ..tensor import Parameter, Tensor | from ..tensor import Parameter, Tensor | ||||
@@ -529,11 +528,7 @@ class Module(metaclass=ABCMeta): | |||||
), "param `{}` shape mismatch, should be {}, get {}".format( | ), "param `{}` shape mismatch, should be {}, get {}".format( | ||||
k, var.shape, to_be_load.shape | k, var.shape, to_be_load.shape | ||||
) | ) | ||||
# For quantized dtype, the initialized dtype | |||||
# scale/zero_points maybe invalid, use pretrained dtype instead. | |||||
if is_quantize(to_be_load.dtype) and is_quantize(var.dtype): | |||||
var = var.astype(to_be_load.dtype) | |||||
var._reset(to_be_load) | |||||
var._reset(type(var)(to_be_load, dtype=to_be_load.dtype, device=var.device)) | |||||
loaded.append(k) | loaded.append(k) | ||||
return set(loaded), set(skipped) | return set(loaded), set(skipped) | ||||
@@ -10,6 +10,7 @@ import numpy as np | |||||
import megengine.functional as F | import megengine.functional as F | ||||
from megengine import tensor | from megengine import tensor | ||||
from megengine.core.tensor import dtype | |||||
from megengine.functional.elemwise import _elwise | from megengine.functional.elemwise import _elwise | ||||
@@ -150,3 +151,18 @@ def test_logical_oprs(): | |||||
np.testing.assert_equal(x & y, F.logical_and(xx, yy).numpy()) | np.testing.assert_equal(x & y, F.logical_and(xx, yy).numpy()) | ||||
np.testing.assert_equal(x | y, F.logical_or(xx, yy).numpy()) | np.testing.assert_equal(x | y, F.logical_or(xx, yy).numpy()) | ||||
np.testing.assert_equal(x ^ y, F.logical_xor(xx, yy).numpy()) | np.testing.assert_equal(x ^ y, F.logical_xor(xx, yy).numpy()) | ||||
def test_qadd(): | |||||
inp_scale = 0.5 | |||||
outp_scale = 0.2 | |||||
x = np.arange(6).reshape(2, 3).astype("float32") | |||||
y = np.arange(6).reshape(2, 3).astype("float32") | |||||
x = tensor(x, dtype=dtype.qint8(inp_scale)) | |||||
y = tensor(y, dtype=dtype.qint8(inp_scale)) | |||||
result_mge = F.elemwise._elemwise_multi_type( | |||||
x, y, mode="QADD", dtype=dtype.qint8(outp_scale) | |||||
) | |||||
result_mge = result_mge.astype("float32").numpy() | |||||
result_expect = x.astype("float32").numpy() + y.astype("float32").numpy() | |||||
np.testing.assert_almost_equal(result_mge, result_expect, decimal=6) |