Browse Source

refactor(mge/optimizer): refine gradmanager api, record = __enter__

GitOrigin-RevId: 5376177237
tags/v1.0.0-rc1
Megvii Engine Team 4 years ago
parent
commit
c7acba41fc
3 changed files with 42 additions and 34 deletions
  1. +35
    -30
      imperative/python/megengine/autodiff/grad_manager.py
  2. +4
    -3
      imperative/python/megengine/distributed/helper.py
  3. +3
    -1
      imperative/python/megengine/optimizer/optimizer.py

+ 35
- 30
imperative/python/megengine/autodiff/grad_manager.py View File

@@ -1,5 +1,6 @@
from collections import defaultdict from collections import defaultdict
from contextlib import contextmanager from contextlib import contextmanager
from typing import Callable


from ..core.autodiff.grad import Grad from ..core.autodiff.grad import Grad
from ..tensor import tensor from ..tensor import tensor
@@ -21,7 +22,11 @@ class GradManager:
self._after_backward_callback = [] self._after_backward_callback = []
self._gradients = dict() self._gradients = dict()


def register(self, params, callbacks=[]):
def register(self, params, callbacks=None):
if callbacks is None:
callbacks = []
if isinstance(callbacks, Callable):
callbacks = [callbacks]
for p in params: for p in params:
self._param_dict[id(p)] = p self._param_dict[id(p)] = p
for cb in callbacks: for cb in callbacks:
@@ -62,37 +67,37 @@ class GradManager:
else: else:
param.grad += grad param.grad += grad
finally: finally:
self._grad = None
self._gradients = dict()
self._stop_record()
backwarding_grad_manager = cache backwarding_grad_manager = cache


def record(self):
@contextmanager
def recorder():
grad = Grad()
if self._recording:
raise RuntimeError("already recording!")
try:
self._recording = True
self._grad = grad
for param_id in self._param_dict.keys():
param_wrapper = self._param_dict[param_id]
callbacks = self._call_back_dict[param_id]
def __enter__(self):
if self._recording:
return self
grad = Grad()
self._recording = True
self._grad = grad
for param_id in self._param_dict.keys():
param_wrapper = self._param_dict[param_id]
callbacks = self._call_back_dict[param_id]


def callback(
param, grad, callbacks=callbacks, p=param_wrapper, gm=self
):
ret = grad
for cb in callbacks:
ret = cb(param, ret)
gm._gradients[id(p)] = ret
def callback(param, grad, callbacks=callbacks, p=param_wrapper, gm=self):
ret = grad
for cb in callbacks:
ret = cb(param, ret)
gm._gradients[id(p)] = ret


grad.wrt(param_wrapper, callback=callback)
with grad:
yield
finally:
self._recording = False
self._grad = None
self._gradients = dict()
grad.wrt(param_wrapper, callback=callback)
grad.__enter__()
return self

def __exit__(self, exc_type, exc_val, exc_tb):
self._stop_record()

record = __enter__


return recorder()
def _stop_record(self):
if self._grad is not None:
self._grad.__exit__(None, None, None)
self._recording = False
self._grad = None
self._gradients = dict()

+ 4
- 3
imperative/python/megengine/distributed/helper.py View File

@@ -10,6 +10,7 @@ import functools
import multiprocessing as mp import multiprocessing as mp
from collections import defaultdict from collections import defaultdict
from typing import Callable from typing import Callable
from weakref import WeakSet


import numpy as np import numpy as np


@@ -23,7 +24,7 @@ from .functional import all_reduce_sum, broadcast
from .group import WORLD, group_barrier, is_distributed from .group import WORLD, group_barrier, is_distributed




class FakeTensor(Future):
class TensorFuture(Future):
def device(self): def device(self):
raise "Sorry, this tensor is not ready" raise "Sorry, this tensor is not ready"


@@ -77,7 +78,7 @@ class AllreduceCallback:
assert reduce_method in ["sum", "mean"] assert reduce_method in ["sum", "mean"]
self._reduce_method = reduce_method self._reduce_method = reduce_method
self._group = group self._group = group
self._marked_gm = set()
self._marked_gm = WeakSet()
self._param_pack_thd = 10 * 1024 * 1024 self._param_pack_thd = 10 * 1024 * 1024
self._reset() self._reset()


@@ -107,7 +108,7 @@ class AllreduceCallback:
gm._register_after_backward_callback(self._flush) gm._register_after_backward_callback(self._flush)
self._marked_gm.add(gm) self._marked_gm.add(gm)
self._params.append(param) self._params.append(param)
self._futures_dict[param] = FakeTensor(ack=False)
self._futures_dict[param] = TensorFuture(ack=False)
self._gradients_dict[param] = grad self._gradients_dict[param] = grad
self._grad_origin_device[param] = str(grad.device) self._grad_origin_device[param] = str(grad.device)




+ 3
- 1
imperative/python/megengine/optimizer/optimizer.py View File

@@ -140,7 +140,7 @@ class Optimizer(metaclass=ABCMeta):
params.append(param) params.append(param)
return params return params


def step(self):
def step(self, clear_grad=False):
r"""Performs a single optimization step. r"""Performs a single optimization step.


""" """
@@ -152,6 +152,8 @@ class Optimizer(metaclass=ABCMeta):
"Please use a list instead." "Please use a list instead."
) )
self._updates(group) self._updates(group)
if clear_grad:
self.clear_grad()


def clear_grad(self): def clear_grad(self):
r"""Clear the grad buffer. r"""Clear the grad buffer.


Loading…
Cancel
Save