diff --git a/tests/core/dataloaders/torch_dataloader/test_fdl.py b/tests/core/dataloaders/torch_dataloader/test_fdl.py index d977e52f..b53790bb 100644 --- a/tests/core/dataloaders/torch_dataloader/test_fdl.py +++ b/tests/core/dataloaders/torch_dataloader/test_fdl.py @@ -147,7 +147,6 @@ class TestFdl: assert 'Parameter:prefetch_factor' in out[0] @recover_logger - @pytest.mark.temp def test_version_111(self): if parse_version(torch.__version__) <= parse_version('1.7'): pytest.skip("Torch version smaller than 1.7") diff --git a/tests/core/dataloaders/torch_dataloader/test_mixdataloader.py b/tests/core/dataloaders/torch_dataloader/test_mixdataloader.py index 35872b39..17c151ea 100644 --- a/tests/core/dataloaders/torch_dataloader/test_mixdataloader.py +++ b/tests/core/dataloaders/torch_dataloader/test_mixdataloader.py @@ -8,7 +8,7 @@ from fastNLP.envs.imports import _NEED_IMPORT_TORCH if _NEED_IMPORT_TORCH: import torch - from torch.utils.data import default_collate, SequentialSampler, RandomSampler + from torch.utils.data import SequentialSampler, RandomSampler d1 = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) @@ -28,6 +28,7 @@ def test_pad_val(tensor, val=0): return True +@pytest.mark.torch class TestMixDataLoader: def test_sequential_init(self):