@@ -182,7 +182,13 @@ def replace_batch_sampler(dataloader: "DataLoader", batch_sampler: "BatchSampler | |||||
if key not in init_params and key != 'self': | if key not in init_params and key != 'self': | ||||
init_params[key] = value | init_params[key] = value | ||||
reconstruct_args = {k: v for k, v in instance_attrs.items() if k in init_params} | |||||
# 如果初始化dataloader所使用的参数不是默认值,那么我们需要将其记录下来用于重新初始化时设置; | |||||
non_default_params = {name for name, p in init_params.items() if | |||||
name in instance_attrs and p.default != instance_attrs[name]} | |||||
# add `dataset` as it might have been replaced with `*args` | |||||
non_default_params.add("dataset") | |||||
reconstruct_args = {k: v for k, v in instance_attrs.items() if k in non_default_params} | |||||
reconstruct_args.update({ | reconstruct_args.update({ | ||||
"batch_sampler": batch_sampler, "shuffle": False, "drop_last": False, "batch_size": 1, | "batch_sampler": batch_sampler, "shuffle": False, "drop_last": False, "batch_size": 1, | ||||
"persistent_workers": dataloader._persistent_workers, | "persistent_workers": dataloader._persistent_workers, | ||||
@@ -13,6 +13,8 @@ from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 | |||||
from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleNormalXYDataset | from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleNormalXYDataset | ||||
from tests.helpers.utils import magic_argv_env_context | from tests.helpers.utils import magic_argv_env_context | ||||
from fastNLP.envs.distributed import rank_zero_rm | from fastNLP.envs.distributed import rank_zero_rm | ||||
from fastNLP import prepare_paddle_dataloader | |||||
from fastNLP.core.drivers.paddle_driver.dist_utils import fastnlp_paddle_all_gather | |||||
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | ||||
if _NEED_IMPORT_PADDLE: | if _NEED_IMPORT_PADDLE: | ||||
import paddle | import paddle | ||||
@@ -814,4 +816,112 @@ class TestSaveLoad: | |||||
assert len(left_y_batches | already_seen_y_set) == len(self.dataset) / num_replicas | assert len(left_y_batches | already_seen_y_set) == len(self.dataset) / num_replicas | ||||
finally: | finally: | ||||
rank_zero_rm(path) | |||||
rank_zero_rm(path) | |||||
@pytest.mark.torch | |||||
@magic_argv_env_context | |||||
@pytest.mark.parametrize("shuffle", ([True, False])) | |||||
@pytest.mark.parametrize("batch_size", ([1, 3, 16, 17])) | |||||
@pytest.mark.parametrize("drop_last", ([True, False])) | |||||
def test_shuffle_dataloader(shuffle, batch_size, drop_last, reproducible=True): | |||||
try: | |||||
# 需要检验一下 set_dist_repro_dataloader 没有修改参数 | |||||
num_samples = 200 | |||||
dataset = PaddleNormalXYDataset(num_samples) | |||||
dl = prepare_paddle_dataloader(dataset, shuffle=shuffle, batch_size=batch_size, drop_last=drop_last) | |||||
model = PaddleNormalModel_Classification_1(10, 32) | |||||
device = [0, 1] | |||||
driver = PaddleFleetDriver(model, parallel_device=device) | |||||
driver.setup() | |||||
dl = driver.set_dist_repro_dataloader(dataloader=dl, dist='dist', reproducible=reproducible) | |||||
data = [] | |||||
flags = [] | |||||
for batch in dl: | |||||
flags.append(batch['x'].shape[0] == batch_size) | |||||
data.extend(batch['x'].reshape(-1).tolist()) | |||||
_num_samples = num_samples//2 | |||||
if drop_last and _num_samples%batch_size != 0: | |||||
assert len(data)!=_num_samples | |||||
assert all(flags) == True | |||||
elif _num_samples%batch_size!=0: | |||||
assert flags[-1] is False | |||||
else: | |||||
assert len(data) == _num_samples | |||||
if not shuffle: | |||||
for i in range(1, len(data)-1): | |||||
assert data[i]>data[i-1] | |||||
else: | |||||
flags = [] | |||||
for i in range(1, len(data)-1): | |||||
flags.append(data[i]>data[i-1]) | |||||
assert all(flags) is False | |||||
datas = fastnlp_paddle_all_gather(data) | |||||
if drop_last: | |||||
assert len(set(datas[0] + datas[1])) == num_samples-_num_samples%batch_size*2 | |||||
else: | |||||
assert len(set(datas[0] + datas[1])) == num_samples | |||||
finally: | |||||
if dist.is_initialized(): | |||||
dist.barrier() | |||||
dist.destroy_process_group() | |||||
@pytest.mark.torch | |||||
@magic_argv_env_context | |||||
@pytest.mark.parametrize("shuffle", ([True, False])) | |||||
@pytest.mark.parametrize("batch_size", ([1, 3, 16, 17])) | |||||
@pytest.mark.parametrize("drop_last", ([True, False])) | |||||
def test_batch_sampler_dataloader(shuffle, batch_size, drop_last, reproducible=True): | |||||
try: | |||||
# 需要检验一下 set_dist_repro_dataloader 没有修改参数 | |||||
num_samples = 200 | |||||
num_device = 2 | |||||
dataset = PaddleNormalXYDataset(num_samples) | |||||
sampler = BucketedBatchSampler(dataset, length=dataset._data, batch_size=batch_size, drop_last=drop_last, | |||||
shuffle=shuffle, num_batch_per_bucket=2) | |||||
dl = prepare_paddle_dataloader(dataset, batch_sampler=sampler) | |||||
model = PaddleNormalModel_Classification_1(10, 32) | |||||
device = [0, 1] | |||||
driver = PaddleFleetDriver(model, parallel_device=device) | |||||
driver.setup() | |||||
dl = driver.set_dist_repro_dataloader(dataloader=dl, dist='dist', reproducible=reproducible) | |||||
data = [] | |||||
flags = [] | |||||
for batch in dl: | |||||
d = batch['x'].reshape(-1).tolist() | |||||
diff = max(d) - min(d) | |||||
assert diff<batch_size*2*2*2 | |||||
data.extend(d) | |||||
flags.append(len(d)==batch_size) | |||||
_num_samples = num_samples//num_device | |||||
if drop_last and _num_samples%batch_size != 0: | |||||
assert len(data)!=_num_samples | |||||
assert all(flags) == True | |||||
elif _num_samples%batch_size!=0: | |||||
assert flags[-1] is False | |||||
else: | |||||
assert len(data) == _num_samples | |||||
if not shuffle: | |||||
for i in range(1, len(data)-1): | |||||
assert data[i]<data[i-1] | |||||
else: | |||||
flags = [] | |||||
for i in range(1, len(data)-1): | |||||
flags.append(data[i]<data[i-1]) | |||||
assert all(flags) is False | |||||
datas = fastnlp_paddle_all_gather(data) | |||||
if drop_last: | |||||
assert len(set(datas[0] + datas[1])) == num_samples-_num_samples%batch_size*2 | |||||
else: | |||||
assert len(set(datas[0] + datas[1])) == num_samples | |||||
finally: | |||||
if dist.is_initialized(): | |||||
dist.barrier() | |||||
dist.destroy_process_group() |
@@ -9,6 +9,7 @@ from tests.helpers.datasets.torch_data import TorchNormalDataset | |||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | ||||
from fastNLP.envs.distributed import rank_zero_rm | from fastNLP.envs.distributed import rank_zero_rm | ||||
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE, _NEED_IMPORT_TORCH | from fastNLP.envs.imports import _NEED_IMPORT_PADDLE, _NEED_IMPORT_TORCH | ||||
from fastNLP import prepare_paddle_dataloader, BucketedBatchSampler | |||||
if _NEED_IMPORT_PADDLE: | if _NEED_IMPORT_PADDLE: | ||||
import paddle | import paddle | ||||
@@ -738,3 +739,85 @@ def test_save_and_load_with_randomsampler(only_state_dict, fp16): | |||||
assert len(left_y_batches | already_seen_y_set) == len(dataset) | assert len(left_y_batches | already_seen_y_set) == len(dataset) | ||||
finally: | finally: | ||||
rank_zero_rm(path) | rank_zero_rm(path) | ||||
@pytest.mark.torch | |||||
@pytest.mark.parametrize("shuffle", ([True, False])) | |||||
@pytest.mark.parametrize("batch_size", ([1, 3, 16, 17])) | |||||
@pytest.mark.parametrize("drop_last", ([True, False])) | |||||
@pytest.mark.parametrize("reproducible", ([True, False])) | |||||
def test_shuffle_dataloader(shuffle, batch_size, drop_last, reproducible): | |||||
# 需要检验一下 set_dist_repro_dataloader 没有修改参数 | |||||
num_samples = 200 | |||||
dataset = PaddleNormalXYDataset(num_samples) | |||||
dl = prepare_paddle_dataloader(dataset, shuffle=shuffle, batch_size=batch_size, drop_last=drop_last) | |||||
model = PaddleNormalModel_Classification_1(1, 2) | |||||
driver = PaddleSingleDriver(model, device="cpu") | |||||
dl = driver.set_dist_repro_dataloader(dataloader=dl, reproducible=reproducible) | |||||
data = [] | |||||
flags = [] | |||||
for batch in dl: | |||||
flags.append(batch['x'].shape[0] == batch_size) | |||||
data.extend(batch['x'].reshape(-1).tolist()) | |||||
if drop_last and num_samples%batch_size != 0: | |||||
assert len(data)!=num_samples | |||||
assert all(flags) == True | |||||
elif num_samples%batch_size!=0: | |||||
assert flags[-1] is False | |||||
else: | |||||
assert len(data) == num_samples | |||||
if not shuffle: | |||||
for i in range(1, len(data)): | |||||
assert data[i]>data[i-1] | |||||
else: | |||||
flags = [] | |||||
for i in range(1, len(data)): | |||||
flags.append(data[i]>data[i-1]) | |||||
assert all(flags) is False | |||||
@pytest.mark.torch | |||||
@pytest.mark.parametrize("shuffle", ([True, False])) | |||||
@pytest.mark.parametrize("batch_size", ([1, 3, 16, 17])) | |||||
@pytest.mark.parametrize("drop_last", ([True, False])) | |||||
@pytest.mark.parametrize("reproducible", ([True, False])) | |||||
def test_batch_sampler_dataloader(shuffle, batch_size, drop_last, reproducible): | |||||
# 需要检验一下 set_dist_repro_dataloader 没有修改参数 | |||||
num_samples = 200 | |||||
dataset = PaddleNormalXYDataset(num_samples) | |||||
sampler = BucketedBatchSampler(dataset, length=dataset._data, batch_size=batch_size, drop_last=drop_last, | |||||
shuffle=shuffle, num_batch_per_bucket=2) | |||||
dl = prepare_paddle_dataloader(dataset, batch_sampler=sampler) | |||||
model = PaddleNormalModel_Classification_1(1, 2) | |||||
driver = PaddleSingleDriver(model, device="cpu") | |||||
dl = driver.set_dist_repro_dataloader(dataloader=dl, reproducible=reproducible) | |||||
data = [] | |||||
flags = [] | |||||
for batch in dl: | |||||
d = batch['x'].reshape(-1).tolist() | |||||
diff = max(d) - min(d) | |||||
assert diff<batch_size*2 | |||||
data.extend(d) | |||||
flags.append(len(d)==batch_size) | |||||
if drop_last and num_samples%batch_size != 0: | |||||
assert len(data)!=num_samples | |||||
assert all(flags) == True | |||||
elif num_samples%batch_size!=0: | |||||
assert flags[-1] is False | |||||
else: | |||||
assert len(data) == num_samples | |||||
if not shuffle: | |||||
for i in range(1, len(data)): | |||||
assert data[i]<data[i-1] | |||||
else: | |||||
flags = [] | |||||
for i in range(1, len(data)): | |||||
flags.append(data[i]<data[i-1]) | |||||
assert all(flags) is False | |||||
@@ -2,6 +2,7 @@ import pytest | |||||
from pathlib import Path | from pathlib import Path | ||||
from fastNLP.core.drivers.torch_driver.ddp import TorchDDPDriver | from fastNLP.core.drivers.torch_driver.ddp import TorchDDPDriver | ||||
from fastNLP import prepare_torch_dataloader | |||||
from fastNLP.core.samplers import ( | from fastNLP.core.samplers import ( | ||||
RandomSampler, | RandomSampler, | ||||
UnrepeatedSampler, | UnrepeatedSampler, | ||||
@@ -13,6 +14,7 @@ from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | |||||
from tests.helpers.datasets.torch_data import TorchNormalDataset, TorchNormalXYDataset | from tests.helpers.datasets.torch_data import TorchNormalDataset, TorchNormalXYDataset | ||||
from tests.helpers.utils import magic_argv_env_context | from tests.helpers.utils import magic_argv_env_context | ||||
from fastNLP.envs.distributed import rank_zero_rm | from fastNLP.envs.distributed import rank_zero_rm | ||||
from fastNLP.core.drivers.torch_driver.dist_utils import fastnlp_torch_all_gather | |||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | from fastNLP.envs.imports import _NEED_IMPORT_TORCH | ||||
if _NEED_IMPORT_TORCH: | if _NEED_IMPORT_TORCH: | ||||
import torch | import torch | ||||
@@ -834,3 +836,112 @@ class TestSaveLoad: | |||||
if dist.is_initialized(): | if dist.is_initialized(): | ||||
dist.destroy_process_group() | dist.destroy_process_group() | ||||
@pytest.mark.torch | |||||
@magic_argv_env_context | |||||
@pytest.mark.parametrize("shuffle", ([True, False])) | |||||
@pytest.mark.parametrize("batch_size", ([1, 3, 16, 17])) | |||||
@pytest.mark.parametrize("drop_last", ([True, False])) | |||||
def test_shuffle_dataloader(shuffle, batch_size, drop_last, reproducible=True): | |||||
try: | |||||
# 需要检验一下 set_dist_repro_dataloader 没有修改参数 | |||||
num_samples = 200 | |||||
dataset = TorchNormalXYDataset(num_samples) | |||||
dl = prepare_torch_dataloader(dataset, shuffle=shuffle, batch_size=batch_size, drop_last=drop_last) | |||||
model = TorchNormalModel_Classification_1(10, 32) | |||||
device = [torch.device(i) for i in [0, 1]] | |||||
driver = TorchDDPDriver(model, parallel_device=device) | |||||
driver.setup() | |||||
dl = driver.set_dist_repro_dataloader(dataloader=dl, dist='dist', reproducible=reproducible) | |||||
data = [] | |||||
flags = [] | |||||
for batch in dl: | |||||
flags.append(batch['x'].size(0) == batch_size) | |||||
data.extend(batch['x'].reshape(-1).tolist()) | |||||
_num_samples = num_samples//2 | |||||
if drop_last and _num_samples%batch_size != 0: | |||||
assert len(data)!=_num_samples | |||||
assert all(flags) == True | |||||
elif _num_samples%batch_size!=0: | |||||
assert flags[-1] is False | |||||
else: | |||||
assert len(data) == _num_samples | |||||
if not shuffle: | |||||
for i in range(1, len(data)-1): | |||||
assert data[i]>data[i-1] | |||||
else: | |||||
flags = [] | |||||
for i in range(1, len(data)-1): | |||||
flags.append(data[i]>data[i-1]) | |||||
assert all(flags) is False | |||||
datas = fastnlp_torch_all_gather(data) | |||||
if drop_last: | |||||
assert len(set(datas[0] + datas[1])) == num_samples-_num_samples%batch_size*2 | |||||
else: | |||||
assert len(set(datas[0] + datas[1])) == num_samples | |||||
finally: | |||||
if dist.is_initialized(): | |||||
dist.barrier() | |||||
dist.destroy_process_group() | |||||
@pytest.mark.torch | |||||
@magic_argv_env_context | |||||
@pytest.mark.parametrize("shuffle", ([True, False])) | |||||
@pytest.mark.parametrize("batch_size", ([1, 3, 16, 17])) | |||||
@pytest.mark.parametrize("drop_last", ([True, False])) | |||||
def test_batch_sampler_dataloader(shuffle, batch_size, drop_last, reproducible=True): | |||||
try: | |||||
# 需要检验一下 set_dist_repro_dataloader 没有修改参数 | |||||
num_samples = 200 | |||||
num_device = 2 | |||||
dataset = TorchNormalXYDataset(num_samples) | |||||
sampler = BucketedBatchSampler(dataset, length=dataset._data, batch_size=batch_size, drop_last=drop_last, | |||||
shuffle=shuffle, num_batch_per_bucket=2) | |||||
dl = prepare_torch_dataloader(dataset, batch_sampler=sampler) | |||||
model = TorchNormalModel_Classification_1(10, 32) | |||||
device = [torch.device(i) for i in [0, 1]] | |||||
driver = TorchDDPDriver(model, parallel_device=device) | |||||
driver.setup() | |||||
dl = driver.set_dist_repro_dataloader(dataloader=dl, dist='dist', reproducible=reproducible) | |||||
data = [] | |||||
flags = [] | |||||
for batch in dl: | |||||
d = batch['x'].reshape(-1).tolist() | |||||
diff = max(d) - min(d) | |||||
assert diff<batch_size*2*2*2 | |||||
data.extend(d) | |||||
flags.append(len(d)==batch_size) | |||||
_num_samples = num_samples//num_device | |||||
if drop_last and _num_samples%batch_size != 0: | |||||
assert len(data)!=num_samples | |||||
assert all(flags) == True | |||||
elif _num_samples%batch_size!=0: | |||||
assert flags[-1] is False | |||||
else: | |||||
assert len(data) == _num_samples | |||||
if not shuffle: | |||||
for i in range(1, len(data)-1): | |||||
assert data[i]<data[i-1] | |||||
else: | |||||
flags = [] | |||||
for i in range(1, len(data)-1): | |||||
flags.append(data[i]<data[i-1]) | |||||
assert all(flags) is False | |||||
datas = fastnlp_torch_all_gather(data) | |||||
if drop_last: | |||||
assert len(set(datas[0] + datas[1])) == num_samples-_num_samples%batch_size*2 | |||||
else: | |||||
assert len(set(datas[0] + datas[1])) == num_samples | |||||
finally: | |||||
if dist.is_initialized(): | |||||
dist.barrier() | |||||
dist.destroy_process_group() |
@@ -11,6 +11,7 @@ from tests.helpers.datasets.paddle_data import PaddleNormalDataset | |||||
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 | from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 | ||||
from fastNLP.envs.distributed import rank_zero_rm | from fastNLP.envs.distributed import rank_zero_rm | ||||
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE, _NEED_IMPORT_TORCH | from fastNLP.envs.imports import _NEED_IMPORT_PADDLE, _NEED_IMPORT_TORCH | ||||
from fastNLP import prepare_torch_dataloader, BucketedBatchSampler | |||||
if _NEED_IMPORT_TORCH: | if _NEED_IMPORT_TORCH: | ||||
import torch | import torch | ||||
@@ -710,3 +711,85 @@ def test_save_and_load_with_randomsampler(only_state_dict, fp16): | |||||
assert len(left_y_batches | already_seen_y_set) == len(dataset) | assert len(left_y_batches | already_seen_y_set) == len(dataset) | ||||
finally: | finally: | ||||
rank_zero_rm(path) | rank_zero_rm(path) | ||||
@pytest.mark.torch | |||||
@pytest.mark.parametrize("shuffle", ([True, False])) | |||||
@pytest.mark.parametrize("batch_size", ([1, 3, 16, 17])) | |||||
@pytest.mark.parametrize("drop_last", ([True, False])) | |||||
@pytest.mark.parametrize("reproducible", ([True, False])) | |||||
def test_shuffle_dataloader(shuffle, batch_size, drop_last, reproducible): | |||||
# 需要检验一下 set_dist_repro_dataloader 没有修改参数 | |||||
num_samples = 100 | |||||
dataset = TorchNormalXYDataset(num_samples) | |||||
dl = prepare_torch_dataloader(dataset, shuffle=shuffle, batch_size=batch_size, drop_last=drop_last) | |||||
model = TorchNormalModel_Classification_1(10, 32) | |||||
driver = TorchSingleDriver(model, device="cpu") | |||||
dl = driver.set_dist_repro_dataloader(dataloader=dl, reproducible=reproducible) | |||||
data = [] | |||||
flags = [] | |||||
for batch in dl: | |||||
flags.append(batch['x'].size(0) == batch_size) | |||||
data.extend(batch['x'].reshape(-1).tolist()) | |||||
if drop_last and num_samples%batch_size != 0: | |||||
assert len(data)!=num_samples | |||||
assert all(flags) == True | |||||
elif num_samples%batch_size!=0: | |||||
assert flags[-1] is False | |||||
else: | |||||
assert len(data) == num_samples | |||||
if not shuffle: | |||||
for i in range(1, len(data)): | |||||
assert data[i]>data[i-1] | |||||
else: | |||||
flags = [] | |||||
for i in range(1, len(data)): | |||||
flags.append(data[i]>data[i-1]) | |||||
assert all(flags) is False | |||||
@pytest.mark.torch | |||||
@pytest.mark.parametrize("shuffle", ([True, False])) | |||||
@pytest.mark.parametrize("batch_size", ([1, 3, 16, 17])) | |||||
@pytest.mark.parametrize("drop_last", ([True, False])) | |||||
@pytest.mark.parametrize("reproducible", ([True, False])) | |||||
def test_batch_sampler_dataloader(shuffle, batch_size, drop_last, reproducible): | |||||
# 需要检验一下 set_dist_repro_dataloader 没有修改参数 | |||||
num_samples = 100 | |||||
dataset = TorchNormalXYDataset(num_samples) | |||||
sampler = BucketedBatchSampler(dataset, length=dataset._data, batch_size=batch_size, drop_last=drop_last, | |||||
shuffle=shuffle, num_batch_per_bucket=2) | |||||
dl = prepare_torch_dataloader(dataset, batch_sampler=sampler) | |||||
model = TorchNormalModel_Classification_1(10, 32) | |||||
driver = TorchSingleDriver(model, device="cpu") | |||||
dl = driver.set_dist_repro_dataloader(dataloader=dl, reproducible=reproducible) | |||||
data = [] | |||||
flags = [] | |||||
for batch in dl: | |||||
d = batch['x'].reshape(-1).tolist() | |||||
diff = max(d) - min(d) | |||||
assert diff<batch_size*2 | |||||
data.extend(d) | |||||
flags.append(len(d)==batch_size) | |||||
if drop_last and num_samples%batch_size != 0: | |||||
assert len(data)!=num_samples | |||||
assert all(flags) == True | |||||
elif num_samples%batch_size!=0: | |||||
assert flags[-1] is False | |||||
else: | |||||
assert len(data) == num_samples | |||||
if not shuffle: | |||||
for i in range(1, len(data)): | |||||
assert data[i]<data[i-1] | |||||
else: | |||||
flags = [] | |||||
for i in range(1, len(data)): | |||||
flags.append(data[i]<data[i-1]) | |||||
assert all(flags) is False | |||||