Browse Source

deepspeed基本功能

tags/v1.0.0alpha
x54-729 3 years ago
parent
commit
1a2eb93ab4
2 changed files with 194 additions and 116 deletions
  1. +193
    -115
      fastNLP/core/drivers/torch_driver/deepspeed.py
  2. +1
    -1
      fastNLP/envs/imports.py

+ 193
- 115
fastNLP/core/drivers/torch_driver/deepspeed.py View File

@@ -1,17 +1,16 @@
import os

from typing import Optional, Union, Callable, Dict, Tuple, Sequence, List
from .torch_driver import TorchDriver
from .utils import _create_default_config
from fastNLP.core.utils import auto_param_call
from fastNLP.core.utils.utils import _get_fun_msg
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, re_instantiate_sampler, \
ReproduceBatchSampler
from .ddp import TorchDDPDriver
from .utils import _create_default_config, _DDPWrappingModel
from fastNLP.core.log import logger
from fastNLP.envs.env import FASTNLP_DISTRIBUTED_CHECK
from fastNLP.envs.imports import _NEED_IMPORT_TORCH, _NEED_IMPORT_DEEPSPEED

if _NEED_IMPORT_TORCH:
import pytorch_lightning
import torch
from torch.nn import DataParallel
import torch.distributed as dist
if _NEED_IMPORT_DEEPSPEED:
import deepspeed
@@ -20,13 +19,160 @@ __all__ = [
"DeepSpeedDriver",
]

class DeepSpeedDriver(TorchDriver):
def __init__(self, model, fp16, strategy, **kwargs):
super(DeepSpeedDriver, self).__init__(model, fp16)

class DeepSpeedDriver(TorchDDPDriver):
# TODO fp16 load_config
def __init__(
self,
model,
parallel_device: Union[List["torch.device"], "torch.device"],
is_pull_by_torch_run = False,
fp16: bool = False,
strategy= "deepspeed",
**kwargs
):
assert _NEED_IMPORT_DEEPSPEED, "deepspeed is not imported."
assert not dist.is_initialized(), "DeepSpeedDriver does not support initialize distributed by user."
TorchDriver.__init__(self, model=model, fp16=False, **kwargs)
self.fp16 = fp16

# 如果用户自己在外面初始化 DDP,那么其一定是通过 python -m torch.distributed.launch 拉起的;
self.is_pull_by_torch_run = is_pull_by_torch_run
self.parallel_device = parallel_device
if not is_pull_by_torch_run and parallel_device is None:
raise ValueError(
"Parameter `parallel_device` can not be None when using `TorchDDPDriver`. This error is caused "
"when your value of parameter `device` is `None` in your `Trainer` instance.")

# 注意我们在 initialize_torch_driver 中的逻辑就是如果是 is_pull_by_torch_run,那么我们就直接把 parallel_device 置为当前进程的gpu;
if is_pull_by_torch_run:
self.model_device = parallel_device
else:
# 我们的 model_device 一定是 torch.device,而不是一个 list;
self.model_device = parallel_device[self.local_rank]

# 暂时不允许在外面初始化
self.outside_ddp = False
self._data_device = kwargs.get("data_device", None)
if isinstance(self._data_device, int):
if self._data_device < 0:
raise ValueError("Parameter `data_device` can not be smaller than 0.")
_could_use_device_num = torch.cuda.device_count()
if self._data_device >= _could_use_device_num:
raise ValueError("The gpu device that parameter `device` specifies is not existed.")
self._data_device = torch.device(f"cuda:{self._data_device}")
elif isinstance(self._data_device, str):
self._data_device = torch.device(self._data_device)
elif self._data_device is not None and not isinstance(self._data_device, torch.device):
raise ValueError("Parameter `device` is wrong type, please check our documentation for the right use.")

self._master_port = None
# world_size 表示的就是全局的显卡的数量;
self.world_size = None # int(os.environ.get("WORLD_SIZE")) len(self.parallel_device)
self.global_rank = 0

