Browse Source

Fix two issues.

tags/v1.2.1
jin-xiulang 4 years ago
parent
commit
0f41805f19
2 changed files with 4 additions and 5 deletions
  1. +3
    -4
      examples/common/dataset/data_processing.py
  2. +1
    -1
      mindarmour/privacy/evaluation/inversion_attack.py

+ 3
- 4
examples/common/dataset/data_processing.py View File

@@ -105,13 +105,12 @@ def vgg_create_dataset100(data_home, image_size, batch_size, rank_id=0, rank_siz
data_set = data_set.map(input_columns="label", operations=type_cast_op)
data_set = data_set.map(input_columns="image", operations=c_trans)

# apply repeat operations
data_set = data_set.repeat(repeat_num)

# apply shuffle operations
# data_set = data_set.shuffle(buffer_size=1000)
data_set = data_set.shuffle(buffer_size=1000)

# apply batch operations
data_set = data_set.batch(batch_size=batch_size, drop_remainder=True)

# apply repeat operations
data_set = data_set.repeat(repeat_num)
return data_set

+ 1
- 1
mindarmour/privacy/evaluation/inversion_attack.py View File

@@ -107,7 +107,7 @@ class ImageInversionAttack:
for sub_loss_weight in loss_weights:
check_value_positive('sub_loss_weight', sub_loss_weight)
self._loss = InversionLoss(self._network, loss_weights)
self._input_shape = check_param_multi_types('input_shape', input_shape, [list, tuple])
self._input_shape = check_param_type('input_shape', input_shape, tuple)
for shape_dim in input_shape:
check_int_positive('shape_dim', shape_dim)
self._input_bound = check_param_multi_types('input_bound', input_bound, [list, tuple])


Loading…
Cancel
Save