Browse Source

fix(mge/imperative): remove backward from optimizer

GitOrigin-RevId: ad6ad444fa
tags/v1.0.0-rc1
Megvii Engine Team 4 years ago
parent
commit
e50fa074cb
2 changed files with 6 additions and 243 deletions
  1. +0
    -120
      imperative/python/megengine/optimizer/distributed_optimizer.py
  2. +6
    -123
      imperative/python/megengine/optimizer/optimizer.py

+ 0
- 120
imperative/python/megengine/optimizer/distributed_optimizer.py View File

@@ -1,120 +0,0 @@
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from typing import Iterable as Iter
from typing import Optional, Union

from ..device import get_default_device
from ..distributed.group import get_client, is_distributed
from ..functional import add_update
from ..functional.distributed import WORLD, Group, all_reduce_sum, broadcast
from ..functional.utils import copy
from ..tensor import Tensor, TensorDict
from ..tensor_nn import Parameter
from .optimizer import Optimizer
from .param_pack import get_pack_list, pack_allreduce_split


class DistributedOptimizer(Optimizer):
r"""Add Distributed Func for distributed training.

:param params: specifies what Tensors should be optimized.
:param defaults: a dict of default parameters of Optimizer, like learning rate or momentum.
:param reduce_method: use all_reduce_sum or all_reduce_mean to reduce gradients
:param bcast_period: broadcasts params every *bcast_period* iterations.
if it equals to 0, it will broadcast params only at the beginning. Default: 500
:param param_pack: whether to pack gradients to avoid small packages send/recv. Default: False
:param param_pack_thd: max size of packed gradients by bytes. Default: 10 * 1024 * 1024
"""

def __init__(
self,
params: Union[Iter[Parameter], dict],
defaults: dict,
reduce_method: Optional[str] = None,
dist_group: Optional[Group] = WORLD,
bcast_period: int = 0,
param_pack: bool = False,
param_pack_thd: int = 10 * 1024 * 1024,
):
if is_distributed():
assert reduce_method in ["sum", "mean"], "reduce_method must be specified"
defaults["orders"] = []
defaults["dist_group"] = dist_group
super().__init__(params, defaults)
self._bcast_period = bcast_period
self._param_pack = param_pack
self._param_pack_thd = param_pack_thd
self._reduce_method = reduce_method

self.add_save_load_state_ignore_keys(
{"grads", "orders", "pack_list", "shape_list", "dist_group"}
)

if is_distributed() and bcast_period != -1:
self.bcast_param()

def grad_callback(self, grad, i, group):
if is_distributed() and group["dist_group"] is not None:
dist_group = group["dist_group"]
if self._param_pack and "pack_list" in group:
for pack, shapes in zip(group["pack_list"], group["shape_list"]):
if i == pack[-1]:
pack_allreduce_split(group, pack, shapes, self._reduce_method)
else:
group["orders"].append(i)
group["grads"][i] = all_reduce_sum(
grad, dist_group, dist_group.comp_node
)
if self._reduce_method == "mean":
group["grads"][i] /= dist_group.size

def _gen_pack_list(self, group):
if "pack_list" not in group:
dist_group = group["dist_group"]
if dist_group.rank == 0:
pack_list, shape_list = get_pack_list(group, self._param_pack_thd)
get_client().set_pack_list(dist_group.key, (pack_list, shape_list))
else:
pack_list, shape_list = get_client().get_pack_list(dist_group.key)
group["pack_list"] = pack_list
group["shape_list"] = shape_list

def backward(self, loss: Tensor):
ret = super().backward(loss)
if is_distributed():
for group in self.param_groups:
if self._param_pack and group["dist_group"] is not None:
self._gen_pack_list(group)
return ret

def step(self):
if is_distributed():
for group in self.param_groups:
device = get_default_device()
for param in group["params"]:
if param.__wrapped__ not in self._grad_skip:
if param.grad.device != device:
param.grad = copy(param.grad, device)
if self._bcast_period > 0:
self._bcast_iter += 1
if self._bcast_iter == self._bcast_period:
self.bcast_param()
self._bcast_iter = 0
super().step()

def bcast_param(self):
device = get_default_device()
for group in self.param_groups:
for param in group["params"]:
dist_group = group["dist_group"]
new_param = broadcast(param, dist_group)
if new_param.device != device:
new_param = copy(new_param, device)
add_update(param, new_param, alpha=0)
param._reset(new_param)

+ 6
- 123
imperative/python/megengine/optimizer/optimizer.py View File

@@ -11,22 +11,13 @@ from collections import Iterable
from contextlib import contextmanager
from typing import Dict
from typing import Iterable as Iter
from typing import Set, Union
from typing import Union

import numpy as np

from ..core.autodiff.grad import Grad
from ..device import get_default_device
from ..distributed.group import get_client, is_distributed
from ..functional import add_update
from ..functional.distributed import all_reduce_sum, broadcast
from ..functional.utils import copy
from ..logger import get_logger
from ..tensor import Tensor, TensorDict
from ..tensor_nn import Buffer, Parameter

