@@ -42,7 +42,7 @@ python train.py --data_path home_path_to_cifar100 --ckpt_path ./ | |||
python example_vgg_cifar.py --data_path home_path_to_cifar100 --pre_trained 0-100_781.ckpt | |||
``` | |||
## 4. suppress privacy training | |||
## 4. Suppress privacy training | |||
With suppress privacy mechanism, the values of some trainable parameters (such as conv layers and fully connected | |||
layers) are set to zero as the training step grows, which can | |||
@@ -52,3 +52,15 @@ With suppress privacy mechanism, the values of some trainable parameters (such | |||
cd examples/privacy/sup_privacy | |||
python sup_privacy.py | |||
``` | |||
## 5. Image inversion attack | |||
Inversion attack means reconstructing an image based on its deep representations. For example, | |||
reconstruct a MNIST image based on its output through LeNet5. The mechanism behind it is that well-trained | |||
model can "remember" those training dataset. Therefore, inversion attack can be used to estimate the privacy | |||
leakage of training tasks. | |||
```sh | |||
cd examples/privacy/inversion_attack | |||
python mnist_inversion_attack.py | |||
``` |
@@ -0,0 +1,108 @@ | |||
# Copyright 2021 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
""" | |||
Examples of image inversion attack | |||
""" | |||
import numpy as np | |||
import matplotlib.pyplot as plt | |||
from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
from mindspore import Tensor, context | |||
from mindspore import nn | |||
from mindarmour.privacy.evaluation.inversion_attack import ImageInversionAttack | |||
from mindarmour.utils.logger import LogUtil | |||
from examples.common.networks.lenet5.lenet5_net import LeNet5, conv, fc_with_initialize | |||
from examples.common.dataset.data_processing import generate_mnist_dataset | |||
LOGGER = LogUtil.get_instance() | |||
LOGGER.set_level('INFO') | |||
TAG = 'InversionAttack' | |||
# pylint: disable=invalid-name | |||
class LeNet5_part(nn.Cell): | |||
""" | |||
Part of LeNet5 network. | |||
""" | |||
def __init__(self): | |||
super(LeNet5_part, self).__init__() | |||
self.conv1 = conv(1, 6, 5) | |||
self.conv2 = conv(6, 16, 5) | |||
self.fc1 = fc_with_initialize(16*5*5, 120) | |||
self.fc2 = fc_with_initialize(120, 84) | |||
self.fc3 = fc_with_initialize(84, 10) | |||
self.relu = nn.ReLU() | |||
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) | |||
self.flatten = nn.Flatten() | |||
def construct(self, x): | |||
x = self.conv1(x) | |||
x = self.relu(x) | |||
x = self.max_pool2d(x) | |||
x = self.conv2(x) | |||
x = self.relu(x) | |||
x = self.max_pool2d(x) | |||
return x | |||
def mnist_inversion_attack(net): | |||
""" | |||
Image inversion attack based on LeNet5 and MNIST dataset. | |||
""" | |||
# upload trained network | |||
ckpt_path = '../../common/networks/lenet5/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt' | |||
load_dict = load_checkpoint(ckpt_path) | |||
load_param_into_net(net, load_dict) | |||
# get test data | |||
data_list = "../../common/dataset/MNIST/test" | |||
batch_size = 32 | |||
ds = generate_mnist_dataset(data_list, batch_size) | |||
inversion_attack = ImageInversionAttack(net, input_shape=(1, 32, 32), input_bound=(0, 1), loss_weights=[1, 0.2, 5]) | |||
i = 0 | |||
batch_num = 1 | |||
sample_num = 10 | |||
for data in ds.create_tuple_iterator(output_numpy=True): | |||
i += 1 | |||
images = data[0].astype(np.float32) | |||
target_features = net(Tensor(images)).asnumpy() | |||
original_images = images[: sample_num] | |||
inversion_images = inversion_attack.generate(target_features[:sample_num], iters=100) | |||
for n in range(1, sample_num+1): | |||
plt.subplot(2, sample_num, n) | |||
plt.gray() | |||
plt.imshow(images[n - 1].reshape(32, 32)) | |||
plt.subplot(2, sample_num, n + sample_num) | |||
plt.gray() | |||
plt.imshow(inversion_images[n - 1].reshape(32, 32)) | |||
plt.show() | |||
if i >= batch_num: | |||
break | |||
# evaluate the similarity between inversion images and original images | |||
avg_l2_dis, avg_ssim = inversion_attack.evaluate(original_images, inversion_images) | |||
LOGGER.info(TAG, 'The average L2 distance between original images and inversion images is: {}'.format(avg_l2_dis)) | |||
LOGGER.info(TAG, 'The average ssim value between original images and inversion images is: {}'.format(avg_ssim)) | |||
if __name__ == '__main__': | |||
# device_target can be "CPU", "GPU" or "Ascend" | |||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | |||
# attack based on complete LeNet5 | |||
mnist_inversion_attack(LeNet5()) | |||
# attack based on part of LeNet5. The network is more shallower and can lead to a better attack result | |||
mnist_inversion_attack(LeNet5_part()) |
@@ -9,6 +9,10 @@ from .adv_robustness.detectors.detector import Detector | |||
from .fuzz_testing.fuzzing import Fuzzer | |||
from .privacy.diff_privacy import DPModel | |||
from .privacy.evaluation.membership_inference import MembershipInference | |||
from .privacy.sup_privacy.sup_ctrl.conctrl import SuppressCtrl | |||
from .privacy.sup_privacy.train.model import SuppressModel | |||
from .privacy.sup_privacy.mask_monitor.masker import SuppressMasker | |||
from .privacy.evaluation.inversion_attack import ImageInversionAttack | |||
__all__ = ['Attack', | |||
'BlackModel', | |||
@@ -16,4 +20,8 @@ __all__ = ['Attack', | |||
'Defense', | |||
'Fuzzer', | |||
'DPModel', | |||
'MembershipInference'] | |||
'MembershipInference', | |||
'SuppressModel', | |||
'SuppressCtrl', | |||
'SuppressMasker', | |||
'ImageInversionAttack'] |
@@ -17,84 +17,15 @@ Attack evaluation. | |||
import numpy as np | |||
from scipy.ndimage.filters import convolve | |||
from mindarmour.utils.logger import LogUtil | |||
from mindarmour.utils._check_param import check_pair_numpy_param, \ | |||
check_param_type, check_numpy_param, check_equal_shape | |||
check_param_type, check_numpy_param | |||
from mindarmour.utils.util import calculate_lp_distance, compute_ssim | |||
LOGGER = LogUtil.get_instance() | |||
TAG = 'AttackEvaluate' | |||
def _compute_ssim(img_1, img_2, kernel_sigma=1.5, kernel_width=11): | |||
""" | |||
compute structural similarity. | |||
Args: | |||
img_1 (numpy.ndarray): The first image to be compared. | |||
img_2 (numpy.ndarray): The second image to be compared. | |||
kernel_sigma (float): Gassian kernel param. Default: 1.5. | |||
kernel_width (int): Another Gassian kernel param. Default: 11. | |||
Returns: | |||
float, structural similarity. | |||
""" | |||
img_1, img_2 = check_equal_shape('images_1', img_1, 'images_2', img_2) | |||
if len(img_1.shape) > 2: | |||
total_ssim = 0 | |||
for i in range(img_1.shape[2]): | |||
total_ssim += _compute_ssim(img_1[:, :, i], img_2[:, :, i]) | |||
return total_ssim / 3 | |||
# Create gaussian kernel | |||
gaussian_kernel = np.zeros((kernel_width, kernel_width)) | |||
for i in range(kernel_width): | |||
for j in range(kernel_width): | |||
gaussian_kernel[i, j] = (1 / (2*np.pi*(kernel_sigma**2)))*np.exp( | |||
- (((i - 5)**2) + ((j - 5)**2)) / (2*(kernel_sigma**2))) | |||
img_1 = img_1.astype(np.float32) | |||
img_2 = img_2.astype(np.float32) | |||
img_sq_1 = img_1**2 | |||
img_sq_2 = img_2**2 | |||
img_12 = img_1*img_2 | |||
# Mean | |||
img_mu_1 = convolve(img_1, gaussian_kernel) | |||
img_mu_2 = convolve(img_2, gaussian_kernel) | |||
# Mean square | |||
img_mu_sq_1 = img_mu_1**2 | |||
img_mu_sq_2 = img_mu_2**2 | |||
img_mu_12 = img_mu_1*img_mu_2 | |||
# Variances | |||
img_sigma_sq_1 = convolve(img_sq_1, gaussian_kernel) | |||
img_sigma_sq_2 = convolve(img_sq_2, gaussian_kernel) | |||
# Covariance | |||
img_sigma_12 = convolve(img_12, gaussian_kernel) | |||
# Centered squares of variances | |||
img_sigma_sq_1 = img_sigma_sq_1 - img_mu_sq_1 | |||
img_sigma_sq_2 = img_sigma_sq_2 - img_mu_sq_2 | |||
img_sigma_12 = img_sigma_12 - img_mu_12 | |||
k_1 = 0.01 | |||
k_2 = 0.03 | |||
c_1 = (k_1*255)**2 | |||
c_2 = (k_2*255)**2 | |||
# Calculate ssim | |||
num_ssim = (2*img_mu_12 + c_1)*(2*img_sigma_12 + c_2) | |||
den_ssim = (img_mu_sq_1 + img_mu_sq_2 + c_1)*(img_sigma_sq_1 | |||
+ img_sigma_sq_2 + c_2) | |||
res = np.average(num_ssim / den_ssim) | |||
return res | |||
class AttackEvaluate: | |||
""" | |||
Evaluation metrics of attack methods. | |||
@@ -217,16 +148,11 @@ class AttackEvaluate: | |||
l0_dist = 0 | |||
l2_dist = 0 | |||
linf_dist = 0 | |||
avoid_zero_div = 1e-14 | |||
for i in idxes: | |||
diff = (self._adv_inputs[i] - self._inputs[i]).flatten() | |||
data = self._inputs[i].flatten() | |||
l0_dist += np.linalg.norm(diff, ord=0) \ | |||
/ (np.linalg.norm(data, ord=0) + avoid_zero_div) | |||
l2_dist += np.linalg.norm(diff, ord=2) \ | |||
/ (np.linalg.norm(data, ord=2) + avoid_zero_div) | |||
linf_dist += np.linalg.norm(diff, ord=np.inf) \ | |||
/ (np.linalg.norm(data, ord=np.inf) + avoid_zero_div) | |||
l0_dist_i, l2_dist_i, linf_dist_i = calculate_lp_distance(self._inputs[i], self._adv_inputs[i]) | |||
l0_dist += l0_dist_i | |||
l2_dist += l2_dist_i | |||
linf_dist += linf_dist_i | |||
return l0_dist / success_num, l2_dist / success_num, \ | |||
linf_dist / success_num | |||
@@ -249,7 +175,7 @@ class AttackEvaluate: | |||
total_ssim = 0.0 | |||
for _, i in enumerate(self._success_idxes): | |||
total_ssim += _compute_ssim(self._adv_inputs[i], self._inputs[i]) | |||
total_ssim += compute_ssim(self._adv_inputs[i], self._inputs[i]) | |||
return total_ssim / success_num | |||
@@ -0,0 +1,210 @@ | |||
# Copyright 2021 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
""" | |||
Inversion Attack | |||
""" | |||
import numpy as np | |||
from mindspore.nn import Cell, MSELoss | |||
from mindspore import Tensor | |||
from mindspore.ops import operations as P | |||
from mindarmour.utils.util import GradWrapWithLoss | |||
from mindarmour.utils._check_param import check_param_type, check_param_multi_types, \ | |||
check_int_positive, check_numpy_param, check_value_positive, check_equal_shape | |||
from mindarmour.utils.logger import LogUtil | |||
from mindarmour.utils.util import calculate_lp_distance, compute_ssim | |||
LOGGER = LogUtil.get_instance() | |||
LOGGER.set_level('INFO') | |||
TAG = 'Image inversion attack' | |||
class InversionLoss(Cell): | |||
""" | |||
The loss function for inversion attack. | |||
Args: | |||
network (Cell): The network used to infer images' deep representations. | |||
weights (Union[list, tuple]): Weights of three sub-loss in InversionLoss, which can be adjusted to | |||
obtain better results. | |||
""" | |||
def __init__(self, network, weights): | |||
super(InversionLoss, self).__init__() | |||
self._network = check_param_type('network', network, Cell) | |||
self._mse_loss = MSELoss() | |||
self._weights = check_param_multi_types('weights', weights, [list, tuple]) | |||
self._get_shape = P.Shape() | |||
def construct(self, input_data, target_features): | |||
""" | |||
Calculate the inversion attack loss, which consists of three parts. Loss_1 is for evaluating the difference | |||
between the target deep representations and current representations; Loss_2 is for evaluating the continuity | |||
between adjacent pixels; Loss_3 is for regularization. | |||
Args: | |||
input_data (Tensor): The reconstructed image during inversion attack. | |||
target_features (Tensor): Deep representations of the original image. | |||
Returns: | |||
Tensor, inversion attack loss of the current iteration. | |||
""" | |||
output = self._network(input_data) | |||
loss_1 = self._mse_loss(output, target_features) / self._mse_loss(target_features, 0) | |||
data_shape = self._get_shape(input_data) | |||
split_op_1 = P.Split(2, data_shape[2]) | |||
split_op_2 = P.Split(3, data_shape[3]) | |||
data_split_1 = split_op_1(input_data) | |||
data_split_2 = split_op_2(input_data) | |||
loss_2 = 0 | |||
for i in range(1, data_shape[2]): | |||
loss_2 += self._mse_loss(data_split_1[i], data_split_1[i-1]) | |||
for j in range(1, data_shape[3]): | |||
loss_2 += self._mse_loss(data_split_2[j], data_split_2[j-1]) | |||
loss_3 = self._mse_loss(input_data, 0) | |||
loss = loss_1*self._weights[0] + loss_2*self._weights[1] + loss_3*self._weights[2] | |||
return loss | |||
class ImageInversionAttack: | |||
""" | |||
An attack method used to reconstruct images by inverting their deep representations. | |||
References: `Aravindh Mahendran, Andrea Vedaldi. Understanding Deep Image Representations by Inverting Them. | |||
2014. <https://arxiv.org/pdf/1412.0035.pdf>`_ | |||
Args: | |||
network (Cell): The network used to infer images' deep representations. | |||
input_shape (tuple): Data shape of single network input, which should be in accordance with the given | |||
network. The format of shape should be (channel, image_width, image_height). | |||
input_bound (Union[tuple, list]): The pixel range of original images, which should be like [minimum_pixel, | |||
maximum_pixel] or (minimum_pixel, maximum_pixel). | |||
loss_weights (Union[list, tuple]): Weights of three sub-loss in InversionLoss, which can be adjusted to | |||
obtain better results. Default: (1, 0.2, 5). | |||
Raises: | |||
TypeError: If the type of network is not Cell. | |||
ValueError: If any value of input_shape is not positive int. | |||
ValueError: If any value of loss_weights is not positive value. | |||
""" | |||
def __init__(self, network, input_shape, input_bound, loss_weights=(1, 0.2, 5)): | |||
self._network = check_param_type('network', network, Cell) | |||
for sub_loss_weight in loss_weights: | |||
check_value_positive('sub_loss_weight', sub_loss_weight) | |||
self._loss = InversionLoss(self._network, loss_weights) | |||
self._input_shape = check_param_multi_types('input_shape', input_shape, [list, tuple]) | |||
for shape_dim in input_shape: | |||
check_int_positive('shape_dim', shape_dim) | |||
self._input_bound = check_param_multi_types('input_bound', input_bound, [list, tuple]) | |||
def generate(self, target_features, iters=100): | |||
""" | |||
Reconstruct images based on target_features. | |||
Args: | |||
target_features (numpy.ndarray): Deep representations of original images. The first dimension of | |||
target_features should be img_num. It should be noted that the shape of target_features should be | |||
(1, dim2, dim3, ...) if img_num equals 1. | |||
iters (int): iteration times of inversion attack, which should be positive integers. Default: 100. | |||
Returns: | |||
numpy.ndarray, reconstructed images, which are expected to be similar to original images. | |||
Raises: | |||
TypeError: If the type of target_features is not numpy.ndarray. | |||
ValueError: If any value of iters is not positive int.Z | |||
Examples: | |||
>>> net = LeNet5() | |||
>>> inversion_attack = ImageInversionAttack(net, input_shape=(1, 32, 32), input_bound=(0, 1), | |||
>>> loss_weights=[1, 0.2, 5]) | |||
>>> features = np.random.random((2, 10)).astype(np.float32) | |||
>>> images = inversion_attack.generate(features, iters=10) | |||
>>> print(images.shape) | |||
(2, 1, 32, 32) | |||
""" | |||
target_features = check_numpy_param('target_features', target_features) | |||
iters = check_int_positive('iters', iters) | |||
# shape checking | |||
img_num = target_features.shape[0] | |||
test_input = np.random.random((img_num,) + self._input_shape).astype(np.float32) | |||
test_out = self._network(Tensor(test_input)).asnumpy() | |||
if test_out.shape != target_features.shape: | |||
msg = "The shape of target_features ({}) is not in accordance with the shape" \ | |||
" of network output({})".format(target_features.shape, test_out.shape) | |||
raise ValueError(msg) | |||
loss_net = self._loss | |||
loss_grad = GradWrapWithLoss(loss_net) | |||
inversion_images = [] | |||
for i in range(img_num): | |||
target_feature_n = target_features[i] | |||
inversion_image_n = np.random.random((1,) + self._input_shape).astype(np.float32)*0.05 | |||
for s in range(iters): | |||
x_grad = loss_grad(Tensor(inversion_image_n), Tensor(target_feature_n)).asnumpy() | |||
x_grad_sign = np.sign(x_grad) | |||
inversion_image_n -= x_grad_sign*0.01 | |||
inversion_image_n = np.clip(inversion_image_n, self._input_bound[0], self._input_bound[1]) | |||
current_loss = self._loss(Tensor(inversion_image_n), Tensor(target_feature_n)) | |||
LOGGER.info(TAG, 'iteration step: {}, loss is {}'.format(s, current_loss)) | |||
inversion_images.append(inversion_image_n) | |||
return np.concatenate(np.array(inversion_images)) | |||
def evaluate(self, original_images, inversion_images): | |||
""" | |||
Compute the average L2 distance and SSIM value between original images and inversion images. | |||
Args: | |||
original_images (numpy.ndarray): Original images, whose shape should be (img_num, channels, img_width, | |||
img_height). | |||
inversion_images (numpy.ndarray): Inversion images, whose shape should be (img_num, channels, img_width, | |||
img_height). | |||
Returns: | |||
tuple, the average l2 distance and average ssim value between original images and inversion images. | |||
Examples: | |||
>>> net = LeNet5() | |||
>>> inversion_attack = ImageInversionAttack(net, input_shape=(1, 32, 32), input_bound=(0, 1), | |||
>>> loss_weights=[1, 0.2, 5]) | |||
>>> features = np.random.random((2, 10)).astype(np.float32) | |||
>>> inver_images = inversion_attack.generate(features, iters=10) | |||
>>> ori_images = np.random.random((2, 1, 32, 32)) | |||
>>> result = inversion_attack.evaluate(ori_images, inver_images) | |||
>>> print(len(result)) | |||
2 | |||
""" | |||
check_numpy_param('original_images', original_images) | |||
check_numpy_param('inversion_images', inversion_images) | |||
img_1, img_2 = check_equal_shape('original_images', original_images, 'inversion_images', inversion_images) | |||
if (len(img_1.shape) != 4) or (img_1.shape[1] != 1 and img_1.shape[1] != 3): | |||
msg = 'The shape format of img_1 and img_2 should be (img_num, channels, img_width, img_height),' \ | |||
' but got {} and {}'.format(img_1.shape, img_2.shape) | |||
raise ValueError(msg) | |||
total_l2_distance = 0 | |||
total_ssim = 0 | |||
img_1 = img_1.transpose(0, 2, 3, 1) | |||
img_2 = img_2.transpose(0, 2, 3, 1) | |||
for i in range(img_1.shape[0]): | |||
_, l2_dis, _ = calculate_lp_distance(img_1[i], img_2[i]) | |||
total_l2_distance += l2_dis | |||
total_ssim += compute_ssim(img_1[i], img_2[i]) | |||
avg_l2_dis = total_l2_distance / img_1.shape[0] | |||
avg_ssim = total_ssim / img_1.shape[0] | |||
return avg_l2_dis, avg_ssim |
@@ -61,6 +61,11 @@ def check_param_multi_types(arg_name, arg_value, valid_types): | |||
def check_int_positive(arg_name, arg_value): | |||
"""Check positive integer.""" | |||
# 'True' is treated as int(1) in python, which is a bug. | |||
if isinstance(arg_value, bool): | |||
msg = '{} should not be bool value, but got {}'.format(arg_name, arg_value) | |||
LOGGER.error(TAG, msg) | |||
raise ValueError(msg) | |||
arg_value = check_param_type(arg_name, arg_value, int) | |||
if arg_value <= 0: | |||
msg = '{} must be greater than 0, but got {}'.format(arg_name, | |||
@@ -13,11 +13,13 @@ | |||
# limitations under the License. | |||
""" Util for MindArmour. """ | |||
import numpy as np | |||
from scipy.ndimage.filters import convolve | |||
from mindspore import Tensor | |||
from mindspore.nn import Cell | |||
from mindspore.ops.composite import GradOperation | |||
from mindarmour.utils._check_param import check_numpy_param, check_param_multi_types | |||
from mindarmour.utils._check_param import check_numpy_param, check_param_multi_types, check_equal_shape | |||
from .logger import LogUtil | |||
@@ -61,7 +63,7 @@ def jacobian_matrix_for_detection(grad_wrap_net, inputs, num_boxes, num_classes) | |||
Args: | |||
grad_wrap_net (Cell): A network wrapped by GradWrap. | |||
inputs (numpy.ndarray): Input samples. | |||
num_boxes (int): Number of boxes infered by each image. | |||
num_boxes (int): Number of boxes inferred by each image. | |||
num_classes (int): Number of labels of model output. | |||
Returns: | |||
@@ -251,3 +253,109 @@ def to_tensor_tuple(inputs_ori): | |||
else: | |||
inputs_tensor = (Tensor(inputs_ori),) | |||
return inputs_tensor | |||
def calculate_lp_distance(original_image, compared_image): | |||
""" | |||
Calculate l0, l2 and linf distance for two images with the same shape. | |||
Args: | |||
original_image (np.ndarray): Original image. | |||
compared_image (np.ndarray): Another image for comparison. | |||
Returns: | |||
tuple, (l0, l2 and linf) distances between two images. | |||
Raises: | |||
TypeError: If type of original_image or type of compared_image is not numpy.ndarray. | |||
ValueError: If the shape of original_image and compared_image are not the same. | |||
""" | |||
check_numpy_param('original_image', original_image) | |||
check_numpy_param('compared_image', compared_image) | |||
check_equal_shape('original_image', original_image, 'compared_image', compared_image) | |||
avoid_zero_div = 1e-14 | |||
diff = (original_image - compared_image).flatten() | |||
data = original_image.flatten() | |||
l0_dist = np.linalg.norm(diff, ord=0) \ | |||
/ (np.linalg.norm(data, ord=0) + avoid_zero_div) | |||
l2_dist = np.linalg.norm(diff, ord=2) \ | |||
/ (np.linalg.norm(data, ord=2) + avoid_zero_div) | |||
linf_dist = np.linalg.norm(diff, ord=np.inf) \ | |||
/ (np.linalg.norm(data, ord=np.inf) + avoid_zero_div) | |||
return l0_dist, l2_dist, linf_dist | |||
def compute_ssim(img_1, img_2, kernel_sigma=1.5, kernel_width=11): | |||
""" | |||
compute structural similarity between two images. | |||
Args: | |||
img_1 (numpy.ndarray): The first image to be compared. The shape of img_1 should be (img_width, img_height, | |||
channels). | |||
img_2 (numpy.ndarray): The second image to be compared. The shape of img_2 should be (img_width, img_height, | |||
channels). | |||
kernel_sigma (float): Gassian kernel param. Default: 1.5. | |||
kernel_width (int): Another Gassian kernel param. Default: 11. | |||
Returns: | |||
float, structural similarity. | |||
""" | |||
img_1, img_2 = check_equal_shape('images_1', img_1, 'images_2', img_2) | |||
if len(img_1.shape) > 2: | |||
if (len(img_1.shape) != 3) or (img_1.shape[2] != 1 and img_1.shape[2] != 3): | |||
msg = 'The shape format of img_1 and img_2 should be (img_width, img_height, channels),' \ | |||
' but got {} and {}'.format(img_1.shape, img_2.shape) | |||
raise ValueError(msg) | |||
if len(img_1.shape) > 2: | |||
total_ssim = 0 | |||
for i in range(img_1.shape[2]): | |||
total_ssim += compute_ssim(img_1[:, :, i], img_2[:, :, i]) | |||
return total_ssim / 3 | |||
# Create gaussian kernel | |||
gaussian_kernel = np.zeros((kernel_width, kernel_width)) | |||
for i in range(kernel_width): | |||
for j in range(kernel_width): | |||
gaussian_kernel[i, j] = (1 / (2*np.pi*(kernel_sigma**2)))*np.exp( | |||
- (((i - 5)**2) + ((j - 5)**2)) / (2*(kernel_sigma**2))) | |||
img_1 = img_1.astype(np.float32) | |||
img_2 = img_2.astype(np.float32) | |||
img_sq_1 = img_1**2 | |||
img_sq_2 = img_2**2 | |||
img_12 = img_1*img_2 | |||
# Mean | |||
img_mu_1 = convolve(img_1, gaussian_kernel) | |||
img_mu_2 = convolve(img_2, gaussian_kernel) | |||
# Mean square | |||
img_mu_sq_1 = img_mu_1**2 | |||
img_mu_sq_2 = img_mu_2**2 | |||
img_mu_12 = img_mu_1*img_mu_2 | |||
# Variances | |||
img_sigma_sq_1 = convolve(img_sq_1, gaussian_kernel) | |||
img_sigma_sq_2 = convolve(img_sq_2, gaussian_kernel) | |||
# Covariance | |||
img_sigma_12 = convolve(img_12, gaussian_kernel) | |||
# Centered squares of variances | |||
img_sigma_sq_1 = img_sigma_sq_1 - img_mu_sq_1 | |||
img_sigma_sq_2 = img_sigma_sq_2 - img_mu_sq_2 | |||
img_sigma_12 = img_sigma_12 - img_mu_12 | |||
k_1 = 0.01 | |||
k_2 = 0.03 | |||
c_1 = (k_1*255)**2 | |||
c_2 = (k_2*255)**2 | |||
# Calculate ssim | |||
num_ssim = (2*img_mu_12 + c_1)*(2*img_sigma_12 + c_2) | |||
den_ssim = (img_mu_sq_1 + img_mu_sq_2 + c_1)*(img_sigma_sq_1 | |||
+ img_sigma_sq_2 + c_2) | |||
res = np.average(num_ssim / den_ssim) | |||
return res |
@@ -0,0 +1,41 @@ | |||
# Copyright 2021 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
""" | |||
Inversion attack test | |||
""" | |||
import pytest | |||
import numpy as np | |||
import mindspore.context as context | |||
from mindarmour.privacy.evaluation.inversion_attack import ImageInversionAttack | |||
from ut.python.utils.mock_net import Net | |||
context.set_context(mode=context.GRAPH_MODE) | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.env_onecard | |||
@pytest.mark.component_mindarmour | |||
def test_inversion_attack(): | |||
net = Net() | |||
target_features = np.random.random((2, 10)).astype(np.float32) | |||
inversion_attack = ImageInversionAttack(net, input_shape=(1, 32, 32), input_bound=(0, 1), loss_weights=[1, 0.2, 5]) | |||
inversion_images = inversion_attack.generate(target_features, iters=10) | |||
assert target_features.shape[0] == inversion_images.shape[0] |