You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_transform.py 3.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import numpy as np
  10. from megengine.data.transform import *
  11. data_shape = (100, 100, 3)
  12. label_shape = (4,)
  13. ToMode_target_shape = (3, 100, 100)
  14. CenterCrop_size = (90, 70)
  15. CenterCrop_target_shape = CenterCrop_size + (3,)
  16. RandomResizedCrop_size = (50, 50)
  17. RandomResizedCrop_target_shape = RandomResizedCrop_size + (3,)
  18. def generate_data():
  19. return [
  20. (
  21. (np.random.rand(*data_shape) * 255).astype(np.uint8),
  22. np.random.randint(10, size=label_shape),
  23. )
  24. for _ in range(*label_shape)
  25. ]
  26. def test_ToMode():
  27. t = ToMode(mode="CHW")
  28. aug_data = t.apply_batch(generate_data())
  29. aug_data_shape = [(a.shape, b.shape) for a, b in aug_data]
  30. target_shape = [(ToMode_target_shape, label_shape)] * 4
  31. assert aug_data_shape == target_shape
  32. def test_CenterCrop():
  33. t = CenterCrop(output_size=CenterCrop_size)
  34. aug_data = t.apply_batch(generate_data())
  35. aug_data_shape = [(a.shape, b.shape) for a, b in aug_data]
  36. target_shape = [(CenterCrop_target_shape, label_shape)] * 4
  37. assert aug_data_shape == target_shape
  38. def test_ColorJitter():
  39. t = ColorJitter()
  40. aug_data = t.apply_batch(generate_data())
  41. aug_data_shape = [(a.shape, b.shape) for a, b in aug_data]
  42. target_shape = [(data_shape, label_shape)] * 4
  43. assert aug_data_shape == target_shape
  44. def test_RandomHorizontalFlip():
  45. t = RandomHorizontalFlip(prob=1)
  46. aug_data = t.apply_batch(generate_data())
  47. aug_data_shape = [(a.shape, b.shape) for a, b in aug_data]
  48. target_shape = [(data_shape, label_shape)] * 4
  49. assert aug_data_shape == target_shape
  50. def test_RandomVerticalFlip():
  51. t = RandomVerticalFlip(prob=1)
  52. aug_data = t.apply_batch(generate_data())
  53. aug_data_shape = [(a.shape, b.shape) for a, b in aug_data]
  54. target_shape = [(data_shape, label_shape)] * 4
  55. assert aug_data_shape == target_shape
  56. def test_RandomResizedCrop():
  57. t = RandomResizedCrop(output_size=RandomResizedCrop_size)
  58. aug_data = t.apply_batch(generate_data())
  59. aug_data_shape = [(a.shape, b.shape) for a, b in aug_data]
  60. target_shape = [(RandomResizedCrop_target_shape, label_shape)] * 4
  61. assert aug_data_shape == target_shape
  62. def test_Normalize():
  63. t = Normalize()
  64. aug_data = t.apply_batch(generate_data())
  65. aug_data_shape = [(a.shape, b.shape) for a, b in aug_data]
  66. target_shape = [(data_shape, label_shape)] * 4
  67. assert aug_data_shape == target_shape
  68. def test_RandomCrop():
  69. t = RandomCrop((150, 120), padding_size=10, padding_value=[1, 2, 3])
  70. aug_data = t.apply_batch(generate_data())
  71. aug_data_shape = [(a.shape, b.shape) for a, b in aug_data]
  72. target_shape = [((150, 120, 3), label_shape)] * 4
  73. assert aug_data_shape == target_shape
  74. def test_Compose():
  75. t = Compose(
  76. [
  77. CenterCrop(output_size=CenterCrop_size),
  78. RandomHorizontalFlip(prob=1),
  79. ToMode(mode="CHW"),
  80. ]
  81. )
  82. aug_data = t.apply_batch(generate_data())
  83. aug_data_shape = [(a.shape, b.shape) for a, b in aug_data]
  84. target_shape = [((3, 90, 70), label_shape)] * 4
  85. assert aug_data_shape == target_shape, "aug {}, target {}".format(
  86. aug_data_shape, target_shape
  87. )

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台