diff --git a/fastNLP/core/drivers/choose_driver.py b/fastNLP/core/drivers/choose_driver.py index 4be1e502..56d30e6f 100644 --- a/fastNLP/core/drivers/choose_driver.py +++ b/fastNLP/core/drivers/choose_driver.py @@ -17,7 +17,7 @@ def choose_driver(model, driver: Union[str, Driver], device: Optional[Union[int, if isinstance(driver, Driver): return driver - if driver in {"torch", "fairscale"}: + if driver in {"torch", "fairscale", "deepspeed"}: from fastNLP.core.drivers.torch_driver.initialize_torch_driver import initialize_torch_driver return initialize_torch_driver(driver, device, model, **kwargs) elif driver in {"jittor"}: diff --git a/fastNLP/core/drivers/torch_driver/initialize_torch_driver.py b/fastNLP/core/drivers/torch_driver/initialize_torch_driver.py index 0deac4dc..5d4d2ab5 100644 --- a/fastNLP/core/drivers/torch_driver/initialize_torch_driver.py +++ b/fastNLP/core/drivers/torch_driver/initialize_torch_driver.py @@ -8,6 +8,7 @@ from .torch_driver import TorchDriver from .single_device import TorchSingleDriver from .ddp import TorchDDPDriver from .fairscale import FairScaleDriver +from .deepspeed import DeepSpeedDriver from fastNLP.core.log import logger from fastNLP.envs import FASTNLP_BACKEND_LAUNCH from pkg_resources import parse_version @@ -20,7 +21,7 @@ def initialize_torch_driver(driver: str, device: Optional[Union[str, "torch.devi r""" 用来根据参数 ``driver` 和 ``device`` 来确定并且初始化一个具体的 ``Driver`` 实例然后返回回去; - :param driver: 该参数的值应为以下之一:``["torch", "fairscale"]``; + :param driver: 该参数的值应为以下之一:``["torch", "fairscale", "deepspeed"]``; :param device: 该参数的格式与 ``Trainer`` 对参数 ``device`` 的要求一致; :param model: 训练或者评测的具体的模型; @@ -41,7 +42,7 @@ def initialize_torch_driver(driver: str, device: Optional[Union[str, "torch.devi return TorchDDPDriver(model, torch.device(f"cuda:{os.environ['LOCAL_RANK']}"), is_pull_by_torch_run=True, **kwargs) - if driver not in {"torch", "fairscale"}: + if driver not in {"torch", "fairscale", "deepspeed"}: raise ValueError("Parameter `driver` can only be one of these values: ['torch', 'fairscale'].") _could_use_device_num = torch.cuda.device_count() @@ -83,4 +84,11 @@ def initialize_torch_driver(driver: str, device: Optional[Union[str, "torch.devi logger.warning_once("Notice you are using `fairscale`, but the `device` is only one gpu.") return FairScaleDriver(model, [device], **kwargs) else: - return FairScaleDriver(model, device, **kwargs) \ No newline at end of file + return FairScaleDriver(model, device, **kwargs) + elif driver == "deepspeed": + if not isinstance(device, List): + if device.type == 'cpu': + raise ValueError("You are using `deepspeed` driver, but your chosen `device` is 'cpu'.") + logger.warning_once("Notice you are using `deepspeed`, but the `device` is only one gpu.") + return DeepSpeedDriver(model, [device], **kwargs) + return DeepSpeedDriver(model, device, **kwargs) \ No newline at end of file