diff --git a/imperative/python/megengine/data/sampler.py b/imperative/python/megengine/data/sampler.py index 7105f16f..1c8af495 100644 --- a/imperative/python/megengine/data/sampler.py +++ b/imperative/python/megengine/data/sampler.py @@ -326,3 +326,10 @@ class Infinite(MapSampler): def __len__(self): return np.iinfo(np.int64).max + + def __getattr__(self, name): + # if attribute could not be found in Infinite, + # try to find it in self.sampler + if name not in self.__dict__: + return getattr(self.sampler, name) + return self.__dict__[name] diff --git a/imperative/python/test/unit/data/test_sampler.py b/imperative/python/test/unit/data/test_sampler.py index 5204d2d7..f4b33a90 100644 --- a/imperative/python/test/unit/data/test_sampler.py +++ b/imperative/python/test/unit/data/test_sampler.py @@ -7,7 +7,12 @@ import numpy as np import pytest from megengine.data.dataset import ArrayDataset -from megengine.data.sampler import RandomSampler, ReplacementSampler, SequentialSampler +from megengine.data.sampler import ( + Infinite, + RandomSampler, + ReplacementSampler, + SequentialSampler, +) def test_sequential_sampler(): @@ -25,6 +30,13 @@ def test_RandomSampler(): assert indices == sorted(list(each[0] for each in sample_indices)) +def test_InfiniteSampler(): + indices = list(range(20)) + seque_sampler = SequentialSampler(ArrayDataset(indices), batch_size=2) + inf_sampler = Infinite(seque_sampler) + assert inf_sampler.batch_size == seque_sampler.batch_size + + def test_random_sampler_seed(): seed = [0, 1] indices = list(range(20))