|
|
@@ -1,17 +1,16 @@ |
|
|
|
import os |
|
|
|
|
|
|
|
from typing import Optional, Union, Callable, Dict, Tuple, Sequence, List |
|
|
|
from .torch_driver import TorchDriver |
|
|
|
from .utils import _create_default_config |
|
|
|
from fastNLP.core.utils import auto_param_call |
|
|
|
from fastNLP.core.utils.utils import _get_fun_msg |
|
|
|
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, re_instantiate_sampler, \ |
|
|
|
ReproduceBatchSampler |
|
|
|
from .ddp import TorchDDPDriver |
|
|
|
from .utils import _create_default_config, _DDPWrappingModel |
|
|
|
from fastNLP.core.log import logger |
|
|
|
from fastNLP.envs.env import FASTNLP_DISTRIBUTED_CHECK |
|
|
|
from fastNLP.envs.imports import _NEED_IMPORT_TORCH, _NEED_IMPORT_DEEPSPEED |
|
|
|
|
|
|
|
if _NEED_IMPORT_TORCH: |
|
|
|
import pytorch_lightning |
|
|
|
import torch |
|
|
|
from torch.nn import DataParallel |
|
|
|
import torch.distributed as dist |
|
|
|
|
|
|
|
if _NEED_IMPORT_DEEPSPEED: |
|
|
|
import deepspeed |
|
|
@@ -20,13 +19,160 @@ __all__ = [ |
|
|
|
"DeepSpeedDriver", |
|
|
|
] |
|
|
|
|
|
|
|
class DeepSpeedDriver(TorchDriver): |
|
|
|
def __init__(self, model, fp16, strategy, **kwargs): |
|
|
|
super(DeepSpeedDriver, self).__init__(model, fp16) |
|
|
|
|
|
|
|
class DeepSpeedDriver(TorchDDPDriver): |
|
|
|
# TODO fp16 load_config |
|
|
|
def __init__( |
|
|
|
self, |
|
|
|
model, |
|
|
|
parallel_device: Union[List["torch.device"], "torch.device"], |
|
|
|
is_pull_by_torch_run = False, |
|
|
|
fp16: bool = False, |
|
|
|
strategy= "deepspeed", |
|
|
|
**kwargs |
|
|
|
): |
|
|
|
assert _NEED_IMPORT_DEEPSPEED, "deepspeed is not imported." |
|
|
|
assert not dist.is_initialized(), "DeepSpeedDriver does not support initialize distributed by user." |
|
|
|
TorchDriver.__init__(self, model=model, fp16=False, **kwargs) |
|
|
|
self.fp16 = fp16 |
|
|
|
|
|
|
|
# 如果用户自己在外面初始化 DDP,那么其一定是通过 python -m torch.distributed.launch 拉起的; |
|
|
|
self.is_pull_by_torch_run = is_pull_by_torch_run |
|
|
|
self.parallel_device = parallel_device |
|
|
|
if not is_pull_by_torch_run and parallel_device is None: |
|
|
|
raise ValueError( |
|
|
|
"Parameter `parallel_device` can not be None when using `TorchDDPDriver`. This error is caused " |
|
|
|
"when your value of parameter `device` is `None` in your `Trainer` instance.") |
|
|
|
|
|
|
|
# 注意我们在 initialize_torch_driver 中的逻辑就是如果是 is_pull_by_torch_run,那么我们就直接把 parallel_device 置为当前进程的gpu; |
|
|
|
if is_pull_by_torch_run: |
|
|
|
self.model_device = parallel_device |
|
|
|
else: |
|
|
|
# 我们的 model_device 一定是 torch.device,而不是一个 list; |
|
|
|
self.model_device = parallel_device[self.local_rank] |
|
|
|
|
|
|
|
# 暂时不允许在外面初始化 |
|
|
|
self.outside_ddp = False |
|
|
|
self._data_device = kwargs.get("data_device", None) |
|
|
|
if isinstance(self._data_device, int): |
|
|
|
if self._data_device < 0: |
|
|
|
raise ValueError("Parameter `data_device` can not be smaller than 0.") |
|
|
|
_could_use_device_num = torch.cuda.device_count() |
|
|
|
if self._data_device >= _could_use_device_num: |
|
|
|
raise ValueError("The gpu device that parameter `device` specifies is not existed.") |
|
|
|
self._data_device = torch.device(f"cuda:{self._data_device}") |
|
|
|
elif isinstance(self._data_device, str): |
|
|
|
self._data_device = torch.device(self._data_device) |
|
|
|
elif self._data_device is not None and not isinstance(self._data_device, torch.device): |
|
|
|
raise ValueError("Parameter `device` is wrong type, please check our documentation for the right use.") |
|
|
|
|
|
|
|
self._master_port = None |
|
|
|
# world_size 表示的就是全局的显卡的数量; |
|
|
|
self.world_size = None # int(os.environ.get("WORLD_SIZE")) len(self.parallel_device) |
|
|
|
self.global_rank = 0 |
|
|
|
|
|
|
|
self.output_from_new_proc = kwargs.get("output_from_new_proc", "only_error") |
|
|
|
assert isinstance(self.output_from_new_proc, str), "Parameter `output_from_new_proc` can only be `str` type." |
|
|
|
if self.output_from_new_proc not in {"all", "ignore", "only_error"}: |
|
|
|
os.makedirs(name=self.output_from_new_proc, exist_ok=True) |
|
|
|
self.output_from_new_proc = os.path.abspath(self.output_from_new_proc) |
|
|
|
|
|
|
|
self._has_setup = False # 设置这一参数是因为 evaluator 中也会进行 setup 操作,但是显然是不需要的也不应该的; |
|
|
|
self._has_ddpwrapped = False # 判断传入的模型是否经过 _has_ddpwrapped 包裹; |
|
|
|
self.strategy = strategy |
|
|
|
|
|
|
|
self._ds_kwargs = kwargs.get("deepspeed_kwargs", {}) |
|
|
|
|
|
|
|
def setup(self): |
|
|
|
r""" |
|
|
|
准备分布式环境,该函数主要做以下两件事情: |
|
|
|
|
|
|
|
1. 开启多进程,每个 gpu 设备对应单独的一个进程; |
|
|
|
2. 每个进程将模型迁移到自己对应的 ``gpu`` 设备上;然后使用 ``DistributedDataParallel`` 包裹模型; |
|
|
|
""" |
|
|
|
if len(self.optimizers) != 1: |
|
|
|
raise ValueError("Multi optimizers is not supported for DeepSpeedDriver right now.") |
|
|
|
if self._has_setup: |
|
|
|
return |
|
|
|
self.setup_config() |
|
|
|
self._has_setup = True |
|
|
|
# 如果用户需要使用多机模式,那么一定进入到这里; |
|
|
|
if self.is_pull_by_torch_run: |
|
|
|
# dist.get_world_size() 只能在 dist.init_process_group 初始化之后进行调用; |
|
|
|
self.world_size = int(os.environ.get("WORLD_SIZE")) |
|
|
|
self.global_rank = int(os.environ.get("RANK")) |
|
|
|
logger.info(f"World size: {self.world_size}, Global rank: {self.global_rank}") |
|
|
|
|
|
|
|
if not dist.is_initialized(): |
|
|
|
deepspeed.init_distributed("nccl", distributed_port=self.master_port) |
|
|
|
|
|
|
|
os.environ["fastnlp_torch_launch_not_ddp"] = "yes" |
|
|
|
|
|
|
|
# 进入到这里的情况时: |
|
|
|
# dist.is_initialized 一定为 False; |
|
|
|
# 一定是单机; |
|
|
|
# self.parallel_device 一定是 List[torch.device]; |
|
|
|
else: |
|
|
|
if not dist.is_initialized(): |
|
|
|
# 这里主要的问题在于要区分 rank0 和其它 rank 的情况; |
|
|
|
self.world_size = len(self.parallel_device) |
|
|
|
self.open_subprocess() |
|
|
|
self.global_rank = self.local_rank # rank 一定是通过环境变量去获取的; |
|
|
|
deepspeed.init_distributed("nccl", distributed_port=self.master_port) |
|
|
|
# 用户在这个 trainer 前面又初始化了一个 trainer,并且使用的是 TorchDDPDriver; |
|
|
|
else: |
|
|
|
# 如果 `dist.is_initialized() == True`,那么说明 TorchDDPDriver 在之前已经初始化并且已经 setup 过一次,那么我们需要保证现在 |
|
|
|
# 使用的(即之后的)TorchDDPDriver 的设置和第一个 TorchDDPDriver 是完全一样的; |
|
|
|
pre_num_processes = int(os.environ[FASTNLP_DISTRIBUTED_CHECK]) |
|
|
|
if pre_num_processes != len(self.parallel_device): |
|
|
|
raise RuntimeError( |
|
|
|
"Notice you are using `TorchDDPDriver` after one instantiated `TorchDDPDriver`, it is not" |
|
|
|
"allowed that your second `TorchDDPDriver` has a new setting of parameters " |
|
|
|
"`num_nodes` and `num_processes`.") |
|
|
|
self.world_size = dist.get_world_size() |
|
|
|
self.global_rank = dist.get_rank() |
|
|
|
|
|
|
|
torch.cuda.set_device(self.model_device) |
|
|
|
self.configure_ddp() |
|
|
|
|
|
|
|
self.barrier() |
|
|
|
# 初始化 self._pids,从而使得每一个进程都能接受到 rank0 的 send 操作; |
|
|
|
self._pids = [torch.tensor(0, dtype=torch.int).to(self.data_device) for _ in range(dist.get_world_size())] |
|
|
|
dist.all_gather(self._pids, torch.tensor(os.getpid(), dtype=torch.int).to(self.data_device)) |
|
|
|
local_world_size = int(os.environ.get("LOCAL_WORLD_SIZE")) if "LOCAL_WORLD_SIZE" in os.environ else None |
|
|
|
if local_world_size is None: |
|
|
|
local_world_size = torch.tensor(int(os.environ.get("LOCAL_RANK")), dtype=torch.int).to(self.data_device) |
|
|
|
dist.all_reduce(local_world_size, op=dist.ReduceOp.MAX) |
|
|
|
local_world_size = local_world_size.tolist() + 1 |
|
|
|
|
|
|
|
node_rank = self.global_rank // local_world_size |
|
|
|
self._pids = self._pids[node_rank * local_world_size: (node_rank + 1) * local_world_size] |
|
|
|
self._pids = self.tensor_to_numeric(self._pids) |
|
|
|
|
|
|
|
def configure_ddp(self): |
|
|
|
|
|
|
|
# 设置 deepspeed |
|
|
|
if not isinstance(self.model, deepspeed.DeepSpeedEngine): |
|
|
|
self.model, ds_optimizer, _, _ = deepspeed.initialize( |
|
|
|
model=_DDPWrappingModel(self.model), |
|
|
|
optimizer=self.optimizers[0], |
|
|
|
config=self.config |
|
|
|
) |
|
|
|
# TODO 是否有必要 |
|
|
|
self._optimizers = [ds_optimizer] |
|
|
|
|
|
|
|
if self.config.get("activation_checkpointing"): |
|
|
|
checkpoint_config = self.config["activation_checkpointing"] |
|
|
|
deepspeed.checkpointing.configure( |
|
|
|
mpu_=None, |
|
|
|
partition_activations=checkpoint_config.get("partition_activations"), |
|
|
|
contiguous_checkpointing=checkpoint_config.get("contiguous_memory_optimization"), |
|
|
|
checkpoint_in_cpu=checkpoint_config.get("cpu_checkpointing"), |
|
|
|
profile=checkpoint_config.get("profile"), |
|
|
|
) |
|
|
|
|
|
|
|
self._has_ddpwrapped = True |
|
|
|
|
|
|
|
def setup_config(self): |
|
|
|
|
|
|
|
if self.strategy == "deepspeed": |
|
|
|
self.config = _create_default_config(stage=2) |
|
|
@@ -53,113 +199,45 @@ class DeepSpeedDriver(TorchDriver): |
|
|
|
offload_params_device="nvme", |
|
|
|
offload_optimizer_device="nvme", |
|
|
|
) |
|
|
|
for i, optimizer in enumerate(self.optimizers): |
|
|
|
# TODO 多个 optimizer |
|
|
|
engine, optimizer_ds, _, _ = deepspeed.initialize( |
|
|
|
model=self.model, |
|
|
|
optimizer=optimizer, |
|
|
|
config=self.config |
|
|
|
) |
|
|
|
self._optimizers[i] = optimizer_ds |
|
|
|
self.model = engine |
|
|
|
|
|
|
|
self._set_deepspeed_activation_checkpointing() |
|
|
|
|
|
|
|
def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict: |
|
|
|
if isinstance(batch, Dict) and not self.wo_auto_param_call: |
|
|
|
return auto_param_call(fn, batch, signature_fn=signature_fn) |
|
|
|
else: |
|
|
|
return fn(batch) |
|
|
|
|
|
|
|
def get_model_call_fn(self, fn: str) -> Tuple: |
|
|
|
if hasattr(self.model, fn): |
|
|
|
fn = getattr(self.model, fn) |
|
|
|
if not callable(fn): |
|
|
|
raise RuntimeError(f"The `{fn}` attribute is not `Callable`.") |
|
|
|
logger.debug(f'Use {_get_fun_msg(fn, with_fp=False)}...') |
|
|
|
return fn, None |
|
|
|
elif fn in {"train_step", "evaluate_step"}: |
|
|
|
logger.debug(f'Use {_get_fun_msg(self.model.forward, with_fp=False)}...') |
|
|
|
return self.model, self.model.forward |
|
|
|
else: |
|
|
|
raise RuntimeError(f"There is no `{fn}` method in your {type(self.model)}.") |
|
|
|
|
|
|
|
def set_dist_repro_dataloader(self, dataloader, |
|
|
|
dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler] = None, |
|
|
|
reproducible: bool = False): |
|
|
|
return dataloader |
|
|
|
# 如果 dist 为 ReproducibleBatchSampler, ReproducibleIterator 说明是在断点重训时 driver.load_checkpoint 函数调用; |
|
|
|
if isinstance(dist, ReproducibleBatchSampler): |
|
|
|
return replace_batch_sampler(dataloader, dist) |
|
|
|
elif isinstance(dist, ReproducibleSampler): |
|
|
|
return replace_sampler(dataloader, dist) |
|
|
|
|
|
|
|
# 如果 dist 为 str 或者 None,说明是在 trainer 初试化时调用; |
|
|
|
args = self.get_dataloader_args(dataloader) |
|
|
|
if isinstance(args.batch_sampler, ReproducibleBatchSampler): |
|
|
|
batch_sampler = re_instantiate_sampler(args.batch_sampler) |
|
|
|
return replace_batch_sampler(dataloader, batch_sampler) |
|
|
|
elif isinstance(args.sampler, ReproducibleSampler): |
|
|
|
sampler = re_instantiate_sampler(args.sampler) |
|
|
|
return replace_sampler(dataloader, sampler) |
|
|
|
|
|
|
|
if reproducible: |
|
|
|
if type(args.batch_sampler) is TorchBatchSampler: |
|
|
|
if type(args.sampler) is TorchRandomSampler: |
|
|
|
if getattr(args.sampler, '_num_samples', None) is None \ |
|
|
|
and getattr(args.sampler, 'replacements', False) is False \ |
|
|
|
and getattr(args.sampler, 'generator', None) is None: |
|
|
|
# 如果本来就是随机的,并且没有定制,直接替换掉吧。 |
|
|
|
sampler = RandomSampler(args.sampler.data_source, shuffle=True) |
|
|
|
logger.debug("Replace torch RandomSampler into fastNLP RandomSampler.") |
|
|
|
return replace_sampler(dataloader, sampler) |
|
|
|
elif type(args.sampler) is TorchSequentialSampler: |
|
|
|
# 需要替换为不要 shuffle 的。 |
|
|
|
sampler = RandomSampler(args.sampler.data_source, shuffle=False) |
|
|
|
logger.debug("Replace torch SequentialSampler into fastNLP RandomSampler.") |
|
|
|
return replace_sampler(dataloader, sampler) |
|
|
|
batch_sampler = ReproduceBatchSampler( |
|
|
|
batch_sampler=args.batch_sampler, |
|
|
|
batch_size=args.batch_size, |
|
|
|
drop_last=args.drop_last |
|
|
|
) |
|
|
|
return replace_batch_sampler(dataloader, batch_sampler) |
|
|
|
else: |
|
|
|
return dataloader |
|
|
|
raise ValueError(f"Unknown deepspeed strategy {self.strategy}.") |
|
|
|
|
|
|
|
self.config.setdefault("train_micro_batch_size_per_gpu", 1) |
|
|
|
self.config.setdefault("steps_per_print", 2147483647) |
|
|
|
|
|
|
|
# TODO 梯度裁剪的设置,这里需要用到trainer |
|
|
|
# 从kwargs 获取 |
|
|
|
# 精度设置 |
|
|
|
# _format_precision_config |
|
|
|
if self.fp16: |
|
|
|
if "fp16" not in self.config: |
|
|
|
# FP16 is a DeepSpeed standalone AMP implementation |
|
|
|
logger.debug("Enabling DeepSpeed FP16.") |
|
|
|
# TODO 这部分是否可以像 pytorch-lightning 那样给用户定制 |
|
|
|
self.config["fp16"] = { |
|
|
|
"enabled": True, |
|
|
|
"loss_scale": 0, |
|
|
|
"initial_scale_power": True, |
|
|
|
"loss_scale_window": 1000, |
|
|
|
"hysteresis": 2, |
|
|
|
"min_loss_scale": 1, |
|
|
|
} |
|
|
|
elif "amp" not in self.config: |
|
|
|
logger.debug("Enabling DeepSpeed APEX Implementation.") |
|
|
|
self.config["amp"] = {"enabled": True, "opt_level": "O1"} |
|
|
|
|
|
|
|
def zero_grad(self): |
|
|
|
# DeepSpeedEngine.step 包含了 zero_grad 功能 |
|
|
|
pass |
|
|
|
|
|
|
|
def backward(self, loss): |
|
|
|
self.model.backward(loss) |
|
|
|
|
|
|
|
def step(self): |
|
|
|
self.model.step() |
|
|
|
|
|
|
|
def unwrap_model(self): |
|
|
|
r""" |
|
|
|
:return: 返回原本的模型,例如没有被 ``DataParallel`` 包裹; |
|
|
|
""" |
|
|
|
if isinstance(self.model, deepspeed.DeepSpeedEngine): |
|
|
|
print(type(self.model.module), self.model.module) |
|
|
|
return self.model.module |
|
|
|
if isinstance(self.model, torch.nn.DataParallel) or \ |
|
|
|
isinstance(self.model, torch.nn.parallel.DistributedDataParallel): |
|
|
|
return self.model.module |
|
|
|
else: |
|
|
|
return self.model |
|
|
|
|
|
|
|
@property |
|
|
|
def data_device(self): |
|
|
|
r""" |
|
|
|
注意单卡模式下使用 ``driver.data_device`` 等价于使用 ``driver.model_device``; |
|
|
|
""" |
|
|
|
return self.model_device |
|
|
|
|
|
|
|
def is_distributed(self): |
|
|
|
r""" |
|
|
|
:return: 返回当前使用的 driver 是否是分布式的 driver,对于 ``TorchSingleDriver`` 来说直接返回 ``False``; |
|
|
|
""" |
|
|
|
return False |
|
|
|
|
|
|
|
def _set_deepspeed_activation_checkpointing(self): |
|
|
|
if self.config.get("activation_checkpointing"): |
|
|
|
checkpoint_config = self.config["activation_checkpointing"] |
|
|
|
deepspeed.checkpointing.configure( |
|
|
|
mpu_=None, |
|
|
|
partition_activations=checkpoint_config.get("partition_activations"), |
|
|
|
contiguous_checkpointing=checkpoint_config.get("contiguous_memory_optimization"), |
|
|
|
checkpoint_in_cpu=checkpoint_config.get("cpu_checkpointing"), |
|
|
|
profile=checkpoint_config.get("profile"), |
|
|
|
) |
|
|
|
return self.model.module.model |