|
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586 |
- # -*- coding: utf-8 -*-
- import copy
- import os
- import sys
-
- import numpy as np
- import pytest
-
- from megengine.data.dataset import ArrayDataset
- from megengine.data.sampler import (
- Infinite,
- RandomSampler,
- ReplacementSampler,
- SequentialSampler,
- )
-
-
- def test_sequential_sampler():
- indices = list(range(100))
- sampler = SequentialSampler(ArrayDataset(indices))
- assert indices == list(each[0] for each in sampler)
-
-
- def test_RandomSampler():
- indices = list(range(20))
- indices_copy = copy.deepcopy(indices)
- sampler = RandomSampler(ArrayDataset(indices_copy))
- sample_indices = sampler
- assert indices != list(each[0] for each in sample_indices)
- assert indices == sorted(list(each[0] for each in sample_indices))
-
-
- def test_InfiniteSampler():
- indices = list(range(20))
- seque_sampler = SequentialSampler(ArrayDataset(indices), batch_size=2)
- inf_sampler = Infinite(seque_sampler)
- assert inf_sampler.batch_size == seque_sampler.batch_size
-
-
- def test_random_sampler_seed():
- seed = [0, 1]
- indices = list(range(20))
- indices_copy1 = copy.deepcopy(indices)
- indices_copy2 = copy.deepcopy(indices)
- indices_copy3 = copy.deepcopy(indices)
- sampler1 = RandomSampler(ArrayDataset(indices_copy1), seed=seed[0])
- sampler2 = RandomSampler(ArrayDataset(indices_copy2), seed=seed[0])
- sampler3 = RandomSampler(ArrayDataset(indices_copy3), seed=seed[1])
- assert indices != list(each[0] for each in sampler1)
- assert indices != list(each[0] for each in sampler2)
- assert indices != list(each[0] for each in sampler3)
- assert indices == sorted(list(each[0] for each in sampler1))
- assert indices == sorted(list(each[0] for each in sampler2))
- assert indices == sorted(list(each[0] for each in sampler3))
- assert list(each[0] for each in sampler1) == list(each[0] for each in sampler2)
- assert list(each[0] for each in sampler1) != list(each[0] for each in sampler3)
-
-
- def test_ReplacementSampler():
- num_samples = 30
- indices = list(range(20))
- weights = list(range(20))
- sampler = ReplacementSampler(
- ArrayDataset(indices), num_samples=num_samples, weights=weights
- )
- assert len(list(each[0] for each in sampler)) == num_samples
-
-
- def test_sampler_drop_last_false():
- batch_size = 5
- drop_last = False
- indices = list(range(24))
- sampler = SequentialSampler(
- ArrayDataset(indices), batch_size=batch_size, drop_last=drop_last
- )
- assert len([each for each in sampler]) == len(sampler)
-
-
- def test_sampler_drop_last_true():
- batch_size = 5
- drop_last = True
- indices = list(range(24))
- sampler = SequentialSampler(
- ArrayDataset(indices), batch_size=batch_size, drop_last=drop_last
- )
- assert len([each for each in sampler]) == len(sampler)
|