self.output_from_new_proc = kwargs.get("output_from_new_proc", "only_error")
assert isinstance(self.output_from_new_proc, str), "Parameter `output_from_new_proc` can only be `str` type."
if self.output_from_new_proc not in {"all", "ignore", "only_error"}:
os.makedirs(name=self.output_from_new_proc, exist_ok=True)
self.output_from_new_proc = os.path.abspath(self.output_from_new_proc)

self._has_setup = False # 设置这一参数是因为 evaluator 中也会进行 setup 操作,但是显然是不需要的也不应该的;
self._has_ddpwrapped = False # 判断传入的模型是否经过 _has_ddpwrapped 包裹;
self.strategy = strategy

self._ds_kwargs = kwargs.get("deepspeed_kwargs", {})

def setup(self):
r"""
准备分布式环境,该函数主要做以下两件事情:

1. 开启多进程,每个 gpu 设备对应单独的一个进程;
2. 每个进程将模型迁移到自己对应的 ``gpu`` 设备上;然后使用 ``DistributedDataParallel`` 包裹模型;
"""
if len(self.optimizers) != 1:
raise ValueError("Multi optimizers is not supported for DeepSpeedDriver right now.")
if self._has_setup:
return
self.setup_config()
self._has_setup = True
# 如果用户需要使用多机模式,那么一定进入到这里;
if self.is_pull_by_torch_run:
# dist.get_world_size() 只能在 dist.init_process_group 初始化之后进行调用;
self.world_size = int(os.environ.get("WORLD_SIZE"))
self.global_rank = int(os.environ.get("RANK"))
logger.info(f"World size: {self.world_size}, Global rank: {self.global_rank}")

if not dist.is_initialized():
deepspeed.init_distributed("nccl", distributed_port=self.master_port)

os.environ["fastnlp_torch_launch_not_ddp"] = "yes"

# 进入到这里的情况时:
# dist.is_initialized 一定为 False;
# 一定是单机;
# self.parallel_device 一定是 List[torch.device];
else:
if not dist.is_initialized():
# 这里主要的问题在于要区分 rank0 和其它 rank 的情况;
self.world_size = len(self.parallel_device)
self.open_subprocess()
self.global_rank = self.local_rank # rank 一定是通过环境变量去获取的;
deepspeed.init_distributed("nccl", distributed_port=self.master_port)
# 用户在这个 trainer 前面又初始化了一个 trainer,并且使用的是 TorchDDPDriver;
else:
# 如果 `dist.is_initialized() == True`,那么说明 TorchDDPDriver 在之前已经初始化并且已经 setup 过一次,那么我们需要保证现在
# 使用的(即之后的)TorchDDPDriver 的设置和第一个 TorchDDPDriver 是完全一样的;
pre_num_processes = int(os.environ[FASTNLP_DISTRIBUTED_CHECK])
if pre_num_processes != len(self.parallel_device):
raise RuntimeError(
"Notice you are using `TorchDDPDriver` after one instantiated `TorchDDPDriver`, it is not"
"allowed that your second `TorchDDPDriver` has a new setting of parameters "
"`num_nodes` and `num_processes`.")
self.world_size = dist.get_world_size()
self.global_rank = dist.get_rank()

torch.cuda.set_device(self.model_device)
self.configure_ddp()

self.barrier()
# 初始化 self._pids,从而使得每一个进程都能接受到 rank0 的 send 操作;
self._pids = [torch.tensor(0, dtype=torch.int).to(self.data_device) for _ in range(dist.get_world_size())]
dist.all_gather(self._pids, torch.tensor(os.getpid(), dtype=torch.int).to(self.data_device))
local_world_size = int(os.environ.get("LOCAL_WORLD_SIZE")) if "LOCAL_WORLD_SIZE" in os.environ else None
if local_world_size is None:
local_world_size = torch.tensor(int(os.environ.get("LOCAL_RANK")), dtype=torch.int).to(self.data_device)
dist.all_reduce(local_world_size, op=dist.ReduceOp.MAX)
local_world_size = local_world_size.tolist() + 1

