From d26d0ad17f1fa896d31d446bed54b58c028b450c Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Fri, 17 Jun 2022 22:16:55 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E9=80=89=E6=8B=A9deepspeed?= =?UTF-8?q?=20driver=E7=9A=84=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/drivers/choose_driver.py | 2 +- .../core/drivers/torch_driver/initialize_torch_driver.py | 14 +++++++++++--- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/fastNLP/core/drivers/choose_driver.py b/fastNLP/core/drivers/choose_driver.py index 4be1e502..56d30e6f 100644 --- a/fastNLP/core/drivers/choose_driver.py +++ b/fastNLP/core/drivers/choose_driver.py @@ -17,7 +17,7 @@ def choose_driver(model, driver: Union[str, Driver], device: Optional[Union[int, if isinstance(driver, Driver): return driver - if driver in {"torch", "fairscale"}: + if driver in {"torch", "fairscale", "deepspeed"}: from fastNLP.core.drivers.torch_driver.initialize_torch_driver import initialize_torch_driver return initialize_torch_driver(driver, device, model, **kwargs) elif driver in {"jittor"}: diff --git a/fastNLP/core/drivers/torch_driver/initialize_torch_driver.py b/fastNLP/core/drivers/torch_driver/initialize_torch_driver.py index 0deac4dc..5d4d2ab5 100644 --- a/fastNLP/core/drivers/torch_driver/initialize_torch_driver.py +++ b/fastNLP/core/drivers/torch_driver/initialize_torch_driver.py @@ -8,6 +8,7 @@ from .torch_driver import TorchDriver from .single_device import TorchSingleDriver from .ddp import TorchDDPDriver from .fairscale import FairScaleDriver +from .deepspeed import DeepSpeedDriver from fastNLP.core.log import logger from fastNLP.envs import FASTNLP_BACKEND_LAUNCH from pkg_resources import parse_version @@ -20,7 +21,7 @@ def initialize_torch_driver(driver: str, device: Optional[Union[str, "torch.devi r""" 用来根据参数 ``driver` 和 ``device`` 来确定并且初始化一个具体的 ``Driver`` 实例然后返回回去; - :param driver: 该参数的值应为以下之一:``["torch", "fairscale"]``; + :param driver: 该参数的值应为以下之一:``["torch", "fairscale", "deepspeed"]``; :param device: 该参数的格式与 ``Trainer`` 对参数 ``device`` 的要求一致; :param model: 训练或者评测的具体的模型; @@ -41,7 +42,7 @@ def initialize_torch_driver(driver: str, device: Optional[Union[str, "torch.devi return TorchDDPDriver(model, torch.device(f"cuda:{os.environ['LOCAL_RANK']}"), is_pull_by_torch_run=True, **kwargs) - if driver not in {"torch", "fairscale"}: + if driver not in {"torch", "fairscale", "deepspeed"}: raise ValueError("Parameter `driver` can only be one of these values: ['torch', 'fairscale'].") _could_use_device_num = torch.cuda.device_count() @@ -83,4 +84,11 @@ def initialize_torch_driver(driver: str, device: Optional[Union[str, "torch.devi logger.warning_once("Notice you are using `fairscale`, but the `device` is only one gpu.") return FairScaleDriver(model, [device], **kwargs) else: - return FairScaleDriver(model, device, **kwargs) \ No newline at end of file + return FairScaleDriver(model, device, **kwargs) + elif driver == "deepspeed": + if not isinstance(device, List): + if device.type == 'cpu': + raise ValueError("You are using `deepspeed` driver, but your chosen `device` is 'cpu'.") + logger.warning_once("Notice you are using `deepspeed`, but the `device` is only one gpu.") + return DeepSpeedDriver(model, [device], **kwargs) + return DeepSpeedDriver(model, device, **kwargs) \ No newline at end of file