Browse Source

feat(mge): use weakref for GradManger.attach

GitOrigin-RevId: 6df336c3c1
release-1.1
Megvii Engine Team 4 years ago
parent
commit
6667100638
3 changed files with 88 additions and 30 deletions
  1. +53
    -30
      imperative/python/megengine/autodiff/grad_manager.py
  2. +1
    -0
      imperative/python/megengine/distributed/helper.py
  3. +34
    -0
      imperative/python/test/unit/autodiff/test_grad_manger.py

+ 53
- 30
imperative/python/megengine/autodiff/grad_manager.py View File

@@ -1,3 +1,4 @@
import weakref
from collections import defaultdict from collections import defaultdict
from contextlib import contextmanager from contextlib import contextmanager
from typing import Callable from typing import Callable
@@ -16,6 +17,10 @@ def get_backwarding_grad_manager():
return backwarding_grad_manager return backwarding_grad_manager




class AttachSpec:
__slots__ = "tensor", "callbacks"


class GradManager: class GradManager:
r""" r"""
GradManager manages auto differentiation and all resources required to perform it. GradManager manages auto differentiation and all resources required to perform it.
@@ -64,14 +69,13 @@ class GradManager:
""" """


def __init__(self): def __init__(self):
self._call_back_dict = defaultdict(list)
self._param_dict = dict()
self._attach_specs = {} # id(Tensor) -> AttachSpec
self._recording = False self._recording = False
self._grad = None self._grad = None
self._after_backward_callback = [] self._after_backward_callback = []
self._gradients = dict()
self._gradients = {}


def attach(self, params: list, callbacks=None):
def attach(self, tensors: list, callbacks=None):
r""" r"""
Registers parameters that gradients should be calculated with respect to. Registers parameters that gradients should be calculated with respect to.
Callback Functions should have a signature like this: Callback Functions should have a signature like this:
@@ -89,22 +93,39 @@ class GradManager:
callbacks = [] callbacks = []
if isinstance(callbacks, Callable): if isinstance(callbacks, Callable):
callbacks = [callbacks] callbacks = [callbacks]
if isinstance(params, Tensor):
params = [params]
for p in params:
self._param_dict[id(p)] = p
for cb in callbacks:
self._call_back_dict[id(p)].append(cb)
if self._grad is not None:
for p in params:
self._record_param(id(p))
if isinstance(tensors, Tensor):
tensors = [tensors]

def make_spec(tensor):
selfref = weakref.ref(self)
key = id(tensor)

def deleter(_):
self = selfref()
if self is not None:
del self._attach_specs[key]

spec = AttachSpec()
spec.tensor = weakref.ref(tensor, deleter)
spec.callbacks = []
return spec

for x in tensors:
spec = self._attach_specs.get(id(x))
new_attach = spec is None
if spec is None:
spec = make_spec(x)
self._attach_specs[id(x)] = spec
spec.callbacks.extend(callbacks)
if new_attach and self._recording:
self._do_record(spec)
return self return self


def _register_after_backward_callback(self, callback): def _register_after_backward_callback(self, callback):
self._after_backward_callback.append(callback) self._after_backward_callback.append(callback)
return self return self


def backward(self, ys=None, dys=None):
def backward(self, y=None, dy=None):
r""" r"""
Performs back-propagation and computes gradients. Performs back-propagation and computes gradients.


@@ -135,14 +156,16 @@ class GradManager:
self._grad(ys, dys) self._grad(ys, dys)
for callback in self._after_backward_callback: for callback in self._after_backward_callback:
callback() callback()
for p, grad in self._gradients.items():
for id_, grad in self._gradients.items():
if isinstance(grad, Future): if isinstance(grad, Future):
grad = grad.get() grad = grad.get()
param = self._param_dict[p]
if param.grad is None:
param.grad = grad
else:
param.grad += grad
spec = self._attach_specs.get(id_)
tensor = spec and spec.tensor()
if tensor is not None:
if tensor.grad is None:
tensor.grad = grad
else:
tensor.grad += grad
finally: finally:
self.release() self.release()
backwarding_grad_manager = cache backwarding_grad_manager = cache
@@ -156,22 +179,22 @@ class GradManager:
grad = Grad() grad = Grad()
self._recording = True self._recording = True
self._grad = grad self._grad = grad
for param_id in self._param_dict.keys():
self._record_param(param_id)
for spec in self._attach_specs.values():
self._do_record(spec)
grad.__enter__() grad.__enter__()


def _record_param(self, param_id):
param_wrapper = self._param_dict[param_id]
callbacks = self._call_back_dict[param_id]
def _do_record(self, spec):
tensor = spec.tensor()
if tensor is None:
return


def callback(param, grad, callbacks=callbacks, p=param_wrapper, gm=self):
ret = grad
def callback(_, grad, callbacks=spec.callbacks):
for cb in callbacks: for cb in callbacks:
ret = cb(param, ret)
gm._gradients[id(p)] = ret
grad = cb(tensor, grad)
self._gradients[id(tensor)] = grad


# NOTE: override prev callback wrt when called serval times # NOTE: override prev callback wrt when called serval times
self._grad.wrt(param_wrapper, callback=callback)
self._grad.wrt(tensor, callback=callback)


def release(self): def release(self):
r""" r"""


+ 1
- 0
imperative/python/megengine/distributed/helper.py View File

@@ -224,6 +224,7 @@ class AllreduceCallback:
self._packing_size[dtype] = 0 self._packing_size[dtype] = 0


def __call__(self, param, grad): def __call__(self, param, grad):
param = param.__wrapped__
gm = get_backwarding_grad_manager() gm = get_backwarding_grad_manager()
assert isinstance(gm, GradManager) assert isinstance(gm, GradManager)
if gm not in self._marked_gm: if gm not in self._marked_gm:


+ 34
- 0
imperative/python/test/unit/autodiff/test_grad_manger.py View File

@@ -6,6 +6,7 @@
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import platform import platform
import weakref


import numpy as np import numpy as np
import pytest import pytest
@@ -59,6 +60,39 @@ def test_attach_in_with_block():
assert int(b.grad.numpy()) == 1 assert int(b.grad.numpy()) == 1




def test_attach_temporary():
w = mge.Parameter(2.0)
gm = GradManager()
gm.attach(w)

def cb(x, g):
assert x is ref()
cb.called = True

for i in range(3):
with gm:
cb.called = False
x = mge.Tensor(i, dtype="float32")
gm.attach(x, callbacks=cb)
ref = weakref.ref(x)
y = x * w
gm.backward(y)
assert cb.called
del x
assert ref() is None

# NOTE: does not guarantee timely release when recording
# for i in range(3):
# with gm:
# x = mge.Tensor(i, dtype='float32')
# gm.attach(x)
# ref = weakref.ref(x)
# y = x * w
# del x
# assert ref() is None
# gm.backward(y)


@pytest.mark.skipif( @pytest.mark.skipif(
platform.system() == "Darwin", reason="do not imp GPU mode at macos now" platform.system() == "Darwin", reason="do not imp GPU mode at macos now"
) )


Loading…
Cancel
Save