Browse Source

Modify inversion attack.

tags/v1.2.1
jin-xiulang 4 years ago
parent
commit
fde222bf2b
3 changed files with 89 additions and 25 deletions
  1. +39
    -20
      examples/privacy/inversion_attack/mnist_inversion_attack.py
  2. +33
    -5
      mindarmour/privacy/evaluation/inversion_attack.py
  3. +17
    -0
      tests/ut/python/privacy/evaluation/test_inversion_attack.py

+ 39
- 20
examples/privacy/inversion_attack/mnist_inversion_attack.py View File

@@ -67,36 +67,55 @@ def mnist_inversion_attack(net):
load_dict = load_checkpoint(ckpt_path)
load_param_into_net(net, load_dict)

# get test data
data_list = "../../common/dataset/MNIST/test"
# get original data and their inferred fearures
data_list = "../../common/dataset/MNIST/train"
batch_size = 32
ds = generate_mnist_dataset(data_list, batch_size)

inversion_attack = ImageInversionAttack(net, input_shape=(1, 32, 32), input_bound=(0, 1), loss_weights=[1, 0.2, 5])

i = 0
batch_num = 1
sample_num = 10
sample_num = 30
for data in ds.create_tuple_iterator(output_numpy=True):
i += 1
images = data[0].astype(np.float32)
target_features = net(Tensor(images)).asnumpy()
true_labels = data[1][: sample_num]
target_features = net(Tensor(images)).asnumpy()[:sample_num]
original_images = images[: sample_num]
inversion_images = inversion_attack.generate(target_features[:sample_num], iters=100)
for n in range(1, sample_num+1):
plt.subplot(2, sample_num, n)
plt.gray()
plt.imshow(images[n - 1].reshape(32, 32))
plt.subplot(2, sample_num, n + sample_num)
plt.gray()
plt.imshow(inversion_images[n - 1].reshape(32, 32))
plt.show()
if i >= batch_num:
break
# evaluate the similarity between inversion images and original images
avg_l2_dis, avg_ssim = inversion_attack.evaluate(original_images, inversion_images)
LOGGER.info(TAG, 'The average L2 distance between original images and inversion images is: {}'.format(avg_l2_dis))
LOGGER.info(TAG, 'The average ssim value between original images and inversion images is: {}'.format(avg_ssim))

# run attacking
inversion_attack = ImageInversionAttack(net, input_shape=(1, 32, 32), input_bound=(0, 1), loss_weights=[1, 0.1, 5])
inversion_images = inversion_attack.generate(target_features, iters=100)

# get the predict results of inversion images on a new trained model
net2 = LeNet5()
new_ckpt_path = '../../common/networks/lenet5/new_trained_ckpt_file/checkpoint_lenet-10_1875.ckpt'
new_load_dict = load_checkpoint(new_ckpt_path)
load_param_into_net(net2, new_load_dict)
pred_labels = np.argmax(net2(Tensor(inversion_images).astype(np.float32)).asnumpy(), axis=1)

# evaluate the quality of inversion images
avg_l2_dis, avg_ssim, avg_confi = inversion_attack.evaluate(original_images, inversion_images, true_labels, net2)
LOGGER.info(TAG, 'The average L2 distance between original images and inverted images is: {}'.format(avg_l2_dis))
LOGGER.info(TAG, 'The average ssim value between original images and inverted images is: {}'.format(avg_ssim))
LOGGER.info(TAG, 'The average prediction confidence on true labels of inverted images is: {}'.format(avg_confi))
LOGGER.info(TAG, 'True labels of original images are: %s' % true_labels)
LOGGER.info(TAG, 'Predicted labels of inverted images are: %s' % pred_labels)

# plot 10 images
plot_num = min(sample_num, 10)
for n in range(1, plot_num+1):
plt.subplot(2, plot_num, n)
if n == 1:
plt.title('Original images', fontsize=16, loc='left')
plt.gray()
plt.imshow(images[n - 1].reshape(32, 32))
plt.subplot(2, plot_num, n + plot_num)
if n == 1:
plt.title('Inverted images', fontsize=16, loc='left')
plt.gray()
plt.imshow(inversion_images[n - 1].reshape(32, 32))
plt.show()


if __name__ == '__main__':


+ 33
- 5
mindarmour/privacy/evaluation/inversion_attack.py View File

