@@ -67,36 +67,55 @@ def mnist_inversion_attack(net): | |||
load_dict = load_checkpoint(ckpt_path) | |||
load_param_into_net(net, load_dict) | |||
# get test data | |||
data_list = "../../common/dataset/MNIST/test" | |||
# get original data and their inferred fearures | |||
data_list = "../../common/dataset/MNIST/train" | |||
batch_size = 32 | |||
ds = generate_mnist_dataset(data_list, batch_size) | |||
inversion_attack = ImageInversionAttack(net, input_shape=(1, 32, 32), input_bound=(0, 1), loss_weights=[1, 0.2, 5]) | |||
i = 0 | |||
batch_num = 1 | |||
sample_num = 10 | |||
sample_num = 30 | |||
for data in ds.create_tuple_iterator(output_numpy=True): | |||
i += 1 | |||
images = data[0].astype(np.float32) | |||
target_features = net(Tensor(images)).asnumpy() | |||
true_labels = data[1][: sample_num] | |||
target_features = net(Tensor(images)).asnumpy()[:sample_num] | |||
original_images = images[: sample_num] | |||
inversion_images = inversion_attack.generate(target_features[:sample_num], iters=100) | |||
for n in range(1, sample_num+1): | |||
plt.subplot(2, sample_num, n) | |||
plt.gray() | |||
plt.imshow(images[n - 1].reshape(32, 32)) | |||
plt.subplot(2, sample_num, n + sample_num) | |||
plt.gray() | |||
plt.imshow(inversion_images[n - 1].reshape(32, 32)) | |||
plt.show() | |||
if i >= batch_num: | |||
break | |||
# evaluate the similarity between inversion images and original images | |||
avg_l2_dis, avg_ssim = inversion_attack.evaluate(original_images, inversion_images) | |||
LOGGER.info(TAG, 'The average L2 distance between original images and inversion images is: {}'.format(avg_l2_dis)) | |||
LOGGER.info(TAG, 'The average ssim value between original images and inversion images is: {}'.format(avg_ssim)) | |||
# run attacking | |||
inversion_attack = ImageInversionAttack(net, input_shape=(1, 32, 32), input_bound=(0, 1), loss_weights=[1, 0.1, 5]) | |||
inversion_images = inversion_attack.generate(target_features, iters=100) | |||
# get the predict results of inversion images on a new trained model | |||
net2 = LeNet5() | |||
new_ckpt_path = '../../common/networks/lenet5/new_trained_ckpt_file/checkpoint_lenet-10_1875.ckpt' | |||
new_load_dict = load_checkpoint(new_ckpt_path) | |||
load_param_into_net(net2, new_load_dict) | |||
pred_labels = np.argmax(net2(Tensor(inversion_images).astype(np.float32)).asnumpy(), axis=1) | |||
# evaluate the quality of inversion images | |||
avg_l2_dis, avg_ssim, avg_confi = inversion_attack.evaluate(original_images, inversion_images, true_labels, net2) | |||
LOGGER.info(TAG, 'The average L2 distance between original images and inverted images is: {}'.format(avg_l2_dis)) | |||
LOGGER.info(TAG, 'The average ssim value between original images and inverted images is: {}'.format(avg_ssim)) | |||
LOGGER.info(TAG, 'The average prediction confidence on true labels of inverted images is: {}'.format(avg_confi)) | |||
LOGGER.info(TAG, 'True labels of original images are: %s' % true_labels) | |||
LOGGER.info(TAG, 'Predicted labels of inverted images are: %s' % pred_labels) | |||
# plot 10 images | |||
plot_num = min(sample_num, 10) | |||
for n in range(1, plot_num+1): | |||
plt.subplot(2, plot_num, n) | |||
if n == 1: | |||
plt.title('Original images', fontsize=16, loc='left') | |||
plt.gray() | |||
plt.imshow(images[n - 1].reshape(32, 32)) | |||
plt.subplot(2, plot_num, n + plot_num) | |||
if n == 1: | |||
plt.title('Inverted images', fontsize=16, loc='left') | |||
plt.gray() | |||
plt.imshow(inversion_images[n - 1].reshape(32, 32)) | |||
plt.show() | |||
if __name__ == '__main__': | |||
@@ -15,6 +15,7 @@ | |||
Inversion Attack | |||
""" | |||
import numpy as np | |||
from scipy.special import softmax | |||
from mindspore.nn import Cell, MSELoss | |||
from mindspore import Tensor | |||
@@ -166,18 +167,24 @@ class ImageInversionAttack: | |||
inversion_images.append(inversion_image_n) | |||
return np.concatenate(np.array(inversion_images)) | |||
def evaluate(self, original_images, inversion_images): | |||
def evaluate(self, original_images, inversion_images, labels=None, new_network=None): | |||
""" | |||
Compute the average L2 distance and SSIM value between original images and inversion images. | |||
Evaluate the quality of inverted images by three index: the average L2 distance and SSIM value between | |||
original images and inversion images, and the average of inverted images' confidence on true labels of inverted | |||
inferred by a new trained network. | |||
Args: | |||
original_images (numpy.ndarray): Original images, whose shape should be (img_num, channels, img_width, | |||
img_height). | |||
inversion_images (numpy.ndarray): Inversion images, whose shape should be (img_num, channels, img_width, | |||
img_height). | |||
labels (numpy.ndarray): Ground truth labels of original images. Default: None. | |||
new_network (Cell): A network whose structure contains all parts of self._network, but loaded with different | |||
checkpoint file. Default: None. | |||
Returns: | |||
tuple, the average l2 distance and average ssim value between original images and inversion images. | |||
tuple, average l2 distance, average ssim value and average confidence (if labels or new_network is None, | |||
then average confidence would be None). | |||
Examples: | |||
>>> net = LeNet5() | |||
@@ -188,15 +195,31 @@ class ImageInversionAttack: | |||
>>> ori_images = np.random.random((2, 1, 32, 32)) | |||
>>> result = inversion_attack.evaluate(ori_images, inver_images) | |||
>>> print(len(result)) | |||
2 | |||
3 | |||
""" | |||
check_numpy_param('original_images', original_images) | |||
check_numpy_param('inversion_images', inversion_images) | |||
if labels is not None: | |||
check_numpy_param('labels', labels) | |||
true_labels = np.squeeze(labels) | |||
if len(true_labels.shape) > 1: | |||
msg = 'Shape of true_labels should be (1, n) or (n,), but got {}'.format(true_labels.shape) | |||
raise ValueError(msg) | |||
if true_labels.size != original_images.shape[0]: | |||
msg = 'The size of true_labels should equal the number of images, but got {} and {}'.format( | |||
true_labels.size, original_images.shape[0]) | |||
raise ValueError(msg) | |||
if new_network is not None: | |||
check_param_type('new_network', new_network, Cell) | |||
LOGGER.info(TAG, 'Please make sure that the network you pass is loaded with different checkpoint files ' | |||
'compared with that of self._network.') | |||
img_1, img_2 = check_equal_shape('original_images', original_images, 'inversion_images', inversion_images) | |||
if (len(img_1.shape) != 4) or (img_1.shape[1] != 1 and img_1.shape[1] != 3): | |||
msg = 'The shape format of img_1 and img_2 should be (img_num, channels, img_width, img_height),' \ | |||
' but got {} and {}'.format(img_1.shape, img_2.shape) | |||
raise ValueError(msg) | |||
total_l2_distance = 0 | |||
total_ssim = 0 | |||
img_1 = img_1.transpose(0, 2, 3, 1) | |||
@@ -207,4 +230,9 @@ class ImageInversionAttack: | |||
total_ssim += compute_ssim(img_1[i], img_2[i]) | |||
avg_l2_dis = total_l2_distance / img_1.shape[0] | |||
avg_ssim = total_ssim / img_1.shape[0] | |||
return avg_l2_dis, avg_ssim | |||
avg_confi = None | |||
if (new_network is not None) and (labels is not None): | |||
pred_logits = new_network(Tensor(inversion_images.astype(np.float32))).asnumpy() | |||
logits_softmax = softmax(pred_logits, axis=1) | |||
avg_confi = np.mean(logits_softmax[np.arange(img_1.shape[0]), true_labels]) | |||
return avg_l2_dis, avg_ssim, avg_confi |
@@ -42,3 +42,20 @@ def test_inversion_attack(): | |||
avg_ssim = inversion_attack.evaluate(original_images, inversion_images) | |||
assert 0 < avg_ssim[1] < 1 | |||
assert target_features.shape[0] == inversion_images.shape[0] | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.env_onecard | |||
@pytest.mark.component_mindarmour | |||
def test_inversion_attack2(): | |||
net = Net() | |||
original_images = np.random.random((2, 1, 32, 32)).astype(np.float32) | |||
target_features = np.random.random((2, 10)).astype(np.float32) | |||
inversion_attack = ImageInversionAttack(net, input_shape=(1, 32, 32), input_bound=(0, 1), loss_weights=[1, 0.2, 5]) | |||
inversion_images = inversion_attack.generate(target_features, iters=10) | |||
true_labels = np.array([1, 2]) | |||
new_net = Net() | |||
indexes = inversion_attack.evaluate(original_images, inversion_images, true_labels, new_net) | |||
assert len(indexes) == 3 |