Browse Source

fix(mge/gm): fix missing dtype checking while attach tensors

GitOrigin-RevId: f0aaea99b9
tags/v1.9.0
Megvii Engine Team 3 years ago
parent
commit
87f00232f2
3 changed files with 26 additions and 2 deletions
  1. +5
    -1
      imperative/python/megengine/autodiff/grad_manager.py
  2. +5
    -1
      imperative/python/megengine/core/tensor/dtype.py
  3. +16
    -0
      imperative/python/test/unit/autodiff/test_grad_manger.py

+ 5
- 1
imperative/python/megengine/autodiff/grad_manager.py View File

@@ -6,11 +6,11 @@
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import weakref import weakref
from collections import OrderedDict
from typing import Callable, Iterable, List, Union from typing import Callable, Iterable, List, Union


from ..core._imperative_rt.core2 import pop_scope, push_scope, set_option from ..core._imperative_rt.core2 import pop_scope, push_scope, set_option
from ..core.autodiff.grad import Grad from ..core.autodiff.grad import Grad
from ..core.tensor.dtype import is_differentible_dtype
from ..logger import get_logger from ..logger import get_logger
from ..tensor import Tensor from ..tensor import Tensor
from ..utils.future import Future from ..utils.future import Future
@@ -208,6 +208,10 @@ class GradManager:


for x in tensors: for x in tensors:
assert isinstance(x, Tensor), "Object to be attached should be Tensor" assert isinstance(x, Tensor), "Object to be attached should be Tensor"
assert is_differentible_dtype(x.dtype), (
"Only tensors of floating point dtype can be attached to get gradients, "
"get tensor dtype: {} and shape: {}".format(x.dtype, x.shape)
)
spec = self._attach_specs.get(id(x)) spec = self._attach_specs.get(id(x))
new_attach = spec is None new_attach = spec is None
if spec is None: if spec is None:


+ 5
- 1
imperative/python/megengine/core/tensor/dtype.py View File

@@ -38,6 +38,10 @@ def is_bfloat16(dtype):
return dtype is bfloat16 return dtype is bfloat16




def is_differentible_dtype(dtype):
return dtype == np.float32 or dtype == np.float16 or is_bfloat16(dtype)


# quantization dtype related # quantization dtype related


# use namedtuple to make class immutable, comparable and easy to print # use namedtuple to make class immutable, comparable and easy to print
@@ -114,7 +118,7 @@ def create_quantized_dtype(
dtype_meta: QuantDtypeMeta, scale: float, zp: Union[int, None] dtype_meta: QuantDtypeMeta, scale: float, zp: Union[int, None]
): ):
r"""Get quantized dtype with metadata attribute according to _metadata_dict. r"""Get quantized dtype with metadata attribute according to _metadata_dict.
Note that unsigned dtype must have ``zero_point`` and signed dtype must Note that unsigned dtype must have ``zero_point`` and signed dtype must
not have ``zero_point``, to be consitent with tensor generated by calling not have ``zero_point``, to be consitent with tensor generated by calling
compiled function from `CompGraph.compile(inputs, outspec)`. compiled function from `CompGraph.compile(inputs, outspec)`.


+ 16
- 0
imperative/python/test/unit/autodiff/test_grad_manger.py View File

@@ -13,6 +13,7 @@ import numpy as np
import pytest import pytest


import megengine as mge import megengine as mge
import megengine.core.tensor.dtype as dtype
import megengine.distributed as dist import megengine.distributed as dist
import megengine.functional as F import megengine.functional as F
import megengine.module as M import megengine.module as M
@@ -469,3 +470,18 @@ def test_2nd_grad_with_custom_gradient():
np.testing.assert_almost_equal( np.testing.assert_almost_equal(
x.grad.numpy(), -np.sin(x_np) - np.cos(x_np), decimal=5 x.grad.numpy(), -np.sin(x_np) - np.cos(x_np), decimal=5
) )


@pytest.mark.parametrize("invalid_dtype", [np.uint8, np.int8, np.int32])
def test_attach_invalid_tensor_dtype(invalid_dtype):
gm = GradManager()
x = mge.tensor([1], dtype=invalid_dtype)
with pytest.raises(AssertionError):
gm.attach([x])


@pytest.mark.parametrize("differentible_dtype", [np.float32, np.float16])
def test_attach_differentible_tensor_dtype(differentible_dtype):
gm = GradManager()
x = mge.tensor([1], dtype=differentible_dtype)
gm.attach([x])

Loading…
Cancel
Save