diff --git a/fastNLP/core/__init__.py b/fastNLP/core/__init__.py index 6cf73d3b..052bed5b 100644 --- a/fastNLP/core/__init__.py +++ b/fastNLP/core/__init__.py @@ -59,6 +59,7 @@ __all__ = [ # drivers "TorchSingleDriver", "TorchDDPDriver", + "DeepSpeedDriver", "PaddleSingleDriver", "PaddleFleetDriver", "JittorSingleDriver", diff --git a/fastNLP/core/drivers/__init__.py b/fastNLP/core/drivers/__init__.py index f9be3180..127e723a 100644 --- a/fastNLP/core/drivers/__init__.py +++ b/fastNLP/core/drivers/__init__.py @@ -3,6 +3,7 @@ __all__ = [ 'TorchDriver', "TorchSingleDriver", "TorchDDPDriver", + "DeepSpeedDriver", "PaddleDriver", "PaddleSingleDriver", "PaddleFleetDriver", @@ -14,7 +15,7 @@ __all__ = [ 'optimizer_state_to_device' ] -from .torch_driver import TorchDriver, TorchSingleDriver, TorchDDPDriver, torch_seed_everything, optimizer_state_to_device +from .torch_driver import TorchDriver, TorchSingleDriver, TorchDDPDriver, DeepSpeedDriver, torch_seed_everything, optimizer_state_to_device from .jittor_driver import JittorDriver, JittorMPIDriver, JittorSingleDriver from .paddle_driver import PaddleDriver, PaddleFleetDriver, PaddleSingleDriver, paddle_seed_everything from .driver import Driver diff --git a/fastNLP/core/drivers/torch_driver/__init__.py b/fastNLP/core/drivers/torch_driver/__init__.py index 8c24fa53..08026d9e 100644 --- a/fastNLP/core/drivers/torch_driver/__init__.py +++ b/fastNLP/core/drivers/torch_driver/__init__.py @@ -1,6 +1,7 @@ __all__ = [ 'TorchDDPDriver', 'TorchSingleDriver', + 'DeepSpeedDriver', 'TorchDriver', 'torch_seed_everything', 'optimizer_state_to_device' @@ -10,6 +11,7 @@ from .ddp import TorchDDPDriver # todo 实现 fairscale 后再将 fairscale 导入到这里; from .single_device import TorchSingleDriver from .torch_driver import TorchDriver +from .deepspeed import DeepSpeedDriver from .utils import torch_seed_everything, optimizer_state_to_device