diff --git a/fastNLP/core/drivers/torch_driver/deepspeed.py b/fastNLP/core/drivers/torch_driver/deepspeed.py index bb4df495..79451b13 100644 --- a/fastNLP/core/drivers/torch_driver/deepspeed.py +++ b/fastNLP/core/drivers/torch_driver/deepspeed.py @@ -4,7 +4,7 @@ from pathlib import Path from typing import Union, Dict, List from .torch_driver import TorchDriver from .ddp import TorchDDPDriver -from .utils import _create_default_config, _DDPWrappingModel +from .utils import _create_default_config, _DeepSpeedWrappingModel from fastNLP.core.utils import nullcontext from fastNLP.core.log import logger from fastNLP.envs import( @@ -14,6 +14,7 @@ from fastNLP.envs import( from fastNLP.envs.imports import _NEED_IMPORT_TORCH, _NEED_IMPORT_DEEPSPEED if _NEED_IMPORT_TORCH: + import pytorch_lightning import torch import torch.distributed as dist @@ -35,8 +36,8 @@ class DeepSpeedDriver(TorchDDPDriver): strategy= "deepspeed", **kwargs ): - assert _NEED_IMPORT_DEEPSPEED, "deepspeed is not imported." - assert not dist.is_initialized(), "DeepSpeedDriver does not support initialize distributed by user." + assert _NEED_IMPORT_DEEPSPEED, "Deepspeed is not imported." + # assert not dist.is_initialized(), "DeepSpeedDriver does not support initialize distributed by user." TorchDriver.__init__(self, model=model, fp16=False, **kwargs) self.fp16 = fp16 @@ -88,7 +89,7 @@ class DeepSpeedDriver(TorchDDPDriver): # 获取 batch_size 以设置 train_micro_batch_size_per_gpu 参数 train_dl = kwargs.get("train_dataloader", None) if train_dl is not None: - self.train_micro_batch_size = self.get_dataloader_args(train_dl) + self.train_micro_batch_size = self.get_dataloader_args(train_dl).batch_size else: logger.warn("No `train_dataloader` found, and we will set `train_micro_batch_size_per_gpu`" "to 1 for deepspeed configuration.") @@ -166,7 +167,7 @@ class DeepSpeedDriver(TorchDDPDriver): # 设置 deepspeed if not isinstance(self.model, deepspeed.DeepSpeedEngine): - model=_DDPWrappingModel(self.model) + model=_DeepSpeedWrappingModel(self.model, self.fp16) model_parameters = filter(lambda p: p.requires_grad, model.parameters()) self.model, ds_optimizer, _, _ = deepspeed.initialize( model=model, @@ -279,7 +280,7 @@ class DeepSpeedDriver(TorchDDPDriver): :return: """ # deepspeed engine 要求在每个 rank 都调用 save_checkpoint,故去掉了 rank_zero_call 装饰器 - if self.zero_stage_3: + if self.stage_3: logger.rank_zero_warning( "When saving the DeepSpeed Stage 3 checkpoint, " "each worker will save a shard of the checkpoint within a directory. " @@ -310,7 +311,8 @@ class DeepSpeedDriver(TorchDDPDriver): def save_checkpoint(self, folder: Path, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs): # deepspeed engine 要求在每个 rank 都调用 save_checkpoint,故去掉了 rank_zero_call 装饰器 # 1. 保存 sampler 的状态 - sampler_state_dict = self.get_sampler_state_dict() + num_consumed_batches = states.pop('num_consumed_batches') + states['sampler_states'] = self.get_sampler_state(dataloader, num_consumed_batches) # 2. 保存模型的状态; if not should_save_model: @@ -318,7 +320,7 @@ class DeepSpeedDriver(TorchDDPDriver): "so we will still save the model for you.") self.model.save_checkpoint(Path(folder).joinpath(FASTNLP_CHECKPOINT_FILENAME), - client_state=sampler_state_dict) + client_state=states) def load_checkpoint(self, folder: Path, dataloader, only_state_dict: bool = True, should_load_model: bool = True, **kwargs) -> Dict: # 1. 加载模型状态; @@ -330,7 +332,9 @@ class DeepSpeedDriver(TorchDDPDriver): raise RuntimeError(f"Failed to load checkpoint from path: {str(folder)}") # 2.恢复 sampler 的状态 - states = self.load_sampler_state_dict(states) + sampler_states = states.pop('sampler_states') + states_ret = self.load_sampler_state(dataloader, sampler_states) + states.update(states_ret) return states diff --git a/fastNLP/core/drivers/torch_driver/utils.py b/fastNLP/core/drivers/torch_driver/utils.py index e2b00aa6..8c44ea37 100644 --- a/fastNLP/core/drivers/torch_driver/utils.py +++ b/fastNLP/core/drivers/torch_driver/utils.py @@ -15,7 +15,7 @@ from fastNLP.envs import ( FASTNLP_GLOBAL_SEED, ) from fastNLP.core.samplers import re_instantiate_sampler, ReproducibleBatchSampler -from fastNLP.core.utils import auto_param_call +from fastNLP.core.utils import auto_param_call, apply_to_collection from fastNLP.core.log import logger if _NEED_IMPORT_TORCH: @@ -107,6 +107,29 @@ class _DDPWrappingModel(Module): else: return fn(batch) +class _DeepSpeedWrappingModel(_DDPWrappingModel): + """ + 继承 ``_DDPWrappingModel``,区别在于进行 forward 之前先将 float 数据转换为 float16 + """ + + def __init__(self, model: Module, fp16): + super(_DeepSpeedWrappingModel, self).__init__(model) + self.fp16 = fp16 + + def forward(self, batch, **kwargs): + if self.fp16: + batch = self._move_float_tensors_to_half(batch) + + return super().forward(batch, **kwargs) + + @staticmethod + def batch_to(data): + return data.half() + + def _move_float_tensors_to_half(self, batch: Any): + batch = apply_to_collection(batch, (torch.FloatTensor, torch.cuda.FloatTensor), function=self.batch_to) + return batch + class DummyGradScaler: """ diff --git a/tests/core/drivers/torch_driver/test_deepspeed.py b/tests/core/drivers/torch_driver/test_deepspeed.py index 8f28c332..462648bd 100644 --- a/tests/core/drivers/torch_driver/test_deepspeed.py +++ b/tests/core/drivers/torch_driver/test_deepspeed.py @@ -1,33 +1,30 @@ import os +from pathlib import Path import pytest -from pathlib import Path from fastNLP.core.drivers.torch_driver.deepspeed import DeepSpeedDriver from fastNLP.core.samplers import ( RandomSampler, - UnrepeatedSampler, BucketedBatchSampler, UnrepeatedRandomSampler, - UnrepeatedSequentialSampler, ) from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 -from tests.helpers.datasets.torch_data import TorchNormalDataset, TorchNormalXYDataset +from tests.helpers.datasets.torch_data import TorchNormalXYDataset from tests.helpers.utils import magic_argv_env_context from fastNLP.envs.distributed import rank_zero_rm from fastNLP import logger - from fastNLP.envs.imports import _NEED_IMPORT_TORCH, _NEED_IMPORT_DEEPSPEED if _NEED_IMPORT_TORCH: import torch import torch.distributed as dist - from torch.utils.data import DataLoader, BatchSampler + from torch.utils.data import DataLoader if _NEED_IMPORT_DEEPSPEED: import deepspeed -def generate_driver(labels, features, device=[0,1], fp16=False, output_from_new_proc="all"): +def generate_driver(labels, features, device=[0,1], fp16=False, output_from_new_proc="all", train_dataloader=None): torch_model = TorchNormalModel_Classification_1(labels, features) torch_opt = torch.optim.Adam(params=torch_model.parameters(), lr=0.01) device = [torch.device(i) for i in device] @@ -35,7 +32,8 @@ def generate_driver(labels, features, device=[0,1], fp16=False, output_from_new_ model=torch_model, parallel_device=device, fp16=fp16, - output_from_new_proc=output_from_new_proc + output_from_new_proc=output_from_new_proc, + train_dataloader=train_dataloader ) driver.set_optimizers(torch_opt) driver.setup() @@ -77,33 +75,33 @@ def dataloader_with_randomsampler(dataset, batch_size, shuffle, drop_last, seed= ############################################################################ # -# 测试 TorchDDPDriver 的一些函数 +# 测试 TorchDeepSpeedDriver 的一些函数 # ############################################################################ -@pytest.mark.torch -@magic_argv_env_context -def test_multi_drivers(): - """ - 测试使用了多个 TorchDDPDriver 的情况。 - """ - generate_driver(10, 10) - generate_driver(20, 10) +# @pytest.mark.deepspeed +# @magic_argv_env_context +# def test_multi_drivers(): +# """ +# 测试使用了多个 TorchDeepSpeedDriver 的情况。 +# """ +# generate_driver(10, 10) +# generate_driver(20, 10) - with pytest.raises(RuntimeError): - # 设备设置不同,应该报错 - generate_driver(20, 3, device=[0,1,2]) - assert False - dist.barrier() +# with pytest.raises(RuntimeError): +# # 设备设置不同,应该报错 +# generate_driver(20, 3, device=[0,1,2]) +# assert False +# dist.barrier() - if dist.is_initialized(): - dist.destroy_process_group() +# if dist.is_initialized(): +# dist.destroy_process_group() @magic_argv_env_context def test_multi_optimizers(): torch_model = TorchNormalModel_Classification_1(10, 10) torch_opt = torch.optim.Adam(params=torch_model.parameters(), lr=0.01) - device = [torch.device(i) for i in device] + device = [torch.device(i) for i in [0, 1]] driver = DeepSpeedDriver( model=torch_model, parallel_device=device, @@ -112,57 +110,59 @@ def test_multi_optimizers(): with pytest.raises(ValueError): driver.setup() - if dist.is_initialized(): - dist.destroy_process_group() + # if dist.is_initialized(): + # dist.destroy_process_group() -@pytest.mark.torch +@pytest.mark.deepspeed class TestDeepSpeedDriverFunction: """ 测试 TorchDeepSpeedDriver 一些简单函数的测试类,基本都是测试能否运行、是否存在 import 错误等问题 """ + @classmethod + def setup_class(cls): + cls.driver = generate_driver(10, 10) @magic_argv_env_context def test_simple_functions(self): """ 简单测试多个函数 """ - driver = generate_driver(10, 10) """ 测试 move_data_to_device 函数。这个函数仅调用了 torch_move_data_to_device ,测试例在 tests/core/utils/test_torch_utils.py中,就不重复测试了 """ - driver.move_data_to_device(torch.rand((32, 64))) + self.driver.move_data_to_device(torch.rand((32, 64))) dist.barrier() """ 测试 is_distributed 函数 """ - assert driver.is_distributed() == True + assert self.driver.is_distributed() == True dist.barrier() """ 测试 get_no_sync_context 函数 """ - res = driver.get_model_no_sync_context() + res = self.driver.get_model_no_sync_context() dist.barrier() """ 测试 is_global_zero 函数 """ - driver.is_global_zero() + self.driver.is_global_zero() dist.barrier() """ 测试 unwrap_model 函数 """ - driver.unwrap_model() + self.driver.unwrap_model() dist.barrier() """ 测试 get_local_rank 函数 """ - driver.get_local_rank() + self.driver.get_local_rank() dist.barrier() """ @@ -170,9 +170,9 @@ class TestDeepSpeedDriverFunction: 详细的测试在 test_dist_utils.py 中完成 """ obj = { - "rank": driver.global_rank + "rank": self.driver.global_rank } - obj_list = driver.all_gather(obj, group=None) + obj_list = self.driver.all_gather(obj, group=None) for i, res in enumerate(obj_list): assert res["rank"] == i @@ -180,28 +180,32 @@ class TestDeepSpeedDriverFunction: 测试 broadcast_object 函数 详细的函数在 test_dist_utils.py 中完成 """ - if driver.global_rank == 0: + if self.driver.global_rank == 0: obj = { - "rank": driver.global_rank + "rank": self.driver.global_rank } else: obj = None - res = driver.broadcast_object(obj, src=0) + res = self.driver.broadcast_object(obj, src=0) assert res["rank"] == 0 - if dist.is_initialized(): - dist.destroy_process_group() + # if dist.is_initialized(): + # dist.destroy_process_group() ############################################################################ # # 测试 save 和 load 相关的功能 # ############################################################################ -@pytest.mark.torch +@pytest.mark.deepspeed class TestSaveLoad: """ 测试多卡情况下 save 和 load 相关函数的表现 """ + @classmethod + def setup_class(cls): + # 不在这里 setup 的话会报错 + cls.driver = generate_driver(10, 10, device=[0,1]) def setup_method(self): self.dataset = TorchNormalXYDataset(100) @@ -216,7 +220,8 @@ class TestSaveLoad: path = "model" dataloader = DataLoader(self.dataset, batch_size=2) - driver1, driver2 = generate_driver(20, 1), generate_driver(20, 1) + driver1, driver2 = generate_driver(20, 1, train_dataloader=dataloader), \ + generate_driver(20, 1, train_dataloader=dataloader) driver1.save_model(path, only_state_dict) @@ -244,8 +249,8 @@ class TestSaveLoad: finally: rank_zero_rm(path) - if dist.is_initialized(): - dist.destroy_process_group() + # if dist.is_initialized(): + # dist.destroy_process_group() @magic_argv_env_context @pytest.mark.parametrize("only_state_dict", ([True, False])) @@ -260,8 +265,6 @@ class TestSaveLoad: path = "model.ckp" num_replicas = len(device) - driver1, driver2 = generate_driver(20, 1, device=device, fp16=fp16), \ - generate_driver(20, 1, device=device, fp16=False) dataloader = dataloader_with_bucketedbatchsampler( self.dataset, length=[10 for i in range(len(self.dataset))], @@ -270,11 +273,13 @@ class TestSaveLoad: drop_last=False ) dataloader.batch_sampler.set_distributed( - num_replicas=driver1.world_size, - rank=driver1.global_rank, - pad=True + num_replicas=int(os.getenv("WORLD_SIZE", "1")), + rank=int(os.getenv("RANK", "0")), + pad=True, ) num_consumed_batches = 4 + driver1, driver2 = generate_driver(20, 1, device=device, fp16=fp16, train_dataloader=dataloader), \ + generate_driver(20, 1, device=device, fp16=False, train_dataloader=dataloader) already_seen_x_set = set() already_seen_y_set = set() @@ -323,10 +328,6 @@ class TestSaveLoad: assert replaced_loader.batch_sampler.seed == sampler_states["seed"] assert replaced_loader.batch_sampler.num_consumed_samples == num_consumed_batches * 4 * num_replicas - # 3. 检查 fp16 是否被加载 - if fp16: - assert not isinstance(driver2.grad_scaler, torch.cuda.amp.GradScaler) - # 4. 检查 model 的参数是否正确 # 5. 检查 batch_idx start_batch = load_states.pop('batch_idx_in_epoch') @@ -338,6 +339,7 @@ class TestSaveLoad: left_x_batches.update(batch["x"].reshape(-1, ).tolist()) left_y_batches.update(batch["y"].reshape(-1, ).tolist()) + batch = driver1.move_data_to_device(batch) res1 = driver1.model( batch, fastnlp_fn=driver1.model.module.model.evaluate_step, @@ -361,8 +363,8 @@ class TestSaveLoad: finally: rank_zero_rm(path) - if dist.is_initialized(): - dist.destroy_process_group() + # if dist.is_initialized(): + # dist.destroy_process_group() @magic_argv_env_context @pytest.mark.parametrize("only_state_dict", ([True, False])) @@ -378,16 +380,16 @@ class TestSaveLoad: num_replicas = len(device) - driver1 = generate_driver(20, 1, device=device, fp16=fp16) - driver2 = generate_driver(20, 1, device=device, fp16=False) - dataloader = dataloader_with_randomsampler(self.dataset, 4, True, False, unrepeated=False) dataloader.batch_sampler.sampler.set_distributed( - num_replicas=driver1.world_size, - rank=driver1.global_rank, + num_replicas=int(os.getenv("WORLD_SIZE", "1")), + rank=int(os.getenv("RANK", "0")), pad=True ) num_consumed_batches = 4 + + driver1 = generate_driver(20, 1, device=device, fp16=fp16, train_dataloader=dataloader) + driver2 = generate_driver(20, 1, device=device, fp16=False, train_dataloader=dataloader) already_seen_x_set = set() already_seen_y_set = set() @@ -448,6 +450,7 @@ class TestSaveLoad: left_x_batches.update(batch["x"].reshape(-1, ).tolist()) left_y_batches.update(batch["y"].reshape(-1, ).tolist()) + batch = driver1.move_data_to_device(batch) res1 = driver1.model( batch, fastnlp_fn=driver1.model.module.model.evaluate_step, @@ -471,5 +474,5 @@ class TestSaveLoad: finally: rank_zero_rm(path) - if dist.is_initialized(): - dist.destroy_process_group() \ No newline at end of file + # if dist.is_initialized(): + # dist.destroy_process_group() \ No newline at end of file diff --git a/tests/pytest.ini b/tests/pytest.ini index 27076810..e2cac8d9 100644 --- a/tests/pytest.ini +++ b/tests/pytest.ini @@ -5,4 +5,5 @@ markers = paddledist jittor torchpaddle - torchjittor \ No newline at end of file + torchjittor + deepspeed \ No newline at end of file