diff --git a/imperative/python/megengine/autodiff/grad_manager.py b/imperative/python/megengine/autodiff/grad_manager.py index 5fab6f14..84b670a8 100644 --- a/imperative/python/megengine/autodiff/grad_manager.py +++ b/imperative/python/megengine/autodiff/grad_manager.py @@ -6,11 +6,11 @@ # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import weakref -from collections import OrderedDict from typing import Callable, Iterable, List, Union from ..core._imperative_rt.core2 import pop_scope, push_scope, set_option from ..core.autodiff.grad import Grad +from ..core.tensor.dtype import is_differentible_dtype from ..logger import get_logger from ..tensor import Tensor from ..utils.future import Future @@ -208,6 +208,10 @@ class GradManager: for x in tensors: assert isinstance(x, Tensor), "Object to be attached should be Tensor" + assert is_differentible_dtype(x.dtype), ( + "Only tensors of floating point dtype can be attached to get gradients, " + "get tensor dtype: {} and shape: {}".format(x.dtype, x.shape) + ) spec = self._attach_specs.get(id(x)) new_attach = spec is None if spec is None: diff --git a/imperative/python/megengine/core/tensor/dtype.py b/imperative/python/megengine/core/tensor/dtype.py index 248cbdaf..0fccae18 100644 --- a/imperative/python/megengine/core/tensor/dtype.py +++ b/imperative/python/megengine/core/tensor/dtype.py @@ -38,6 +38,10 @@ def is_bfloat16(dtype): return dtype is bfloat16 +def is_differentible_dtype(dtype): + return dtype == np.float32 or dtype == np.float16 or is_bfloat16(dtype) + + # quantization dtype related # use namedtuple to make class immutable, comparable and easy to print @@ -114,7 +118,7 @@ def create_quantized_dtype( dtype_meta: QuantDtypeMeta, scale: float, zp: Union[int, None] ): r"""Get quantized dtype with metadata attribute according to _metadata_dict. - + Note that unsigned dtype must have ``zero_point`` and signed dtype must not have ``zero_point``, to be consitent with tensor generated by calling compiled function from `CompGraph.compile(inputs, outspec)`. diff --git a/imperative/python/test/unit/autodiff/test_grad_manger.py b/imperative/python/test/unit/autodiff/test_grad_manger.py index eb1f5a4c..393d689f 100644 --- a/imperative/python/test/unit/autodiff/test_grad_manger.py +++ b/imperative/python/test/unit/autodiff/test_grad_manger.py @@ -13,6 +13,7 @@ import numpy as np import pytest import megengine as mge +import megengine.core.tensor.dtype as dtype import megengine.distributed as dist import megengine.functional as F import megengine.module as M @@ -469,3 +470,18 @@ def test_2nd_grad_with_custom_gradient(): np.testing.assert_almost_equal( x.grad.numpy(), -np.sin(x_np) - np.cos(x_np), decimal=5 ) + + +@pytest.mark.parametrize("invalid_dtype", [np.uint8, np.int8, np.int32]) +def test_attach_invalid_tensor_dtype(invalid_dtype): + gm = GradManager() + x = mge.tensor([1], dtype=invalid_dtype) + with pytest.raises(AssertionError): + gm.attach([x]) + + +@pytest.mark.parametrize("differentible_dtype", [np.float32, np.float16]) +def test_attach_differentible_tensor_dtype(differentible_dtype): + gm = GradManager() + x = mge.tensor([1], dtype=differentible_dtype) + gm.attach([x])