|
|
@@ -7,7 +7,12 @@ import numpy as np |
|
|
|
import pytest |
|
|
|
|
|
|
|
from megengine.data.dataset import ArrayDataset |
|
|
|
from megengine.data.sampler import RandomSampler, ReplacementSampler, SequentialSampler |
|
|
|
from megengine.data.sampler import ( |
|
|
|
Infinite, |
|
|
|
RandomSampler, |
|
|
|
ReplacementSampler, |
|
|
|
SequentialSampler, |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def test_sequential_sampler(): |
|
|
@@ -25,6 +30,13 @@ def test_RandomSampler(): |
|
|
|
assert indices == sorted(list(each[0] for each in sample_indices)) |
|
|
|
|
|
|
|
|
|
|
|
def test_InfiniteSampler(): |
|
|
|
indices = list(range(20)) |
|
|
|
seque_sampler = SequentialSampler(ArrayDataset(indices), batch_size=2) |
|
|
|
inf_sampler = Infinite(seque_sampler) |
|
|
|
assert inf_sampler.batch_size == seque_sampler.batch_size |
|
|
|
|
|
|
|
|
|
|
|
def test_random_sampler_seed(): |
|
|
|
seed = [0, 1] |
|
|
|
indices = list(range(20)) |
|
|
|