You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

distributed_optimizer.py 5.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. from typing import Iterable as Iter
  10. from typing import Optional, Union
  11. from ..device import get_default_device
  12. from ..distributed.group import get_client, is_distributed
  13. from ..functional import add_update
  14. from ..functional.distributed import WORLD, Group, all_reduce_sum, broadcast
  15. from ..functional.utils import copy
  16. from ..tensor import Tensor, TensorDict
  17. from ..tensor_nn import Parameter
  18. from .optimizer import Optimizer
  19. from .param_pack import get_pack_list, pack_allreduce_split
  20. class DistributedOptimizer(Optimizer):
  21. r"""Add Distributed Func for distributed training.
  22. :param params: specifies what Tensors should be optimized.
  23. :param defaults: a dict of default parameters of Optimizer, like learning rate or momentum.
  24. :param reduce_method: use all_reduce_sum or all_reduce_mean to reduce gradients
  25. :param bcast_period: broadcasts params every *bcast_period* iterations.
  26. if it equals to 0, it will broadcast params only at the beginning. Default: 500
  27. :param param_pack: whether to pack gradients to avoid small packages send/recv. Default: False
  28. :param param_pack_thd: max size of packed gradients by bytes. Default: 10 * 1024 * 1024
  29. """
  30. def __init__(
  31. self,
  32. params: Union[Iter[Parameter], dict],
  33. defaults: dict,
  34. reduce_method: Optional[str] = None,
  35. dist_group: Optional[Group] = WORLD,
  36. bcast_period: int = 0,
  37. param_pack: bool = False,
  38. param_pack_thd: int = 10 * 1024 * 1024,
  39. ):
  40. if is_distributed():
  41. assert reduce_method in ["sum", "mean"], "reduce_method must be specified"
  42. defaults["orders"] = []
  43. defaults["dist_group"] = dist_group
  44. super().__init__(params, defaults)
  45. self._bcast_period = bcast_period
  46. self._param_pack = param_pack
  47. self._param_pack_thd = param_pack_thd
  48. self._reduce_method = reduce_method
  49. self.add_save_load_state_ignore_keys(
  50. {"grads", "orders", "pack_list", "shape_list", "dist_group"}
  51. )
  52. if is_distributed() and bcast_period != -1:
  53. self.bcast_param()
  54. def grad_callback(self, grad, i, group):
  55. if is_distributed() and group["dist_group"] is not None:
  56. dist_group = group["dist_group"]
  57. if self._param_pack and "pack_list" in group:
  58. for pack, shapes in zip(group["pack_list"], group["shape_list"]):
  59. if i == pack[-1]:
  60. pack_allreduce_split(group, pack, shapes, self._reduce_method)
  61. else:
  62. group["orders"].append(i)
  63. group["grads"][i] = all_reduce_sum(
  64. grad, dist_group, dist_group.comp_node
  65. )
  66. if self._reduce_method == "mean":
  67. group["grads"][i] /= dist_group.size
  68. def _gen_pack_list(self, group):
  69. if "pack_list" not in group:
  70. dist_group = group["dist_group"]
  71. if dist_group.rank == 0:
  72. pack_list, shape_list = get_pack_list(group, self._param_pack_thd)
  73. get_client().set_pack_list(dist_group.key, (pack_list, shape_list))
  74. else:
  75. pack_list, shape_list = get_client().get_pack_list(dist_group.key)
  76. group["pack_list"] = pack_list
  77. group["shape_list"] = shape_list
  78. def backward(self, loss: Tensor):
  79. ret = super().backward(loss)
  80. if is_distributed():
  81. for group in self.param_groups:
  82. if self._param_pack and group["dist_group"] is not None:
  83. self._gen_pack_list(group)
  84. return ret
  85. def step(self):
  86. if is_distributed():
  87. for group in self.param_groups:
  88. device = get_default_device()
  89. for param in group["params"]:
  90. if param.__wrapped__ not in self._grad_skip:
  91. if param.grad.device != device:
  92. param.grad = copy(param.grad, device)
  93. if self._bcast_period > 0:
  94. self._bcast_iter += 1
  95. if self._bcast_iter == self._bcast_period:
  96. self.bcast_param()
  97. self._bcast_iter = 0
  98. super().step()
  99. def bcast_param(self):
  100. device = get_default_device()
  101. for group in self.param_groups:
  102. for param in group["params"]:
  103. dist_group = group["dist_group"]
  104. new_param = broadcast(param, dist_group)
  105. if new_param.device != device:
  106. new_param = copy(new_param, device)
  107. add_update(param, new_param, alpha=0)
  108. param._reset(new_param)

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台