From 55d8738def5f9c38241e44e718e06abe108b8390 Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Tue, 14 Jun 2022 15:03:41 +0000 Subject: [PATCH] deepspeed driver init --- fastNLP/core/drivers/torch_driver/deepspeed.py | 165 +++++++++++++++++++++++++ fastNLP/core/drivers/torch_driver/utils.py | 87 ++++++++++++- fastNLP/envs/imports.py | 1 + 3 files changed, 252 insertions(+), 1 deletion(-) create mode 100644 fastNLP/core/drivers/torch_driver/deepspeed.py diff --git a/fastNLP/core/drivers/torch_driver/deepspeed.py b/fastNLP/core/drivers/torch_driver/deepspeed.py new file mode 100644 index 00000000..0e0637a1 --- /dev/null +++ b/fastNLP/core/drivers/torch_driver/deepspeed.py @@ -0,0 +1,165 @@ +from typing import Optional, Union, Callable, Dict, Tuple, Sequence, List +from .torch_driver import TorchDriver +from .utils import _create_default_config +from fastNLP.core.utils import auto_param_call +from fastNLP.core.utils.utils import _get_fun_msg +from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, re_instantiate_sampler, \ + ReproduceBatchSampler +from fastNLP.core.log import logger +from fastNLP.envs.imports import _NEED_IMPORT_TORCH, _NEED_IMPORT_DEEPSPEED + +if _NEED_IMPORT_TORCH: + import pytorch_lightning + import torch + from torch.nn import DataParallel + +if _NEED_IMPORT_DEEPSPEED: + import deepspeed + +__all__ = [ + "DeepSpeedDriver", +] + +class DeepSpeedDriver(TorchDriver): + def __init__(self, model, fp16, strategy, **kwargs): + super(DeepSpeedDriver, self).__init__(model, fp16) + + self.strategy = strategy + + def setup(self): + + if self.strategy == "deepspeed": + self.config = _create_default_config(stage=2) + elif self.strategy == "deepspeed_stage_1": + self.config = _create_default_config(stage=1) + elif self.strategy == "deepspeed_stage_2": + self.config = _create_default_config(stage=2) + elif self.strategy == "deepspeed_stage_2_offload": + self.config = _create_default_config(stage=2, offload_optimizer=True) + elif self.strategy == "deepspeed_stage_3": + self.config = _create_default_config(stage=3) + elif self.strategy == "deepspeed_stage_3_offload": + self.config = _create_default_config( + stage=3, + offload_optimizer=True, + offload_parameters=True, + ) + elif self.strategy == "deepspeed_stage_3_offload_nvme": + self.config = _create_default_config( + stage=3, + offload_optimizer=True, + offload_parameters=True, + remote_device="nvme", + offload_params_device="nvme", + offload_optimizer_device="nvme", + ) + for i, optimizer in enumerate(self.optimizers): + # TODO 多个 optimizer + engine, optimizer_ds, _, _ = deepspeed.initialize( + model=self.model, + optimizer=optimizer, + config=self.config + ) + self._optimizers[i] = optimizer_ds + self.model = engine + + self._set_deepspeed_activation_checkpointing() + + def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict: + if isinstance(batch, Dict) and not self.wo_auto_param_call: + return auto_param_call(fn, batch, signature_fn=signature_fn) + else: + return fn(batch) + + def get_model_call_fn(self, fn: str) -> Tuple: + if hasattr(self.model, fn): + fn = getattr(self.model, fn) + if not callable(fn): + raise RuntimeError(f"The `{fn}` attribute is not `Callable`.") + logger.debug(f'Use {_get_fun_msg(fn, with_fp=False)}...') + return fn, None + elif fn in {"train_step", "evaluate_step"}: + logger.debug(f'Use {_get_fun_msg(self.model.forward, with_fp=False)}...') + return self.model, self.model.forward + else: + raise RuntimeError(f"There is no `{fn}` method in your {type(self.model)}.") + + def set_dist_repro_dataloader(self, dataloader, + dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler] = None, + reproducible: bool = False): + return dataloader + # 如果 dist 为 ReproducibleBatchSampler, ReproducibleIterator 说明是在断点重训时 driver.load_checkpoint 函数调用; + if isinstance(dist, ReproducibleBatchSampler): + return replace_batch_sampler(dataloader, dist) + elif isinstance(dist, ReproducibleSampler): + return replace_sampler(dataloader, dist) + + # 如果 dist 为 str 或者 None,说明是在 trainer 初试化时调用; + args = self.get_dataloader_args(dataloader) + if isinstance(args.batch_sampler, ReproducibleBatchSampler): + batch_sampler = re_instantiate_sampler(args.batch_sampler) + return replace_batch_sampler(dataloader, batch_sampler) + elif isinstance(args.sampler, ReproducibleSampler): + sampler = re_instantiate_sampler(args.sampler) + return replace_sampler(dataloader, sampler) + + if reproducible: + if type(args.batch_sampler) is TorchBatchSampler: + if type(args.sampler) is TorchRandomSampler: + if getattr(args.sampler, '_num_samples', None) is None \ + and getattr(args.sampler, 'replacements', False) is False \ + and getattr(args.sampler, 'generator', None) is None: + # 如果本来就是随机的,并且没有定制,直接替换掉吧。 + sampler = RandomSampler(args.sampler.data_source, shuffle=True) + logger.debug("Replace torch RandomSampler into fastNLP RandomSampler.") + return replace_sampler(dataloader, sampler) + elif type(args.sampler) is TorchSequentialSampler: + # 需要替换为不要 shuffle 的。 + sampler = RandomSampler(args.sampler.data_source, shuffle=False) + logger.debug("Replace torch SequentialSampler into fastNLP RandomSampler.") + return replace_sampler(dataloader, sampler) + batch_sampler = ReproduceBatchSampler( + batch_sampler=args.batch_sampler, + batch_size=args.batch_size, + drop_last=args.drop_last + ) + return replace_batch_sampler(dataloader, batch_sampler) + else: + return dataloader + + def unwrap_model(self): + r""" + :return: 返回原本的模型,例如没有被 ``DataParallel`` 包裹; + """ + if isinstance(self.model, deepspeed.DeepSpeedEngine): + print(type(self.model.module), self.model.module) + return self.model.module + if isinstance(self.model, torch.nn.DataParallel) or \ + isinstance(self.model, torch.nn.parallel.DistributedDataParallel): + return self.model.module + else: + return self.model + + @property + def data_device(self): + r""" + 注意单卡模式下使用 ``driver.data_device`` 等价于使用 ``driver.model_device``; + """ + return self.model_device + + def is_distributed(self): + r""" + :return: 返回当前使用的 driver 是否是分布式的 driver,对于 ``TorchSingleDriver`` 来说直接返回 ``False``; + """ + return False + + def _set_deepspeed_activation_checkpointing(self): + if self.config.get("activation_checkpointing"): + checkpoint_config = self.config["activation_checkpointing"] + deepspeed.checkpointing.configure( + mpu_=None, + partition_activations=checkpoint_config.get("partition_activations"), + contiguous_checkpointing=checkpoint_config.get("contiguous_memory_optimization"), + checkpoint_in_cpu=checkpoint_config.get("cpu_checkpointing"), + profile=checkpoint_config.get("profile"), + ) \ No newline at end of file diff --git a/fastNLP/core/drivers/torch_driver/utils.py b/fastNLP/core/drivers/torch_driver/utils.py index f5a76a9e..da621b60 100644 --- a/fastNLP/core/drivers/torch_driver/utils.py +++ b/fastNLP/core/drivers/torch_driver/utils.py @@ -1,6 +1,6 @@ import os -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Union from enum import IntEnum import contextlib import random @@ -292,3 +292,88 @@ def _check_dataloader_args_for_distributed(args, controller='Trainer'): f"``{substitution}``. The customized sampler should set for distributed running " f"before initializing ``{controller}`` , and then set the " f"parameter ``use_dist_sampler`` of ``{controller}`` to ``False``.") + +def _create_default_config( + zero_optimization: bool = True, + zero_allow_untested_optimizer: bool = True, + logging_batch_size_per_gpu: Union[str, int] = "auto", + partition_activations: bool = False, + cpu_checkpointing: bool = False, + contiguous_memory_optimization: bool = False, + synchronize_checkpoint_boundary: bool = False, + offload_optimizer: bool = False, + offload_parameters: bool = False, + offload_params_device: str = "cpu", + nvme_path: str = "/local_nvme", + params_buffer_count: int = 5, + params_buffer_size: int = 100_000_000, + max_in_cpu: int = 1_000_000_000, + offload_optimizer_device: str = "cpu", + optimizer_buffer_count: int = 4, + pin_memory: bool = False, + block_size: int = 1048576, + queue_depth: int = 8, + single_submit: bool = False, + overlap_events: bool = True, + thread_count: int = 1, + stage: int = 2, + contiguous_gradients: bool = True, + overlap_comm: bool = True, + allgather_partitions: bool = True, + reduce_scatter: bool = True, + allgather_bucket_size: int = 200_000_000, + reduce_bucket_size: int = 200_000_000, + sub_group_size: int = 1_000_000_000_000, +) -> Dict: + cfg = { + "activation_checkpointing": { + "partition_activations": partition_activations, + "cpu_checkpointing": cpu_checkpointing, + "contiguous_memory_optimization": contiguous_memory_optimization, + "synchronize_checkpoint_boundary": synchronize_checkpoint_boundary, + }, + "aio": { + "block_size": block_size, + "queue_depth": queue_depth, + "single_submit": single_submit, + "overlap_events": overlap_events, + "thread_count": thread_count, + }, + } + zero_kwargs = { + "stage": stage, + "contiguous_gradients": contiguous_gradients, + "overlap_comm": overlap_comm, + "allgather_partitions": allgather_partitions, + "reduce_scatter": reduce_scatter, + "allgather_bucket_size": allgather_bucket_size, + "reduce_bucket_size": reduce_bucket_size, + "sub_group_size": sub_group_size, + } + if zero_optimization: + zero_config = zero_kwargs + + if offload_optimizer: + zero_config["offload_optimizer"] = { + "device": offload_optimizer_device, + "nvme_path": nvme_path, + "buffer_count": optimizer_buffer_count, + "pin_memory": pin_memory, + } + if offload_parameters: + zero_config["offload_param"] = { + "device": offload_params_device, + "nvme_path": nvme_path, + "buffer_count": params_buffer_count, + "buffer_size": params_buffer_size, + "max_in_cpu": max_in_cpu, + "pin_memory": pin_memory, + } + cfg = { + "zero_allow_untested_optimizer": zero_allow_untested_optimizer, + "zero_optimization": zero_config, + **cfg, + } + if logging_batch_size_per_gpu != "auto": + cfg = {"train_micro_batch_size_per_gpu": logging_batch_size_per_gpu, **cfg} + return cfg \ No newline at end of file diff --git a/fastNLP/envs/imports.py b/fastNLP/envs/imports.py index 77b642c3..52f49e59 100644 --- a/fastNLP/envs/imports.py +++ b/fastNLP/envs/imports.py @@ -22,5 +22,6 @@ _NEED_IMPORT_FAIRSCALE = not _IS_WINDOWS and _module_available("fairscale") and _NEED_IMPORT_TORCH = _module_available("torch") and 'torch' in need_import _NEED_IMPORT_JITTOR = _module_available("jittor") and 'jittor' in need_import _NEED_IMPORT_PADDLE = _module_available("paddle") and 'paddle' in need_import +_NEED_IMPORT_DEEPSPEED = _module_available("deepspeed") and 'deepspeed' in need_import _TORCH_GREATER_EQUAL_1_8 = _NEED_IMPORT_TORCH and _compare_version("torch", operator.ge, "1.8.0")