From a05744f8c9471f20de161b00923fc6afbcc1af39 Mon Sep 17 00:00:00 2001 From: jin-xiulang Date: Sat, 30 Jan 2021 10:38:35 +0800 Subject: [PATCH] Add image inversion attack method --- examples/privacy/README.md | 14 +- examples/privacy/inversion_attack/__init__.py | 0 .../inversion_attack/mnist_inversion_attack.py | 108 +++++++++++ mindarmour/__init__.py | 10 +- .../evaluations/attack_evaluation.py | 88 +-------- mindarmour/privacy/evaluation/inversion_attack.py | 210 +++++++++++++++++++++ mindarmour/utils/_check_param.py | 5 + mindarmour/utils/util.py | 112 ++++++++++- .../privacy/evaluation/test_inversion_attack.py | 41 ++++ 9 files changed, 503 insertions(+), 85 deletions(-) create mode 100644 examples/privacy/inversion_attack/__init__.py create mode 100644 examples/privacy/inversion_attack/mnist_inversion_attack.py create mode 100644 mindarmour/privacy/evaluation/inversion_attack.py create mode 100644 tests/ut/python/privacy/evaluation/test_inversion_attack.py diff --git a/examples/privacy/README.md b/examples/privacy/README.md index 837d4b5..cee1352 100644 --- a/examples/privacy/README.md +++ b/examples/privacy/README.md @@ -42,7 +42,7 @@ python train.py --data_path home_path_to_cifar100 --ckpt_path ./ python example_vgg_cifar.py --data_path home_path_to_cifar100 --pre_trained 0-100_781.ckpt ``` -## 4. suppress privacy training +## 4. Suppress privacy training With suppress privacy mechanism, the values of some trainable parameters (such as conv layers and fully connected layers) are set to zero as the training step grows, which can @@ -52,3 +52,15 @@ With suppress privacy mechanism, the values of some trainable parameters (such cd examples/privacy/sup_privacy python sup_privacy.py ``` + +## 5. Image inversion attack + +Inversion attack means reconstructing an image based on its deep representations. For example, +reconstruct a MNIST image based on its output through LeNet5. The mechanism behind it is that well-trained +model can "remember" those training dataset. Therefore, inversion attack can be used to estimate the privacy +leakage of training tasks. + +```sh +cd examples/privacy/inversion_attack +python mnist_inversion_attack.py +``` \ No newline at end of file diff --git a/examples/privacy/inversion_attack/__init__.py b/examples/privacy/inversion_attack/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/examples/privacy/inversion_attack/mnist_inversion_attack.py b/examples/privacy/inversion_attack/mnist_inversion_attack.py new file mode 100644 index 0000000..9921091 --- /dev/null +++ b/examples/privacy/inversion_attack/mnist_inversion_attack.py @@ -0,0 +1,108 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Examples of image inversion attack +""" +import numpy as np +import matplotlib.pyplot as plt + +from mindspore.train.serialization import load_checkpoint, load_param_into_net +from mindspore import Tensor, context +from mindspore import nn +from mindarmour.privacy.evaluation.inversion_attack import ImageInversionAttack +from mindarmour.utils.logger import LogUtil + +from examples.common.networks.lenet5.lenet5_net import LeNet5, conv, fc_with_initialize +from examples.common.dataset.data_processing import generate_mnist_dataset + +LOGGER = LogUtil.get_instance() +LOGGER.set_level('INFO') +TAG = 'InversionAttack' + + +# pylint: disable=invalid-name +class LeNet5_part(nn.Cell): + """ + Part of LeNet5 network. + """ + def __init__(self): + super(LeNet5_part, self).__init__() + self.conv1 = conv(1, 6, 5) + self.conv2 = conv(6, 16, 5) + self.fc1 = fc_with_initialize(16*5*5, 120) + self.fc2 = fc_with_initialize(120, 84) + self.fc3 = fc_with_initialize(84, 10) + self.relu = nn.ReLU() + self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) + self.flatten = nn.Flatten() + + def construct(self, x): + x = self.conv1(x) + x = self.relu(x) + x = self.max_pool2d(x) + x = self.conv2(x) + x = self.relu(x) + x = self.max_pool2d(x) + return x + + +def mnist_inversion_attack(net): + """ + Image inversion attack based on LeNet5 and MNIST dataset. + """ + # upload trained network + ckpt_path = '../../common/networks/lenet5/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt' + load_dict = load_checkpoint(ckpt_path) + load_param_into_net(net, load_dict) + + # get test data + data_list = "../../common/dataset/MNIST/test" + batch_size = 32 + ds = generate_mnist_dataset(data_list, batch_size) + + inversion_attack = ImageInversionAttack(net, input_shape=(1, 32, 32), input_bound=(0, 1), loss_weights=[1, 0.2, 5]) + + i = 0 + batch_num = 1 + sample_num = 10 + for data in ds.create_tuple_iterator(output_numpy=True): + i += 1 + images = data[0].astype(np.float32) + target_features = net(Tensor(images)).asnumpy() + original_images = images[: sample_num] + inversion_images = inversion_attack.generate(target_features[:sample_num], iters=100) + for n in range(1, sample_num+1): + plt.subplot(2, sample_num, n) + plt.gray() + plt.imshow(images[n - 1].reshape(32, 32)) + plt.subplot(2, sample_num, n + sample_num) + plt.gray() + plt.imshow(inversion_images[n - 1].reshape(32, 32)) + plt.show() + if i >= batch_num: + break + # evaluate the similarity between inversion images and original images + avg_l2_dis, avg_ssim = inversion_attack.evaluate(original_images, inversion_images) + LOGGER.info(TAG, 'The average L2 distance between original images and inversion images is: {}'.format(avg_l2_dis)) + LOGGER.info(TAG, 'The average ssim value between original images and inversion images is: {}'.format(avg_ssim)) + + +if __name__ == '__main__': + # device_target can be "CPU", "GPU" or "Ascend" + context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + # attack based on complete LeNet5 + mnist_inversion_attack(LeNet5()) + # attack based on part of LeNet5. The network is more shallower and can lead to a better attack result + mnist_inversion_attack(LeNet5_part()) diff --git a/mindarmour/__init__.py b/mindarmour/__init__.py index 0526847..9611107 100644 --- a/mindarmour/__init__.py +++ b/mindarmour/__init__.py @@ -9,6 +9,10 @@ from .adv_robustness.detectors.detector import Detector from .fuzz_testing.fuzzing import Fuzzer from .privacy.diff_privacy import DPModel from .privacy.evaluation.membership_inference import MembershipInference +from .privacy.sup_privacy.sup_ctrl.conctrl import SuppressCtrl +from .privacy.sup_privacy.train.model import SuppressModel +from .privacy.sup_privacy.mask_monitor.masker import SuppressMasker +from .privacy.evaluation.inversion_attack import ImageInversionAttack __all__ = ['Attack', 'BlackModel', @@ -16,4 +20,8 @@ __all__ = ['Attack', 'Defense', 'Fuzzer', 'DPModel', - 'MembershipInference'] + 'MembershipInference', + 'SuppressModel', + 'SuppressCtrl', + 'SuppressMasker', + 'ImageInversionAttack'] diff --git a/mindarmour/adv_robustness/evaluations/attack_evaluation.py b/mindarmour/adv_robustness/evaluations/attack_evaluation.py index b828dde..c327510 100644 --- a/mindarmour/adv_robustness/evaluations/attack_evaluation.py +++ b/mindarmour/adv_robustness/evaluations/attack_evaluation.py @@ -17,84 +17,15 @@ Attack evaluation. import numpy as np -from scipy.ndimage.filters import convolve - from mindarmour.utils.logger import LogUtil from mindarmour.utils._check_param import check_pair_numpy_param, \ - check_param_type, check_numpy_param, check_equal_shape + check_param_type, check_numpy_param +from mindarmour.utils.util import calculate_lp_distance, compute_ssim LOGGER = LogUtil.get_instance() TAG = 'AttackEvaluate' -def _compute_ssim(img_1, img_2, kernel_sigma=1.5, kernel_width=11): - """ - compute structural similarity. - Args: - img_1 (numpy.ndarray): The first image to be compared. - img_2 (numpy.ndarray): The second image to be compared. - kernel_sigma (float): Gassian kernel param. Default: 1.5. - kernel_width (int): Another Gassian kernel param. Default: 11. - - Returns: - float, structural similarity. - """ - img_1, img_2 = check_equal_shape('images_1', img_1, 'images_2', img_2) - - if len(img_1.shape) > 2: - total_ssim = 0 - for i in range(img_1.shape[2]): - total_ssim += _compute_ssim(img_1[:, :, i], img_2[:, :, i]) - return total_ssim / 3 - - # Create gaussian kernel - gaussian_kernel = np.zeros((kernel_width, kernel_width)) - for i in range(kernel_width): - for j in range(kernel_width): - gaussian_kernel[i, j] = (1 / (2*np.pi*(kernel_sigma**2)))*np.exp( - - (((i - 5)**2) + ((j - 5)**2)) / (2*(kernel_sigma**2))) - - img_1 = img_1.astype(np.float32) - img_2 = img_2.astype(np.float32) - - img_sq_1 = img_1**2 - img_sq_2 = img_2**2 - img_12 = img_1*img_2 - - # Mean - img_mu_1 = convolve(img_1, gaussian_kernel) - img_mu_2 = convolve(img_2, gaussian_kernel) - - # Mean square - img_mu_sq_1 = img_mu_1**2 - img_mu_sq_2 = img_mu_2**2 - img_mu_12 = img_mu_1*img_mu_2 - - # Variances - img_sigma_sq_1 = convolve(img_sq_1, gaussian_kernel) - img_sigma_sq_2 = convolve(img_sq_2, gaussian_kernel) - - # Covariance - img_sigma_12 = convolve(img_12, gaussian_kernel) - - # Centered squares of variances - img_sigma_sq_1 = img_sigma_sq_1 - img_mu_sq_1 - img_sigma_sq_2 = img_sigma_sq_2 - img_mu_sq_2 - img_sigma_12 = img_sigma_12 - img_mu_12 - - k_1 = 0.01 - k_2 = 0.03 - c_1 = (k_1*255)**2 - c_2 = (k_2*255)**2 - - # Calculate ssim - num_ssim = (2*img_mu_12 + c_1)*(2*img_sigma_12 + c_2) - den_ssim = (img_mu_sq_1 + img_mu_sq_2 + c_1)*(img_sigma_sq_1 - + img_sigma_sq_2 + c_2) - res = np.average(num_ssim / den_ssim) - return res - - class AttackEvaluate: """ Evaluation metrics of attack methods. @@ -217,16 +148,11 @@ class AttackEvaluate: l0_dist = 0 l2_dist = 0 linf_dist = 0 - avoid_zero_div = 1e-14 for i in idxes: - diff = (self._adv_inputs[i] - self._inputs[i]).flatten() - data = self._inputs[i].flatten() - l0_dist += np.linalg.norm(diff, ord=0) \ - / (np.linalg.norm(data, ord=0) + avoid_zero_div) - l2_dist += np.linalg.norm(diff, ord=2) \ - / (np.linalg.norm(data, ord=2) + avoid_zero_div) - linf_dist += np.linalg.norm(diff, ord=np.inf) \ - / (np.linalg.norm(data, ord=np.inf) + avoid_zero_div) + l0_dist_i, l2_dist_i, linf_dist_i = calculate_lp_distance(self._inputs[i], self._adv_inputs[i]) + l0_dist += l0_dist_i + l2_dist += l2_dist_i + linf_dist += linf_dist_i return l0_dist / success_num, l2_dist / success_num, \ linf_dist / success_num @@ -249,7 +175,7 @@ class AttackEvaluate: total_ssim = 0.0 for _, i in enumerate(self._success_idxes): - total_ssim += _compute_ssim(self._adv_inputs[i], self._inputs[i]) + total_ssim += compute_ssim(self._adv_inputs[i], self._inputs[i]) return total_ssim / success_num diff --git a/mindarmour/privacy/evaluation/inversion_attack.py b/mindarmour/privacy/evaluation/inversion_attack.py new file mode 100644 index 0000000..47a18a0 --- /dev/null +++ b/mindarmour/privacy/evaluation/inversion_attack.py @@ -0,0 +1,210 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Inversion Attack +""" +import numpy as np + +from mindspore.nn import Cell, MSELoss +from mindspore import Tensor +from mindspore.ops import operations as P + +from mindarmour.utils.util import GradWrapWithLoss +from mindarmour.utils._check_param import check_param_type, check_param_multi_types, \ + check_int_positive, check_numpy_param, check_value_positive, check_equal_shape +from mindarmour.utils.logger import LogUtil +from mindarmour.utils.util import calculate_lp_distance, compute_ssim + +LOGGER = LogUtil.get_instance() +LOGGER.set_level('INFO') +TAG = 'Image inversion attack' + + +class InversionLoss(Cell): + """ + The loss function for inversion attack. + + Args: + network (Cell): The network used to infer images' deep representations. + weights (Union[list, tuple]): Weights of three sub-loss in InversionLoss, which can be adjusted to + obtain better results. + """ + def __init__(self, network, weights): + super(InversionLoss, self).__init__() + self._network = check_param_type('network', network, Cell) + self._mse_loss = MSELoss() + self._weights = check_param_multi_types('weights', weights, [list, tuple]) + self._get_shape = P.Shape() + + def construct(self, input_data, target_features): + """ + Calculate the inversion attack loss, which consists of three parts. Loss_1 is for evaluating the difference + between the target deep representations and current representations; Loss_2 is for evaluating the continuity + between adjacent pixels; Loss_3 is for regularization. + + Args: + input_data (Tensor): The reconstructed image during inversion attack. + target_features (Tensor): Deep representations of the original image. + + Returns: + Tensor, inversion attack loss of the current iteration. + """ + output = self._network(input_data) + loss_1 = self._mse_loss(output, target_features) / self._mse_loss(target_features, 0) + + data_shape = self._get_shape(input_data) + split_op_1 = P.Split(2, data_shape[2]) + split_op_2 = P.Split(3, data_shape[3]) + data_split_1 = split_op_1(input_data) + data_split_2 = split_op_2(input_data) + loss_2 = 0 + for i in range(1, data_shape[2]): + loss_2 += self._mse_loss(data_split_1[i], data_split_1[i-1]) + for j in range(1, data_shape[3]): + loss_2 += self._mse_loss(data_split_2[j], data_split_2[j-1]) + + loss_3 = self._mse_loss(input_data, 0) + + loss = loss_1*self._weights[0] + loss_2*self._weights[1] + loss_3*self._weights[2] + return loss + + +class ImageInversionAttack: + """ + An attack method used to reconstruct images by inverting their deep representations. + + References: `Aravindh Mahendran, Andrea Vedaldi. Understanding Deep Image Representations by Inverting Them. + 2014. `_ + + Args: + network (Cell): The network used to infer images' deep representations. + input_shape (tuple): Data shape of single network input, which should be in accordance with the given + network. The format of shape should be (channel, image_width, image_height). + input_bound (Union[tuple, list]): The pixel range of original images, which should be like [minimum_pixel, + maximum_pixel] or (minimum_pixel, maximum_pixel). + loss_weights (Union[list, tuple]): Weights of three sub-loss in InversionLoss, which can be adjusted to + obtain better results. Default: (1, 0.2, 5). + + Raises: + TypeError: If the type of network is not Cell. + ValueError: If any value of input_shape is not positive int. + ValueError: If any value of loss_weights is not positive value. + """ + def __init__(self, network, input_shape, input_bound, loss_weights=(1, 0.2, 5)): + self._network = check_param_type('network', network, Cell) + for sub_loss_weight in loss_weights: + check_value_positive('sub_loss_weight', sub_loss_weight) + self._loss = InversionLoss(self._network, loss_weights) + self._input_shape = check_param_multi_types('input_shape', input_shape, [list, tuple]) + for shape_dim in input_shape: + check_int_positive('shape_dim', shape_dim) + self._input_bound = check_param_multi_types('input_bound', input_bound, [list, tuple]) + + def generate(self, target_features, iters=100): + """ + Reconstruct images based on target_features. + + Args: + target_features (numpy.ndarray): Deep representations of original images. The first dimension of + target_features should be img_num. It should be noted that the shape of target_features should be + (1, dim2, dim3, ...) if img_num equals 1. + iters (int): iteration times of inversion attack, which should be positive integers. Default: 100. + + Returns: + numpy.ndarray, reconstructed images, which are expected to be similar to original images. + + Raises: + TypeError: If the type of target_features is not numpy.ndarray. + ValueError: If any value of iters is not positive int.Z + + Examples: + >>> net = LeNet5() + >>> inversion_attack = ImageInversionAttack(net, input_shape=(1, 32, 32), input_bound=(0, 1), + >>> loss_weights=[1, 0.2, 5]) + >>> features = np.random.random((2, 10)).astype(np.float32) + >>> images = inversion_attack.generate(features, iters=10) + >>> print(images.shape) + (2, 1, 32, 32) + """ + target_features = check_numpy_param('target_features', target_features) + iters = check_int_positive('iters', iters) + + # shape checking + img_num = target_features.shape[0] + test_input = np.random.random((img_num,) + self._input_shape).astype(np.float32) + test_out = self._network(Tensor(test_input)).asnumpy() + if test_out.shape != target_features.shape: + msg = "The shape of target_features ({}) is not in accordance with the shape" \ + " of network output({})".format(target_features.shape, test_out.shape) + raise ValueError(msg) + + loss_net = self._loss + loss_grad = GradWrapWithLoss(loss_net) + + inversion_images = [] + for i in range(img_num): + target_feature_n = target_features[i] + inversion_image_n = np.random.random((1,) + self._input_shape).astype(np.float32)*0.05 + for s in range(iters): + x_grad = loss_grad(Tensor(inversion_image_n), Tensor(target_feature_n)).asnumpy() + x_grad_sign = np.sign(x_grad) + inversion_image_n -= x_grad_sign*0.01 + inversion_image_n = np.clip(inversion_image_n, self._input_bound[0], self._input_bound[1]) + current_loss = self._loss(Tensor(inversion_image_n), Tensor(target_feature_n)) + LOGGER.info(TAG, 'iteration step: {}, loss is {}'.format(s, current_loss)) + inversion_images.append(inversion_image_n) + return np.concatenate(np.array(inversion_images)) + + def evaluate(self, original_images, inversion_images): + """ + Compute the average L2 distance and SSIM value between original images and inversion images. + + Args: + original_images (numpy.ndarray): Original images, whose shape should be (img_num, channels, img_width, + img_height). + inversion_images (numpy.ndarray): Inversion images, whose shape should be (img_num, channels, img_width, + img_height). + + Returns: + tuple, the average l2 distance and average ssim value between original images and inversion images. + + Examples: + >>> net = LeNet5() + >>> inversion_attack = ImageInversionAttack(net, input_shape=(1, 32, 32), input_bound=(0, 1), + >>> loss_weights=[1, 0.2, 5]) + >>> features = np.random.random((2, 10)).astype(np.float32) + >>> inver_images = inversion_attack.generate(features, iters=10) + >>> ori_images = np.random.random((2, 1, 32, 32)) + >>> result = inversion_attack.evaluate(ori_images, inver_images) + >>> print(len(result)) + 2 + """ + check_numpy_param('original_images', original_images) + check_numpy_param('inversion_images', inversion_images) + img_1, img_2 = check_equal_shape('original_images', original_images, 'inversion_images', inversion_images) + if (len(img_1.shape) != 4) or (img_1.shape[1] != 1 and img_1.shape[1] != 3): + msg = 'The shape format of img_1 and img_2 should be (img_num, channels, img_width, img_height),' \ + ' but got {} and {}'.format(img_1.shape, img_2.shape) + raise ValueError(msg) + total_l2_distance = 0 + total_ssim = 0 + img_1 = img_1.transpose(0, 2, 3, 1) + img_2 = img_2.transpose(0, 2, 3, 1) + for i in range(img_1.shape[0]): + _, l2_dis, _ = calculate_lp_distance(img_1[i], img_2[i]) + total_l2_distance += l2_dis + total_ssim += compute_ssim(img_1[i], img_2[i]) + avg_l2_dis = total_l2_distance / img_1.shape[0] + avg_ssim = total_ssim / img_1.shape[0] + return avg_l2_dis, avg_ssim diff --git a/mindarmour/utils/_check_param.py b/mindarmour/utils/_check_param.py index cf5f954..d4147c1 100644 --- a/mindarmour/utils/_check_param.py +++ b/mindarmour/utils/_check_param.py @@ -61,6 +61,11 @@ def check_param_multi_types(arg_name, arg_value, valid_types): def check_int_positive(arg_name, arg_value): """Check positive integer.""" + # 'True' is treated as int(1) in python, which is a bug. + if isinstance(arg_value, bool): + msg = '{} should not be bool value, but got {}'.format(arg_name, arg_value) + LOGGER.error(TAG, msg) + raise ValueError(msg) arg_value = check_param_type(arg_name, arg_value, int) if arg_value <= 0: msg = '{} must be greater than 0, but got {}'.format(arg_name, diff --git a/mindarmour/utils/util.py b/mindarmour/utils/util.py index 118e620..859af10 100644 --- a/mindarmour/utils/util.py +++ b/mindarmour/utils/util.py @@ -13,11 +13,13 @@ # limitations under the License. """ Util for MindArmour. """ import numpy as np +from scipy.ndimage.filters import convolve + from mindspore import Tensor from mindspore.nn import Cell from mindspore.ops.composite import GradOperation -from mindarmour.utils._check_param import check_numpy_param, check_param_multi_types +from mindarmour.utils._check_param import check_numpy_param, check_param_multi_types, check_equal_shape from .logger import LogUtil @@ -61,7 +63,7 @@ def jacobian_matrix_for_detection(grad_wrap_net, inputs, num_boxes, num_classes) Args: grad_wrap_net (Cell): A network wrapped by GradWrap. inputs (numpy.ndarray): Input samples. - num_boxes (int): Number of boxes infered by each image. + num_boxes (int): Number of boxes inferred by each image. num_classes (int): Number of labels of model output. Returns: @@ -251,3 +253,109 @@ def to_tensor_tuple(inputs_ori): else: inputs_tensor = (Tensor(inputs_ori),) return inputs_tensor + + +def calculate_lp_distance(original_image, compared_image): + """ + Calculate l0, l2 and linf distance for two images with the same shape. + + Args: + original_image (np.ndarray): Original image. + compared_image (np.ndarray): Another image for comparison. + + Returns: + tuple, (l0, l2 and linf) distances between two images. + + Raises: + TypeError: If type of original_image or type of compared_image is not numpy.ndarray. + ValueError: If the shape of original_image and compared_image are not the same. + """ + check_numpy_param('original_image', original_image) + check_numpy_param('compared_image', compared_image) + check_equal_shape('original_image', original_image, 'compared_image', compared_image) + avoid_zero_div = 1e-14 + diff = (original_image - compared_image).flatten() + data = original_image.flatten() + l0_dist = np.linalg.norm(diff, ord=0) \ + / (np.linalg.norm(data, ord=0) + avoid_zero_div) + l2_dist = np.linalg.norm(diff, ord=2) \ + / (np.linalg.norm(data, ord=2) + avoid_zero_div) + linf_dist = np.linalg.norm(diff, ord=np.inf) \ + / (np.linalg.norm(data, ord=np.inf) + avoid_zero_div) + return l0_dist, l2_dist, linf_dist + + +def compute_ssim(img_1, img_2, kernel_sigma=1.5, kernel_width=11): + """ + compute structural similarity between two images. + + Args: + img_1 (numpy.ndarray): The first image to be compared. The shape of img_1 should be (img_width, img_height, + channels). + img_2 (numpy.ndarray): The second image to be compared. The shape of img_2 should be (img_width, img_height, + channels). + kernel_sigma (float): Gassian kernel param. Default: 1.5. + kernel_width (int): Another Gassian kernel param. Default: 11. + + Returns: + float, structural similarity. + """ + img_1, img_2 = check_equal_shape('images_1', img_1, 'images_2', img_2) + if len(img_1.shape) > 2: + if (len(img_1.shape) != 3) or (img_1.shape[2] != 1 and img_1.shape[2] != 3): + msg = 'The shape format of img_1 and img_2 should be (img_width, img_height, channels),' \ + ' but got {} and {}'.format(img_1.shape, img_2.shape) + raise ValueError(msg) + + if len(img_1.shape) > 2: + total_ssim = 0 + for i in range(img_1.shape[2]): + total_ssim += compute_ssim(img_1[:, :, i], img_2[:, :, i]) + return total_ssim / 3 + + # Create gaussian kernel + gaussian_kernel = np.zeros((kernel_width, kernel_width)) + for i in range(kernel_width): + for j in range(kernel_width): + gaussian_kernel[i, j] = (1 / (2*np.pi*(kernel_sigma**2)))*np.exp( + - (((i - 5)**2) + ((j - 5)**2)) / (2*(kernel_sigma**2))) + + img_1 = img_1.astype(np.float32) + img_2 = img_2.astype(np.float32) + + img_sq_1 = img_1**2 + img_sq_2 = img_2**2 + img_12 = img_1*img_2 + + # Mean + img_mu_1 = convolve(img_1, gaussian_kernel) + img_mu_2 = convolve(img_2, gaussian_kernel) + + # Mean square + img_mu_sq_1 = img_mu_1**2 + img_mu_sq_2 = img_mu_2**2 + img_mu_12 = img_mu_1*img_mu_2 + + # Variances + img_sigma_sq_1 = convolve(img_sq_1, gaussian_kernel) + img_sigma_sq_2 = convolve(img_sq_2, gaussian_kernel) + + # Covariance + img_sigma_12 = convolve(img_12, gaussian_kernel) + + # Centered squares of variances + img_sigma_sq_1 = img_sigma_sq_1 - img_mu_sq_1 + img_sigma_sq_2 = img_sigma_sq_2 - img_mu_sq_2 + img_sigma_12 = img_sigma_12 - img_mu_12 + + k_1 = 0.01 + k_2 = 0.03 + c_1 = (k_1*255)**2 + c_2 = (k_2*255)**2 + + # Calculate ssim + num_ssim = (2*img_mu_12 + c_1)*(2*img_sigma_12 + c_2) + den_ssim = (img_mu_sq_1 + img_mu_sq_2 + c_1)*(img_sigma_sq_1 + + img_sigma_sq_2 + c_2) + res = np.average(num_ssim / den_ssim) + return res diff --git a/tests/ut/python/privacy/evaluation/test_inversion_attack.py b/tests/ut/python/privacy/evaluation/test_inversion_attack.py new file mode 100644 index 0000000..143680e --- /dev/null +++ b/tests/ut/python/privacy/evaluation/test_inversion_attack.py @@ -0,0 +1,41 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Inversion attack test +""" +import pytest + +import numpy as np + +import mindspore.context as context + +from mindarmour.privacy.evaluation.inversion_attack import ImageInversionAttack + +from ut.python.utils.mock_net import Net + + +context.set_context(mode=context.GRAPH_MODE) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.env_onecard +@pytest.mark.component_mindarmour +def test_inversion_attack(): + net = Net() + target_features = np.random.random((2, 10)).astype(np.float32) + inversion_attack = ImageInversionAttack(net, input_shape=(1, 32, 32), input_bound=(0, 1), loss_weights=[1, 0.2, 5]) + inversion_images = inversion_attack.generate(target_features, iters=10) + assert target_features.shape[0] == inversion_images.shape[0]