Browse Source

Add image inversion attack method

tags/v1.2.1
jin-xiulang 4 years ago
parent
commit
a05744f8c9
9 changed files with 503 additions and 85 deletions
  1. +13
    -1
      examples/privacy/README.md
  2. +0
    -0
      examples/privacy/inversion_attack/__init__.py
  3. +108
    -0
      examples/privacy/inversion_attack/mnist_inversion_attack.py
  4. +9
    -1
      mindarmour/__init__.py
  5. +7
    -81
      mindarmour/adv_robustness/evaluations/attack_evaluation.py
  6. +210
    -0
      mindarmour/privacy/evaluation/inversion_attack.py
  7. +5
    -0
      mindarmour/utils/_check_param.py
  8. +110
    -2
      mindarmour/utils/util.py
  9. +41
    -0
      tests/ut/python/privacy/evaluation/test_inversion_attack.py

+ 13
- 1
examples/privacy/README.md View File

@@ -42,7 +42,7 @@ python train.py --data_path home_path_to_cifar100 --ckpt_path ./
python example_vgg_cifar.py --data_path home_path_to_cifar100 --pre_trained 0-100_781.ckpt
```

## 4. suppress privacy training
## 4. Suppress privacy training

With suppress privacy mechanism, the values of some trainable parameters (such as conv layers and fully connected
layers) are set to zero as the training step grows, which can
@@ -52,3 +52,15 @@ With suppress privacy mechanism, the values of some trainable parameters (such
cd examples/privacy/sup_privacy
python sup_privacy.py
```

## 5. Image inversion attack

Inversion attack means reconstructing an image based on its deep representations. For example,
reconstruct a MNIST image based on its output through LeNet5. The mechanism behind it is that well-trained
model can "remember" those training dataset. Therefore, inversion attack can be used to estimate the privacy
leakage of training tasks.

```sh
cd examples/privacy/inversion_attack
python mnist_inversion_attack.py
```

+ 0
- 0
examples/privacy/inversion_attack/__init__.py View File


+ 108
- 0
examples/privacy/inversion_attack/mnist_inversion_attack.py View File

