|
|
@@ -8,7 +8,7 @@ from fastNLP.envs.imports import _NEED_IMPORT_TORCH |
|
|
|
|
|
|
|
if _NEED_IMPORT_TORCH: |
|
|
|
import torch |
|
|
|
from torch.utils.data import default_collate, SequentialSampler, RandomSampler |
|
|
|
from torch.utils.data import SequentialSampler, RandomSampler |
|
|
|
|
|
|
|
d1 = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) |
|
|
|
|
|
|
@@ -17,7 +17,7 @@ d2 = DataSet({'x': [[101, 201], [201, 301, 401], [100]] * 10, 'y': [20, 10, 10] |
|
|
|
d3 = DataSet({'x': [[1000, 2000], [0], [2000, 3000, 4000, 5000]] * 100, 'y': [100, 100, 200] * 100}) |
|
|
|
|
|
|
|
|
|
|
|
def test_pad_val(tensor, val=0): |
|
|
|
def _test_pad_val(tensor, val=0): |
|
|
|
if isinstance(tensor, torch.Tensor): |
|
|
|
tensor = tensor.tolist() |
|
|
|
for item in tensor: |
|
|
@@ -28,6 +28,7 @@ def test_pad_val(tensor, val=0): |
|
|
|
return True |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.torch |
|
|
|
class TestMixDataLoader: |
|
|
|
|
|
|
|
def test_sequential_init(self): |
|
|
@@ -44,7 +45,7 @@ class TestMixDataLoader: |
|
|
|
if idx > 1: |
|
|
|
# d3 |
|
|
|
assert batch['x'].shape == torch.Size([16, 4]) |
|
|
|
assert test_pad_val(batch['x'], val=0) |
|
|
|
assert _test_pad_val(batch['x'], val=0) |
|
|
|
|
|
|
|
# collate_fn = Callable |
|
|
|
def collate_batch(batch): |
|
|
@@ -73,13 +74,13 @@ class TestMixDataLoader: |
|
|
|
dl2 = MixDataLoader(datasets=datasets, mode='sequential', collate_fn=collate_fns, drop_last=True) |
|
|
|
for idx, batch in enumerate(dl2): |
|
|
|
if idx == 0: |
|
|
|
assert test_pad_val(batch['x'], val=-1) |
|
|
|
assert _test_pad_val(batch['x'], val=-1) |
|
|
|
assert batch['x'].shape == torch.Size([16, 4]) |
|
|
|
if idx == 1: |
|
|
|
assert test_pad_val(batch['x'], val=-2) |
|
|
|
assert _test_pad_val(batch['x'], val=-2) |
|
|
|
assert batch['x'].shape == torch.Size([16, 3]) |
|
|
|
if idx > 1: |
|
|
|
assert test_pad_val(batch['x'], val=-3) |
|
|
|
assert _test_pad_val(batch['x'], val=-3) |
|
|
|
assert batch['x'].shape == torch.Size([16, 4]) |
|
|
|
|
|
|
|
# sampler 为 str |
|
|
@@ -100,7 +101,7 @@ class TestMixDataLoader: |
|
|
|
if idx > 1: |
|
|
|
# d3 |
|
|
|
assert batch['x'].shape == torch.Size([16, 4]) |
|
|
|
assert test_pad_val(batch['x'], val=0) |
|
|
|
assert _test_pad_val(batch['x'], val=0) |
|
|
|
|
|
|
|
for idx, batch in enumerate(dl4): |
|
|
|
if idx == 0: |
|
|
@@ -117,7 +118,7 @@ class TestMixDataLoader: |
|
|
|
if idx > 1: |
|
|
|
# d3 |
|
|
|
assert batch['x'].shape == torch.Size([16, 4]) |
|
|
|
assert test_pad_val(batch['x'], val=0) |
|
|
|
assert _test_pad_val(batch['x'], val=0) |
|
|
|
|
|
|
|
# sampler 为 Dict |
|
|
|
samplers = {'d1': SequentialSampler(d1), |
|
|
@@ -136,7 +137,7 @@ class TestMixDataLoader: |
|
|
|
if idx > 1: |
|
|
|
# d3 |
|
|
|
assert batch['x'].shape == torch.Size([16, 4]) |
|
|
|
assert test_pad_val(batch['x'], val=0) |
|
|
|
assert _test_pad_val(batch['x'], val=0) |
|
|
|
|
|
|
|
# ds_ratio 为 'truncate_to_least' |
|
|
|
dl6 = MixDataLoader(datasets=datasets, mode='sequential', ds_ratio='truncate_to_least', drop_last=True) |
|
|
@@ -153,7 +154,7 @@ class TestMixDataLoader: |
|
|
|
# d3 |
|
|
|
assert batch['x'][:3].tolist() == [[1000, 2000, 0, 0], [0, 0, 0, 0], [2000, 3000, 4000, 5000]] |
|
|
|
assert batch['x'].shape == torch.Size([16, 4]) |
|
|
|
assert test_pad_val(batch['x'], val=0) |
|
|
|
assert _test_pad_val(batch['x'], val=0) |
|
|
|
if idx > 2: |
|
|
|
raise ValueError(f"ds_ratio: 'truncate_to_least' error") |
|
|
|
|
|
|
@@ -169,7 +170,7 @@ class TestMixDataLoader: |
|
|
|
if 36 <= idx < 54: |
|
|
|
# d3 |
|
|
|
assert batch['x'].shape == torch.Size([16, 4]) |
|
|
|
assert test_pad_val(batch['x'], val=0) |
|
|
|
assert _test_pad_val(batch['x'], val=0) |
|
|
|
if idx >= 54: |
|
|
|
raise ValueError(f"ds_ratio: 'pad_to_most' error") |
|
|
|
|
|
|
@@ -186,7 +187,7 @@ class TestMixDataLoader: |
|
|
|
if 4 <= idx < 41: |
|
|
|
# d3 |
|
|
|
assert batch['x'].shape == torch.Size([16, 4]) |
|
|
|
assert test_pad_val(batch['x'], val=0) |
|
|
|
assert _test_pad_val(batch['x'], val=0) |
|
|
|
if idx >= 41: |
|
|
|
raise ValueError(f"ds_ratio: 'pad_to_most' error") |
|
|
|
|
|
|
@@ -200,7 +201,7 @@ class TestMixDataLoader: |
|
|
|
# d3 |
|
|
|
assert batch['x'].shape == torch.Size([16, 4]) |
|
|
|
|
|
|
|
assert test_pad_val(batch['x'], val=0) |
|
|
|
assert _test_pad_val(batch['x'], val=0) |
|
|
|
if idx >= 19: |
|
|
|
raise ValueError(f"ds_ratio: 'pad_to_most' error") |
|
|
|
|
|
|
@@ -208,7 +209,7 @@ class TestMixDataLoader: |
|
|
|
datasets = {'d1': d1, 'd2': d2, 'd3': d3} |
|
|
|
dl = MixDataLoader(datasets=datasets, mode='mix', collate_fn='auto', drop_last=True) |
|
|
|
for idx, batch in enumerate(dl): |
|
|
|
assert test_pad_val(batch['x'], val=0) |
|
|
|
assert _test_pad_val(batch['x'], val=0) |
|
|
|
if idx >= 22: |
|
|
|
raise ValueError(f"out of range") |
|
|
|
|
|
|
@@ -223,7 +224,7 @@ class TestMixDataLoader: |
|
|
|
dl1 = MixDataLoader(datasets=datasets, mode='mix', collate_fn=collate_batch, drop_last=True) |
|
|
|
for idx, batch in enumerate(dl1): |
|
|
|
assert isinstance(batch['x'], list) |
|
|
|
assert test_pad_val(batch['x'], val=0) |
|
|
|
assert _test_pad_val(batch['x'], val=0) |
|
|
|
if idx >= 22: |
|
|
|
raise ValueError(f"out of range") |
|
|
|
|
|
|
@@ -236,12 +237,12 @@ class TestMixDataLoader: |
|
|
|
# sampler 为 str |
|
|
|
dl3 = MixDataLoader(datasets=datasets, mode='mix', sampler='seq', drop_last=True) |
|
|
|
for idx, batch in enumerate(dl3): |
|
|
|
assert test_pad_val(batch['x'], val=0) |
|
|
|
assert _test_pad_val(batch['x'], val=0) |
|
|
|
if idx >= 22: |
|
|
|
raise ValueError(f"out of range") |
|
|
|
dl4 = MixDataLoader(datasets=datasets, mode='mix', sampler='rand', drop_last=True) |
|
|
|
for idx, batch in enumerate(dl4): |
|
|
|
assert test_pad_val(batch['x'], val=0) |
|
|
|
assert _test_pad_val(batch['x'], val=0) |
|
|
|
if idx >= 22: |
|
|
|
raise ValueError(f"out of range") |
|
|
|
# sampler 为 Dict |
|
|
@@ -250,7 +251,7 @@ class TestMixDataLoader: |
|
|
|
'd3': RandomSampler(d3)} |
|
|
|
dl5 = MixDataLoader(datasets=datasets, mode='mix', sampler=samplers, drop_last=True) |
|
|
|
for idx, batch in enumerate(dl5): |
|
|
|
assert test_pad_val(batch['x'], val=0) |
|
|
|
assert _test_pad_val(batch['x'], val=0) |
|
|
|
if idx >= 22: |
|
|
|
raise ValueError(f"out of range") |
|
|
|
# ds_ratio 为 'truncate_to_least' |
|
|
@@ -332,7 +333,7 @@ class TestMixDataLoader: |
|
|
|
assert batch['x'].shape[1] == 4 |
|
|
|
if idx > 20: |
|
|
|
raise ValueError(f"out of range") |
|
|
|
test_pad_val(batch['x'], val=0) |
|
|
|
_test_pad_val(batch['x'], val=0) |
|
|
|
|
|
|
|
# collate_fn = Callable |
|
|
|
def collate_batch(batch): |
|
|
@@ -360,16 +361,16 @@ class TestMixDataLoader: |
|
|
|
dl1 = MixDataLoader(datasets=datasets, mode='polling', collate_fn=collate_fns, batch_size=18) |
|
|
|
for idx, batch in enumerate(dl1): |
|
|
|
if idx == 0 or idx == 3: |
|
|
|
assert test_pad_val(batch['x'], val=-1) |
|
|
|
assert _test_pad_val(batch['x'], val=-1) |
|
|
|
assert batch['x'][:3].tolist() == [[1, 2, -1, -1], [2, 3, 4, -1], [4, 5, 6, 7]] |
|
|
|
assert batch['x'].shape[1] == 4 |
|
|
|
elif idx == 1 or idx == 4: |
|
|
|
# d2 |
|
|
|
assert test_pad_val(batch['x'], val=-2) |
|
|
|
assert _test_pad_val(batch['x'], val=-2) |
|
|
|
assert batch['x'][:3].tolist() == [[101, 201, -2], [201, 301, 401], [100, -2, -2]] |
|
|
|
assert batch['x'].shape[1] == 3 |
|
|
|
elif idx == 2 or 4 < idx <= 20: |
|
|
|
assert test_pad_val(batch['x'], val=-3) |
|
|
|
assert _test_pad_val(batch['x'], val=-3) |
|
|
|
assert batch['x'][:3].tolist() == [[1000, 2000, -3, -3], [0, -3, -3, -3], [2000, 3000, 4000, 5000]] |
|
|
|
assert batch['x'].shape[1] == 4 |
|
|
|
if idx > 20: |
|
|
@@ -391,7 +392,7 @@ class TestMixDataLoader: |
|
|
|
assert batch['x'].shape[1] == 4 |
|
|
|
if idx > 20: |
|
|
|
raise ValueError(f"out of range") |
|
|
|
test_pad_val(batch['x'], val=0) |
|
|
|
_test_pad_val(batch['x'], val=0) |
|
|
|
for idx, batch in enumerate(dl3): |
|
|
|
if idx == 0 or idx == 3: |
|
|
|
assert batch['x'].shape[1] == 4 |
|
|
@@ -402,7 +403,7 @@ class TestMixDataLoader: |
|
|
|
assert batch['x'].shape[1] == 4 |
|
|
|
if idx > 20: |
|
|
|
raise ValueError(f"out of range") |
|
|
|
test_pad_val(batch['x'], val=0) |
|
|
|
_test_pad_val(batch['x'], val=0) |
|
|
|
# sampler 为 Dict |
|
|
|
samplers = {'d1': SequentialSampler(d1), |
|
|
|
'd2': SequentialSampler(d2), |
|
|
@@ -420,7 +421,7 @@ class TestMixDataLoader: |
|
|
|
assert batch['x'].shape[1] == 4 |
|
|
|
if idx > 20: |
|
|
|
raise ValueError(f"out of range") |
|
|
|
test_pad_val(batch['x'], val=0) |
|
|
|
_test_pad_val(batch['x'], val=0) |
|
|
|
|
|
|
|
# ds_ratio 为 'truncate_to_least' |
|
|
|
dl5 = MixDataLoader(datasets=datasets, mode='polling', ds_ratio='truncate_to_least', batch_size=18) |
|
|
@@ -437,7 +438,7 @@ class TestMixDataLoader: |
|
|
|
assert batch['x'].shape[1] == 4 |
|
|
|
if idx > 5: |
|
|
|
raise ValueError(f"out of range") |
|
|
|
test_pad_val(batch['x'], val=0) |
|
|
|
_test_pad_val(batch['x'], val=0) |
|
|
|
|
|
|
|
# ds_ratio 为 'pad_to_most' |
|
|
|
dl6 = MixDataLoader(datasets=datasets, mode='polling', ds_ratio='pad_to_most', batch_size=18) |
|
|
@@ -456,7 +457,7 @@ class TestMixDataLoader: |
|
|
|
assert batch['x'].shape[1] == 4 |
|
|
|
if idx >= 51: |
|
|
|
raise ValueError(f"out of range") |
|
|
|
test_pad_val(batch['x'], val=0) |
|
|
|
_test_pad_val(batch['x'], val=0) |
|
|
|
|
|
|
|
# ds_ratio 为 Dict[str, float] |
|
|
|
ds_ratio = {'d1': 1.0, 'd2': 2.0, 'd3': 2.0} |
|
|
@@ -474,7 +475,7 @@ class TestMixDataLoader: |
|
|
|
assert batch['x'].shape[1] == 4 |
|
|
|
if idx > 39: |
|
|
|
raise ValueError(f"out of range") |
|
|
|
test_pad_val(batch['x'], val=0) |
|
|
|
_test_pad_val(batch['x'], val=0) |
|
|
|
|
|
|
|
ds_ratio = {'d1': 0.1, 'd2': 0.6, 'd3': 1.0} |
|
|
|
dl8 = MixDataLoader(datasets=datasets, mode='polling', ds_ratio=ds_ratio, batch_size=18) |
|
|
@@ -492,4 +493,4 @@ class TestMixDataLoader: |
|
|
|
|
|
|
|
if idx > 18: |
|
|
|
raise ValueError(f"out of range") |
|
|
|
test_pad_val(batch['x'], val=0) |
|
|
|
_test_pad_val(batch['x'], val=0) |