|
|
@@ -100,21 +100,6 @@ class GradManager: |
|
|
|
self._record_param(id(p)) |
|
|
|
return self |
|
|
|
|
|
|
|
def detach(self, params: list): |
|
|
|
r""" |
|
|
|
Remove specific registered parameters and callback functions. |
|
|
|
|
|
|
|
:param params: registered parameters |
|
|
|
""" |
|
|
|
if isinstance(params, Tensor): |
|
|
|
params = [params] |
|
|
|
for idx, param in enumerate(params): |
|
|
|
if id(param) in self._param_dict: |
|
|
|
self._param_dict.pop(id(param)) |
|
|
|
self._call_back_dict.pop(id(param)) |
|
|
|
else: |
|
|
|
logger.warning("params with index {} is not attached.".format(idx)) |
|
|
|
|
|
|
|
def _register_after_backward_callback(self, callback): |
|
|
|
self._after_backward_callback.append(callback) |
|
|
|
return self |
|
|
|