@@ -0,0 +1,108 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Examples of image inversion attack
"""
import numpy as np
import matplotlib.pyplot as plt

from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore import Tensor, context
from mindspore import nn
from mindarmour.privacy.evaluation.inversion_attack import ImageInversionAttack
from mindarmour.utils.logger import LogUtil

from examples.common.networks.lenet5.lenet5_net import LeNet5, conv, fc_with_initialize
from examples.common.dataset.data_processing import generate_mnist_dataset

LOGGER = LogUtil.get_instance()
LOGGER.set_level('INFO')
TAG = 'InversionAttack'


# pylint: disable=invalid-name
class LeNet5_part(nn.Cell):
"""
Part of LeNet5 network.
"""
def __init__(self):
super(LeNet5_part, self).__init__()
self.conv1 = conv(1, 6, 5)
self.conv2 = conv(6, 16, 5)
self.fc1 = fc_with_initialize(16*5*5, 120)
self.fc2 = fc_with_initialize(120, 84)
self.fc3 = fc_with_initialize(84, 10)
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = nn.Flatten()

def construct(self, x):
x = self.conv1(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.conv2(x)
x = self.relu(x)
x = self.max_pool2d(x)
return x


def mnist_inversion_attack(net):
"""
Image inversion attack based on LeNet5 and MNIST dataset.
"""
# upload trained network
ckpt_path = '../../common/networks/lenet5/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt'
load_dict = load_checkpoint(ckpt_path)
load_param_into_net(net, load_dict)

# get test data
data_list = "../../common/dataset/MNIST/test"
batch_size = 32
ds = generate_mnist_dataset(data_list, batch_size)

inversion_attack = ImageInversionAttack(net, input_shape=(1, 32, 32), input_bound=(0, 1), loss_weights=[1, 0.2, 5])

i = 0
batch_num = 1
sample_num = 10
for data in ds.create_tuple_iterator(output_numpy=True):
i += 1
images = data[0].astype(np.float32)
target_features = net(Tensor(images)).asnumpy()
original_images = images[: sample_num]
inversion_images = inversion_attack.generate(target_features[:sample_num], iters=100)
for n in range(1, sample_num+1):
plt.subplot(2, sample_num, n)
plt.gray()
plt.imshow(images[n - 1].reshape(32, 32))
plt.subplot(2, sample_num, n + sample_num)
plt.gray()
plt.imshow(inversion_images[n - 1].reshape(32, 32))
plt.show()
if i >= batch_num:
break
# evaluate the similarity between inversion images and original images
avg_l2_dis, avg_ssim = inversion_attack.evaluate(original_images, inversion_images)
LOGGER.info(TAG, 'The average L2 distance between original images and inversion images is: {}'.format(avg_l2_dis))
LOGGER.info(TAG, 'The average ssim value between original images and inversion images is: {}'.format(avg_ssim))


if __name__ == '__main__':
# device_target can be "CPU", "GPU" or "Ascend"
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
# attack based on complete LeNet5
mnist_inversion_attack(LeNet5())
# attack based on part of LeNet5. The network is more shallower and can lead to a better attack result
mnist_inversion_attack(LeNet5_part())

+ 9
- 1
mindarmour/__init__.py View File

@@ -9,6 +9,10 @@ from .adv_robustness.detectors.detector import Detector
from .fuzz_testing.fuzzing import Fuzzer
from .privacy.diff_privacy import DPModel
from .privacy.evaluation.membership_inference import MembershipInference
from .privacy.sup_privacy.sup_ctrl.conctrl import SuppressCtrl
from .privacy.sup_privacy.train.model import SuppressModel
from .privacy.sup_privacy.mask_monitor.masker import SuppressMasker
from .privacy.evaluation.inversion_attack import ImageInversionAttack

__all__ = ['Attack',
'BlackModel',
@@ -16,4 +20,8 @@ __all__ = ['Attack',
'Defense',
'Fuzzer',
'DPModel',
'MembershipInference']
'MembershipInference',
'SuppressModel',
'SuppressCtrl',
'SuppressMasker',
'ImageInversionAttack']

+ 7
- 81
mindarmour/adv_robustness/evaluations/attack_evaluation.py View File

@@ -17,84 +17,15 @@ Attack evaluation.

import numpy as np

from scipy.ndimage.filters import convolve

from mindarmour.utils.logger import LogUtil
from mindarmour.utils._check_param import check_pair_numpy_param, \
check_param_type, check_numpy_param, check_equal_shape
check_param_type, check_numpy_param
from mindarmour.utils.util import calculate_lp_distance, compute_ssim

LOGGER = LogUtil.get_instance()
TAG = 'AttackEvaluate'


def _compute_ssim(img_1, img_2, kernel_sigma=1.5, kernel_width=11):
"""
compute structural similarity.
Args:
img_1 (numpy.ndarray): The first image to be compared.
img_2 (numpy.ndarray): The second image to be compared.
kernel_sigma (float): Gassian kernel param. Default: 1.5.
kernel_width (int): Another Gassian kernel param. Default: 11.

