Browse Source

添加选择deepspeed driver的逻辑

tags/v1.0.0alpha
x54-729 3 years ago
parent
commit
d26d0ad17f
2 changed files with 12 additions and 4 deletions
  1. +1
    -1
      fastNLP/core/drivers/choose_driver.py
  2. +11
    -3
      fastNLP/core/drivers/torch_driver/initialize_torch_driver.py

+ 1
- 1
fastNLP/core/drivers/choose_driver.py View File

@@ -17,7 +17,7 @@ def choose_driver(model, driver: Union[str, Driver], device: Optional[Union[int,
if isinstance(driver, Driver): if isinstance(driver, Driver):
return driver return driver


if driver in {"torch", "fairscale"}:
if driver in {"torch", "fairscale", "deepspeed"}:
from fastNLP.core.drivers.torch_driver.initialize_torch_driver import initialize_torch_driver from fastNLP.core.drivers.torch_driver.initialize_torch_driver import initialize_torch_driver
return initialize_torch_driver(driver, device, model, **kwargs) return initialize_torch_driver(driver, device, model, **kwargs)
elif driver in {"jittor"}: elif driver in {"jittor"}:


+ 11
- 3
fastNLP/core/drivers/torch_driver/initialize_torch_driver.py View File

@@ -8,6 +8,7 @@ from .torch_driver import TorchDriver
from .single_device import TorchSingleDriver from .single_device import TorchSingleDriver
from .ddp import TorchDDPDriver from .ddp import TorchDDPDriver
from .fairscale import FairScaleDriver from .fairscale import FairScaleDriver
from .deepspeed import DeepSpeedDriver
from fastNLP.core.log import logger from fastNLP.core.log import logger
from fastNLP.envs import FASTNLP_BACKEND_LAUNCH from fastNLP.envs import FASTNLP_BACKEND_LAUNCH
from pkg_resources import parse_version from pkg_resources import parse_version
@@ -20,7 +21,7 @@ def initialize_torch_driver(driver: str, device: Optional[Union[str, "torch.devi
r""" r"""
用来根据参数 ``driver` 和 ``device`` 来确定并且初始化一个具体的 ``Driver`` 实例然后返回回去; 用来根据参数 ``driver` 和 ``device`` 来确定并且初始化一个具体的 ``Driver`` 实例然后返回回去;


:param driver: 该参数的值应为以下之一:``["torch", "fairscale"]``;
:param driver: 该参数的值应为以下之一:``["torch", "fairscale", "deepspeed"]``;
:param device: 该参数的格式与 ``Trainer`` 对参数 ``device`` 的要求一致; :param device: 该参数的格式与 ``Trainer`` 对参数 ``device`` 的要求一致;
:param model: 训练或者评测的具体的模型; :param model: 训练或者评测的具体的模型;


@@ -41,7 +42,7 @@ def initialize_torch_driver(driver: str, device: Optional[Union[str, "torch.devi
return TorchDDPDriver(model, torch.device(f"cuda:{os.environ['LOCAL_RANK']}"), return TorchDDPDriver(model, torch.device(f"cuda:{os.environ['LOCAL_RANK']}"),
is_pull_by_torch_run=True, **kwargs) is_pull_by_torch_run=True, **kwargs)


if driver not in {"torch", "fairscale"}:
if driver not in {"torch", "fairscale", "deepspeed"}:
raise ValueError("Parameter `driver` can only be one of these values: ['torch', 'fairscale'].") raise ValueError("Parameter `driver` can only be one of these values: ['torch', 'fairscale'].")


_could_use_device_num = torch.cuda.device_count() _could_use_device_num = torch.cuda.device_count()
@@ -83,4 +84,11 @@ def initialize_torch_driver(driver: str, device: Optional[Union[str, "torch.devi
logger.warning_once("Notice you are using `fairscale`, but the `device` is only one gpu.") logger.warning_once("Notice you are using `fairscale`, but the `device` is only one gpu.")
return FairScaleDriver(model, [device], **kwargs) return FairScaleDriver(model, [device], **kwargs)
else: else:
return FairScaleDriver(model, device, **kwargs)
return FairScaleDriver(model, device, **kwargs)
elif driver == "deepspeed":
if not isinstance(device, List):
if device.type == 'cpu':
raise ValueError("You are using `deepspeed` driver, but your chosen `device` is 'cpu'.")
logger.warning_once("Notice you are using `deepspeed`, but the `device` is only one gpu.")
return DeepSpeedDriver(model, [device], **kwargs)
return DeepSpeedDriver(model, device, **kwargs)

Loading…
Cancel
Save