You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_transform.py 3.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. # -*- coding: utf-8 -*-
  2. import numpy as np
  3. from megengine.data.transform import *
  4. data_shape = (100, 100, 3)
  5. label_shape = (4,)
  6. ToMode_target_shape = (3, 100, 100)
  7. CenterCrop_size = (90, 70)
  8. CenterCrop_target_shape = CenterCrop_size + (3,)
  9. RandomResizedCrop_size = (50, 50)
  10. RandomResizedCrop_target_shape = RandomResizedCrop_size + (3,)
  11. def generate_data():
  12. return [
  13. (
  14. (np.random.rand(*data_shape) * 255).astype(np.uint8),
  15. np.random.randint(10, size=label_shape),
  16. )
  17. for _ in range(*label_shape)
  18. ]
  19. def test_ToMode():
  20. t = ToMode(mode="CHW")
  21. aug_data = t.apply_batch(generate_data())
  22. aug_data_shape = [(a.shape, b.shape) for a, b in aug_data]
  23. target_shape = [(ToMode_target_shape, label_shape)] * 4
  24. assert aug_data_shape == target_shape
  25. def test_CenterCrop():
  26. t = CenterCrop(output_size=CenterCrop_size)
  27. aug_data = t.apply_batch(generate_data())
  28. aug_data_shape = [(a.shape, b.shape) for a, b in aug_data]
  29. target_shape = [(CenterCrop_target_shape, label_shape)] * 4
  30. assert aug_data_shape == target_shape
  31. def test_ColorJitter():
  32. t = ColorJitter()
  33. aug_data = t.apply_batch(generate_data())
  34. aug_data_shape = [(a.shape, b.shape) for a, b in aug_data]
  35. target_shape = [(data_shape, label_shape)] * 4
  36. assert aug_data_shape == target_shape
  37. def test_RandomHorizontalFlip():
  38. t = RandomHorizontalFlip(prob=1)
  39. aug_data = t.apply_batch(generate_data())
  40. aug_data_shape = [(a.shape, b.shape) for a, b in aug_data]
  41. target_shape = [(data_shape, label_shape)] * 4
  42. assert aug_data_shape == target_shape
  43. def test_RandomVerticalFlip():
  44. t = RandomVerticalFlip(prob=1)
  45. aug_data = t.apply_batch(generate_data())
  46. aug_data_shape = [(a.shape, b.shape) for a, b in aug_data]
  47. target_shape = [(data_shape, label_shape)] * 4
  48. assert aug_data_shape == target_shape
  49. def test_RandomResizedCrop():
  50. t = RandomResizedCrop(output_size=RandomResizedCrop_size)
  51. aug_data = t.apply_batch(generate_data())
  52. aug_data_shape = [(a.shape, b.shape) for a, b in aug_data]
  53. target_shape = [(RandomResizedCrop_target_shape, label_shape)] * 4
  54. assert aug_data_shape == target_shape
  55. def test_Normalize():
  56. t = Normalize()
  57. aug_data = t.apply_batch(generate_data())
  58. aug_data_shape = [(a.shape, b.shape) for a, b in aug_data]
  59. target_shape = [(data_shape, label_shape)] * 4
  60. assert aug_data_shape == target_shape
  61. def test_RandomCrop():
  62. t = RandomCrop((150, 120), padding_size=10, padding_value=[1, 2, 3])
  63. aug_data = t.apply_batch(generate_data())
  64. aug_data_shape = [(a.shape, b.shape) for a, b in aug_data]
  65. target_shape = [((150, 120, 3), label_shape)] * 4
  66. assert aug_data_shape == target_shape
  67. def test_Compose():
  68. t = Compose(
  69. [
  70. CenterCrop(output_size=CenterCrop_size),
  71. RandomHorizontalFlip(prob=1),
  72. ToMode(mode="CHW"),
  73. ]
  74. )
  75. aug_data = t.apply_batch(generate_data())
  76. aug_data_shape = [(a.shape, b.shape) for a, b in aug_data]
  77. target_shape = [((3, 90, 70), label_shape)] * 4
  78. assert aug_data_shape == target_shape, "aug {}, target {}".format(
  79. aug_data_shape, target_shape
  80. )