node_rank = self.global_rank // local_world_size
self._pids = self._pids[node_rank * local_world_size: (node_rank + 1) * local_world_size]
self._pids = self.tensor_to_numeric(self._pids)

def configure_ddp(self):
# 设置 deepspeed
if not isinstance(self.model, deepspeed.DeepSpeedEngine):
self.model, ds_optimizer, _, _ = deepspeed.initialize(
model=_DDPWrappingModel(self.model),
optimizer=self.optimizers[0],
config=self.config
)
# TODO 是否有必要
self._optimizers = [ds_optimizer]

if self.config.get("activation_checkpointing"):
checkpoint_config = self.config["activation_checkpointing"]
deepspeed.checkpointing.configure(
mpu_=None,
partition_activations=checkpoint_config.get("partition_activations"),
contiguous_checkpointing=checkpoint_config.get("contiguous_memory_optimization"),
checkpoint_in_cpu=checkpoint_config.get("cpu_checkpointing"),
profile=checkpoint_config.get("profile"),
)

self._has_ddpwrapped = True

def setup_config(self):

if self.strategy == "deepspeed":
self.config = _create_default_config(stage=2)
@@ -53,113 +199,45 @@ class DeepSpeedDriver(TorchDriver):
offload_params_device="nvme",
offload_optimizer_device="nvme",
)
for i, optimizer in enumerate(self.optimizers):
# TODO 多个 optimizer
engine, optimizer_ds, _, _ = deepspeed.initialize(
model=self.model,
optimizer=optimizer,
config=self.config
)
self._optimizers[i] = optimizer_ds
self.model = engine

self._set_deepspeed_activation_checkpointing()

def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict:
if isinstance(batch, Dict) and not self.wo_auto_param_call:
return auto_param_call(fn, batch, signature_fn=signature_fn)
else:
return fn(batch)

def get_model_call_fn(self, fn: str) -> Tuple:
if hasattr(self.model, fn):
fn = getattr(self.model, fn)
if not callable(fn):
raise RuntimeError(f"The `{fn}` attribute is not `Callable`.")
logger.debug(f'Use {_get_fun_msg(fn, with_fp=False)}...')
return fn, None
elif fn in {"train_step", "evaluate_step"}:
logger.debug(f'Use {_get_fun_msg(self.model.forward, with_fp=False)}...')
return self.model, self.model.forward
else:
raise RuntimeError(f"There is no `{fn}` method in your {type(self.model)}.")

def set_dist_repro_dataloader(self, dataloader,
dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler] = None,
reproducible: bool = False):
return dataloader
# 如果 dist 为 ReproducibleBatchSampler, ReproducibleIterator 说明是在断点重训时 driver.load_checkpoint 函数调用;
if isinstance(dist, ReproducibleBatchSampler):
return replace_batch_sampler(dataloader, dist)
elif isinstance(dist, ReproducibleSampler):
return replace_sampler(dataloader, dist)

# 如果 dist 为 str 或者 None,说明是在 trainer 初试化时调用;
args = self.get_dataloader_args(dataloader)
if isinstance(args.batch_sampler, ReproducibleBatchSampler):
batch_sampler = re_instantiate_sampler(args.batch_sampler)
return replace_batch_sampler(dataloader, batch_sampler)
elif isinstance(args.sampler, ReproducibleSampler):
sampler = re_instantiate_sampler(args.sampler)
return replace_sampler(dataloader, sampler)

