diff --git a/fastNLP/core/drivers/paddle_driver/utils.py b/fastNLP/core/drivers/paddle_driver/utils.py index e53f4066..9f35cf2a 100644 --- a/fastNLP/core/drivers/paddle_driver/utils.py +++ b/fastNLP/core/drivers/paddle_driver/utils.py @@ -182,7 +182,13 @@ def replace_batch_sampler(dataloader: "DataLoader", batch_sampler: "BatchSampler if key not in init_params and key != 'self': init_params[key] = value - reconstruct_args = {k: v for k, v in instance_attrs.items() if k in init_params} + # 如果初始化dataloader所使用的参数不是默认值,那么我们需要将其记录下来用于重新初始化时设置; + non_default_params = {name for name, p in init_params.items() if + name in instance_attrs and p.default != instance_attrs[name]} + # add `dataset` as it might have been replaced with `*args` + non_default_params.add("dataset") + + reconstruct_args = {k: v for k, v in instance_attrs.items() if k in non_default_params} reconstruct_args.update({ "batch_sampler": batch_sampler, "shuffle": False, "drop_last": False, "batch_size": 1, "persistent_workers": dataloader._persistent_workers, diff --git a/tests/core/drivers/paddle_driver/test_fleet.py b/tests/core/drivers/paddle_driver/test_fleet.py index 93d3e832..4421e4b1 100644 --- a/tests/core/drivers/paddle_driver/test_fleet.py +++ b/tests/core/drivers/paddle_driver/test_fleet.py @@ -13,6 +13,8 @@ from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleNormalXYDataset from tests.helpers.utils import magic_argv_env_context from fastNLP.envs.distributed import rank_zero_rm +from fastNLP import prepare_paddle_dataloader +from fastNLP.core.drivers.paddle_driver.dist_utils import fastnlp_paddle_all_gather from fastNLP.envs.imports import _NEED_IMPORT_PADDLE if _NEED_IMPORT_PADDLE: import paddle @@ -814,4 +816,112 @@ class TestSaveLoad: assert len(left_y_batches | already_seen_y_set) == len(self.dataset) / num_replicas finally: - rank_zero_rm(path) \ No newline at end of file + rank_zero_rm(path) + + +@pytest.mark.torch +@magic_argv_env_context +@pytest.mark.parametrize("shuffle", ([True, False])) +@pytest.mark.parametrize("batch_size", ([1, 3, 16, 17])) +@pytest.mark.parametrize("drop_last", ([True, False])) +def test_shuffle_dataloader(shuffle, batch_size, drop_last, reproducible=True): + try: + # 需要检验一下 set_dist_repro_dataloader 没有修改参数 + num_samples = 200 + dataset = PaddleNormalXYDataset(num_samples) + dl = prepare_paddle_dataloader(dataset, shuffle=shuffle, batch_size=batch_size, drop_last=drop_last) + model = PaddleNormalModel_Classification_1(10, 32) + device = [0, 1] + driver = PaddleFleetDriver(model, parallel_device=device) + driver.setup() + dl = driver.set_dist_repro_dataloader(dataloader=dl, dist='dist', reproducible=reproducible) + + data = [] + flags = [] + for batch in dl: + flags.append(batch['x'].shape[0] == batch_size) + data.extend(batch['x'].reshape(-1).tolist()) + + _num_samples = num_samples//2 + + if drop_last and _num_samples%batch_size != 0: + assert len(data)!=_num_samples + assert all(flags) == True + elif _num_samples%batch_size!=0: + assert flags[-1] is False + else: + assert len(data) == _num_samples + + if not shuffle: + for i in range(1, len(data)-1): + assert data[i]>data[i-1] + else: + flags = [] + for i in range(1, len(data)-1): + flags.append(data[i]>data[i-1]) + assert all(flags) is False + datas = fastnlp_paddle_all_gather(data) + if drop_last: + assert len(set(datas[0] + datas[1])) == num_samples-_num_samples%batch_size*2 + else: + assert len(set(datas[0] + datas[1])) == num_samples + finally: + if dist.is_initialized(): + dist.barrier() + dist.destroy_process_group() + + +@pytest.mark.torch +@magic_argv_env_context +@pytest.mark.parametrize("shuffle", ([True, False])) +@pytest.mark.parametrize("batch_size", ([1, 3, 16, 17])) +@pytest.mark.parametrize("drop_last", ([True, False])) +def test_batch_sampler_dataloader(shuffle, batch_size, drop_last, reproducible=True): + try: + # 需要检验一下 set_dist_repro_dataloader 没有修改参数 + num_samples = 200 + num_device = 2 + dataset = PaddleNormalXYDataset(num_samples) + sampler = BucketedBatchSampler(dataset, length=dataset._data, batch_size=batch_size, drop_last=drop_last, + shuffle=shuffle, num_batch_per_bucket=2) + dl = prepare_paddle_dataloader(dataset, batch_sampler=sampler) + model = PaddleNormalModel_Classification_1(10, 32) + device = [0, 1] + driver = PaddleFleetDriver(model, parallel_device=device) + driver.setup() + dl = driver.set_dist_repro_dataloader(dataloader=dl, dist='dist', reproducible=reproducible) + + data = [] + flags = [] + for batch in dl: + d = batch['x'].reshape(-1).tolist() + diff = max(d) - min(d) + assert diffdata[i-1] + else: + flags = [] + for i in range(1, len(data)): + flags.append(data[i]>data[i-1]) + assert all(flags) is False + + +@pytest.mark.torch +@pytest.mark.parametrize("shuffle", ([True, False])) +@pytest.mark.parametrize("batch_size", ([1, 3, 16, 17])) +@pytest.mark.parametrize("drop_last", ([True, False])) +@pytest.mark.parametrize("reproducible", ([True, False])) +def test_batch_sampler_dataloader(shuffle, batch_size, drop_last, reproducible): + # 需要检验一下 set_dist_repro_dataloader 没有修改参数 + num_samples = 200 + dataset = PaddleNormalXYDataset(num_samples) + sampler = BucketedBatchSampler(dataset, length=dataset._data, batch_size=batch_size, drop_last=drop_last, + shuffle=shuffle, num_batch_per_bucket=2) + dl = prepare_paddle_dataloader(dataset, batch_sampler=sampler) + model = PaddleNormalModel_Classification_1(1, 2) + driver = PaddleSingleDriver(model, device="cpu") + dl = driver.set_dist_repro_dataloader(dataloader=dl, reproducible=reproducible) + + data = [] + flags = [] + for batch in dl: + d = batch['x'].reshape(-1).tolist() + diff = max(d) - min(d) + assert diffdata[i-1] + else: + flags = [] + for i in range(1, len(data)-1): + flags.append(data[i]>data[i-1]) + assert all(flags) is False + datas = fastnlp_torch_all_gather(data) + if drop_last: + assert len(set(datas[0] + datas[1])) == num_samples-_num_samples%batch_size*2 + else: + assert len(set(datas[0] + datas[1])) == num_samples + finally: + if dist.is_initialized(): + dist.barrier() + dist.destroy_process_group() + + +@pytest.mark.torch +@magic_argv_env_context +@pytest.mark.parametrize("shuffle", ([True, False])) +@pytest.mark.parametrize("batch_size", ([1, 3, 16, 17])) +@pytest.mark.parametrize("drop_last", ([True, False])) +def test_batch_sampler_dataloader(shuffle, batch_size, drop_last, reproducible=True): + try: + # 需要检验一下 set_dist_repro_dataloader 没有修改参数 + num_samples = 200 + num_device = 2 + dataset = TorchNormalXYDataset(num_samples) + sampler = BucketedBatchSampler(dataset, length=dataset._data, batch_size=batch_size, drop_last=drop_last, + shuffle=shuffle, num_batch_per_bucket=2) + dl = prepare_torch_dataloader(dataset, batch_sampler=sampler) + model = TorchNormalModel_Classification_1(10, 32) + device = [torch.device(i) for i in [0, 1]] + driver = TorchDDPDriver(model, parallel_device=device) + driver.setup() + dl = driver.set_dist_repro_dataloader(dataloader=dl, dist='dist', reproducible=reproducible) + + data = [] + flags = [] + for batch in dl: + d = batch['x'].reshape(-1).tolist() + diff = max(d) - min(d) + assert diffdata[i-1] + else: + flags = [] + for i in range(1, len(data)): + flags.append(data[i]>data[i-1]) + assert all(flags) is False + + +@pytest.mark.torch +@pytest.mark.parametrize("shuffle", ([True, False])) +@pytest.mark.parametrize("batch_size", ([1, 3, 16, 17])) +@pytest.mark.parametrize("drop_last", ([True, False])) +@pytest.mark.parametrize("reproducible", ([True, False])) +def test_batch_sampler_dataloader(shuffle, batch_size, drop_last, reproducible): + # 需要检验一下 set_dist_repro_dataloader 没有修改参数 + num_samples = 100 + dataset = TorchNormalXYDataset(num_samples) + sampler = BucketedBatchSampler(dataset, length=dataset._data, batch_size=batch_size, drop_last=drop_last, + shuffle=shuffle, num_batch_per_bucket=2) + dl = prepare_torch_dataloader(dataset, batch_sampler=sampler) + model = TorchNormalModel_Classification_1(10, 32) + driver = TorchSingleDriver(model, device="cpu") + dl = driver.set_dist_repro_dataloader(dataloader=dl, reproducible=reproducible) + + data = [] + flags = [] + for batch in dl: + d = batch['x'].reshape(-1).tolist() + diff = max(d) - min(d) + assert diff