Browse Source

添加选择deepspeed driver的逻辑

tags/v1.0.0alpha
x54-729 3 years ago
parent
commit
d26d0ad17f
2 changed files with 12 additions and 4 deletions
  1. +1
    -1
      fastNLP/core/drivers/choose_driver.py
  2. +11
    -3
      fastNLP/core/drivers/torch_driver/initialize_torch_driver.py

+ 1
- 1
fastNLP/core/drivers/choose_driver.py View File

@@ -17,7 +17,7 @@ def choose_driver(model, driver: Union[str, Driver], device: Optional[Union[int,
if isinstance(driver, Driver):
return driver

if driver in {"torch", "fairscale"}:
if driver in {"torch", "fairscale", "deepspeed"}:
from fastNLP.core.drivers.torch_driver.initialize_torch_driver import initialize_torch_driver
return initialize_torch_driver(driver, device, model, **kwargs)
elif driver in {"jittor"}:


+ 11
- 3
fastNLP/core/drivers/torch_driver/initialize_torch_driver.py View File

@@ -8,6 +8,7 @@ from .torch_driver import TorchDriver
from .single_device import TorchSingleDriver
from .ddp import TorchDDPDriver
from .fairscale import FairScaleDriver
from .deepspeed import DeepSpeedDriver
from fastNLP.core.log import logger
from fastNLP.envs import FASTNLP_BACKEND_LAUNCH
from pkg_resources import parse_version
@@ -20,7 +21,7 @@ def initialize_torch_driver(driver: str, device: Optional[Union[str, "torch.devi
r"""
用来根据参数 ``driver` 和 ``device`` 来确定并且初始化一个具体的 ``Driver`` 实例然后返回回去;

:param driver: 该参数的值应为以下之一:``["torch", "fairscale"]``;
:param driver: 该参数的值应为以下之一:``["torch", "fairscale", "deepspeed"]``;
:param device: 该参数的格式与 ``Trainer`` 对参数 ``device`` 的要求一致;
:param model: 训练或者评测的具体的模型;

@@ -41,7 +42,7 @@ def initialize_torch_driver(driver: str, device: Optional[Union[str, "torch.devi
return TorchDDPDriver(model, torch.device(f"cuda:{os.environ['LOCAL_RANK']}"),
is_pull_by_torch_run=True, **kwargs)

if driver not in {"torch", "fairscale"}:
if driver not in {"torch", "fairscale", "deepspeed"}:
raise ValueError("Parameter `driver` can only be one of these values: ['torch', 'fairscale'].")

_could_use_device_num = torch.cuda.device_count()
@@ -83,4 +84,11 @@ def initialize_torch_driver(driver: str, device: Optional[Union[str, "torch.devi
logger.warning_once("Notice you are using `fairscale`, but the `device` is only one gpu.")
return FairScaleDriver(model, [device], **kwargs)
else:
return FairScaleDriver(model, device, **kwargs)
return FairScaleDriver(model, device, **kwargs)
elif driver == "deepspeed":
if not isinstance(device, List):
if device.type == 'cpu':
raise ValueError("You are using `deepspeed` driver, but your chosen `device` is 'cpu'.")
logger.warning_once("Notice you are using `deepspeed`, but the `device` is only one gpu.")
return DeepSpeedDriver(model, [device], **kwargs)
return DeepSpeedDriver(model, device, **kwargs)

Loading…
Cancel
Save