You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_dataset.py 874 B

12345678910111213141516171819202122232425262728293031323334
  1. # -*- coding: utf-8 -*-
  2. import os
  3. import sys
  4. import numpy as np
  5. import pytest
  6. from megengine.data.dataset import ArrayDataset, Dataset, StreamDataset
  7. def test_abstract_cls():
  8. with pytest.raises(TypeError):
  9. Dataset()
  10. with pytest.raises(TypeError):
  11. StreamDataset()
  12. def test_array_dataset():
  13. size = (10,)
  14. data_shape = (3, 256, 256)
  15. label_shape = (1,)
  16. data = np.random.randint(0, 255, size + data_shape)
  17. label = np.random.randint(0, 9, size + label_shape)
  18. dataset = ArrayDataset(data, label)
  19. assert dataset[0][0].shape == data_shape
  20. assert dataset[0][1].shape == label_shape
  21. assert len(dataset) == size[0]
  22. def test_array_dataset_dim_error():
  23. data = np.random.randint(0, 255, (10, 3, 256, 256))
  24. label = np.random.randint(0, 9, (1,))
  25. with pytest.raises(ValueError):
  26. ArrayDataset(data, label)