Browse Source

fix(pytest/windows/impertive): fix impertive pytest failed on windows

GitOrigin-RevId: 02f4c0a0be
release-1.1
Megvii Engine Team 4 years ago
parent
commit
786399c5c9
2 changed files with 15 additions and 0 deletions
  1. +6
    -0
      imperative/python/megengine/data/dataloader.py
  2. +9
    -0
      imperative/python/test/unit/data/test_dataloader.py

+ 6
- 0
imperative/python/megengine/data/dataloader.py View File

@@ -9,6 +9,7 @@
import collections import collections
import math import math
import multiprocessing import multiprocessing
import platform
import queue import queue
import random import random
import time import time
@@ -113,6 +114,11 @@ class DataLoader:
self.__initialized = True self.__initialized = True


def __iter__(self): def __iter__(self):
if platform.system() == "Windows":
print(
"pyarrow.plasma does not support ParallelDataLoader on windows, changing num_workers to be zero"
)
self.num_workers = 0
if self.num_workers == 0: if self.num_workers == 0:
return _SerialDataLoaderIter(self) return _SerialDataLoaderIter(self)
else: else:


+ 9
- 0
imperative/python/test/unit/data/test_dataloader.py View File

@@ -7,6 +7,7 @@
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import os import os
import platform
import time import time


import numpy as np import numpy as np
@@ -89,6 +90,10 @@ def test_dataloader_parallel():
assert label.shape == (4,) assert label.shape == (4,)




@pytest.mark.skipif(
platform.system() == "Windows",
reason="dataloader do not support parallel on windows",
)
def test_dataloader_parallel_timeout(): def test_dataloader_parallel_timeout():
dataset = init_dataset() dataset = init_dataset()


@@ -112,6 +117,10 @@ def test_dataloader_parallel_timeout():
batch_data = next(data_iter) batch_data = next(data_iter)




@pytest.mark.skipif(
platform.system() == "Windows",
reason="dataloader do not support parallel on windows",
)
def test_dataloader_parallel_worker_exception(): def test_dataloader_parallel_worker_exception():
dataset = init_dataset() dataset = init_dataset()




Loading…
Cancel
Save