From 7d83a9ad47a4aede61bd3e4965e38e57909e3aac Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 14 Nov 2022 11:51:26 +0800 Subject: [PATCH] fix(imperative): infinite sampler support get batchsize GitOrigin-RevId: 52e3d6524932e74432ad5c989cab3f7fddd9192f --- imperative/python/megengine/data/sampler.py | 7 +++++++ imperative/python/test/unit/data/test_sampler.py | 14 +++++++++++++- 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/imperative/python/megengine/data/sampler.py b/imperative/python/megengine/data/sampler.py index 7105f16f..1c8af495 100644 --- a/imperative/python/megengine/data/sampler.py +++ b/imperative/python/megengine/data/sampler.py @@ -326,3 +326,10 @@ class Infinite(MapSampler): def __len__(self): return np.iinfo(np.int64).max + + def __getattr__(self, name): + # if attribute could not be found in Infinite, + # try to find it in self.sampler + if name not in self.__dict__: + return getattr(self.sampler, name) + return self.__dict__[name] diff --git a/imperative/python/test/unit/data/test_sampler.py b/imperative/python/test/unit/data/test_sampler.py index 5204d2d7..f4b33a90 100644 --- a/imperative/python/test/unit/data/test_sampler.py +++ b/imperative/python/test/unit/data/test_sampler.py @@ -7,7 +7,12 @@ import numpy as np import pytest from megengine.data.dataset import ArrayDataset -from megengine.data.sampler import RandomSampler, ReplacementSampler, SequentialSampler +from megengine.data.sampler import ( + Infinite, + RandomSampler, + ReplacementSampler, + SequentialSampler, +) def test_sequential_sampler(): @@ -25,6 +30,13 @@ def test_RandomSampler(): assert indices == sorted(list(each[0] for each in sample_indices)) +def test_InfiniteSampler(): + indices = list(range(20)) + seque_sampler = SequentialSampler(ArrayDataset(indices), batch_size=2) + inf_sampler = Infinite(seque_sampler) + assert inf_sampler.batch_size == seque_sampler.batch_size + + def test_random_sampler_seed(): seed = [0, 1] indices = list(range(20))