Returns:
float, structural similarity.
"""
img_1, img_2 = check_equal_shape('images_1', img_1, 'images_2', img_2)

if len(img_1.shape) > 2:
total_ssim = 0
for i in range(img_1.shape[2]):
total_ssim += _compute_ssim(img_1[:, :, i], img_2[:, :, i])
return total_ssim / 3

# Create gaussian kernel
gaussian_kernel = np.zeros((kernel_width, kernel_width))
for i in range(kernel_width):
for j in range(kernel_width):
gaussian_kernel[i, j] = (1 / (2*np.pi*(kernel_sigma**2)))*np.exp(
- (((i - 5)**2) + ((j - 5)**2)) / (2*(kernel_sigma**2)))

img_1 = img_1.astype(np.float32)
img_2 = img_2.astype(np.float32)

img_sq_1 = img_1**2
img_sq_2 = img_2**2
img_12 = img_1*img_2

# Mean
img_mu_1 = convolve(img_1, gaussian_kernel)
img_mu_2 = convolve(img_2, gaussian_kernel)

# Mean square
img_mu_sq_1 = img_mu_1**2
img_mu_sq_2 = img_mu_2**2
img_mu_12 = img_mu_1*img_mu_2

# Variances
img_sigma_sq_1 = convolve(img_sq_1, gaussian_kernel)
img_sigma_sq_2 = convolve(img_sq_2, gaussian_kernel)

# Covariance
img_sigma_12 = convolve(img_12, gaussian_kernel)

# Centered squares of variances
img_sigma_sq_1 = img_sigma_sq_1 - img_mu_sq_1
img_sigma_sq_2 = img_sigma_sq_2 - img_mu_sq_2
img_sigma_12 = img_sigma_12 - img_mu_12

k_1 = 0.01
k_2 = 0.03
c_1 = (k_1*255)**2
c_2 = (k_2*255)**2

# Calculate ssim
num_ssim = (2*img_mu_12 + c_1)*(2*img_sigma_12 + c_2)
den_ssim = (img_mu_sq_1 + img_mu_sq_2 + c_1)*(img_sigma_sq_1
+ img_sigma_sq_2 + c_2)
res = np.average(num_ssim / den_ssim)
return res


class AttackEvaluate:
"""
Evaluation metrics of attack methods.
@@ -217,16 +148,11 @@ class AttackEvaluate:
l0_dist = 0
l2_dist = 0
linf_dist = 0
avoid_zero_div = 1e-14
for i in idxes:
diff = (self._adv_inputs[i] - self._inputs[i]).flatten()
data = self._inputs[i].flatten()
l0_dist += np.linalg.norm(diff, ord=0) \
/ (np.linalg.norm(data, ord=0) + avoid_zero_div)
l2_dist += np.linalg.norm(diff, ord=2) \
/ (np.linalg.norm(data, ord=2) + avoid_zero_div)
linf_dist += np.linalg.norm(diff, ord=np.inf) \
/ (np.linalg.norm(data, ord=np.inf) + avoid_zero_div)
l0_dist_i, l2_dist_i, linf_dist_i = calculate_lp_distance(self._inputs[i], self._adv_inputs[i])
l0_dist += l0_dist_i
l2_dist += l2_dist_i
linf_dist += linf_dist_i

return l0_dist / success_num, l2_dist / success_num, \
linf_dist / success_num
@@ -249,7 +175,7 @@ class AttackEvaluate:

total_ssim = 0.0
for _, i in enumerate(self._success_idxes):
total_ssim += _compute_ssim(self._adv_inputs[i], self._inputs[i])
total_ssim += compute_ssim(self._adv_inputs[i], self._inputs[i])

return total_ssim / success_num



+ 210
- 0
mindarmour/privacy/evaluation/inversion_attack.py View File

@@ -0,0 +1,210 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Inversion Attack
"""
import numpy as np

from mindspore.nn import Cell, MSELoss
from mindspore import Tensor
from mindspore.ops import operations as P

from mindarmour.utils.util import GradWrapWithLoss
from mindarmour.utils._check_param import check_param_type, check_param_multi_types, \
check_int_positive, check_numpy_param, check_value_positive, check_equal_shape
from mindarmour.utils.logger import LogUtil
from mindarmour.utils.util import calculate_lp_distance, compute_ssim

LOGGER = LogUtil.get_instance()
LOGGER.set_level('INFO')
TAG = 'Image inversion attack'


class InversionLoss(Cell):
"""
The loss function for inversion attack.

Args:
network (Cell): The network used to infer images' deep representations.
weights (Union[list, tuple]): Weights of three sub-loss in InversionLoss, which can be adjusted to
obtain better results.
"""
def __init__(self, network, weights):
super(InversionLoss, self).__init__()
self._network = check_param_type('network', network, Cell)
self._mse_loss = MSELoss()
self._weights = check_param_multi_types('weights', weights, [list, tuple])
self._get_shape = P.Shape()

def construct(self, input_data, target_features):
"""
Calculate the inversion attack loss, which consists of three parts. Loss_1 is for evaluating the difference
between the target deep representations and current representations; Loss_2 is for evaluating the continuity
between adjacent pixels; Loss_3 is for regularization.

Args:
input_data (Tensor): The reconstructed image during inversion attack.
target_features (Tensor): Deep representations of the original image.

Returns:
Tensor, inversion attack loss of the current iteration.
"""
output = self._network(input_data)
loss_1 = self._mse_loss(output, target_features) / self._mse_loss(target_features, 0)

data_shape = self._get_shape(input_data)
split_op_1 = P.Split(2, data_shape[2])
split_op_2 = P.Split(3, data_shape[3])
data_split_1 = split_op_1(input_data)
data_split_2 = split_op_2(input_data)
loss_2 = 0
for i in range(1, data_shape[2]):
loss_2 += self._mse_loss(data_split_1[i], data_split_1[i-1])
for j in range(1, data_shape[3]):
loss_2 += self._mse_loss(data_split_2[j], data_split_2[j-1])

loss_3 = self._mse_loss(input_data, 0)

loss = loss_1*self._weights[0] + loss_2*self._weights[1] + loss_3*self._weights[2]
return loss


class ImageInversionAttack:
"""
An attack method used to reconstruct images by inverting their deep representations.

