From fde222bf2bb2a2ac7ddf9d80acb0fc958b041974 Mon Sep 17 00:00:00 2001 From: jin-xiulang Date: Tue, 9 Feb 2021 18:37:56 +0800 Subject: [PATCH] Modify inversion attack. --- .../inversion_attack/mnist_inversion_attack.py | 59 ++++++++++++++-------- mindarmour/privacy/evaluation/inversion_attack.py | 38 ++++++++++++-- .../privacy/evaluation/test_inversion_attack.py | 17 +++++++ 3 files changed, 89 insertions(+), 25 deletions(-) diff --git a/examples/privacy/inversion_attack/mnist_inversion_attack.py b/examples/privacy/inversion_attack/mnist_inversion_attack.py index 9921091..75b7b36 100644 --- a/examples/privacy/inversion_attack/mnist_inversion_attack.py +++ b/examples/privacy/inversion_attack/mnist_inversion_attack.py @@ -67,36 +67,55 @@ def mnist_inversion_attack(net): load_dict = load_checkpoint(ckpt_path) load_param_into_net(net, load_dict) - # get test data - data_list = "../../common/dataset/MNIST/test" + # get original data and their inferred fearures + data_list = "../../common/dataset/MNIST/train" batch_size = 32 ds = generate_mnist_dataset(data_list, batch_size) - - inversion_attack = ImageInversionAttack(net, input_shape=(1, 32, 32), input_bound=(0, 1), loss_weights=[1, 0.2, 5]) - i = 0 batch_num = 1 - sample_num = 10 + sample_num = 30 for data in ds.create_tuple_iterator(output_numpy=True): i += 1 images = data[0].astype(np.float32) - target_features = net(Tensor(images)).asnumpy() + true_labels = data[1][: sample_num] + target_features = net(Tensor(images)).asnumpy()[:sample_num] original_images = images[: sample_num] - inversion_images = inversion_attack.generate(target_features[:sample_num], iters=100) - for n in range(1, sample_num+1): - plt.subplot(2, sample_num, n) - plt.gray() - plt.imshow(images[n - 1].reshape(32, 32)) - plt.subplot(2, sample_num, n + sample_num) - plt.gray() - plt.imshow(inversion_images[n - 1].reshape(32, 32)) - plt.show() if i >= batch_num: break - # evaluate the similarity between inversion images and original images - avg_l2_dis, avg_ssim = inversion_attack.evaluate(original_images, inversion_images) - LOGGER.info(TAG, 'The average L2 distance between original images and inversion images is: {}'.format(avg_l2_dis)) - LOGGER.info(TAG, 'The average ssim value between original images and inversion images is: {}'.format(avg_ssim)) + + # run attacking + inversion_attack = ImageInversionAttack(net, input_shape=(1, 32, 32), input_bound=(0, 1), loss_weights=[1, 0.1, 5]) + inversion_images = inversion_attack.generate(target_features, iters=100) + + # get the predict results of inversion images on a new trained model + net2 = LeNet5() + new_ckpt_path = '../../common/networks/lenet5/new_trained_ckpt_file/checkpoint_lenet-10_1875.ckpt' + new_load_dict = load_checkpoint(new_ckpt_path) + load_param_into_net(net2, new_load_dict) + pred_labels = np.argmax(net2(Tensor(inversion_images).astype(np.float32)).asnumpy(), axis=1) + + # evaluate the quality of inversion images + avg_l2_dis, avg_ssim, avg_confi = inversion_attack.evaluate(original_images, inversion_images, true_labels, net2) + LOGGER.info(TAG, 'The average L2 distance between original images and inverted images is: {}'.format(avg_l2_dis)) + LOGGER.info(TAG, 'The average ssim value between original images and inverted images is: {}'.format(avg_ssim)) + LOGGER.info(TAG, 'The average prediction confidence on true labels of inverted images is: {}'.format(avg_confi)) + LOGGER.info(TAG, 'True labels of original images are: %s' % true_labels) + LOGGER.info(TAG, 'Predicted labels of inverted images are: %s' % pred_labels) + + # plot 10 images + plot_num = min(sample_num, 10) + for n in range(1, plot_num+1): + plt.subplot(2, plot_num, n) + if n == 1: + plt.title('Original images', fontsize=16, loc='left') + plt.gray() + plt.imshow(images[n - 1].reshape(32, 32)) + plt.subplot(2, plot_num, n + plot_num) + if n == 1: + plt.title('Inverted images', fontsize=16, loc='left') + plt.gray() + plt.imshow(inversion_images[n - 1].reshape(32, 32)) + plt.show() if __name__ == '__main__': diff --git a/mindarmour/privacy/evaluation/inversion_attack.py b/mindarmour/privacy/evaluation/inversion_attack.py index 47a18a0..6650c06 100644 --- a/mindarmour/privacy/evaluation/inversion_attack.py +++ b/mindarmour/privacy/evaluation/inversion_attack.py @@ -15,6 +15,7 @@ Inversion Attack """ import numpy as np +from scipy.special import softmax from mindspore.nn import Cell, MSELoss from mindspore import Tensor @@ -166,18 +167,24 @@ class ImageInversionAttack: inversion_images.append(inversion_image_n) return np.concatenate(np.array(inversion_images)) - def evaluate(self, original_images, inversion_images): + def evaluate(self, original_images, inversion_images, labels=None, new_network=None): """ - Compute the average L2 distance and SSIM value between original images and inversion images. + Evaluate the quality of inverted images by three index: the average L2 distance and SSIM value between + original images and inversion images, and the average of inverted images' confidence on true labels of inverted + inferred by a new trained network. Args: original_images (numpy.ndarray): Original images, whose shape should be (img_num, channels, img_width, img_height). inversion_images (numpy.ndarray): Inversion images, whose shape should be (img_num, channels, img_width, img_height). + labels (numpy.ndarray): Ground truth labels of original images. Default: None. + new_network (Cell): A network whose structure contains all parts of self._network, but loaded with different + checkpoint file. Default: None. Returns: - tuple, the average l2 distance and average ssim value between original images and inversion images. + tuple, average l2 distance, average ssim value and average confidence (if labels or new_network is None, + then average confidence would be None). Examples: >>> net = LeNet5() @@ -188,15 +195,31 @@ class ImageInversionAttack: >>> ori_images = np.random.random((2, 1, 32, 32)) >>> result = inversion_attack.evaluate(ori_images, inver_images) >>> print(len(result)) - 2 + 3 """ check_numpy_param('original_images', original_images) check_numpy_param('inversion_images', inversion_images) + if labels is not None: + check_numpy_param('labels', labels) + true_labels = np.squeeze(labels) + if len(true_labels.shape) > 1: + msg = 'Shape of true_labels should be (1, n) or (n,), but got {}'.format(true_labels.shape) + raise ValueError(msg) + if true_labels.size != original_images.shape[0]: + msg = 'The size of true_labels should equal the number of images, but got {} and {}'.format( + true_labels.size, original_images.shape[0]) + raise ValueError(msg) + if new_network is not None: + check_param_type('new_network', new_network, Cell) + LOGGER.info(TAG, 'Please make sure that the network you pass is loaded with different checkpoint files ' + 'compared with that of self._network.') + img_1, img_2 = check_equal_shape('original_images', original_images, 'inversion_images', inversion_images) if (len(img_1.shape) != 4) or (img_1.shape[1] != 1 and img_1.shape[1] != 3): msg = 'The shape format of img_1 and img_2 should be (img_num, channels, img_width, img_height),' \ ' but got {} and {}'.format(img_1.shape, img_2.shape) raise ValueError(msg) + total_l2_distance = 0 total_ssim = 0 img_1 = img_1.transpose(0, 2, 3, 1) @@ -207,4 +230,9 @@ class ImageInversionAttack: total_ssim += compute_ssim(img_1[i], img_2[i]) avg_l2_dis = total_l2_distance / img_1.shape[0] avg_ssim = total_ssim / img_1.shape[0] - return avg_l2_dis, avg_ssim + avg_confi = None + if (new_network is not None) and (labels is not None): + pred_logits = new_network(Tensor(inversion_images.astype(np.float32))).asnumpy() + logits_softmax = softmax(pred_logits, axis=1) + avg_confi = np.mean(logits_softmax[np.arange(img_1.shape[0]), true_labels]) + return avg_l2_dis, avg_ssim, avg_confi diff --git a/tests/ut/python/privacy/evaluation/test_inversion_attack.py b/tests/ut/python/privacy/evaluation/test_inversion_attack.py index a6a08c3..23651f6 100644 --- a/tests/ut/python/privacy/evaluation/test_inversion_attack.py +++ b/tests/ut/python/privacy/evaluation/test_inversion_attack.py @@ -42,3 +42,20 @@ def test_inversion_attack(): avg_ssim = inversion_attack.evaluate(original_images, inversion_images) assert 0 < avg_ssim[1] < 1 assert target_features.shape[0] == inversion_images.shape[0] + + +@pytest.mark.level0 +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.env_onecard +@pytest.mark.component_mindarmour +def test_inversion_attack2(): + net = Net() + original_images = np.random.random((2, 1, 32, 32)).astype(np.float32) + target_features = np.random.random((2, 10)).astype(np.float32) + inversion_attack = ImageInversionAttack(net, input_shape=(1, 32, 32), input_bound=(0, 1), loss_weights=[1, 0.2, 5]) + inversion_images = inversion_attack.generate(target_features, iters=10) + true_labels = np.array([1, 2]) + new_net = Net() + indexes = inversion_attack.evaluate(original_images, inversion_images, true_labels, new_net) + assert len(indexes) == 3