|
|
@@ -3,6 +3,7 @@ __all__ = [ |
|
|
|
'TorchDriver', |
|
|
|
"TorchSingleDriver", |
|
|
|
"TorchDDPDriver", |
|
|
|
"DeepSpeedDriver", |
|
|
|
"PaddleDriver", |
|
|
|
"PaddleSingleDriver", |
|
|
|
"PaddleFleetDriver", |
|
|
@@ -14,7 +15,7 @@ __all__ = [ |
|
|
|
'optimizer_state_to_device' |
|
|
|
] |
|
|
|
|
|
|
|
from .torch_driver import TorchDriver, TorchSingleDriver, TorchDDPDriver, torch_seed_everything, optimizer_state_to_device |
|
|
|
from .torch_driver import TorchDriver, TorchSingleDriver, TorchDDPDriver, DeepSpeedDriver, torch_seed_everything, optimizer_state_to_device |
|
|
|
from .jittor_driver import JittorDriver, JittorMPIDriver, JittorSingleDriver |
|
|
|
from .paddle_driver import PaddleDriver, PaddleFleetDriver, PaddleSingleDriver, paddle_seed_everything |
|
|
|
from .driver import Driver |
|
|
|