|
|
@@ -11,22 +11,13 @@ from collections import Iterable |
|
|
|
from contextlib import contextmanager |
|
|
|
from typing import Dict |
|
|
|
from typing import Iterable as Iter |
|
|
|
from typing import Set, Union |
|
|
|
from typing import Union |
|
|
|
|
|
|
|
import numpy as np |
|
|
|
|
|
|
|
from ..core.autodiff.grad import Grad |
|
|
|
from ..device import get_default_device |
|
|
|
from ..distributed.group import get_client, is_distributed |
|
|
|
from ..functional import add_update |
|
|
|
from ..functional.distributed import all_reduce_sum, broadcast |
|
|
|
from ..functional.utils import copy |
|
|
|
from ..logger import get_logger |
|
|
|
from ..tensor import Tensor, TensorDict |
|
|
|
from ..tensor_nn import Buffer, Parameter |
|
|
|
|
|
|
|
logger = get_logger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
class _RequiredParameter: |
|
|
|
def __repr__(self): |
|
|
@@ -43,10 +34,6 @@ class Optimizer(metaclass=ABCMeta): |
|
|
|
:param defaults: a dict of default parameters of Optimizer, like learning rate or momentum. |
|
|
|
""" |
|
|
|
|
|
|
|
_recording = None |
|
|
|
_grad = None |
|
|
|
_gradients = None |
|
|
|
|
|
|
|
def __init__( # pylint: disable=too-many-branches |
|
|
|
self, params: Union[Iter[Parameter], dict], defaults: dict, |
|
|
|
): |
|
|
@@ -63,7 +50,6 @@ class Optimizer(metaclass=ABCMeta): |
|
|
|
) |
|
|
|
|
|
|
|
self.param_groups = [] # type: list |
|
|
|
self.save_load_state_ignore_keys = set() |
|
|
|
|
|
|
|
param_groups = list(params) |
|
|
|
if len(param_groups) == 0: |
|
|
@@ -154,100 +140,6 @@ class Optimizer(metaclass=ABCMeta): |
|
|
|
params.append(param) |
|
|
|
return params |
|
|
|
|
|
|
|
def grad_callback(self, grad, i, group): |
|
|
|
pass |
|
|
|
|
|
|
|
def record(self): |
|
|
|
@contextmanager |
|
|
|
def recorder(): |
|
|
|
params = self._get_params() |
|
|
|
grad = Grad() |
|
|
|
gradients = [None] * len(params) |
|
|
|
if self._recording: |
|
|
|
raise RuntimeError("already recording!") |
|
|
|
try: |
|
|
|
self._recording = True |
|
|
|
self._grad = grad |
|
|
|
for group in self.param_groups: |
|
|
|
group["grads"] = [None] * len(group["params"]) |
|
|
|
for i, param in enumerate(group["params"]): |
|
|
|
|
|
|
|
def callback(tensor, grad, i=i, group=group, self=self): |
|
|
|
group["grads"][i] = grad |
|
|
|
self.grad_callback(grad, i, group) |
|
|
|
|
|
|
|
grad.wrt(param, callback=callback) |
|
|
|
with grad: |
|
|
|
yield |
|
|
|
finally: |
|
|
|
self._recording = False |
|
|
|
self._grad = None |
|
|
|
for group in self.param_groups: |
|
|
|
group["grads"] = [] |
|
|
|
|
|
|
|
return recorder() |
|
|
|
|
|
|
|
def _calculate_gradients(self, loss: Tensor): |
|
|
|
if not self._recording: |
|
|
|
raise RuntimeError( |
|
|
|
"no computation history. " |
|
|
|
"did you forget record() or " |
|
|
|
"call a method that clears the history?" |
|
|
|
) |
|
|
|
assert self._grad is not None |
|
|
|
|
|
|
|
if len(loss.__wrapped__._extra_data) == 0: # in case loss depends on no tensor |
|
|
|
self._grad = None |
|
|
|
return |
|
|
|
|
|
|
|
one = Tensor([1.0], dtype=loss.dtype, device=loss.device) |
|
|
|
one = one.reshape(loss.shape) |
|
|
|
try: |
|
|
|
self._grad(loss, one) |
|
|
|
finally: |
|
|
|
self._grad = None |
|
|
|
|
|
|
|
def minimize(self, loss: Tensor): |
|
|
|
self.backward(loss) |
|
|
|
self.step() |
|
|
|
|
|
|
|
def backward(self, loss: Tensor): |
|
|
|
"""Computes the back-propagation of the network given loss. |
|
|
|
|
|
|
|
:param loss: The obtained loss tensor |
|
|
|
""" |
|
|
|
rst = [] |
|
|
|
self._calculate_gradients(loss) |
|
|
|
|
|
|
|
# _grad_skip records the parameters which are not in the path of backward |
|
|
|
self._grad_skip = set() |
|
|
|
for group in self.param_groups: |
|
|
|
# _grad_skip is consumed in optimizer.step() |
|
|
|
# XXX: assumptions |
|
|
|
# 1. Assume the same execution sequence for all GPUs in data parallel |
|
|
|
# 2. If backward is called by multiple times to accumulate grad, |
|
|
|
# it's also assumed same _grad_skip for all backward() calls |
|
|
|
# Please change the code if any assumption is invalid |
|
|
|
for param, grad in zip(group["params"], group["grads"]): |
|
|
|
if grad is None: |
|
|
|
self._grad_skip.add(param.__wrapped__) |
|
|
|
continue |
|
|
|
grad = Buffer(grad) |
|
|
|
if getattr(param, "grad", None) is None: |
|
|
|
param.grad = grad |
|
|
|
else: |
|
|
|
assert isinstance(param.grad, Buffer) |
|
|
|
param.grad += grad |
|
|
|
rst.append(param.grad) |
|
|
|
if len(self._grad_skip) > 0: |
|
|
|
get_logger(__name__).warning( |
|
|
|
"{} parameters have no grad! " |
|
|
|
"Make sure you pass the right parameters list".format( |
|
|
|
len(self._grad_skip) |
|
|
|
) |
|
|
|
) |
|
|
|
return rst |
|
|
|
|
|
|
|
def step(self): |
|
|
|
r"""Performs a single optimization step. |
|
|
|
|
|
|
@@ -261,8 +153,8 @@ class Optimizer(metaclass=ABCMeta): |
|
|
|
) |
|
|
|
self._updates(group) |
|
|
|
|
|
|
|
def zero_grad(self): |
|
|
|
r"""Reset the grad to zeros. |
|
|
|
def clear_grad(self): |
|
|
|
r"""Clear the grad buffer. |
|
|
|
|
|
|
|
""" |
|
|
|
for param_group in self.param_groups: |
|
|
@@ -270,9 +162,6 @@ class Optimizer(metaclass=ABCMeta): |
|
|
|
if getattr(param, "grad", None) is not None: |
|
|
|
param.grad = None |
|
|
|
|
|
|
|
def add_save_load_state_ignore_keys(self, keys: Set[str]): |
|
|
|
self.save_load_state_ignore_keys |= keys |
|
|
|
|
|
|
|
def state_dict(self) -> Dict: |
|
|
|
r"""Export the optimizer state. |
|
|
|
|
|
|
@@ -293,11 +182,7 @@ class Optimizer(metaclass=ABCMeta): |
|
|
|
state[param2id[param]] = st |
|
|
|
|
|
|
|
for group in self.param_groups: |
|
|
|
param_group = { |
|
|
|
k: v |
|
|
|
for k, v in group.items() |
|
|
|
if k != "params" and k not in self.save_load_state_ignore_keys |
|
|
|
} |
|
|
|
param_group = {k: v for k, v in group.items() if k != "params"} |
|
|
|
param_group["params"] = [param2id[param] for param in group["params"]] |
|
|
|
param_groups.append(param_group) |
|
|
|
|
|
|
@@ -329,14 +214,12 @@ class Optimizer(metaclass=ABCMeta): |
|
|
|
if isinstance(v, Buffer): |
|
|
|
self._state[p][k] = Buffer(v.numpy()) |
|
|
|
|
|
|
|
new_keys = set(group_new.keys()) - self.save_load_state_ignore_keys |
|
|
|
saved_keys = set(group_saved.keys()) - self.save_load_state_ignore_keys |
|
|
|
if new_keys != saved_keys: |
|
|
|
if set(group_new.keys()) != set(group_saved.keys()): |
|
|
|
raise ValueError( |
|
|
|
"loaded state dict contains a parameter group that " |
|
|
|
"doesn't match the keys of optimizer's group" |
|
|
|
) |
|
|
|
for key in saved_keys: |
|
|
|
for key in group_new.keys(): |
|
|
|
if key != "params": |
|
|
|
group_new[key] = group_saved[key] |
|
|
|
|
|
|
|