From 756bd09d1adb6e87ba8c44973b9c49a7df4a3fbb Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Mon, 23 May 2022 13:17:45 +0000 Subject: [PATCH] =?UTF-8?q?1.=E5=88=A0=E9=99=A4=E4=B8=8D=E5=BF=85=E8=A6=81?= =?UTF-8?q?=E7=9A=84=E6=B5=8B=E8=AF=95=E6=A0=87=E7=AD=BE=202.=E4=B8=BAtest?= =?UTF-8?q?=5Fmixdataloader=E6=B7=BB=E5=8A=A0torch=E6=A0=87=E7=AD=BE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/core/dataloaders/torch_dataloader/test_fdl.py | 1 - tests/core/dataloaders/torch_dataloader/test_mixdataloader.py | 3 ++- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/core/dataloaders/torch_dataloader/test_fdl.py b/tests/core/dataloaders/torch_dataloader/test_fdl.py index d977e52f..b53790bb 100644 --- a/tests/core/dataloaders/torch_dataloader/test_fdl.py +++ b/tests/core/dataloaders/torch_dataloader/test_fdl.py @@ -147,7 +147,6 @@ class TestFdl: assert 'Parameter:prefetch_factor' in out[0] @recover_logger - @pytest.mark.temp def test_version_111(self): if parse_version(torch.__version__) <= parse_version('1.7'): pytest.skip("Torch version smaller than 1.7") diff --git a/tests/core/dataloaders/torch_dataloader/test_mixdataloader.py b/tests/core/dataloaders/torch_dataloader/test_mixdataloader.py index 35872b39..17c151ea 100644 --- a/tests/core/dataloaders/torch_dataloader/test_mixdataloader.py +++ b/tests/core/dataloaders/torch_dataloader/test_mixdataloader.py @@ -8,7 +8,7 @@ from fastNLP.envs.imports import _NEED_IMPORT_TORCH if _NEED_IMPORT_TORCH: import torch - from torch.utils.data import default_collate, SequentialSampler, RandomSampler + from torch.utils.data import SequentialSampler, RandomSampler d1 = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) @@ -28,6 +28,7 @@ def test_pad_val(tensor, val=0): return True +@pytest.mark.torch class TestMixDataLoader: def test_sequential_init(self):