References: `Aravindh Mahendran, Andrea Vedaldi. Understanding Deep Image Representations by Inverting Them.
2014. <https://arxiv.org/pdf/1412.0035.pdf>`_

Args:
network (Cell): The network used to infer images' deep representations.
input_shape (tuple): Data shape of single network input, which should be in accordance with the given
network. The format of shape should be (channel, image_width, image_height).
input_bound (Union[tuple, list]): The pixel range of original images, which should be like [minimum_pixel,
maximum_pixel] or (minimum_pixel, maximum_pixel).
loss_weights (Union[list, tuple]): Weights of three sub-loss in InversionLoss, which can be adjusted to
obtain better results. Default: (1, 0.2, 5).

Raises:
TypeError: If the type of network is not Cell.
ValueError: If any value of input_shape is not positive int.
ValueError: If any value of loss_weights is not positive value.
"""
def __init__(self, network, input_shape, input_bound, loss_weights=(1, 0.2, 5)):
self._network = check_param_type('network', network, Cell)
for sub_loss_weight in loss_weights:
check_value_positive('sub_loss_weight', sub_loss_weight)
self._loss = InversionLoss(self._network, loss_weights)
self._input_shape = check_param_multi_types('input_shape', input_shape, [list, tuple])
for shape_dim in input_shape:
check_int_positive('shape_dim', shape_dim)
self._input_bound = check_param_multi_types('input_bound', input_bound, [list, tuple])

def generate(self, target_features, iters=100):
"""
Reconstruct images based on target_features.

Args:
target_features (numpy.ndarray): Deep representations of original images. The first dimension of
target_features should be img_num. It should be noted that the shape of target_features should be
(1, dim2, dim3, ...) if img_num equals 1.
iters (int): iteration times of inversion attack, which should be positive integers. Default: 100.

Returns:
numpy.ndarray, reconstructed images, which are expected to be similar to original images.

Raises:
TypeError: If the type of target_features is not numpy.ndarray.
ValueError: If any value of iters is not positive int.Z

Examples:
>>> net = LeNet5()
>>> inversion_attack = ImageInversionAttack(net, input_shape=(1, 32, 32), input_bound=(0, 1),
>>> loss_weights=[1, 0.2, 5])
>>> features = np.random.random((2, 10)).astype(np.float32)
>>> images = inversion_attack.generate(features, iters=10)
>>> print(images.shape)
(2, 1, 32, 32)
"""
target_features = check_numpy_param('target_features', target_features)
iters = check_int_positive('iters', iters)

# shape checking
img_num = target_features.shape[0]
test_input = np.random.random((img_num,) + self._input_shape).astype(np.float32)
test_out = self._network(Tensor(test_input)).asnumpy()
if test_out.shape != target_features.shape:
msg = "The shape of target_features ({}) is not in accordance with the shape" \
" of network output({})".format(target_features.shape, test_out.shape)
raise ValueError(msg)

loss_net = self._loss
loss_grad = GradWrapWithLoss(loss_net)

inversion_images = []
for i in range(img_num):
target_feature_n = target_features[i]
inversion_image_n = np.random.random((1,) + self._input_shape).astype(np.float32)*0.05
for s in range(iters):
x_grad = loss_grad(Tensor(inversion_image_n), Tensor(target_feature_n)).asnumpy()
x_grad_sign = np.sign(x_grad)
inversion_image_n -= x_grad_sign*0.01
inversion_image_n = np.clip(inversion_image_n, self._input_bound[0], self._input_bound[1])
current_loss = self._loss(Tensor(inversion_image_n), Tensor(target_feature_n))
LOGGER.info(TAG, 'iteration step: {}, loss is {}'.format(s, current_loss))
inversion_images.append(inversion_image_n)
return np.concatenate(np.array(inversion_images))

def evaluate(self, original_images, inversion_images):
"""
Compute the average L2 distance and SSIM value between original images and inversion images.

