You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dataloader.py 7.2 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. import os
  2. import nvidia.dali.ops as ops
  3. import nvidia.dali.types as types
  4. import torch.utils.data
  5. from nvidia.dali.pipeline import Pipeline
  6. from nvidia.dali.plugin.pytorch import DALIClassificationIterator
  7. class HybridTrainPipe(Pipeline):
  8. def __init__(self, batch_size, num_threads, device_id, data_dir, crop, seed=12, local_rank=0, world_size=1,
  9. spos_pre=False):
  10. super(HybridTrainPipe, self).__init__(batch_size, num_threads, device_id, seed=seed + device_id)
  11. color_space_type = types.BGR if spos_pre else types.RGB
  12. self.input = ops.FileReader(file_root=data_dir, shard_id=local_rank, num_shards=world_size, random_shuffle=True)
  13. self.decode = ops.ImageDecoder(device="mixed", output_type=types.BGR) # color_space_type
  14. self.res = ops.RandomResizedCrop(device="gpu", size=crop,
  15. interp_type=types.INTERP_LINEAR if spos_pre else types.INTERP_TRIANGULAR)
  16. self.twist = ops.ColorTwist(device="gpu")
  17. self.jitter_rng = ops.Uniform(range=[0.6, 1.4])
  18. # self.cmnp = ops.CropMirrorNormalize(device="gpu",
  19. # dtype = types.FLOAT, # output_dtype=types.FLOAT,
  20. # output_layout=types.NCHW,
  21. # # image_type=color_space_type, # 该功能被删掉了,在ImageDecoder中即可完成
  22. # mean=0. if spos_pre else [0.485 * 255, 0.456 * 255, 0.406 * 255],
  23. # std=1. if spos_pre else [0.229 * 255, 0.224 * 255, 0.225 * 255])
  24. self.cmnp = ops.CropMirrorNormalize(device="gpu", dtype = types.FLOAT, output_layout=types.NCHW,
  25. mean= [0.485 * 255, 0.456 * 255, 0.406 * 255],
  26. std = [0.229 * 255, 0.224 * 255, 0.225 * 255]
  27. )
  28. self.coin = ops.CoinFlip(probability=0.5)
  29. def define_graph(self):
  30. rng = self.coin()
  31. self.jpegs, self.labels = self.input(name="Reader")
  32. images = self.decode(self.jpegs)
  33. images = self.res(images)
  34. images = self.twist(images, saturation=self.jitter_rng(),
  35. contrast=self.jitter_rng(), brightness=self.jitter_rng())
  36. output = self.cmnp(images, mirror=rng) # 临时删除,测试准确率为零是否是数据处理的原因
  37. return [output, self.labels] # output
  38. class HybridValPipe(Pipeline):
  39. def __init__(self, batch_size, num_threads, device_id, data_dir, crop, size, seed=12, local_rank=0, world_size=1,
  40. spos_pre=False, shuffle=False):
  41. super(HybridValPipe, self).__init__(batch_size, num_threads, device_id, seed=seed + device_id)
  42. color_space_type = types.BGR if spos_pre else types.RGB
  43. self.input = ops.FileReader(file_root=data_dir, shard_id=local_rank, num_shards=world_size,
  44. random_shuffle=shuffle)
  45. self.decode = ops.ImageDecoder(device="mixed", output_type=types.BGR)
  46. self.res = ops.Resize(device="gpu", resize_shorter=size,
  47. interp_type=types.INTERP_LINEAR if spos_pre else types.INTERP_TRIANGULAR)
  48. # self.cmnp = ops.CropMirrorNormalize(device="gpu",
  49. # dtype = types.FLOAT, # output_dtype=types.FLOAT,
  50. # output_layout=types.NCHW,
  51. # crop=(crop, crop),
  52. # # image_type=color_space_type,
  53. # mean=0. if spos_pre else [0.485 * 255, 0.456 * 255, 0.406 * 255],
  54. # std=1. if spos_pre else [0.229 * 255, 0.224 * 255, 0.225 * 255])
  55. self.cmnp = ops.CropMirrorNormalize(device="gpu",
  56. dtype = types.FLOAT, # output_dtype=types.FLOAT,
  57. output_layout=types.NCHW,
  58. crop=(crop, crop),
  59. # image_type=color_space_type,
  60. mean = [0.485 * 255, 0.456 * 255, 0.406 * 255],
  61. std = [0.229 * 255, 0.224 * 255, 0.225 * 255])
  62. def define_graph(self):
  63. self.jpegs, self.labels = self.input(name="Reader")
  64. images = self.decode(self.jpegs)
  65. images = self.res(images)
  66. output = self.cmnp(images)
  67. return [output, self.labels]
  68. class ClassificationWrapper:
  69. def __init__(self, loader, size):
  70. self.loader = loader
  71. self.size = size
  72. def __iter__(self):
  73. return self
  74. def __next__(self):
  75. data = next(self.loader)
  76. return data[0]["data"], data[0]["label"].view(-1).long().cuda(device="cuda:0", non_blocking=True) # .cuda(non_blocking=True)
  77. def __len__(self):
  78. return self.size
  79. def get_imagenet_iter_dali(split, image_dir, batch_size, num_threads, crop=224, val_size=256,
  80. spos_preprocessing=False, seed=12, shuffle=False, device_id=None):
  81. world_size, local_rank = 1, 0
  82. if device_id is None:
  83. device_id = torch.cuda.device_count() - 1 # use last gpu
  84. if split == "train":
  85. pipeline = HybridTrainPipe(batch_size=batch_size, num_threads=num_threads, device_id=device_id,
  86. data_dir=os.path.join(image_dir, "train"), seed=seed,
  87. crop=crop, world_size=world_size, local_rank=local_rank,
  88. spos_pre=spos_preprocessing)
  89. elif split == "val":
  90. pipeline = HybridValPipe(batch_size=batch_size, num_threads=num_threads, device_id=device_id,
  91. data_dir=os.path.join(image_dir, "val"), seed=seed,
  92. crop=crop, size=val_size, world_size=world_size, local_rank=local_rank,
  93. spos_pre=spos_preprocessing, shuffle=shuffle)
  94. else:
  95. raise AssertionError
  96. pipeline.build()
  97. num_samples = pipeline.epoch_size("Reader")
  98. # fill_last_batch的设置
  99. # 参考这里, valid和train设置为一样的策略
  100. # https://docs.nvidia.com/deeplearning/dali/user-guide/docs/plugins/pytorch_plugin_api.html?highlight=daliclassificationiterator#nvidia.dali.plugin.pytorch.DALIClassificationIterator
  101. last_batch_policy = ""
  102. last_batch_padded = True
  103. return ClassificationWrapper(
  104. DALIClassificationIterator(pipeline,
  105. # size=num_samples,
  106. last_batch_policy = last_batch_policy,
  107. # last_batch_padded = last_batch_padded,
  108. # fill_last_batch=split == "train", # 这个方法已经不建议使用了
  109. auto_reset=True),
  110. (num_samples + batch_size - 1) // batch_size)

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能