Browse Source

fix(mge/gm): fix missing dtype checking while attach tensors

GitOrigin-RevId: f0aaea99b9
tags/v1.9.0
Megvii Engine Team 3 years ago
parent
commit
87f00232f2
3 changed files with 26 additions and 2 deletions
  1. +5
    -1
      imperative/python/megengine/autodiff/grad_manager.py
  2. +5
    -1
      imperative/python/megengine/core/tensor/dtype.py
  3. +16
    -0
      imperative/python/test/unit/autodiff/test_grad_manger.py

+ 5
- 1
imperative/python/megengine/autodiff/grad_manager.py View File

@@ -6,11 +6,11 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import weakref
from collections import OrderedDict
from typing import Callable, Iterable, List, Union

from ..core._imperative_rt.core2 import pop_scope, push_scope, set_option
from ..core.autodiff.grad import Grad
from ..core.tensor.dtype import is_differentible_dtype
from ..logger import get_logger
from ..tensor import Tensor
from ..utils.future import Future
@@ -208,6 +208,10 @@ class GradManager:

for x in tensors:
assert isinstance(x, Tensor), "Object to be attached should be Tensor"
assert is_differentible_dtype(x.dtype), (
"Only tensors of floating point dtype can be attached to get gradients, "
"get tensor dtype: {} and shape: {}".format(x.dtype, x.shape)
)
spec = self._attach_specs.get(id(x))
new_attach = spec is None
if spec is None:


+ 5
- 1
imperative/python/megengine/core/tensor/dtype.py View File

@@ -38,6 +38,10 @@ def is_bfloat16(dtype):
return dtype is bfloat16


def is_differentible_dtype(dtype):
return dtype == np.float32 or dtype == np.float16 or is_bfloat16(dtype)


# quantization dtype related

# use namedtuple to make class immutable, comparable and easy to print
@@ -114,7 +118,7 @@ def create_quantized_dtype(
dtype_meta: QuantDtypeMeta, scale: float, zp: Union[int, None]
):
r"""Get quantized dtype with metadata attribute according to _metadata_dict.
Note that unsigned dtype must have ``zero_point`` and signed dtype must
not have ``zero_point``, to be consitent with tensor generated by calling
compiled function from `CompGraph.compile(inputs, outspec)`.


+ 16
- 0
imperative/python/test/unit/autodiff/test_grad_manger.py View File

@@ -13,6 +13,7 @@ import numpy as np
import pytest

import megengine as mge
import megengine.core.tensor.dtype as dtype
import megengine.distributed as dist
import megengine.functional as F
import megengine.module as M
@@ -469,3 +470,18 @@ def test_2nd_grad_with_custom_gradient():
np.testing.assert_almost_equal(
x.grad.numpy(), -np.sin(x_np) - np.cos(x_np), decimal=5
)


@pytest.mark.parametrize("invalid_dtype", [np.uint8, np.int8, np.int32])
def test_attach_invalid_tensor_dtype(invalid_dtype):
gm = GradManager()
x = mge.tensor([1], dtype=invalid_dtype)
with pytest.raises(AssertionError):
gm.attach([x])


@pytest.mark.parametrize("differentible_dtype", [np.float32, np.float16])
def test_attach_differentible_tensor_dtype(differentible_dtype):
gm = GradManager()
x = mge.tensor([1], dtype=differentible_dtype)
gm.attach([x])

Loading…
Cancel
Save