From 9005cf74d2cf86b874f500e0c3e901267347496b Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 15 Sep 2020 18:21:57 +0800 Subject: [PATCH] fix(mge/imperative): fix tracer leak GitOrigin-RevId: 1a8ac20b46e25897a34c62da088dbf49e0e19ee6 --- imperative/python/megengine/core/autodiff/grad.py | 14 +++++++++- imperative/python/test/unit/core/test_autodiff.py | 32 +++++++++++++++++++++++ 2 files changed, 45 insertions(+), 1 deletion(-) diff --git a/imperative/python/megengine/core/autodiff/grad.py b/imperative/python/megengine/core/autodiff/grad.py index c9d93e46..7562fbb4 100644 --- a/imperative/python/megengine/core/autodiff/grad.py +++ b/imperative/python/megengine/core/autodiff/grad.py @@ -82,6 +82,7 @@ class Grad: # ops forms the computational graph self.ops = [] + self._attached_tensors = weakref.WeakSet() self._enabled = True @property @@ -107,6 +108,7 @@ class Grad: return self def _new_variable(self, owner, opnode=None, callback=None): + self._attached_tensors.add(owner) return VariableNode(self, owner, opnode=opnode, callback=callback) def _new_opnode(self, inputs, outputs): @@ -131,13 +133,18 @@ class Grad: def __enter__(self): return self - def __exit__(self, *_): + def _exit(self): """clear all resources""" self._enabled = False for o in self.ops: o = o() if o: o.clear() + for i in self._attached_tensors: + i._extra_data.pop(self, None) + + def __exit__(self, *_): + self._exit() def __call__(self, ys, dys): """ Defines Grad(). @@ -275,6 +282,11 @@ class Grad: for v in cache.values(): assert v is None + self._exit() + + def __del__(self): + self._exit() + class clearable: __cleared = False diff --git a/imperative/python/test/unit/core/test_autodiff.py b/imperative/python/test/unit/core/test_autodiff.py index 3caaed61..1cc7d453 100644 --- a/imperative/python/test/unit/core/test_autodiff.py +++ b/imperative/python/test/unit/core/test_autodiff.py @@ -6,6 +6,7 @@ # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +import gc import platform import weakref @@ -158,6 +159,37 @@ def test_grad_with_tensor_wrapper(): np.testing.assert_almost_equal(x.grad.numpy(), 4 * x_np ** 3, decimal=6) +def test_release(): + def check(f): + n = 0 + d = None + for i in range(3): + f() + m = len(gc.get_objects()) + d = m - n + n = m + assert d == 0 + + x = TensorWrapper([0.0]) + dy = TensorWrapper(np.ones_like(x.numpy())) + + @check + def _(): + g = Grad().wrt(x) + y = x * x + g(y, dy) + + @check + def _(): + with Grad().wrt(x) as g: + pass + + @check + def _(): + with Grad().wrt(x) as g: + y = x * x + + def test_grad_inplace(): x_np = np.random.rand(10).astype("float32") x = TensorWrapper(x_np)