|
- # -*- coding: utf-8 -*-
- # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
- #
- # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
- #
- # Unless required by applicable law or agreed to in writing,
- # software distributed under the License is distributed on an
- # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- from typing import Iterable as Iter
- from typing import Optional, Union
-
- from ..device import get_default_device
- from ..distributed.group import get_client, is_distributed
- from ..functional import add_update
- from ..functional.distributed import WORLD, Group, all_reduce_sum, broadcast
- from ..functional.utils import copy
- from ..tensor import Tensor, TensorDict
- from ..tensor_nn import Parameter
- from .optimizer import Optimizer
- from .param_pack import get_pack_list, pack_allreduce_split
-
-
- class DistributedOptimizer(Optimizer):
- r"""Add Distributed Func for distributed training.
-
- :param params: specifies what Tensors should be optimized.
- :param defaults: a dict of default parameters of Optimizer, like learning rate or momentum.
- :param reduce_method: use all_reduce_sum or all_reduce_mean to reduce gradients
- :param bcast_period: broadcasts params every *bcast_period* iterations.
- if it equals to 0, it will broadcast params only at the beginning. Default: 500
- :param param_pack: whether to pack gradients to avoid small packages send/recv. Default: False
- :param param_pack_thd: max size of packed gradients by bytes. Default: 10 * 1024 * 1024
- """
-
- def __init__(
- self,
- params: Union[Iter[Parameter], dict],
- defaults: dict,
- reduce_method: Optional[str] = None,
- dist_group: Optional[Group] = WORLD,
- bcast_period: int = 0,
- param_pack: bool = False,
- param_pack_thd: int = 10 * 1024 * 1024,
- ):
- if is_distributed():
- assert reduce_method in ["sum", "mean"], "reduce_method must be specified"
- defaults["orders"] = []
- defaults["dist_group"] = dist_group
- super().__init__(params, defaults)
- self._bcast_period = bcast_period
- self._param_pack = param_pack
- self._param_pack_thd = param_pack_thd
- self._reduce_method = reduce_method
-
- self.add_save_load_state_ignore_keys(
- {"grads", "orders", "pack_list", "shape_list", "dist_group"}
- )
-
- if is_distributed() and bcast_period != -1:
- self.bcast_param()
-
- def grad_callback(self, grad, i, group):
- if is_distributed() and group["dist_group"] is not None:
- dist_group = group["dist_group"]
- if self._param_pack and "pack_list" in group:
- for pack, shapes in zip(group["pack_list"], group["shape_list"]):
- if i == pack[-1]:
- pack_allreduce_split(group, pack, shapes, self._reduce_method)
- else:
- group["orders"].append(i)
- group["grads"][i] = all_reduce_sum(
- grad, dist_group, dist_group.comp_node
- )
- if self._reduce_method == "mean":
- group["grads"][i] /= dist_group.size
-
- def _gen_pack_list(self, group):
- if "pack_list" not in group:
- dist_group = group["dist_group"]
- if dist_group.rank == 0:
- pack_list, shape_list = get_pack_list(group, self._param_pack_thd)
- get_client().set_pack_list(dist_group.key, (pack_list, shape_list))
- else:
- pack_list, shape_list = get_client().get_pack_list(dist_group.key)
- group["pack_list"] = pack_list
- group["shape_list"] = shape_list
-
- def backward(self, loss: Tensor):
- ret = super().backward(loss)
- if is_distributed():
- for group in self.param_groups:
- if self._param_pack and group["dist_group"] is not None:
- self._gen_pack_list(group)
- return ret
-
- def step(self):
- if is_distributed():
- for group in self.param_groups:
- device = get_default_device()
- for param in group["params"]:
- if param.__wrapped__ not in self._grad_skip:
- if param.grad.device != device:
- param.grad = copy(param.grad, device)
- if self._bcast_period > 0:
- self._bcast_iter += 1
- if self._bcast_iter == self._bcast_period:
- self.bcast_param()
- self._bcast_iter = 0
- super().step()
-
- def bcast_param(self):
- device = get_default_device()
- for group in self.param_groups:
- for param in group["params"]:
- dist_group = group["dist_group"]
- new_param = broadcast(param, dist_group)
- if new_param.device != device:
- new_param = copy(new_param, device)
- add_update(param, new_param, alpha=0)
- param._reset(new_param)
|