Browse Source

refactor(mge/optimizer): refine gradmanager api, record = __enter__

GitOrigin-RevId: 5376177237
tags/v1.0.0-rc1
Megvii Engine Team 4 years ago
parent
commit
c7acba41fc
3 changed files with 42 additions and 34 deletions
  1. +35
    -30
      imperative/python/megengine/autodiff/grad_manager.py
  2. +4
    -3
      imperative/python/megengine/distributed/helper.py
  3. +3
    -1
      imperative/python/megengine/optimizer/optimizer.py

+ 35
- 30
imperative/python/megengine/autodiff/grad_manager.py View File

@@ -1,5 +1,6 @@
from collections import defaultdict
from contextlib import contextmanager
from typing import Callable

from ..core.autodiff.grad import Grad
from ..tensor import tensor
@@ -21,7 +22,11 @@ class GradManager:
self._after_backward_callback = []
self._gradients = dict()

def register(self, params, callbacks=[]):
def register(self, params, callbacks=None):
if callbacks is None:
callbacks = []
if isinstance(callbacks, Callable):
callbacks = [callbacks]
for p in params:
self._param_dict[id(p)] = p
for cb in callbacks:
@@ -62,37 +67,37 @@ class GradManager:
else:
param.grad += grad
finally:
self._grad = None
self._gradients = dict()
self._stop_record()
backwarding_grad_manager = cache

def record(self):
@contextmanager
def recorder():
grad = Grad()
if self._recording:
raise RuntimeError("already recording!")
try:
self._recording = True
self._grad = grad
for param_id in self._param_dict.keys():
param_wrapper = self._param_dict[param_id]
callbacks = self._call_back_dict[param_id]
def __enter__(self):
if self._recording:
return self
grad = Grad()
self._recording = True
self._grad = grad
for param_id in self._param_dict.keys():
param_wrapper = self._param_dict[param_id]
callbacks = self._call_back_dict[param_id]

def callback(
param, grad, callbacks=callbacks, p=param_wrapper, gm=self
):
ret = grad
for cb in callbacks:
ret = cb(param, ret)
gm._gradients[id(p)] = ret
def callback(param, grad, callbacks=callbacks, p=param_wrapper, gm=self):
ret = grad
for cb in callbacks:
ret = cb(param, ret)
gm._gradients[id(p)] = ret

grad.wrt(param_wrapper, callback=callback)
with grad:
yield
finally:
self._recording = False
self._grad = None
self._gradients = dict()
grad.wrt(param_wrapper, callback=callback)
grad.__enter__()
return self

def __exit__(self, exc_type, exc_val, exc_tb):
self._stop_record()

record = __enter__

return recorder()
def _stop_record(self):
if self._grad is not None:
self._grad.__exit__(None, None, None)
self._recording = False
self._grad = None
self._gradients = dict()

+ 4
- 3
imperative/python/megengine/distributed/helper.py View File

@@ -10,6 +10,7 @@ import functools
import multiprocessing as mp
from collections import defaultdict
from typing import Callable
from weakref import WeakSet

import numpy as np

@@ -23,7 +24,7 @@ from .functional import all_reduce_sum, broadcast
from .group import WORLD, group_barrier, is_distributed


class FakeTensor(Future):
class TensorFuture(Future):
def device(self):
raise "Sorry, this tensor is not ready"

@@ -77,7 +78,7 @@ class AllreduceCallback:
assert reduce_method in ["sum", "mean"]
self._reduce_method = reduce_method
self._group = group
self._marked_gm = set()
self._marked_gm = WeakSet()
self._param_pack_thd = 10 * 1024 * 1024
self._reset()

@@ -107,7 +108,7 @@ class AllreduceCallback:
gm._register_after_backward_callback(self._flush)
self._marked_gm.add(gm)
self._params.append(param)
self._futures_dict[param] = FakeTensor(ack=False)
self._futures_dict[param] = TensorFuture(ack=False)
self._gradients_dict[param] = grad
self._grad_origin_device[param] = str(grad.device)



+ 3
- 1
imperative/python/megengine/optimizer/optimizer.py View File

@@ -140,7 +140,7 @@ class Optimizer(metaclass=ABCMeta):
params.append(param)
return params

def step(self):
def step(self, clear_grad=False):
r"""Performs a single optimization step.

"""
@@ -152,6 +152,8 @@ class Optimizer(metaclass=ABCMeta):
"Please use a list instead."
)
self._updates(group)
if clear_grad:
self.clear_grad()

def clear_grad(self):
r"""Clear the grad buffer.


Loading…
Cancel
Save