Browse Source

deepspeed的save load功能

tags/v1.0.0alpha
x54-729 3 years ago
parent
commit
7023ea550c
4 changed files with 106 additions and 75 deletions
  1. +13
    -9
      fastNLP/core/drivers/torch_driver/deepspeed.py
  2. +24
    -1
      fastNLP/core/drivers/torch_driver/utils.py
  3. +67
    -64
      tests/core/drivers/torch_driver/test_deepspeed.py
  4. +2
    -1
      tests/pytest.ini

+ 13
- 9
fastNLP/core/drivers/torch_driver/deepspeed.py View File

@@ -4,7 +4,7 @@ from pathlib import Path
from typing import Union, Dict, List
from .torch_driver import TorchDriver
from .ddp import TorchDDPDriver
from .utils import _create_default_config, _DDPWrappingModel
from .utils import _create_default_config, _DeepSpeedWrappingModel
from fastNLP.core.utils import nullcontext
from fastNLP.core.log import logger
from fastNLP.envs import(
@@ -14,6 +14,7 @@ from fastNLP.envs import(
from fastNLP.envs.imports import _NEED_IMPORT_TORCH, _NEED_IMPORT_DEEPSPEED

if _NEED_IMPORT_TORCH:
import pytorch_lightning
import torch
import torch.distributed as dist
@@ -35,8 +36,8 @@ class DeepSpeedDriver(TorchDDPDriver):
strategy= "deepspeed",
**kwargs
):
assert _NEED_IMPORT_DEEPSPEED, "deepspeed is not imported."
assert not dist.is_initialized(), "DeepSpeedDriver does not support initialize distributed by user."
assert _NEED_IMPORT_DEEPSPEED, "Deepspeed is not imported."
# assert not dist.is_initialized(), "DeepSpeedDriver does not support initialize distributed by user."
TorchDriver.__init__(self, model=model, fp16=False, **kwargs)
self.fp16 = fp16

@@ -88,7 +89,7 @@ class DeepSpeedDriver(TorchDDPDriver):
# 获取 batch_size 以设置 train_micro_batch_size_per_gpu 参数
train_dl = kwargs.get("train_dataloader", None)
if train_dl is not None:
self.train_micro_batch_size = self.get_dataloader_args(train_dl)
self.train_micro_batch_size = self.get_dataloader_args(train_dl).batch_size
else:
logger.warn("No `train_dataloader` found, and we will set `train_micro_batch_size_per_gpu`"
"to 1 for deepspeed configuration.")
@@ -166,7 +167,7 @@ class DeepSpeedDriver(TorchDDPDriver):
# 设置 deepspeed
if not isinstance(self.model, deepspeed.DeepSpeedEngine):
model=_DDPWrappingModel(self.model)
model=_DeepSpeedWrappingModel(self.model, self.fp16)
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
self.model, ds_optimizer, _, _ = deepspeed.initialize(
model=model,
@@ -279,7 +280,7 @@ class DeepSpeedDriver(TorchDDPDriver):
:return:
"""
# deepspeed engine 要求在每个 rank 都调用 save_checkpoint,故去掉了 rank_zero_call 装饰器
if self.zero_stage_3:
if self.stage_3:
logger.rank_zero_warning(
"When saving the DeepSpeed Stage 3 checkpoint, "
"each worker will save a shard of the checkpoint within a directory. "
@@ -310,7 +311,8 @@ class DeepSpeedDriver(TorchDDPDriver):
def save_checkpoint(self, folder: Path, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs):
# deepspeed engine 要求在每个 rank 都调用 save_checkpoint,故去掉了 rank_zero_call 装饰器
# 1. 保存 sampler 的状态
sampler_state_dict = self.get_sampler_state_dict()
num_consumed_batches = states.pop('num_consumed_batches')
states['sampler_states'] = self.get_sampler_state(dataloader, num_consumed_batches)

# 2. 保存模型的状态;
if not should_save_model:
@@ -318,7 +320,7 @@ class DeepSpeedDriver(TorchDDPDriver):
"so we will still save the model for you.")

self.model.save_checkpoint(Path(folder).joinpath(FASTNLP_CHECKPOINT_FILENAME),
client_state=sampler_state_dict)
client_state=states)

def load_checkpoint(self, folder: Path, dataloader, only_state_dict: bool = True, should_load_model: bool = True, **kwargs) -> Dict:
# 1. 加载模型状态;
@@ -330,7 +332,9 @@ class DeepSpeedDriver(TorchDDPDriver):
raise RuntimeError(f"Failed to load checkpoint from path: {str(folder)}")

# 2.恢复 sampler 的状态
states = self.load_sampler_state_dict(states)
sampler_states = states.pop('sampler_states')
states_ret = self.load_sampler_state(dataloader, sampler_states)
states.update(states_ret)

return states



+ 24
- 1
fastNLP/core/drivers/torch_driver/utils.py View File

@@ -15,7 +15,7 @@ from fastNLP.envs import (
FASTNLP_GLOBAL_SEED,
)
from fastNLP.core.samplers import re_instantiate_sampler, ReproducibleBatchSampler
from fastNLP.core.utils import auto_param_call
from fastNLP.core.utils import auto_param_call, apply_to_collection
from fastNLP.core.log import logger

if _NEED_IMPORT_TORCH:
@@ -107,6 +107,29 @@ class _DDPWrappingModel(Module):
else:
return fn(batch)

class _DeepSpeedWrappingModel(_DDPWrappingModel):
"""
继承 ``_DDPWrappingModel``,区别在于进行 forward 之前先将 float 数据转换为 float16
"""

def __init__(self, model: Module, fp16):
super(_DeepSpeedWrappingModel, self).__init__(model)
self.fp16 = fp16

def forward(self, batch, **kwargs):
if self.fp16:
batch = self._move_float_tensors_to_half(batch)

return super().forward(batch, **kwargs)

@staticmethod
def batch_to(data):
return data.half()

def _move_float_tensors_to_half(self, batch: Any):
batch = apply_to_collection(batch, (torch.FloatTensor, torch.cuda.FloatTensor), function=self.batch_to)
return batch


class DummyGradScaler:
"""


+ 67
- 64
tests/core/drivers/torch_driver/test_deepspeed.py View File

@@ -1,33 +1,30 @@
import os
from pathlib import Path

import pytest
from pathlib import Path

from fastNLP.core.drivers.torch_driver.deepspeed import DeepSpeedDriver
from fastNLP.core.samplers import (
RandomSampler,
UnrepeatedSampler,
BucketedBatchSampler,
UnrepeatedRandomSampler,
UnrepeatedSequentialSampler,
)
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
from tests.helpers.datasets.torch_data import TorchNormalDataset, TorchNormalXYDataset
from tests.helpers.datasets.torch_data import TorchNormalXYDataset
from tests.helpers.utils import magic_argv_env_context
from fastNLP.envs.distributed import rank_zero_rm
from fastNLP import logger

from fastNLP.envs.imports import _NEED_IMPORT_TORCH, _NEED_IMPORT_DEEPSPEED

if _NEED_IMPORT_TORCH:
import torch
import torch.distributed as dist
from torch.utils.data import DataLoader, BatchSampler
from torch.utils.data import DataLoader
if _NEED_IMPORT_DEEPSPEED:
import deepspeed

def generate_driver(labels, features, device=[0,1], fp16=False, output_from_new_proc="all"):
def generate_driver(labels, features, device=[0,1], fp16=False, output_from_new_proc="all", train_dataloader=None):
torch_model = TorchNormalModel_Classification_1(labels, features)
torch_opt = torch.optim.Adam(params=torch_model.parameters(), lr=0.01)
device = [torch.device(i) for i in device]
@@ -35,7 +32,8 @@ def generate_driver(labels, features, device=[0,1], fp16=False, output_from_new_
model=torch_model,
parallel_device=device,
fp16=fp16,
output_from_new_proc=output_from_new_proc
output_from_new_proc=output_from_new_proc,
train_dataloader=train_dataloader
)
driver.set_optimizers(torch_opt)
driver.setup()
@@ -77,33 +75,33 @@ def dataloader_with_randomsampler(dataset, batch_size, shuffle, drop_last, seed=

############################################################################
#
# 测试 TorchDDPDriver 的一些函数
# 测试 TorchDeepSpeedDriver 的一些函数
#
############################################################################

@pytest.mark.torch
@magic_argv_env_context
def test_multi_drivers():
"""
测试使用了多个 TorchDDPDriver 的情况。
"""
generate_driver(10, 10)
generate_driver(20, 10)
# @pytest.mark.deepspeed
# @magic_argv_env_context
# def test_multi_drivers():
# """
# 测试使用了多个 TorchDeepSpeedDriver 的情况。
# """
# generate_driver(10, 10)
# generate_driver(20, 10)
with pytest.raises(RuntimeError):
# 设备设置不同,应该报错
generate_driver(20, 3, device=[0,1,2])
assert False
dist.barrier()
# with pytest.raises(RuntimeError):
# # 设备设置不同,应该报错
# generate_driver(20, 3, device=[0,1,2])
# assert False
# dist.barrier()

if dist.is_initialized():
dist.destroy_process_group()
# if dist.is_initialized():
# dist.destroy_process_group()

@magic_argv_env_context
def test_multi_optimizers():
torch_model = TorchNormalModel_Classification_1(10, 10)
torch_opt = torch.optim.Adam(params=torch_model.parameters(), lr=0.01)
device = [torch.device(i) for i in device]
device = [torch.device(i) for i in [0, 1]]
driver = DeepSpeedDriver(
model=torch_model,
parallel_device=device,
@@ -112,57 +110,59 @@ def test_multi_optimizers():
with pytest.raises(ValueError):
driver.setup()

if dist.is_initialized():
dist.destroy_process_group()
# if dist.is_initialized():
# dist.destroy_process_group()

@pytest.mark.torch
@pytest.mark.deepspeed
class TestDeepSpeedDriverFunction:
"""
测试 TorchDeepSpeedDriver 一些简单函数的测试类,基本都是测试能否运行、是否存在 import 错误等问题
"""
@classmethod
def setup_class(cls):
cls.driver = generate_driver(10, 10)

@magic_argv_env_context
def test_simple_functions(self):
"""
简单测试多个函数
"""
driver = generate_driver(10, 10)

"""
测试 move_data_to_device 函数。这个函数仅调用了 torch_move_data_to_device ,测试例在
tests/core/utils/test_torch_utils.py中,就不重复测试了
"""
driver.move_data_to_device(torch.rand((32, 64)))
self.driver.move_data_to_device(torch.rand((32, 64)))
dist.barrier()

"""
测试 is_distributed 函数
"""
assert driver.is_distributed() == True
assert self.driver.is_distributed() == True
dist.barrier()

"""
测试 get_no_sync_context 函数
"""
res = driver.get_model_no_sync_context()
res = self.driver.get_model_no_sync_context()
dist.barrier()

"""
测试 is_global_zero 函数
"""
driver.is_global_zero()
self.driver.is_global_zero()
dist.barrier()

"""
测试 unwrap_model 函数
"""
driver.unwrap_model()
self.driver.unwrap_model()
dist.barrier()

"""
测试 get_local_rank 函数
"""
driver.get_local_rank()
self.driver.get_local_rank()
dist.barrier()

"""
@@ -170,9 +170,9 @@ class TestDeepSpeedDriverFunction:
详细的测试在 test_dist_utils.py 中完成
"""
obj = {
"rank": driver.global_rank
"rank": self.driver.global_rank
}
obj_list = driver.all_gather(obj, group=None)
obj_list = self.driver.all_gather(obj, group=None)
for i, res in enumerate(obj_list):
assert res["rank"] == i

@@ -180,28 +180,32 @@ class TestDeepSpeedDriverFunction:
测试 broadcast_object 函数
详细的函数在 test_dist_utils.py 中完成
"""
if driver.global_rank == 0:
if self.driver.global_rank == 0:
obj = {
"rank": driver.global_rank
"rank": self.driver.global_rank
}
else:
obj = None
res = driver.broadcast_object(obj, src=0)
res = self.driver.broadcast_object(obj, src=0)
assert res["rank"] == 0

if dist.is_initialized():
dist.destroy_process_group()
# if dist.is_initialized():
# dist.destroy_process_group()

############################################################################
#
# 测试 save 和 load 相关的功能
#
############################################################################
@pytest.mark.torch
@pytest.mark.deepspeed
class TestSaveLoad:
"""
测试多卡情况下 save 和 load 相关函数的表现
"""
@classmethod
def setup_class(cls):
# 不在这里 setup 的话会报错
cls.driver = generate_driver(10, 10, device=[0,1])

def setup_method(self):
self.dataset = TorchNormalXYDataset(100)
@@ -216,7 +220,8 @@ class TestSaveLoad:
path = "model"

dataloader = DataLoader(self.dataset, batch_size=2)
driver1, driver2 = generate_driver(20, 1), generate_driver(20, 1)
driver1, driver2 = generate_driver(20, 1, train_dataloader=dataloader), \
generate_driver(20, 1, train_dataloader=dataloader)

driver1.save_model(path, only_state_dict)

@@ -244,8 +249,8 @@ class TestSaveLoad:
finally:
rank_zero_rm(path)

if dist.is_initialized():
dist.destroy_process_group()
# if dist.is_initialized():
# dist.destroy_process_group()

@magic_argv_env_context
@pytest.mark.parametrize("only_state_dict", ([True, False]))
@@ -260,8 +265,6 @@ class TestSaveLoad:
path = "model.ckp"
num_replicas = len(device)

driver1, driver2 = generate_driver(20, 1, device=device, fp16=fp16), \
generate_driver(20, 1, device=device, fp16=False)
dataloader = dataloader_with_bucketedbatchsampler(
self.dataset,
length=[10 for i in range(len(self.dataset))],
@@ -270,11 +273,13 @@ class TestSaveLoad:
drop_last=False
)
dataloader.batch_sampler.set_distributed(
num_replicas=driver1.world_size,
rank=driver1.global_rank,
pad=True
num_replicas=int(os.getenv("WORLD_SIZE", "1")),
rank=int(os.getenv("RANK", "0")),
pad=True,
)
num_consumed_batches = 4
driver1, driver2 = generate_driver(20, 1, device=device, fp16=fp16, train_dataloader=dataloader), \
generate_driver(20, 1, device=device, fp16=False, train_dataloader=dataloader)

already_seen_x_set = set()
already_seen_y_set = set()
@@ -323,10 +328,6 @@ class TestSaveLoad:
assert replaced_loader.batch_sampler.seed == sampler_states["seed"]
assert replaced_loader.batch_sampler.num_consumed_samples == num_consumed_batches * 4 * num_replicas

# 3. 检查 fp16 是否被加载
if fp16:
assert not isinstance(driver2.grad_scaler, torch.cuda.amp.GradScaler)

# 4. 检查 model 的参数是否正确
# 5. 检查 batch_idx
start_batch = load_states.pop('batch_idx_in_epoch')
@@ -338,6 +339,7 @@ class TestSaveLoad:

left_x_batches.update(batch["x"].reshape(-1, ).tolist())
left_y_batches.update(batch["y"].reshape(-1, ).tolist())
batch = driver1.move_data_to_device(batch)
res1 = driver1.model(
batch,
fastnlp_fn=driver1.model.module.model.evaluate_step,
@@ -361,8 +363,8 @@ class TestSaveLoad:
finally:
rank_zero_rm(path)

if dist.is_initialized():
dist.destroy_process_group()
# if dist.is_initialized():
# dist.destroy_process_group()

@magic_argv_env_context
@pytest.mark.parametrize("only_state_dict", ([True, False]))
@@ -378,16 +380,16 @@ class TestSaveLoad:

num_replicas = len(device)

driver1 = generate_driver(20, 1, device=device, fp16=fp16)
driver2 = generate_driver(20, 1, device=device, fp16=False)

dataloader = dataloader_with_randomsampler(self.dataset, 4, True, False, unrepeated=False)
dataloader.batch_sampler.sampler.set_distributed(
num_replicas=driver1.world_size,
rank=driver1.global_rank,
num_replicas=int(os.getenv("WORLD_SIZE", "1")),
rank=int(os.getenv("RANK", "0")),
pad=True
)
num_consumed_batches = 4
driver1 = generate_driver(20, 1, device=device, fp16=fp16, train_dataloader=dataloader)
driver2 = generate_driver(20, 1, device=device, fp16=False, train_dataloader=dataloader)

already_seen_x_set = set()
already_seen_y_set = set()
@@ -448,6 +450,7 @@ class TestSaveLoad:

left_x_batches.update(batch["x"].reshape(-1, ).tolist())
left_y_batches.update(batch["y"].reshape(-1, ).tolist())
batch = driver1.move_data_to_device(batch)
res1 = driver1.model(
batch,
fastnlp_fn=driver1.model.module.model.evaluate_step,
@@ -471,5 +474,5 @@ class TestSaveLoad:
finally:
rank_zero_rm(path)

if dist.is_initialized():
dist.destroy_process_group()
# if dist.is_initialized():
# dist.destroy_process_group()

+ 2
- 1
tests/pytest.ini View File

@@ -5,4 +5,5 @@ markers =
paddledist
jittor
torchpaddle
torchjittor
torchjittor
deepspeed

Loading…
Cancel
Save