You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

data_processing.py 7.7 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """data processing"""
  15. import os
  16. import mindspore.dataset as ds
  17. import mindspore.dataset.vision.c_transforms as CV
  18. import mindspore.dataset.transforms.c_transforms as C
  19. from mindspore.dataset.vision import Inter
  20. import mindspore.common.dtype as mstype
  21. def generate_mnist_dataset(data_path, batch_size=32, repeat_size=1,
  22. num_samples=None, num_parallel_workers=1, sparse=True):
  23. """
  24. create dataset for training or testing
  25. """
  26. # define dataset
  27. ds1 = ds.MnistDataset(data_path, num_samples=num_samples)
  28. # define operation parameters
  29. resize_height, resize_width = 32, 32
  30. rescale = 1.0 / 255.0
  31. shift = 0.0
  32. # define map operations
  33. resize_op = CV.Resize((resize_height, resize_width),
  34. interpolation=Inter.LINEAR)
  35. rescale_op = CV.Rescale(rescale, shift)
  36. hwc2chw_op = CV.HWC2CHW()
  37. type_cast_op = C.TypeCast(mstype.int32)
  38. # apply map operations on images
  39. if not sparse:
  40. one_hot_enco = C.OneHot(10)
  41. ds1 = ds1.map(input_columns="label", operations=one_hot_enco,
  42. num_parallel_workers=num_parallel_workers)
  43. type_cast_op = C.TypeCast(mstype.float32)
  44. ds1 = ds1.map(input_columns="label", operations=type_cast_op,
  45. num_parallel_workers=num_parallel_workers)
  46. ds1 = ds1.map(input_columns="image", operations=resize_op,
  47. num_parallel_workers=num_parallel_workers)
  48. ds1 = ds1.map(input_columns="image", operations=rescale_op,
  49. num_parallel_workers=num_parallel_workers)
  50. ds1 = ds1.map(input_columns="image", operations=hwc2chw_op,
  51. num_parallel_workers=num_parallel_workers)
  52. # apply DatasetOps
  53. buffer_size = 10000
  54. ds1 = ds1.shuffle(buffer_size=buffer_size)
  55. ds1 = ds1.batch(batch_size, drop_remainder=True)
  56. ds1 = ds1.repeat(repeat_size)
  57. return ds1
  58. def vgg_create_dataset100(data_home, image_size, batch_size, rank_id=0, rank_size=1, repeat_num=1,
  59. training=True, num_samples=None, shuffle=True):
  60. """Data operations."""
  61. ds.config.set_seed(1)
  62. data_dir = os.path.join(data_home, "train")
  63. if not training:
  64. data_dir = os.path.join(data_home, "test")
  65. if num_samples is not None:
  66. data_set = ds.Cifar100Dataset(data_dir, num_shards=rank_size, shard_id=rank_id,
  67. num_samples=num_samples, shuffle=shuffle)
  68. else:
  69. data_set = ds.Cifar100Dataset(data_dir, num_shards=rank_size, shard_id=rank_id)
  70. input_columns = ["fine_label"]
  71. output_columns = ["label"]
  72. data_set = data_set.rename(input_columns=input_columns, output_columns=output_columns)
  73. data_set = data_set.project(["image", "label"])
  74. rescale = 1.0 / 255.0
  75. shift = 0.0
  76. # define map operations
  77. random_crop_op = CV.RandomCrop((32, 32), (4, 4, 4, 4)) # padding_mode default CONSTANT
  78. random_horizontal_op = CV.RandomHorizontalFlip()
  79. resize_op = CV.Resize(image_size) # interpolation default BILINEAR
  80. rescale_op = CV.Rescale(rescale, shift)
  81. normalize_op = CV.Normalize((0.4465, 0.4822, 0.4914), (0.2010, 0.1994, 0.2023))
  82. changeswap_op = CV.HWC2CHW()
  83. type_cast_op = C.TypeCast(mstype.int32)
  84. c_trans = []
  85. if training:
  86. c_trans = [random_crop_op, random_horizontal_op]
  87. c_trans += [resize_op, rescale_op, normalize_op,
  88. changeswap_op]
  89. # apply map operations on images
  90. data_set = data_set.map(input_columns="label", operations=type_cast_op)
  91. data_set = data_set.map(input_columns="image", operations=c_trans)
  92. # apply shuffle operations
  93. data_set = data_set.shuffle(buffer_size=1000)
  94. # apply batch operations
  95. data_set = data_set.batch(batch_size=batch_size, drop_remainder=True)
  96. # apply repeat operations
  97. data_set = data_set.repeat(repeat_num)
  98. return data_set
  99. def create_dataset_imagenet(path, batch_size=32, repeat_size=20, status="train", target="GPU"):
  100. image_ds = ds.ImageFolderDataset(path, decode=True)
  101. rescale = 1.0 / 255.0
  102. shift = 0.0
  103. cfg = {'num_classes': 10,
  104. 'learning_rate': 0.002,
  105. 'momentum': 0.9,
  106. 'epoch_size': 30,
  107. 'batch_size': 32,
  108. 'buffer_size': 1000,
  109. 'image_height': 224,
  110. 'image_width': 224,
  111. 'save_checkpoint_steps': 1562,
  112. 'keep_checkpoint_max': 10}
  113. resize_op = CV.Resize((cfg['image_height'], cfg['image_width']))
  114. rescale_op = CV.Rescale(rescale, shift)
  115. normalize_op = CV.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
  116. random_crop_op = CV.RandomCrop([32, 32], [4, 4, 4, 4])
  117. random_horizontal_op = CV.RandomHorizontalFlip()
  118. channel_swap_op = CV.HWC2CHW()
  119. typecast_op = C.TypeCast(mstype.int32)
  120. image_ds = image_ds.map(input_columns="label", operations=typecast_op, num_parallel_workers=6)
  121. image_ds = image_ds.map(input_columns="image", operations=random_crop_op, num_parallel_workers=6)
  122. image_ds = image_ds.map(input_columns="image", operations=random_horizontal_op, num_parallel_workers=6)
  123. image_ds = image_ds.map(input_columns="image", operations=resize_op, num_parallel_workers=6)
  124. image_ds = image_ds.map(input_columns="image", operations=rescale_op, num_parallel_workers=6)
  125. image_ds = image_ds.map(input_columns="image", operations=normalize_op, num_parallel_workers=6)
  126. image_ds = image_ds.map(input_columns="image", operations=channel_swap_op, num_parallel_workers=6)
  127. image_ds = image_ds.shuffle(buffer_size=cfg['buffer_size'])
  128. image_ds = image_ds.repeat(repeat_size)
  129. return image_ds
  130. def create_dataset_cifar(data_path, image_height, image_width, repeat_num=1, training=True):
  131. """
  132. create data for next use such as training or infering
  133. """
  134. cifar_ds = ds.Cifar10Dataset(data_path)
  135. resize_height = image_height # 224
  136. resize_width = image_width # 224
  137. rescale = 1.0 / 255.0
  138. shift = 0.0
  139. batch_size = 32
  140. # define map operations
  141. random_crop_op = CV.RandomCrop((32, 32), (4, 4, 4, 4)) # padding_mode default CONSTANT
  142. random_horizontal_op = CV.RandomHorizontalFlip()
  143. resize_op = CV.Resize((resize_height, resize_width)) # interpolation default BILINEAR
  144. rescale_op = CV.Rescale(rescale, shift)
  145. normalize_op = CV.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
  146. changeswap_op = CV.HWC2CHW()
  147. type_cast_op = C.TypeCast(mstype.int32)
  148. c_trans = []
  149. if training:
  150. c_trans = [random_crop_op, random_horizontal_op]
  151. c_trans += [resize_op, rescale_op, normalize_op,
  152. changeswap_op]
  153. # apply map operations on images
  154. cifar_ds = cifar_ds.map(operations=type_cast_op, input_columns="label")
  155. cifar_ds = cifar_ds.map(operations=c_trans, input_columns="image")
  156. # apply shuffle operations
  157. cifar_ds = cifar_ds.shuffle(buffer_size=10)
  158. # apply batch operations
  159. cifar_ds = cifar_ds.batch(batch_size=batch_size, drop_remainder=True)
  160. # apply repeat operations
  161. cifar_ds = cifar_ds.repeat(repeat_num)
  162. return cifar_ds

MindArmour关注AI的安全和隐私问题。致力于增强模型的安全可信、保护用户的数据隐私。主要包含3个模块:对抗样本鲁棒性模块、Fuzz Testing模块、隐私保护与评估模块。 对抗样本鲁棒性模块 对抗样本鲁棒性模块用于评估模型对于对抗样本的鲁棒性,并提供模型增强方法用于增强模型抗对抗样本攻击的能力,提升模型鲁棒性。对抗样本鲁棒性模块包含了4个子模块:对抗样本的生成、对抗样本的检测、模型防御、攻防评估。