diff --git a/fastNLP/core/drivers/torch_driver/deepspeed.py b/fastNLP/core/drivers/torch_driver/deepspeed.py index 579a50f4..3d519099 100644 --- a/fastNLP/core/drivers/torch_driver/deepspeed.py +++ b/fastNLP/core/drivers/torch_driver/deepspeed.py @@ -1,4 +1,6 @@ import os +import argparse +import logging from pathlib import Path from typing import Union, Dict, List @@ -46,7 +48,7 @@ class DeepSpeedDriver(TorchDDPDriver): self.parallel_device = parallel_device if not is_pull_by_torch_run and parallel_device is None: raise ValueError( - "Parameter `parallel_device` can not be None when using `TorchDDPDriver`. This error is caused " + "Parameter `parallel_device` can not be None when using `TorchDeepSpeedDriver`. This error is caused " "when your value of parameter `device` is `None` in your `Trainer` instance.") # 注意我们在 initialize_torch_driver 中的逻辑就是如果是 is_pull_by_torch_run,那么我们就直接把 parallel_device 置为当前进程的gpu; @@ -68,8 +70,6 @@ class DeepSpeedDriver(TorchDDPDriver): self.outside_ddp = True self.config = model.config - # 用户只有将模型上传到对应机器上后才能用 DistributedDataParallel 包裹,因此如果用户在外面初始化了 DDP,那么在 TorchDDPDriver 中 - # 我们就直接将 model_device 置为 None; self.model_device = None self._data_device = kwargs.get("data_device", None) @@ -110,6 +110,8 @@ class DeepSpeedDriver(TorchDDPDriver): self._ds_kwargs = kwargs.get("deepspeed_kwargs", {}) self.strategy = self._ds_kwargs.get("strategy", "deepspeed") + deepspeed_logging_level = self._ds_kwargs.get("logging_level", logging.ERROR) + deepspeed.utils.logging.logger.setLevel(deepspeed_logging_level) @staticmethod def _check_optimizer_legality(optimizers): @@ -126,7 +128,7 @@ class DeepSpeedDriver(TorchDDPDriver): 2. 每个进程将模型迁移到自己对应的 ``gpu`` 设备上;然后使用 ``DistributedDataParallel`` 包裹模型; """ if len(self.optimizers) != 1: - raise ValueError("Multi optimizers is not supported for DeepSpeedDriver right now.") + raise ValueError("Multi optimizers is not supported for `DeepSpeedDriver` right now.") if self._has_setup: return self._has_setup = True @@ -173,6 +175,9 @@ class DeepSpeedDriver(TorchDDPDriver): if not self.outside_ddp: torch.cuda.set_device(self.model_device) + # TODO 模型过大的话应该会导致显存溢出,但是不加的话显存会占用rank对应的设备 + # lightning里在之前通过broadcast_list广播了log_dir所以没有这种情况 + self.model.to(self.model_device) self.configure_ddp() self.barrier() @@ -196,10 +201,12 @@ class DeepSpeedDriver(TorchDDPDriver): model=_DeepSpeedWrappingModel(self.model, self.fp16) model_parameters = filter(lambda p: p.requires_grad, model.parameters()) self.model, ds_optimizer, _, _ = deepspeed.initialize( + args=argparse.Namespace(device_rank=self.model_device.index), model=model, optimizer=self.optimizers[0], model_parameters=model_parameters, config=self.config, + dist_init_required=False ) self._optimizers = [ds_optimizer] diff --git a/fastNLP/core/drivers/torch_driver/initialize_torch_driver.py b/fastNLP/core/drivers/torch_driver/initialize_torch_driver.py index b0a16112..f242b813 100644 --- a/fastNLP/core/drivers/torch_driver/initialize_torch_driver.py +++ b/fastNLP/core/drivers/torch_driver/initialize_torch_driver.py @@ -38,7 +38,7 @@ def initialize_torch_driver(driver: str, device: Optional[Union[str, "torch.devi if driver == 'fairscale': return FairScaleDriver(model, torch.device(f"cuda:{os.environ['LOCAL_RANK']}"), is_pull_by_torch_run=True, **kwargs) - elif kwargs.get("deepspeed_kwargs") is not None: + elif driver == 'deepspeed': return DeepSpeedDriver(model, torch.device(f"cuda:{os.environ['LOCAL_RANK']}"), is_pull_by_torch_run=True, **kwargs) else: @@ -76,14 +76,6 @@ def initialize_torch_driver(driver: str, device: Optional[Union[str, "torch.devi raise ValueError("Parameter `device` is wrong type, please check our documentation for the right use.") if driver == "torch": # single, ddp, 直接启动。 - if kwargs.get("deepspeed_kwargs") is not None: - # 选择的是 deepspeed - if not isinstance(device, List): - if device.type == 'cpu': - raise ValueError("You are using `deepspeed` driver, but your chosen `device` is 'cpu'.") - logger.warning_once("Notice you are using `deepspeed`, but the `device` is only one gpu.") - return DeepSpeedDriver(model, [device], **kwargs) - return DeepSpeedDriver(model, device, **kwargs) if not isinstance(device, List): return TorchSingleDriver(model, device, **kwargs) else: @@ -95,4 +87,12 @@ def initialize_torch_driver(driver: str, device: Optional[Union[str, "torch.devi logger.warning_once("Notice you are using `fairscale`, but the `device` is only one gpu.") return FairScaleDriver(model, [device], **kwargs) else: - return FairScaleDriver(model, device, **kwargs) \ No newline at end of file + return FairScaleDriver(model, device, **kwargs) + elif driver == "deepspeed": + if not isinstance(device, List): + if device.type == 'cpu': + raise ValueError("You are using `deepspeed` driver, but your chosen `device` is 'cpu'.") + logger.warning_once("Notice you are using `deepspeed`, but the `device` is only one gpu.") + return DeepSpeedDriver(model, [device], **kwargs) + else: + return DeepSpeedDriver(model, device, **kwargs) \ No newline at end of file diff --git a/tests/core/controllers/_test_trainer_deepspeed.py b/tests/core/controllers/_test_trainer_deepspeed.py index 2dc6326c..0c51e47c 100644 --- a/tests/core/controllers/_test_trainer_deepspeed.py +++ b/tests/core/controllers/_test_trainer_deepspeed.py @@ -60,7 +60,7 @@ def test_trainer_deepspeed( config["train_micro_batch_size_per_gpu"] = TrainDeepSpeedConfig.batch_size trainer = Trainer( model=model, - driver="torch", + driver="deepspeed", device=device, optimizers=optimizers, train_dataloader=train_dataloader, @@ -79,7 +79,7 @@ def test_trainer_deepspeed( trainer.run() if __name__ == "__main__": - device = [0,1] + device = [4, 5] # device = [0,1,3] callbacks = [ # RecordMetricCallback(monitor="acc#acc", metric_threshold=0.0, larger_better=True), diff --git a/tests/core/controllers/_test_trainer_deepspeed_outside.py b/tests/core/controllers/_test_trainer_deepspeed_outside.py index a8dbd823..6821787e 100644 --- a/tests/core/controllers/_test_trainer_deepspeed_outside.py +++ b/tests/core/controllers/_test_trainer_deepspeed_outside.py @@ -69,7 +69,7 @@ def test_trainer_deepspeed( ) trainer = Trainer( model=model, - driver="torch", + driver="deepspeed", device=device, data_device=torch.device(f"cuda:{local_rank}"), optimizers=optimizers,