Args:
original_images (numpy.ndarray): Original images, whose shape should be (img_num, channels, img_width,
img_height).
inversion_images (numpy.ndarray): Inversion images, whose shape should be (img_num, channels, img_width,
img_height).

Returns:
tuple, the average l2 distance and average ssim value between original images and inversion images.

Examples:
>>> net = LeNet5()
>>> inversion_attack = ImageInversionAttack(net, input_shape=(1, 32, 32), input_bound=(0, 1),
>>> loss_weights=[1, 0.2, 5])
>>> features = np.random.random((2, 10)).astype(np.float32)
>>> inver_images = inversion_attack.generate(features, iters=10)
>>> ori_images = np.random.random((2, 1, 32, 32))
>>> result = inversion_attack.evaluate(ori_images, inver_images)
>>> print(len(result))
2
"""
check_numpy_param('original_images', original_images)
check_numpy_param('inversion_images', inversion_images)
img_1, img_2 = check_equal_shape('original_images', original_images, 'inversion_images', inversion_images)
if (len(img_1.shape) != 4) or (img_1.shape[1] != 1 and img_1.shape[1] != 3):
msg = 'The shape format of img_1 and img_2 should be (img_num, channels, img_width, img_height),' \
' but got {} and {}'.format(img_1.shape, img_2.shape)
raise ValueError(msg)
total_l2_distance = 0
total_ssim = 0
img_1 = img_1.transpose(0, 2, 3, 1)
img_2 = img_2.transpose(0, 2, 3, 1)
for i in range(img_1.shape[0]):
_, l2_dis, _ = calculate_lp_distance(img_1[i], img_2[i])
total_l2_distance += l2_dis
total_ssim += compute_ssim(img_1[i], img_2[i])
avg_l2_dis = total_l2_distance / img_1.shape[0]
avg_ssim = total_ssim / img_1.shape[0]
return avg_l2_dis, avg_ssim

+ 5
- 0
mindarmour/utils/_check_param.py View File

@@ -61,6 +61,11 @@ def check_param_multi_types(arg_name, arg_value, valid_types):

def check_int_positive(arg_name, arg_value):
"""Check positive integer."""
# 'True' is treated as int(1) in python, which is a bug.
if isinstance(arg_value, bool):
msg = '{} should not be bool value, but got {}'.format(arg_name, arg_value)
LOGGER.error(TAG, msg)
raise ValueError(msg)
arg_value = check_param_type(arg_name, arg_value, int)
if arg_value <= 0:
msg = '{} must be greater than 0, but got {}'.format(arg_name,


+ 110
- 2
mindarmour/utils/util.py View File

@@ -13,11 +13,13 @@
# limitations under the License.
""" Util for MindArmour. """
import numpy as np
from scipy.ndimage.filters import convolve

from mindspore import Tensor
from mindspore.nn import Cell
from mindspore.ops.composite import GradOperation

from mindarmour.utils._check_param import check_numpy_param, check_param_multi_types
from mindarmour.utils._check_param import check_numpy_param, check_param_multi_types, check_equal_shape

from .logger import LogUtil

@@ -61,7 +63,7 @@ def jacobian_matrix_for_detection(grad_wrap_net, inputs, num_boxes, num_classes)
Args:
grad_wrap_net (Cell): A network wrapped by GradWrap.
inputs (numpy.ndarray): Input samples.
num_boxes (int): Number of boxes infered by each image.
num_boxes (int): Number of boxes inferred by each image.
num_classes (int): Number of labels of model output.

Returns:
@@ -251,3 +253,109 @@ def to_tensor_tuple(inputs_ori):
else:
inputs_tensor = (Tensor(inputs_ori),)
return inputs_tensor


def calculate_lp_distance(original_image, compared_image):
"""
Calculate l0, l2 and linf distance for two images with the same shape.

Args:
original_image (np.ndarray): Original image.
compared_image (np.ndarray): Another image for comparison.

Returns:
tuple, (l0, l2 and linf) distances between two images.

