You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_sampler.py 2.8 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. # -*- coding: utf-8 -*-
  2. import copy
  3. import os
  4. import sys
  5. import numpy as np
  6. import pytest
  7. from megengine.data.dataset import ArrayDataset
  8. from megengine.data.sampler import (
  9. Infinite,
  10. RandomSampler,
  11. ReplacementSampler,
  12. SequentialSampler,
  13. )
  14. def test_sequential_sampler():
  15. indices = list(range(100))
  16. sampler = SequentialSampler(ArrayDataset(indices))
  17. assert indices == list(each[0] for each in sampler)
  18. def test_RandomSampler():
  19. indices = list(range(20))
  20. indices_copy = copy.deepcopy(indices)
  21. sampler = RandomSampler(ArrayDataset(indices_copy))
  22. sample_indices = sampler
  23. assert indices != list(each[0] for each in sample_indices)
  24. assert indices == sorted(list(each[0] for each in sample_indices))
  25. def test_InfiniteSampler():
  26. indices = list(range(20))
  27. seque_sampler = SequentialSampler(ArrayDataset(indices), batch_size=2)
  28. inf_sampler = Infinite(seque_sampler)
  29. assert inf_sampler.batch_size == seque_sampler.batch_size
  30. def test_random_sampler_seed():
  31. seed = [0, 1]
  32. indices = list(range(20))
  33. indices_copy1 = copy.deepcopy(indices)
  34. indices_copy2 = copy.deepcopy(indices)
  35. indices_copy3 = copy.deepcopy(indices)
  36. sampler1 = RandomSampler(ArrayDataset(indices_copy1), seed=seed[0])
  37. sampler2 = RandomSampler(ArrayDataset(indices_copy2), seed=seed[0])
  38. sampler3 = RandomSampler(ArrayDataset(indices_copy3), seed=seed[1])
  39. assert indices != list(each[0] for each in sampler1)
  40. assert indices != list(each[0] for each in sampler2)
  41. assert indices != list(each[0] for each in sampler3)
  42. assert indices == sorted(list(each[0] for each in sampler1))
  43. assert indices == sorted(list(each[0] for each in sampler2))
  44. assert indices == sorted(list(each[0] for each in sampler3))
  45. assert list(each[0] for each in sampler1) == list(each[0] for each in sampler2)
  46. assert list(each[0] for each in sampler1) != list(each[0] for each in sampler3)
  47. def test_ReplacementSampler():
  48. num_samples = 30
  49. indices = list(range(20))
  50. weights = list(range(20))
  51. sampler = ReplacementSampler(
  52. ArrayDataset(indices), num_samples=num_samples, weights=weights
  53. )
  54. assert len(list(each[0] for each in sampler)) == num_samples
  55. def test_sampler_drop_last_false():
  56. batch_size = 5
  57. drop_last = False
  58. indices = list(range(24))
  59. sampler = SequentialSampler(
  60. ArrayDataset(indices), batch_size=batch_size, drop_last=drop_last
  61. )
  62. assert len([each for each in sampler]) == len(sampler)
  63. def test_sampler_drop_last_true():
  64. batch_size = 5
  65. drop_last = True
  66. indices = list(range(24))
  67. sampler = SequentialSampler(
  68. ArrayDataset(indices), batch_size=batch_size, drop_last=drop_last
  69. )
  70. assert len([each for each in sampler]) == len(sampler)