GitOrigin-RevId: f0aaea99b9
tags/v1.9.0
@@ -6,11 +6,11 @@ | |||||
# software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
import weakref | import weakref | ||||
from collections import OrderedDict | |||||
from typing import Callable, Iterable, List, Union | from typing import Callable, Iterable, List, Union | ||||
from ..core._imperative_rt.core2 import pop_scope, push_scope, set_option | from ..core._imperative_rt.core2 import pop_scope, push_scope, set_option | ||||
from ..core.autodiff.grad import Grad | from ..core.autodiff.grad import Grad | ||||
from ..core.tensor.dtype import is_differentible_dtype | |||||
from ..logger import get_logger | from ..logger import get_logger | ||||
from ..tensor import Tensor | from ..tensor import Tensor | ||||
from ..utils.future import Future | from ..utils.future import Future | ||||
@@ -208,6 +208,10 @@ class GradManager: | |||||
for x in tensors: | for x in tensors: | ||||
assert isinstance(x, Tensor), "Object to be attached should be Tensor" | assert isinstance(x, Tensor), "Object to be attached should be Tensor" | ||||
assert is_differentible_dtype(x.dtype), ( | |||||
"Only tensors of floating point dtype can be attached to get gradients, " | |||||
"get tensor dtype: {} and shape: {}".format(x.dtype, x.shape) | |||||
) | |||||
spec = self._attach_specs.get(id(x)) | spec = self._attach_specs.get(id(x)) | ||||
new_attach = spec is None | new_attach = spec is None | ||||
if spec is None: | if spec is None: | ||||
@@ -38,6 +38,10 @@ def is_bfloat16(dtype): | |||||
return dtype is bfloat16 | return dtype is bfloat16 | ||||
def is_differentible_dtype(dtype): | |||||
return dtype == np.float32 or dtype == np.float16 or is_bfloat16(dtype) | |||||
# quantization dtype related | # quantization dtype related | ||||
# use namedtuple to make class immutable, comparable and easy to print | # use namedtuple to make class immutable, comparable and easy to print | ||||
@@ -114,7 +118,7 @@ def create_quantized_dtype( | |||||
dtype_meta: QuantDtypeMeta, scale: float, zp: Union[int, None] | dtype_meta: QuantDtypeMeta, scale: float, zp: Union[int, None] | ||||
): | ): | ||||
r"""Get quantized dtype with metadata attribute according to _metadata_dict. | r"""Get quantized dtype with metadata attribute according to _metadata_dict. | ||||
Note that unsigned dtype must have ``zero_point`` and signed dtype must | Note that unsigned dtype must have ``zero_point`` and signed dtype must | ||||
not have ``zero_point``, to be consitent with tensor generated by calling | not have ``zero_point``, to be consitent with tensor generated by calling | ||||
compiled function from `CompGraph.compile(inputs, outspec)`. | compiled function from `CompGraph.compile(inputs, outspec)`. | ||||
@@ -13,6 +13,7 @@ import numpy as np | |||||
import pytest | import pytest | ||||
import megengine as mge | import megengine as mge | ||||
import megengine.core.tensor.dtype as dtype | |||||
import megengine.distributed as dist | import megengine.distributed as dist | ||||
import megengine.functional as F | import megengine.functional as F | ||||
import megengine.module as M | import megengine.module as M | ||||
@@ -469,3 +470,18 @@ def test_2nd_grad_with_custom_gradient(): | |||||
np.testing.assert_almost_equal( | np.testing.assert_almost_equal( | ||||
x.grad.numpy(), -np.sin(x_np) - np.cos(x_np), decimal=5 | x.grad.numpy(), -np.sin(x_np) - np.cos(x_np), decimal=5 | ||||
) | ) | ||||
@pytest.mark.parametrize("invalid_dtype", [np.uint8, np.int8, np.int32]) | |||||
def test_attach_invalid_tensor_dtype(invalid_dtype): | |||||
gm = GradManager() | |||||
x = mge.tensor([1], dtype=invalid_dtype) | |||||
with pytest.raises(AssertionError): | |||||
gm.attach([x]) | |||||
@pytest.mark.parametrize("differentible_dtype", [np.float32, np.float16]) | |||||
def test_attach_differentible_tensor_dtype(differentible_dtype): | |||||
gm = GradManager() | |||||
x = mge.tensor([1], dtype=differentible_dtype) | |||||
gm.attach([x]) |