From 0f41805f19de1be50ed684e4664a52ccc45350bd Mon Sep 17 00:00:00 2001 From: jin-xiulang Date: Mon, 22 Feb 2021 16:58:20 +0800 Subject: [PATCH] Fix two issues. --- examples/common/dataset/data_processing.py | 7 +++---- mindarmour/privacy/evaluation/inversion_attack.py | 2 +- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/examples/common/dataset/data_processing.py b/examples/common/dataset/data_processing.py index 8baf5ba..cf8b850 100644 --- a/examples/common/dataset/data_processing.py +++ b/examples/common/dataset/data_processing.py @@ -105,13 +105,12 @@ def vgg_create_dataset100(data_home, image_size, batch_size, rank_id=0, rank_siz data_set = data_set.map(input_columns="label", operations=type_cast_op) data_set = data_set.map(input_columns="image", operations=c_trans) - # apply repeat operations - data_set = data_set.repeat(repeat_num) - # apply shuffle operations - # data_set = data_set.shuffle(buffer_size=1000) + data_set = data_set.shuffle(buffer_size=1000) # apply batch operations data_set = data_set.batch(batch_size=batch_size, drop_remainder=True) + # apply repeat operations + data_set = data_set.repeat(repeat_num) return data_set diff --git a/mindarmour/privacy/evaluation/inversion_attack.py b/mindarmour/privacy/evaluation/inversion_attack.py index 8080d03..4fd959e 100644 --- a/mindarmour/privacy/evaluation/inversion_attack.py +++ b/mindarmour/privacy/evaluation/inversion_attack.py @@ -107,7 +107,7 @@ class ImageInversionAttack: for sub_loss_weight in loss_weights: check_value_positive('sub_loss_weight', sub_loss_weight) self._loss = InversionLoss(self._network, loss_weights) - self._input_shape = check_param_multi_types('input_shape', input_shape, [list, tuple]) + self._input_shape = check_param_type('input_shape', input_shape, tuple) for shape_dim in input_shape: check_int_positive('shape_dim', shape_dim) self._input_bound = check_param_multi_types('input_bound', input_bound, [list, tuple])