@@ -15,6 +15,7 @@
Inversion Attack
"""
import numpy as np
from scipy.special import softmax

from mindspore.nn import Cell, MSELoss
from mindspore import Tensor
@@ -166,18 +167,24 @@ class ImageInversionAttack:
inversion_images.append(inversion_image_n)
return np.concatenate(np.array(inversion_images))

def evaluate(self, original_images, inversion_images):
def evaluate(self, original_images, inversion_images, labels=None, new_network=None):
"""
Compute the average L2 distance and SSIM value between original images and inversion images.
Evaluate the quality of inverted images by three index: the average L2 distance and SSIM value between
original images and inversion images, and the average of inverted images' confidence on true labels of inverted
inferred by a new trained network.

Args:
original_images (numpy.ndarray): Original images, whose shape should be (img_num, channels, img_width,
img_height).
inversion_images (numpy.ndarray): Inversion images, whose shape should be (img_num, channels, img_width,
img_height).
labels (numpy.ndarray): Ground truth labels of original images. Default: None.
new_network (Cell): A network whose structure contains all parts of self._network, but loaded with different
checkpoint file. Default: None.

Returns:
tuple, the average l2 distance and average ssim value between original images and inversion images.
tuple, average l2 distance, average ssim value and average confidence (if labels or new_network is None,
then average confidence would be None).

Examples:
>>> net = LeNet5()
@@ -188,15 +195,31 @@ class ImageInversionAttack:
>>> ori_images = np.random.random((2, 1, 32, 32))
>>> result = inversion_attack.evaluate(ori_images, inver_images)
>>> print(len(result))
2
3
"""
check_numpy_param('original_images', original_images)
check_numpy_param('inversion_images', inversion_images)
if labels is not None:
check_numpy_param('labels', labels)
true_labels = np.squeeze(labels)
if len(true_labels.shape) > 1:
msg = 'Shape of true_labels should be (1, n) or (n,), but got {}'.format(true_labels.shape)
raise ValueError(msg)
if true_labels.size != original_images.shape[0]:
msg = 'The size of true_labels should equal the number of images, but got {} and {}'.format(
true_labels.size, original_images.shape[0])
raise ValueError(msg)
if new_network is not None:
check_param_type('new_network', new_network, Cell)
LOGGER.info(TAG, 'Please make sure that the network you pass is loaded with different checkpoint files '
'compared with that of self._network.')

img_1, img_2 = check_equal_shape('original_images', original_images, 'inversion_images', inversion_images)
if (len(img_1.shape) != 4) or (img_1.shape[1] != 1 and img_1.shape[1] != 3):
msg = 'The shape format of img_1 and img_2 should be (img_num, channels, img_width, img_height),' \
' but got {} and {}'.format(img_1.shape, img_2.shape)
raise ValueError(msg)

total_l2_distance = 0
total_ssim = 0
img_1 = img_1.transpose(0, 2, 3, 1)
@@ -207,4 +230,9 @@ class ImageInversionAttack:
total_ssim += compute_ssim(img_1[i], img_2[i])
avg_l2_dis = total_l2_distance / img_1.shape[0]
avg_ssim = total_ssim / img_1.shape[0]
return avg_l2_dis, avg_ssim
avg_confi = None
if (new_network is not None) and (labels is not None):
pred_logits = new_network(Tensor(inversion_images.astype(np.float32))).asnumpy()
logits_softmax = softmax(pred_logits, axis=1)
avg_confi = np.mean(logits_softmax[np.arange(img_1.shape[0]), true_labels])
return avg_l2_dis, avg_ssim, avg_confi

+ 17
- 0
tests/ut/python/privacy/evaluation/test_inversion_attack.py View File

@@ -42,3 +42,20 @@ def test_inversion_attack():
avg_ssim = inversion_attack.evaluate(original_images, inversion_images)
assert 0 < avg_ssim[1] < 1
assert target_features.shape[0] == inversion_images.shape[0]


@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
@pytest.mark.component_mindarmour
def test_inversion_attack2():
net = Net()
original_images = np.random.random((2, 1, 32, 32)).astype(np.float32)
target_features = np.random.random((2, 10)).astype(np.float32)
inversion_attack = ImageInversionAttack(net, input_shape=(1, 32, 32), input_bound=(0, 1), loss_weights=[1, 0.2, 5])
inversion_images = inversion_attack.generate(target_features, iters=10)
true_labels = np.array([1, 2])
new_net = Net()
indexes = inversion_attack.evaluate(original_images, inversion_images, true_labels, new_net)
assert len(indexes) == 3

Loading…
Cancel
Save