You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_sampler.py 2.5 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. # -*- coding: utf-8 -*-
  2. import copy
  3. import os
  4. import sys
  5. import numpy as np
  6. import pytest
  7. from megengine.data.dataset import ArrayDataset
  8. from megengine.data.sampler import RandomSampler, ReplacementSampler, SequentialSampler
  9. def test_sequential_sampler():
  10. indices = list(range(100))
  11. sampler = SequentialSampler(ArrayDataset(indices))
  12. assert indices == list(each[0] for each in sampler)
  13. def test_RandomSampler():
  14. indices = list(range(20))
  15. indices_copy = copy.deepcopy(indices)
  16. sampler = RandomSampler(ArrayDataset(indices_copy))
  17. sample_indices = sampler
  18. assert indices != list(each[0] for each in sample_indices)
  19. assert indices == sorted(list(each[0] for each in sample_indices))
  20. def test_random_sampler_seed():
  21. seed = [0, 1]
  22. indices = list(range(20))
  23. indices_copy1 = copy.deepcopy(indices)
  24. indices_copy2 = copy.deepcopy(indices)
  25. indices_copy3 = copy.deepcopy(indices)
  26. sampler1 = RandomSampler(ArrayDataset(indices_copy1), seed=seed[0])
  27. sampler2 = RandomSampler(ArrayDataset(indices_copy2), seed=seed[0])
  28. sampler3 = RandomSampler(ArrayDataset(indices_copy3), seed=seed[1])
  29. assert indices != list(each[0] for each in sampler1)
  30. assert indices != list(each[0] for each in sampler2)
  31. assert indices != list(each[0] for each in sampler3)
  32. assert indices == sorted(list(each[0] for each in sampler1))
  33. assert indices == sorted(list(each[0] for each in sampler2))
  34. assert indices == sorted(list(each[0] for each in sampler3))
  35. assert list(each[0] for each in sampler1) == list(each[0] for each in sampler2)
  36. assert list(each[0] for each in sampler1) != list(each[0] for each in sampler3)
  37. def test_ReplacementSampler():
  38. num_samples = 30
  39. indices = list(range(20))
  40. weights = list(range(20))
  41. sampler = ReplacementSampler(
  42. ArrayDataset(indices), num_samples=num_samples, weights=weights
  43. )
  44. assert len(list(each[0] for each in sampler)) == num_samples
  45. def test_sampler_drop_last_false():
  46. batch_size = 5
  47. drop_last = False
  48. indices = list(range(24))
  49. sampler = SequentialSampler(
  50. ArrayDataset(indices), batch_size=batch_size, drop_last=drop_last
  51. )
  52. assert len([each for each in sampler]) == len(sampler)
  53. def test_sampler_drop_last_true():
  54. batch_size = 5
  55. drop_last = True
  56. indices = list(range(24))
  57. sampler = SequentialSampler(
  58. ArrayDataset(indices), batch_size=batch_size, drop_last=drop_last
  59. )
  60. assert len([each for each in sampler]) == len(sampler)