|
|
@@ -1,33 +1,30 @@ |
|
|
|
import os |
|
|
|
from pathlib import Path |
|
|
|
|
|
|
|
import pytest |
|
|
|
from pathlib import Path |
|
|
|
|
|
|
|
from fastNLP.core.drivers.torch_driver.deepspeed import DeepSpeedDriver |
|
|
|
from fastNLP.core.samplers import ( |
|
|
|
RandomSampler, |
|
|
|
UnrepeatedSampler, |
|
|
|
BucketedBatchSampler, |
|
|
|
UnrepeatedRandomSampler, |
|
|
|
UnrepeatedSequentialSampler, |
|
|
|
) |
|
|
|
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 |
|
|
|
from tests.helpers.datasets.torch_data import TorchNormalDataset, TorchNormalXYDataset |
|
|
|
from tests.helpers.datasets.torch_data import TorchNormalXYDataset |
|
|
|
from tests.helpers.utils import magic_argv_env_context |
|
|
|
from fastNLP.envs.distributed import rank_zero_rm |
|
|
|
from fastNLP import logger |
|
|
|
|
|
|
|
from fastNLP.envs.imports import _NEED_IMPORT_TORCH, _NEED_IMPORT_DEEPSPEED |
|
|
|
|
|
|
|
if _NEED_IMPORT_TORCH: |
|
|
|
import torch |
|
|
|
import torch.distributed as dist |
|
|
|
from torch.utils.data import DataLoader, BatchSampler |
|
|
|
from torch.utils.data import DataLoader |
|
|
|
|
|
|
|
if _NEED_IMPORT_DEEPSPEED: |
|
|
|
import deepspeed |
|
|
|
|
|
|
|
def generate_driver(labels, features, device=[0,1], fp16=False, output_from_new_proc="all"): |
|
|
|
def generate_driver(labels, features, device=[0,1], fp16=False, output_from_new_proc="all", train_dataloader=None): |
|
|
|
torch_model = TorchNormalModel_Classification_1(labels, features) |
|
|
|
torch_opt = torch.optim.Adam(params=torch_model.parameters(), lr=0.01) |
|
|
|
device = [torch.device(i) for i in device] |
|
|
@@ -35,7 +32,8 @@ def generate_driver(labels, features, device=[0,1], fp16=False, output_from_new_ |
|
|
|
model=torch_model, |
|
|
|
parallel_device=device, |
|
|
|
fp16=fp16, |
|
|
|
output_from_new_proc=output_from_new_proc |
|
|
|
output_from_new_proc=output_from_new_proc, |
|
|
|
train_dataloader=train_dataloader |
|
|
|
) |
|
|
|
driver.set_optimizers(torch_opt) |
|
|
|
driver.setup() |
|
|
@@ -77,33 +75,33 @@ def dataloader_with_randomsampler(dataset, batch_size, shuffle, drop_last, seed= |
|
|
|
|
|
|
|
############################################################################ |
|
|
|
# |
|
|
|
# 测试 TorchDDPDriver 的一些函数 |
|
|
|
# 测试 TorchDeepSpeedDriver 的一些函数 |
|
|
|
# |
|
|
|
############################################################################ |
|
|
|
|
|
|
|
@pytest.mark.torch |
|
|
|
@magic_argv_env_context |
|
|
|
def test_multi_drivers(): |
|
|
|
""" |
|
|
|
测试使用了多个 TorchDDPDriver 的情况。 |
|
|
|
""" |
|
|
|
generate_driver(10, 10) |
|
|
|
generate_driver(20, 10) |
|
|
|
# @pytest.mark.deepspeed |
|
|
|
# @magic_argv_env_context |
|
|
|
# def test_multi_drivers(): |
|
|
|
# """ |
|
|
|
# 测试使用了多个 TorchDeepSpeedDriver 的情况。 |
|
|
|
# """ |
|
|
|
# generate_driver(10, 10) |
|
|
|
# generate_driver(20, 10) |
|
|
|
|
|
|
|
with pytest.raises(RuntimeError): |
|
|
|
# 设备设置不同,应该报错 |
|
|
|
generate_driver(20, 3, device=[0,1,2]) |
|
|
|
assert False |
|
|
|
dist.barrier() |
|
|
|
# with pytest.raises(RuntimeError): |
|
|
|
# # 设备设置不同,应该报错 |
|
|
|
# generate_driver(20, 3, device=[0,1,2]) |
|
|
|
# assert False |
|
|
|
# dist.barrier() |
|
|
|
|
|
|
|
if dist.is_initialized(): |
|
|
|
dist.destroy_process_group() |
|
|
|
# if dist.is_initialized(): |
|
|
|
# dist.destroy_process_group() |
|
|
|
|
|
|
|
@magic_argv_env_context |
|
|
|
def test_multi_optimizers(): |
|
|
|
torch_model = TorchNormalModel_Classification_1(10, 10) |
|
|
|
torch_opt = torch.optim.Adam(params=torch_model.parameters(), lr=0.01) |
|
|
|
device = [torch.device(i) for i in device] |
|
|
|
device = [torch.device(i) for i in [0, 1]] |
|
|
|
driver = DeepSpeedDriver( |
|
|
|
model=torch_model, |
|
|
|
parallel_device=device, |
|
|
@@ -112,57 +110,59 @@ def test_multi_optimizers(): |
|
|
|
with pytest.raises(ValueError): |
|
|
|
driver.setup() |
|
|
|
|
|
|
|
if dist.is_initialized(): |
|
|
|
dist.destroy_process_group() |
|
|
|
# if dist.is_initialized(): |
|
|
|
# dist.destroy_process_group() |
|
|
|
|
|
|
|
@pytest.mark.torch |
|
|
|
@pytest.mark.deepspeed |
|
|
|
class TestDeepSpeedDriverFunction: |
|
|
|
""" |
|
|
|
测试 TorchDeepSpeedDriver 一些简单函数的测试类,基本都是测试能否运行、是否存在 import 错误等问题 |
|
|
|
""" |
|
|
|
@classmethod |
|
|
|
def setup_class(cls): |
|
|
|
cls.driver = generate_driver(10, 10) |
|
|
|
|
|
|
|
@magic_argv_env_context |
|
|
|
def test_simple_functions(self): |
|
|
|
""" |
|
|
|
简单测试多个函数 |
|
|
|
""" |
|
|
|
driver = generate_driver(10, 10) |
|
|
|
|
|
|
|
""" |
|
|
|
测试 move_data_to_device 函数。这个函数仅调用了 torch_move_data_to_device ,测试例在 |
|
|
|
tests/core/utils/test_torch_utils.py中,就不重复测试了 |
|
|
|
""" |
|
|
|
driver.move_data_to_device(torch.rand((32, 64))) |
|
|
|
self.driver.move_data_to_device(torch.rand((32, 64))) |
|
|
|
dist.barrier() |
|
|
|
|
|
|
|
""" |
|
|
|
测试 is_distributed 函数 |
|
|
|
""" |
|
|
|
assert driver.is_distributed() == True |
|
|
|
assert self.driver.is_distributed() == True |
|
|
|
dist.barrier() |
|
|
|
|
|
|
|
""" |
|
|
|
测试 get_no_sync_context 函数 |
|
|
|
""" |
|
|
|
res = driver.get_model_no_sync_context() |
|
|
|
res = self.driver.get_model_no_sync_context() |
|
|
|
dist.barrier() |
|
|
|
|
|
|
|
""" |
|
|
|
测试 is_global_zero 函数 |
|
|
|
""" |
|
|
|
driver.is_global_zero() |
|
|
|
self.driver.is_global_zero() |
|
|
|
dist.barrier() |
|
|
|
|
|
|
|
""" |
|
|
|
测试 unwrap_model 函数 |
|
|
|
""" |
|
|
|
driver.unwrap_model() |
|
|
|
self.driver.unwrap_model() |
|
|
|
dist.barrier() |
|
|
|
|
|
|
|
""" |
|
|
|
测试 get_local_rank 函数 |
|
|
|
""" |
|
|
|
driver.get_local_rank() |
|
|
|
self.driver.get_local_rank() |
|
|
|
dist.barrier() |
|
|
|
|
|
|
|
""" |
|
|
@@ -170,9 +170,9 @@ class TestDeepSpeedDriverFunction: |
|
|
|
详细的测试在 test_dist_utils.py 中完成 |
|
|
|
""" |
|
|
|
obj = { |
|
|
|
"rank": driver.global_rank |
|
|
|
"rank": self.driver.global_rank |
|
|
|
} |
|
|
|
obj_list = driver.all_gather(obj, group=None) |
|
|
|
obj_list = self.driver.all_gather(obj, group=None) |
|
|
|
for i, res in enumerate(obj_list): |
|
|
|
assert res["rank"] == i |
|
|
|
|
|
|
@@ -180,28 +180,32 @@ class TestDeepSpeedDriverFunction: |
|
|
|
测试 broadcast_object 函数 |
|
|
|
详细的函数在 test_dist_utils.py 中完成 |
|
|
|
""" |
|
|
|
if driver.global_rank == 0: |
|
|
|
if self.driver.global_rank == 0: |
|
|
|
obj = { |
|
|
|
"rank": driver.global_rank |
|
|
|
"rank": self.driver.global_rank |
|
|
|
} |
|
|
|
else: |
|
|
|
obj = None |
|
|
|
res = driver.broadcast_object(obj, src=0) |
|
|
|
res = self.driver.broadcast_object(obj, src=0) |
|
|
|
assert res["rank"] == 0 |
|
|
|
|
|
|
|
if dist.is_initialized(): |
|
|
|
dist.destroy_process_group() |
|
|
|
# if dist.is_initialized(): |
|
|
|
# dist.destroy_process_group() |
|
|
|
|
|
|
|
############################################################################ |
|
|
|
# |
|
|
|
# 测试 save 和 load 相关的功能 |
|
|
|
# |
|
|
|
############################################################################ |
|
|
|
@pytest.mark.torch |
|
|
|
@pytest.mark.deepspeed |
|
|
|
class TestSaveLoad: |
|
|
|
""" |
|
|
|
测试多卡情况下 save 和 load 相关函数的表现 |
|
|
|
""" |
|
|
|
@classmethod |
|
|
|
def setup_class(cls): |
|
|
|
# 不在这里 setup 的话会报错 |
|
|
|
cls.driver = generate_driver(10, 10, device=[0,1]) |
|
|
|
|
|
|
|
def setup_method(self): |
|
|
|
self.dataset = TorchNormalXYDataset(100) |
|
|
@@ -216,7 +220,8 @@ class TestSaveLoad: |
|
|
|
path = "model" |
|
|
|
|
|
|
|
dataloader = DataLoader(self.dataset, batch_size=2) |
|
|
|
driver1, driver2 = generate_driver(20, 1), generate_driver(20, 1) |
|
|
|
driver1, driver2 = generate_driver(20, 1, train_dataloader=dataloader), \ |
|
|
|
generate_driver(20, 1, train_dataloader=dataloader) |
|
|
|
|
|
|
|
driver1.save_model(path, only_state_dict) |
|
|
|
|
|
|
@@ -244,8 +249,8 @@ class TestSaveLoad: |
|
|
|
finally: |
|
|
|
rank_zero_rm(path) |
|
|
|
|
|
|
|
if dist.is_initialized(): |
|
|
|
dist.destroy_process_group() |
|
|
|
# if dist.is_initialized(): |
|
|
|
# dist.destroy_process_group() |
|
|
|
|
|
|
|
@magic_argv_env_context |
|
|
|
@pytest.mark.parametrize("only_state_dict", ([True, False])) |
|
|
@@ -260,8 +265,6 @@ class TestSaveLoad: |
|
|
|
path = "model.ckp" |
|
|
|
num_replicas = len(device) |
|
|
|
|
|
|
|
driver1, driver2 = generate_driver(20, 1, device=device, fp16=fp16), \ |
|
|
|
generate_driver(20, 1, device=device, fp16=False) |
|
|
|
dataloader = dataloader_with_bucketedbatchsampler( |
|
|
|
self.dataset, |
|
|
|
length=[10 for i in range(len(self.dataset))], |
|
|
@@ -270,11 +273,13 @@ class TestSaveLoad: |
|
|
|
drop_last=False |
|
|
|
) |
|
|
|
dataloader.batch_sampler.set_distributed( |
|
|
|
num_replicas=driver1.world_size, |
|
|
|
rank=driver1.global_rank, |
|
|
|
pad=True |
|
|
|
num_replicas=int(os.getenv("WORLD_SIZE", "1")), |
|
|
|
rank=int(os.getenv("RANK", "0")), |
|
|
|
pad=True, |
|
|
|
) |
|
|
|
num_consumed_batches = 4 |
|
|
|
driver1, driver2 = generate_driver(20, 1, device=device, fp16=fp16, train_dataloader=dataloader), \ |
|
|
|
generate_driver(20, 1, device=device, fp16=False, train_dataloader=dataloader) |
|
|
|
|
|
|
|
already_seen_x_set = set() |
|
|
|
already_seen_y_set = set() |
|
|
@@ -323,10 +328,6 @@ class TestSaveLoad: |
|
|
|
assert replaced_loader.batch_sampler.seed == sampler_states["seed"] |
|
|
|
assert replaced_loader.batch_sampler.num_consumed_samples == num_consumed_batches * 4 * num_replicas |
|
|
|
|
|
|
|
# 3. 检查 fp16 是否被加载 |
|
|
|
if fp16: |
|
|
|
assert not isinstance(driver2.grad_scaler, torch.cuda.amp.GradScaler) |
|
|
|
|
|
|
|
# 4. 检查 model 的参数是否正确 |
|
|
|
# 5. 检查 batch_idx |
|
|
|
start_batch = load_states.pop('batch_idx_in_epoch') |
|
|
@@ -338,6 +339,7 @@ class TestSaveLoad: |
|
|
|
|
|
|
|
left_x_batches.update(batch["x"].reshape(-1, ).tolist()) |
|
|
|
left_y_batches.update(batch["y"].reshape(-1, ).tolist()) |
|
|
|
batch = driver1.move_data_to_device(batch) |
|
|
|
res1 = driver1.model( |
|
|
|
batch, |
|
|
|
fastnlp_fn=driver1.model.module.model.evaluate_step, |
|
|
@@ -361,8 +363,8 @@ class TestSaveLoad: |
|
|
|
finally: |
|
|
|
rank_zero_rm(path) |
|
|
|
|
|
|
|
if dist.is_initialized(): |
|
|
|
dist.destroy_process_group() |
|
|
|
# if dist.is_initialized(): |
|
|
|
# dist.destroy_process_group() |
|
|
|
|
|
|
|
@magic_argv_env_context |
|
|
|
@pytest.mark.parametrize("only_state_dict", ([True, False])) |
|
|
@@ -378,16 +380,16 @@ class TestSaveLoad: |
|
|
|
|
|
|
|
num_replicas = len(device) |
|
|
|
|
|
|
|
driver1 = generate_driver(20, 1, device=device, fp16=fp16) |
|
|
|
driver2 = generate_driver(20, 1, device=device, fp16=False) |
|
|
|
|
|
|
|
dataloader = dataloader_with_randomsampler(self.dataset, 4, True, False, unrepeated=False) |
|
|
|
dataloader.batch_sampler.sampler.set_distributed( |
|
|
|
num_replicas=driver1.world_size, |
|
|
|
rank=driver1.global_rank, |
|
|
|
num_replicas=int(os.getenv("WORLD_SIZE", "1")), |
|
|
|
rank=int(os.getenv("RANK", "0")), |
|
|
|
pad=True |
|
|
|
) |
|
|
|
num_consumed_batches = 4 |
|
|
|
|
|
|
|
driver1 = generate_driver(20, 1, device=device, fp16=fp16, train_dataloader=dataloader) |
|
|
|
driver2 = generate_driver(20, 1, device=device, fp16=False, train_dataloader=dataloader) |
|
|
|
|
|
|
|
already_seen_x_set = set() |
|
|
|
already_seen_y_set = set() |
|
|
@@ -448,6 +450,7 @@ class TestSaveLoad: |
|
|
|
|
|
|
|
left_x_batches.update(batch["x"].reshape(-1, ).tolist()) |
|
|
|
left_y_batches.update(batch["y"].reshape(-1, ).tolist()) |
|
|
|
batch = driver1.move_data_to_device(batch) |
|
|
|
res1 = driver1.model( |
|
|
|
batch, |
|
|
|
fastnlp_fn=driver1.model.module.model.evaluate_step, |
|
|
@@ -471,5 +474,5 @@ class TestSaveLoad: |
|
|
|
finally: |
|
|
|
rank_zero_rm(path) |
|
|
|
|
|
|
|
if dist.is_initialized(): |
|
|
|
dist.destroy_process_group() |
|
|
|
# if dist.is_initialized(): |
|
|
|
# dist.destroy_process_group() |