Browse Source

deepspeed driver init

tags/v1.0.0alpha
x54-729 3 years ago
parent
commit
55d8738def
3 changed files with 252 additions and 1 deletions
  1. +165
    -0
      fastNLP/core/drivers/torch_driver/deepspeed.py
  2. +86
    -1
      fastNLP/core/drivers/torch_driver/utils.py
  3. +1
    -0
      fastNLP/envs/imports.py

+ 165
- 0
fastNLP/core/drivers/torch_driver/deepspeed.py View File

@@ -0,0 +1,165 @@
from typing import Optional, Union, Callable, Dict, Tuple, Sequence, List
from .torch_driver import TorchDriver
from .utils import _create_default_config
from fastNLP.core.utils import auto_param_call
from fastNLP.core.utils.utils import _get_fun_msg
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, re_instantiate_sampler, \
ReproduceBatchSampler
from fastNLP.core.log import logger
from fastNLP.envs.imports import _NEED_IMPORT_TORCH, _NEED_IMPORT_DEEPSPEED

if _NEED_IMPORT_TORCH:
import pytorch_lightning
import torch
from torch.nn import DataParallel
if _NEED_IMPORT_DEEPSPEED:
import deepspeed

__all__ = [
"DeepSpeedDriver",
]

class DeepSpeedDriver(TorchDriver):
def __init__(self, model, fp16, strategy, **kwargs):
super(DeepSpeedDriver, self).__init__(model, fp16)

self.strategy = strategy

def setup(self):

if self.strategy == "deepspeed":
self.config = _create_default_config(stage=2)
elif self.strategy == "deepspeed_stage_1":
self.config = _create_default_config(stage=1)
elif self.strategy == "deepspeed_stage_2":
self.config = _create_default_config(stage=2)
elif self.strategy == "deepspeed_stage_2_offload":
self.config = _create_default_config(stage=2, offload_optimizer=True)
elif self.strategy == "deepspeed_stage_3":
self.config = _create_default_config(stage=3)
elif self.strategy == "deepspeed_stage_3_offload":
self.config = _create_default_config(
stage=3,
offload_optimizer=True,
offload_parameters=True,
)
elif self.strategy == "deepspeed_stage_3_offload_nvme":
self.config = _create_default_config(
stage=3,
offload_optimizer=True,
offload_parameters=True,
remote_device="nvme",
offload_params_device="nvme",
offload_optimizer_device="nvme",
)
for i, optimizer in enumerate(self.optimizers):
# TODO 多个 optimizer
engine, optimizer_ds, _, _ = deepspeed.initialize(
model=self.model,
optimizer=optimizer,
config=self.config
)
self._optimizers[i] = optimizer_ds
self.model = engine

self._set_deepspeed_activation_checkpointing()

def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict:
if isinstance(batch, Dict) and not self.wo_auto_param_call:
return auto_param_call(fn, batch, signature_fn=signature_fn)
else:
return fn(batch)

def get_model_call_fn(self, fn: str) -> Tuple:
if hasattr(self.model, fn):
fn = getattr(self.model, fn)
if not callable(fn):
raise RuntimeError(f"The `{fn}` attribute is not `Callable`.")
logger.debug(f'Use {_get_fun_msg(fn, with_fp=False)}...')
return fn, None
elif fn in {"train_step", "evaluate_step"}:
logger.debug(f'Use {_get_fun_msg(self.model.forward, with_fp=False)}...')
return self.model, self.model.forward
else:
raise RuntimeError(f"There is no `{fn}` method in your {type(self.model)}.")

def set_dist_repro_dataloader(self, dataloader,
dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler] = None,
reproducible: bool = False):
return dataloader
# 如果 dist 为 ReproducibleBatchSampler, ReproducibleIterator 说明是在断点重训时 driver.load_checkpoint 函数调用;
if isinstance(dist, ReproducibleBatchSampler):
return replace_batch_sampler(dataloader, dist)
elif isinstance(dist, ReproducibleSampler):
return replace_sampler(dataloader, dist)

# 如果 dist 为 str 或者 None,说明是在 trainer 初试化时调用;
args = self.get_dataloader_args(dataloader)
if isinstance(args.batch_sampler, ReproducibleBatchSampler):
batch_sampler = re_instantiate_sampler(args.batch_sampler)
return replace_batch_sampler(dataloader, batch_sampler)
elif isinstance(args.sampler, ReproducibleSampler):
sampler = re_instantiate_sampler(args.sampler)
return replace_sampler(dataloader, sampler)

if reproducible:
if type(args.batch_sampler) is TorchBatchSampler:
if type(args.sampler) is TorchRandomSampler:
if getattr(args.sampler, '_num_samples', None) is None \
and getattr(args.sampler, 'replacements', False) is False \
and getattr(args.sampler, 'generator', None) is None:
# 如果本来就是随机的,并且没有定制,直接替换掉吧。
sampler = RandomSampler(args.sampler.data_source, shuffle=True)
logger.debug("Replace torch RandomSampler into fastNLP RandomSampler.")
return replace_sampler(dataloader, sampler)
elif type(args.sampler) is TorchSequentialSampler:
# 需要替换为不要 shuffle 的。
sampler = RandomSampler(args.sampler.data_source, shuffle=False)
logger.debug("Replace torch SequentialSampler into fastNLP RandomSampler.")
return replace_sampler(dataloader, sampler)
batch_sampler = ReproduceBatchSampler(
batch_sampler=args.batch_sampler,
batch_size=args.batch_size,
drop_last=args.drop_last
)
return replace_batch_sampler(dataloader, batch_sampler)
else:
return dataloader

