|
|
@@ -6,11 +6,11 @@ |
|
|
|
# software distributed under the License is distributed on an |
|
|
|
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
|
|
import weakref |
|
|
|
from collections import OrderedDict |
|
|
|
from typing import Callable, Iterable, List, Union |
|
|
|
|
|
|
|
from ..core._imperative_rt.core2 import pop_scope, push_scope, set_option |
|
|
|
from ..core.autodiff.grad import Grad |
|
|
|
from ..core.tensor.dtype import is_differentible_dtype |
|
|
|
from ..logger import get_logger |
|
|
|
from ..tensor import Tensor |
|
|
|
from ..utils.future import Future |
|
|
@@ -208,6 +208,10 @@ class GradManager: |
|
|
|
|
|
|
|
for x in tensors: |
|
|
|
assert isinstance(x, Tensor), "Object to be attached should be Tensor" |
|
|
|
assert is_differentible_dtype(x.dtype), ( |
|
|
|
"Only tensors of floating point dtype can be attached to get gradients, " |
|
|
|
"get tensor dtype: {} and shape: {}".format(x.dtype, x.shape) |
|
|
|
) |
|
|
|
spec = self._attach_specs.get(id(x)) |
|
|
|
new_attach = spec is None |
|
|
|
if spec is None: |
|
|
|