logger = get_logger(__name__)


class _RequiredParameter:
def __repr__(self):
@@ -43,10 +34,6 @@ class Optimizer(metaclass=ABCMeta):
:param defaults: a dict of default parameters of Optimizer, like learning rate or momentum.
"""

_recording = None
_grad = None
_gradients = None

def __init__( # pylint: disable=too-many-branches
self, params: Union[Iter[Parameter], dict], defaults: dict,
):
@@ -63,7 +50,6 @@ class Optimizer(metaclass=ABCMeta):
)

self.param_groups = [] # type: list
self.save_load_state_ignore_keys = set()

param_groups = list(params)
if len(param_groups) == 0:
@@ -154,100 +140,6 @@ class Optimizer(metaclass=ABCMeta):
params.append(param)
return params

def grad_callback(self, grad, i, group):
pass

def record(self):
@contextmanager
def recorder():
params = self._get_params()
grad = Grad()
gradients = [None] * len(params)
if self._recording:
raise RuntimeError("already recording!")
try:
self._recording = True
self._grad = grad
for group in self.param_groups:
group["grads"] = [None] * len(group["params"])
for i, param in enumerate(group["params"]):

def callback(tensor, grad, i=i, group=group, self=self):
group["grads"][i] = grad
self.grad_callback(grad, i, group)

grad.wrt(param, callback=callback)
with grad:
yield
finally:
self._recording = False
self._grad = None
for group in self.param_groups:
group["grads"] = []

return recorder()

def _calculate_gradients(self, loss: Tensor):
if not self._recording:
raise RuntimeError(
"no computation history. "
"did you forget record() or "
"call a method that clears the history?"
)
assert self._grad is not None

if len(loss.__wrapped__._extra_data) == 0: # in case loss depends on no tensor
self._grad = None
return

one = Tensor([1.0], dtype=loss.dtype, device=loss.device)
one = one.reshape(loss.shape)
try:
self._grad(loss, one)
finally:
self._grad = None

def minimize(self, loss: Tensor):
self.backward(loss)
self.step()

def backward(self, loss: Tensor):
"""Computes the back-propagation of the network given loss.

:param loss: The obtained loss tensor
"""
rst = []
self._calculate_gradients(loss)

# _grad_skip records the parameters which are not in the path of backward
self._grad_skip = set()
for group in self.param_groups:
# _grad_skip is consumed in optimizer.step()
# XXX: assumptions
# 1. Assume the same execution sequence for all GPUs in data parallel
# 2. If backward is called by multiple times to accumulate grad,
# it's also assumed same _grad_skip for all backward() calls
# Please change the code if any assumption is invalid
for param, grad in zip(group["params"], group["grads"]):
if grad is None:
self._grad_skip.add(param.__wrapped__)
continue
grad = Buffer(grad)
if getattr(param, "grad", None) is None:
param.grad = grad
else:
assert isinstance(param.grad, Buffer)
param.grad += grad
rst.append(param.grad)
if len(self._grad_skip) > 0:
get_logger(__name__).warning(
"{} parameters have no grad! "
"Make sure you pass the right parameters list".format(
len(self._grad_skip)
)
)
return rst

def step(self):
r"""Performs a single optimization step.

@@ -261,8 +153,8 @@ class Optimizer(metaclass=ABCMeta):
)
self._updates(group)

def zero_grad(self):
r"""Reset the grad to zeros.
def clear_grad(self):
r"""Clear the grad buffer.

"""
for param_group in self.param_groups:
@@ -270,9 +162,6 @@ class Optimizer(metaclass=ABCMeta):
if getattr(param, "grad", None) is not None:
param.grad = None

def add_save_load_state_ignore_keys(self, keys: Set[str]):
self.save_load_state_ignore_keys |= keys

def state_dict(self) -> Dict:
r"""Export the optimizer state.

@@ -293,11 +182,7 @@ class Optimizer(metaclass=ABCMeta):
state[param2id[param]] = st

for group in self.param_groups:
param_group = {
k: v
for k, v in group.items()
if k != "params" and k not in self.save_load_state_ignore_keys
}
param_group = {k: v for k, v in group.items() if k != "params"}
param_group["params"] = [param2id[param] for param in group["params"]]
param_groups.append(param_group)

@@ -329,14 +214,12 @@ class Optimizer(metaclass=ABCMeta):
if isinstance(v, Buffer):
self._state[p][k] = Buffer(v.numpy())

new_keys = set(group_new.keys()) - self.save_load_state_ignore_keys
saved_keys = set(group_saved.keys()) - self.save_load_state_ignore_keys
if new_keys != saved_keys:
if set(group_new.keys()) != set(group_saved.keys()):
raise ValueError(
"loaded state dict contains a parameter group that "
"doesn't match the keys of optimizer's group"
)
for key in saved_keys:
for key in group_new.keys():
if key != "params":
group_new[key] = group_saved[key]



Loading…
Cancel
Save