Raises:
TypeError: If type of original_image or type of compared_image is not numpy.ndarray.
ValueError: If the shape of original_image and compared_image are not the same.
"""
check_numpy_param('original_image', original_image)
check_numpy_param('compared_image', compared_image)
check_equal_shape('original_image', original_image, 'compared_image', compared_image)
avoid_zero_div = 1e-14
diff = (original_image - compared_image).flatten()
data = original_image.flatten()
l0_dist = np.linalg.norm(diff, ord=0) \
/ (np.linalg.norm(data, ord=0) + avoid_zero_div)
l2_dist = np.linalg.norm(diff, ord=2) \
/ (np.linalg.norm(data, ord=2) + avoid_zero_div)
linf_dist = np.linalg.norm(diff, ord=np.inf) \
/ (np.linalg.norm(data, ord=np.inf) + avoid_zero_div)
return l0_dist, l2_dist, linf_dist


def compute_ssim(img_1, img_2, kernel_sigma=1.5, kernel_width=11):
"""
compute structural similarity between two images.

Args:
img_1 (numpy.ndarray): The first image to be compared. The shape of img_1 should be (img_width, img_height,
channels).
img_2 (numpy.ndarray): The second image to be compared. The shape of img_2 should be (img_width, img_height,
channels).
kernel_sigma (float): Gassian kernel param. Default: 1.5.
kernel_width (int): Another Gassian kernel param. Default: 11.

Returns:
float, structural similarity.
"""
img_1, img_2 = check_equal_shape('images_1', img_1, 'images_2', img_2)
if len(img_1.shape) > 2:
if (len(img_1.shape) != 3) or (img_1.shape[2] != 1 and img_1.shape[2] != 3):
msg = 'The shape format of img_1 and img_2 should be (img_width, img_height, channels),' \
' but got {} and {}'.format(img_1.shape, img_2.shape)
raise ValueError(msg)

if len(img_1.shape) > 2:
total_ssim = 0
for i in range(img_1.shape[2]):
total_ssim += compute_ssim(img_1[:, :, i], img_2[:, :, i])
return total_ssim / 3

# Create gaussian kernel
gaussian_kernel = np.zeros((kernel_width, kernel_width))
for i in range(kernel_width):
for j in range(kernel_width):
gaussian_kernel[i, j] = (1 / (2*np.pi*(kernel_sigma**2)))*np.exp(
- (((i - 5)**2) + ((j - 5)**2)) / (2*(kernel_sigma**2)))

img_1 = img_1.astype(np.float32)
img_2 = img_2.astype(np.float32)

img_sq_1 = img_1**2
img_sq_2 = img_2**2
img_12 = img_1*img_2

# Mean
img_mu_1 = convolve(img_1, gaussian_kernel)
img_mu_2 = convolve(img_2, gaussian_kernel)

# Mean square
img_mu_sq_1 = img_mu_1**2
img_mu_sq_2 = img_mu_2**2
img_mu_12 = img_mu_1*img_mu_2

# Variances
img_sigma_sq_1 = convolve(img_sq_1, gaussian_kernel)
img_sigma_sq_2 = convolve(img_sq_2, gaussian_kernel)

# Covariance
img_sigma_12 = convolve(img_12, gaussian_kernel)

# Centered squares of variances
img_sigma_sq_1 = img_sigma_sq_1 - img_mu_sq_1
img_sigma_sq_2 = img_sigma_sq_2 - img_mu_sq_2
img_sigma_12 = img_sigma_12 - img_mu_12

k_1 = 0.01
k_2 = 0.03
c_1 = (k_1*255)**2
c_2 = (k_2*255)**2

# Calculate ssim
num_ssim = (2*img_mu_12 + c_1)*(2*img_sigma_12 + c_2)
den_ssim = (img_mu_sq_1 + img_mu_sq_2 + c_1)*(img_sigma_sq_1
+ img_sigma_sq_2 + c_2)
res = np.average(num_ssim / den_ssim)
return res

+ 41
- 0
tests/ut/python/privacy/evaluation/test_inversion_attack.py View File

@@ -0,0 +1,41 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Inversion attack test
"""
import pytest

import numpy as np

import mindspore.context as context

from mindarmour.privacy.evaluation.inversion_attack import ImageInversionAttack

from ut.python.utils.mock_net import Net


context.set_context(mode=context.GRAPH_MODE)


@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
@pytest.mark.component_mindarmour
def test_inversion_attack():
net = Net()
target_features = np.random.random((2, 10)).astype(np.float32)
inversion_attack = ImageInversionAttack(net, input_shape=(1, 32, 32), input_bound=(0, 1), loss_weights=[1, 0.2, 5])
inversion_images = inversion_attack.generate(target_features, iters=10)
assert target_features.shape[0] == inversion_images.shape[0]

Loading…
Cancel
Save