Browse Source

fix(mge/imperative): fix tracer leak

GitOrigin-RevId: 1a8ac20b46
tags/v1.0.0-rc1
Megvii Engine Team 4 years ago
parent
commit
9005cf74d2
2 changed files with 45 additions and 1 deletions
  1. +13
    -1
      imperative/python/megengine/core/autodiff/grad.py
  2. +32
    -0
      imperative/python/test/unit/core/test_autodiff.py

+ 13
- 1
imperative/python/megengine/core/autodiff/grad.py View File

@@ -82,6 +82,7 @@ class Grad:
# ops forms the computational graph
self.ops = []

self._attached_tensors = weakref.WeakSet()
self._enabled = True

@property
@@ -107,6 +108,7 @@ class Grad:
return self

def _new_variable(self, owner, opnode=None, callback=None):
self._attached_tensors.add(owner)
return VariableNode(self, owner, opnode=opnode, callback=callback)

def _new_opnode(self, inputs, outputs):
@@ -131,13 +133,18 @@ class Grad:
def __enter__(self):
return self

def __exit__(self, *_):
def _exit(self):
"""clear all resources"""
self._enabled = False
for o in self.ops:
o = o()
if o:
o.clear()
for i in self._attached_tensors:
i._extra_data.pop(self, None)

def __exit__(self, *_):
self._exit()

def __call__(self, ys, dys):
""" Defines Grad().
@@ -275,6 +282,11 @@ class Grad:
for v in cache.values():
assert v is None

self._exit()

def __del__(self):
self._exit()


class clearable:
__cleared = False


+ 32
- 0
imperative/python/test/unit/core/test_autodiff.py View File

@@ -6,6 +6,7 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import gc
import platform
import weakref

@@ -158,6 +159,37 @@ def test_grad_with_tensor_wrapper():
np.testing.assert_almost_equal(x.grad.numpy(), 4 * x_np ** 3, decimal=6)


def test_release():
def check(f):
n = 0
d = None
for i in range(3):
f()
m = len(gc.get_objects())
d = m - n
n = m
assert d == 0

x = TensorWrapper([0.0])
dy = TensorWrapper(np.ones_like(x.numpy()))

@check
def _():
g = Grad().wrt(x)
y = x * x
g(y, dy)

@check
def _():
with Grad().wrt(x) as g:
pass

@check
def _():
with Grad().wrt(x) as g:
y = x * x


def test_grad_inplace():
x_np = np.random.rand(10).astype("float32")
x = TensorWrapper(x_np)


Loading…
Cancel
Save