def unwrap_model(self):
r"""
:return: 返回原本的模型,例如没有被 ``DataParallel`` 包裹;
"""
if isinstance(self.model, deepspeed.DeepSpeedEngine):
print(type(self.model.module), self.model.module)
return self.model.module
if isinstance(self.model, torch.nn.DataParallel) or \
isinstance(self.model, torch.nn.parallel.DistributedDataParallel):
return self.model.module
else:
return self.model

@property
def data_device(self):
r"""
注意单卡模式下使用 ``driver.data_device`` 等价于使用 ``driver.model_device``;
"""
return self.model_device

def is_distributed(self):
r"""
:return: 返回当前使用的 driver 是否是分布式的 driver,对于 ``TorchSingleDriver`` 来说直接返回 ``False``;
"""
return False

def _set_deepspeed_activation_checkpointing(self):
if self.config.get("activation_checkpointing"):
checkpoint_config = self.config["activation_checkpointing"]
deepspeed.checkpointing.configure(
mpu_=None,
partition_activations=checkpoint_config.get("partition_activations"),
contiguous_checkpointing=checkpoint_config.get("contiguous_memory_optimization"),
checkpoint_in_cpu=checkpoint_config.get("cpu_checkpointing"),
profile=checkpoint_config.get("profile"),
)

+ 86
- 1
fastNLP/core/drivers/torch_driver/utils.py View File

@@ -1,6 +1,6 @@
import os

from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Union
from enum import IntEnum
import contextlib
import random
@@ -292,3 +292,88 @@ def _check_dataloader_args_for_distributed(args, controller='Trainer'):
f"``{substitution}``. The customized sampler should set for distributed running "
f"before initializing ``{controller}`` , and then set the "
f"parameter ``use_dist_sampler`` of ``{controller}`` to ``False``.")

def _create_default_config(
zero_optimization: bool = True,
zero_allow_untested_optimizer: bool = True,
logging_batch_size_per_gpu: Union[str, int] = "auto",
partition_activations: bool = False,
cpu_checkpointing: bool = False,
contiguous_memory_optimization: bool = False,
synchronize_checkpoint_boundary: bool = False,
offload_optimizer: bool = False,
offload_parameters: bool = False,
offload_params_device: str = "cpu",
nvme_path: str = "/local_nvme",
params_buffer_count: int = 5,
params_buffer_size: int = 100_000_000,
max_in_cpu: int = 1_000_000_000,
offload_optimizer_device: str = "cpu",
optimizer_buffer_count: int = 4,
pin_memory: bool = False,
block_size: int = 1048576,
queue_depth: int = 8,
single_submit: bool = False,
overlap_events: bool = True,
thread_count: int = 1,
stage: int = 2,
contiguous_gradients: bool = True,
overlap_comm: bool = True,
allgather_partitions: bool = True,
reduce_scatter: bool = True,
allgather_bucket_size: int = 200_000_000,
reduce_bucket_size: int = 200_000_000,
sub_group_size: int = 1_000_000_000_000,
) -> Dict:
cfg = {
"activation_checkpointing": {
"partition_activations": partition_activations,
"cpu_checkpointing": cpu_checkpointing,
"contiguous_memory_optimization": contiguous_memory_optimization,
"synchronize_checkpoint_boundary": synchronize_checkpoint_boundary,
},
"aio": {
"block_size": block_size,
"queue_depth": queue_depth,
"single_submit": single_submit,
"overlap_events": overlap_events,
"thread_count": thread_count,
},
}
zero_kwargs = {
"stage": stage,
"contiguous_gradients": contiguous_gradients,
"overlap_comm": overlap_comm,
"allgather_partitions": allgather_partitions,
"reduce_scatter": reduce_scatter,
"allgather_bucket_size": allgather_bucket_size,
"reduce_bucket_size": reduce_bucket_size,
"sub_group_size": sub_group_size,
}
if zero_optimization:
zero_config = zero_kwargs

if offload_optimizer:
zero_config["offload_optimizer"] = {
"device": offload_optimizer_device,
"nvme_path": nvme_path,
"buffer_count": optimizer_buffer_count,
"pin_memory": pin_memory,
}
if offload_parameters:
zero_config["offload_param"] = {
"device": offload_params_device,
"nvme_path": nvme_path,
"buffer_count": params_buffer_count,
"buffer_size": params_buffer_size,
"max_in_cpu": max_in_cpu,
"pin_memory": pin_memory,
}
cfg = {
"zero_allow_untested_optimizer": zero_allow_untested_optimizer,
"zero_optimization": zero_config,
**cfg,
}
if logging_batch_size_per_gpu != "auto":
cfg = {"train_micro_batch_size_per_gpu": logging_batch_size_per_gpu, **cfg}
return cfg

+ 1
- 0
fastNLP/envs/imports.py View File

@@ -22,5 +22,6 @@ _NEED_IMPORT_FAIRSCALE = not _IS_WINDOWS and _module_available("fairscale") and
_NEED_IMPORT_TORCH = _module_available("torch") and 'torch' in need_import
_NEED_IMPORT_JITTOR = _module_available("jittor") and 'jittor' in need_import
_NEED_IMPORT_PADDLE = _module_available("paddle") and 'paddle' in need_import
_NEED_IMPORT_DEEPSPEED = _module_available("deepspeed") and 'deepspeed' in need_import

_TORCH_GREATER_EQUAL_1_8 = _NEED_IMPORT_TORCH and _compare_version("torch", operator.ge, "1.8.0")

Loading…
Cancel
Save