Browse Source

fix(mge/autodiff): check tensors to be attached

GitOrigin-RevId: c4f3f80876
release-1.4
Megvii Engine Team 4 years ago
parent
commit
1ecf2ab4ec
1 changed files with 3 additions and 2 deletions
  1. +3
    -2
      imperative/python/megengine/autodiff/grad_manager.py

+ 3
- 2
imperative/python/megengine/autodiff/grad_manager.py View File

@@ -1,7 +1,7 @@
import weakref import weakref
from collections import defaultdict from collections import defaultdict
from contextlib import contextmanager from contextlib import contextmanager
from typing import Callable
from typing import Callable, Iterable


from ..core._imperative_rt.core2 import pop_scope, push_scope, set_option from ..core._imperative_rt.core2 import pop_scope, push_scope, set_option
from ..core.autodiff.grad import Grad from ..core.autodiff.grad import Grad
@@ -121,7 +121,7 @@ class GradManager:
self._after_backward_callback = [] self._after_backward_callback = []
self._gradients = {} self._gradients = {}


def attach(self, tensors: list, callbacks=None):
def attach(self, tensors: Iterable[Tensor], callbacks=None):
r""" r"""
Instruct GradManager to track operations on tensors, so that gradients with respect Instruct GradManager to track operations on tensors, so that gradients with respect
to those tensors could be evaluated later. to those tensors could be evaluated later.
@@ -199,6 +199,7 @@ class GradManager:
return spec return spec


for x in tensors: for x in tensors:
assert isinstance(x, Tensor), "Object to be attached should be Tensor"
spec = self._attach_specs.get(id(x)) spec = self._attach_specs.get(id(x))
new_attach = spec is None new_attach = spec is None
if spec is None: if spec is None:


Loading…
Cancel
Save