Browse Source

fix(mge/autodiff): check tensors to be attached

GitOrigin-RevId: c4f3f80876
release-1.4
Megvii Engine Team 4 years ago
parent
commit
1ecf2ab4ec
1 changed files with 3 additions and 2 deletions
  1. +3
    -2
      imperative/python/megengine/autodiff/grad_manager.py

+ 3
- 2
imperative/python/megengine/autodiff/grad_manager.py View File

@@ -1,7 +1,7 @@
import weakref
from collections import defaultdict
from contextlib import contextmanager
from typing import Callable
from typing import Callable, Iterable

from ..core._imperative_rt.core2 import pop_scope, push_scope, set_option
from ..core.autodiff.grad import Grad
@@ -121,7 +121,7 @@ class GradManager:
self._after_backward_callback = []
self._gradients = {}

def attach(self, tensors: list, callbacks=None):
def attach(self, tensors: Iterable[Tensor], callbacks=None):
r"""
Instruct GradManager to track operations on tensors, so that gradients with respect
to those tensors could be evaluated later.
@@ -199,6 +199,7 @@ class GradManager:
return spec

for x in tensors:
assert isinstance(x, Tensor), "Object to be attached should be Tensor"
spec = self._attach_specs.get(id(x))
new_attach = spec is None
if spec is None:


Loading…
Cancel
Save