|
|
@@ -1,7 +1,7 @@ |
|
|
|
import weakref |
|
|
|
from collections import defaultdict |
|
|
|
from contextlib import contextmanager |
|
|
|
from typing import Callable |
|
|
|
from typing import Callable, Iterable |
|
|
|
|
|
|
|
from ..core._imperative_rt.core2 import pop_scope, push_scope, set_option |
|
|
|
from ..core.autodiff.grad import Grad |
|
|
@@ -121,7 +121,7 @@ class GradManager: |
|
|
|
self._after_backward_callback = [] |
|
|
|
self._gradients = {} |
|
|
|
|
|
|
|
def attach(self, tensors: list, callbacks=None): |
|
|
|
def attach(self, tensors: Iterable[Tensor], callbacks=None): |
|
|
|
r""" |
|
|
|
Instruct GradManager to track operations on tensors, so that gradients with respect |
|
|
|
to those tensors could be evaluated later. |
|
|
@@ -199,6 +199,7 @@ class GradManager: |
|
|
|
return spec |
|
|
|
|
|
|
|
for x in tensors: |
|
|
|
assert isinstance(x, Tensor), "Object to be attached should be Tensor" |
|
|
|
spec = self._attach_specs.get(id(x)) |
|
|
|
new_attach = spec is None |
|
|
|
if spec is None: |
|
|
|