|
|
@@ -8,6 +8,7 @@ from .torch_driver import TorchDriver |
|
|
|
from .single_device import TorchSingleDriver |
|
|
|
from .ddp import TorchDDPDriver |
|
|
|
from .fairscale import FairScaleDriver |
|
|
|
from .deepspeed import DeepSpeedDriver |
|
|
|
from fastNLP.core.log import logger |
|
|
|
from fastNLP.envs import FASTNLP_BACKEND_LAUNCH |
|
|
|
from pkg_resources import parse_version |
|
|
@@ -20,7 +21,7 @@ def initialize_torch_driver(driver: str, device: Optional[Union[str, "torch.devi |
|
|
|
r""" |
|
|
|
用来根据参数 ``driver` 和 ``device`` 来确定并且初始化一个具体的 ``Driver`` 实例然后返回回去; |
|
|
|
|
|
|
|
:param driver: 该参数的值应为以下之一:``["torch", "fairscale"]``; |
|
|
|
:param driver: 该参数的值应为以下之一:``["torch", "fairscale", "deepspeed"]``; |
|
|
|
:param device: 该参数的格式与 ``Trainer`` 对参数 ``device`` 的要求一致; |
|
|
|
:param model: 训练或者评测的具体的模型; |
|
|
|
|
|
|
@@ -41,7 +42,7 @@ def initialize_torch_driver(driver: str, device: Optional[Union[str, "torch.devi |
|
|
|
return TorchDDPDriver(model, torch.device(f"cuda:{os.environ['LOCAL_RANK']}"), |
|
|
|
is_pull_by_torch_run=True, **kwargs) |
|
|
|
|
|
|
|
if driver not in {"torch", "fairscale"}: |
|
|
|
if driver not in {"torch", "fairscale", "deepspeed"}: |
|
|
|
raise ValueError("Parameter `driver` can only be one of these values: ['torch', 'fairscale'].") |
|
|
|
|
|
|
|
_could_use_device_num = torch.cuda.device_count() |
|
|
@@ -83,4 +84,11 @@ def initialize_torch_driver(driver: str, device: Optional[Union[str, "torch.devi |
|
|
|
logger.warning_once("Notice you are using `fairscale`, but the `device` is only one gpu.") |
|
|
|
return FairScaleDriver(model, [device], **kwargs) |
|
|
|
else: |
|
|
|
return FairScaleDriver(model, device, **kwargs) |
|
|
|
return FairScaleDriver(model, device, **kwargs) |
|
|
|
elif driver == "deepspeed": |
|
|
|
if not isinstance(device, List): |
|
|
|
if device.type == 'cpu': |
|
|
|
raise ValueError("You are using `deepspeed` driver, but your chosen `device` is 'cpu'.") |
|
|
|
logger.warning_once("Notice you are using `deepspeed`, but the `device` is only one gpu.") |
|
|
|
return DeepSpeedDriver(model, [device], **kwargs) |
|
|
|
return DeepSpeedDriver(model, device, **kwargs) |