if reproducible:
if type(args.batch_sampler) is TorchBatchSampler:
if type(args.sampler) is TorchRandomSampler:
if getattr(args.sampler, '_num_samples', None) is None \
and getattr(args.sampler, 'replacements', False) is False \
and getattr(args.sampler, 'generator', None) is None:
# 如果本来就是随机的,并且没有定制,直接替换掉吧。
sampler = RandomSampler(args.sampler.data_source, shuffle=True)
logger.debug("Replace torch RandomSampler into fastNLP RandomSampler.")
return replace_sampler(dataloader, sampler)
elif type(args.sampler) is TorchSequentialSampler:
# 需要替换为不要 shuffle 的。
sampler = RandomSampler(args.sampler.data_source, shuffle=False)
logger.debug("Replace torch SequentialSampler into fastNLP RandomSampler.")
return replace_sampler(dataloader, sampler)
batch_sampler = ReproduceBatchSampler(
batch_sampler=args.batch_sampler,
batch_size=args.batch_size,
drop_last=args.drop_last
)
return replace_batch_sampler(dataloader, batch_sampler)
else:
return dataloader
raise ValueError(f"Unknown deepspeed strategy {self.strategy}.")

self.config.setdefault("train_micro_batch_size_per_gpu", 1)
self.config.setdefault("steps_per_print", 2147483647)

# TODO 梯度裁剪的设置,这里需要用到trainer
# 从kwargs 获取
# 精度设置
# _format_precision_config
if self.fp16:
if "fp16" not in self.config:
# FP16 is a DeepSpeed standalone AMP implementation
logger.debug("Enabling DeepSpeed FP16.")
# TODO 这部分是否可以像 pytorch-lightning 那样给用户定制
self.config["fp16"] = {
"enabled": True,
"loss_scale": 0,
"initial_scale_power": True,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1,
}
elif "amp" not in self.config:
logger.debug("Enabling DeepSpeed APEX Implementation.")
self.config["amp"] = {"enabled": True, "opt_level": "O1"}

def zero_grad(self):
# DeepSpeedEngine.step 包含了 zero_grad 功能
pass

def backward(self, loss):
self.model.backward(loss)

def step(self):
self.model.step()

def unwrap_model(self):
r"""
:return: 返回原本的模型,例如没有被 ``DataParallel`` 包裹;
"""
if isinstance(self.model, deepspeed.DeepSpeedEngine):
print(type(self.model.module), self.model.module)
return self.model.module
if isinstance(self.model, torch.nn.DataParallel) or \
isinstance(self.model, torch.nn.parallel.DistributedDataParallel):
return self.model.module
else:
return self.model

@property
def data_device(self):
r"""
注意单卡模式下使用 ``driver.data_device`` 等价于使用 ``driver.model_device``;
"""
return self.model_device

def is_distributed(self):
r"""
:return: 返回当前使用的 driver 是否是分布式的 driver,对于 ``TorchSingleDriver`` 来说直接返回 ``False``;
"""
return False

def _set_deepspeed_activation_checkpointing(self):
if self.config.get("activation_checkpointing"):
checkpoint_config = self.config["activation_checkpointing"]
deepspeed.checkpointing.configure(
mpu_=None,
partition_activations=checkpoint_config.get("partition_activations"),
contiguous_checkpointing=checkpoint_config.get("contiguous_memory_optimization"),
checkpoint_in_cpu=checkpoint_config.get("cpu_checkpointing"),
profile=checkpoint_config.get("profile"),
)
return self.model.module.model

+ 1
- 1
fastNLP/envs/imports.py View File

@@ -22,6 +22,6 @@ _NEED_IMPORT_FAIRSCALE = not _IS_WINDOWS and _module_available("fairscale") and
_NEED_IMPORT_TORCH = _module_available("torch") and 'torch' in need_import
_NEED_IMPORT_JITTOR = _module_available("jittor") and 'jittor' in need_import
_NEED_IMPORT_PADDLE = _module_available("paddle") and 'paddle' in need_import
_NEED_IMPORT_DEEPSPEED = _module_available("deepspeed") and 'deepspeed' in need_import
_NEED_IMPORT_DEEPSPEED = _module_available("deepspeed") and 'torch' in need_import

_TORCH_GREATER_EQUAL_1_8 = _NEED_IMPORT_TORCH and _compare_version("torch", operator.ge, "1.8.0")

Loading…
Cancel
Save