Browse Source

fix(mge/imperative): fix tracer leak

GitOrigin-RevId: 1a8ac20b46
tags/v1.0.0-rc1
Megvii Engine Team 4 years ago
parent
commit
9005cf74d2
2 changed files with 45 additions and 1 deletions
  1. +13
    -1
      imperative/python/megengine/core/autodiff/grad.py
  2. +32
    -0
      imperative/python/test/unit/core/test_autodiff.py

+ 13
- 1
imperative/python/megengine/core/autodiff/grad.py View File

@@ -82,6 +82,7 @@ class Grad:
# ops forms the computational graph # ops forms the computational graph
self.ops = [] self.ops = []


self._attached_tensors = weakref.WeakSet()
self._enabled = True self._enabled = True


@property @property
@@ -107,6 +108,7 @@ class Grad:
return self return self


def _new_variable(self, owner, opnode=None, callback=None): def _new_variable(self, owner, opnode=None, callback=None):
self._attached_tensors.add(owner)
return VariableNode(self, owner, opnode=opnode, callback=callback) return VariableNode(self, owner, opnode=opnode, callback=callback)


def _new_opnode(self, inputs, outputs): def _new_opnode(self, inputs, outputs):
@@ -131,13 +133,18 @@ class Grad:
def __enter__(self): def __enter__(self):
return self return self


def __exit__(self, *_):
def _exit(self):
"""clear all resources""" """clear all resources"""
self._enabled = False self._enabled = False
for o in self.ops: for o in self.ops:
o = o() o = o()
if o: if o:
o.clear() o.clear()
for i in self._attached_tensors:
i._extra_data.pop(self, None)

def __exit__(self, *_):
self._exit()


def __call__(self, ys, dys): def __call__(self, ys, dys):
""" Defines Grad(). """ Defines Grad().
@@ -275,6 +282,11 @@ class Grad:
for v in cache.values(): for v in cache.values():
assert v is None assert v is None


self._exit()

def __del__(self):
self._exit()



class clearable: class clearable:
__cleared = False __cleared = False


+ 32
- 0
imperative/python/test/unit/core/test_autodiff.py View File

@@ -6,6 +6,7 @@
# Unless required by applicable law or agreed to in writing, # Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import gc
import platform import platform
import weakref import weakref


@@ -158,6 +159,37 @@ def test_grad_with_tensor_wrapper():
np.testing.assert_almost_equal(x.grad.numpy(), 4 * x_np ** 3, decimal=6) np.testing.assert_almost_equal(x.grad.numpy(), 4 * x_np ** 3, decimal=6)




def test_release():
def check(f):
n = 0
d = None
for i in range(3):
f()
m = len(gc.get_objects())
d = m - n
n = m
assert d == 0

x = TensorWrapper([0.0])
dy = TensorWrapper(np.ones_like(x.numpy()))

@check
def _():
g = Grad().wrt(x)
y = x * x
g(y, dy)

@check
def _():
with Grad().wrt(x) as g:
pass

@check
def _():
with Grad().wrt(x) as g:
y = x * x


def test_grad_inplace(): def test_grad_inplace():
x_np = np.random.rand(10).astype("float32") x_np = np.random.rand(10).astype("float32")
x = TensorWrapper(x_np) x = TensorWrapper(x_np)


Loading…
Cancel
Save