|
|
@@ -82,6 +82,7 @@ class Grad: |
|
|
|
# ops forms the computational graph |
|
|
|
self.ops = [] |
|
|
|
|
|
|
|
self._attached_tensors = weakref.WeakSet() |
|
|
|
self._enabled = True |
|
|
|
|
|
|
|
@property |
|
|
@@ -107,6 +108,7 @@ class Grad: |
|
|
|
return self |
|
|
|
|
|
|
|
def _new_variable(self, owner, opnode=None, callback=None): |
|
|
|
self._attached_tensors.add(owner) |
|
|
|
return VariableNode(self, owner, opnode=opnode, callback=callback) |
|
|
|
|
|
|
|
def _new_opnode(self, inputs, outputs): |
|
|
@@ -131,13 +133,18 @@ class Grad: |
|
|
|
def __enter__(self): |
|
|
|
return self |
|
|
|
|
|
|
|
def __exit__(self, *_): |
|
|
|
def _exit(self): |
|
|
|
"""clear all resources""" |
|
|
|
self._enabled = False |
|
|
|
for o in self.ops: |
|
|
|
o = o() |
|
|
|
if o: |
|
|
|
o.clear() |
|
|
|
for i in self._attached_tensors: |
|
|
|
i._extra_data.pop(self, None) |
|
|
|
|
|
|
|
def __exit__(self, *_): |
|
|
|
self._exit() |
|
|
|
|
|
|
|
def __call__(self, ys, dys): |
|
|
|
""" Defines Grad(). |
|
|
@@ -275,6 +282,11 @@ class Grad: |
|
|
|
for v in cache.values(): |
|
|
|
assert v is None |
|
|
|
|
|
|
|
self._exit() |
|
|
|
|
|
|
|
def __del__(self): |
|
|
|
self._exit() |
|
|
|
|
|
|
|
|
|
|
|
class clearable: |
|
|
|
__cleared = False |
|
|
|