Browse Source

fix(imperative): infinite sampler support get batchsize

GitOrigin-RevId: 52e3d65249
master
Megvii Engine Team 2 years ago
parent
commit
7d83a9ad47
2 changed files with 20 additions and 1 deletions
  1. +7
    -0
      imperative/python/megengine/data/sampler.py
  2. +13
    -1
      imperative/python/test/unit/data/test_sampler.py

+ 7
- 0
imperative/python/megengine/data/sampler.py View File

@@ -326,3 +326,10 @@ class Infinite(MapSampler):

def __len__(self):
return np.iinfo(np.int64).max

def __getattr__(self, name):
# if attribute could not be found in Infinite,
# try to find it in self.sampler
if name not in self.__dict__:
return getattr(self.sampler, name)
return self.__dict__[name]

+ 13
- 1
imperative/python/test/unit/data/test_sampler.py View File

@@ -7,7 +7,12 @@ import numpy as np
import pytest

from megengine.data.dataset import ArrayDataset
from megengine.data.sampler import RandomSampler, ReplacementSampler, SequentialSampler
from megengine.data.sampler import (
Infinite,
RandomSampler,
ReplacementSampler,
SequentialSampler,
)


def test_sequential_sampler():
@@ -25,6 +30,13 @@ def test_RandomSampler():
assert indices == sorted(list(each[0] for each in sample_indices))


def test_InfiniteSampler():
indices = list(range(20))
seque_sampler = SequentialSampler(ArrayDataset(indices), batch_size=2)
inf_sampler = Infinite(seque_sampler)
assert inf_sampler.batch_size == seque_sampler.batch_size


def test_random_sampler_seed():
seed = [0, 1]
indices = list(range(20))


Loading…
Cancel
Save