Browse Source

Fix an issue of inversion attack

tags/v1.6.0
jin-xiulang 4 years ago
parent
commit
30a17fa462
2 changed files with 35 additions and 7 deletions
  1. +19
    -7
      mindarmour/privacy/evaluation/inversion_attack.py
  2. +16
    -0
      tests/ut/python/privacy/evaluation/test_inversion_attack.py

+ 19
- 7
mindarmour/privacy/evaluation/inversion_attack.py View File

@@ -20,6 +20,7 @@ from scipy.special import softmax
from mindspore.nn import Cell, MSELoss
from mindspore import Tensor
from mindspore.ops import operations as P
from mindspore import context

from mindarmour.utils.util import GradWrapWithLoss
from mindarmour.utils._check_param import check_param_type, check_param_multi_types, \
@@ -48,6 +49,7 @@ class InversionLoss(Cell):
self._weights = check_param_multi_types('weights', weights, [list, tuple])
self._get_shape = P.Shape()
self._zeros = P.ZerosLike()
self._device_target = context.get_context("device_target")

def construct(self, input_data, target_features):
"""
@@ -66,11 +68,22 @@ class InversionLoss(Cell):
loss_1 = self._mse_loss(output, target_features) / self._mse_loss(target_features, 0)

data_shape = self._get_shape(input_data)
data_copy_1 = self._zeros(input_data)
data_copy_2 = self._zeros(input_data)
data_copy_1[:, :, :(data_shape[2] - 1), :] = input_data[:, :, 1:, :]
data_copy_2[:, :, :, :(data_shape[2] - 1)] = input_data[:, :, :, 1:]
loss_2 = self._mse_loss(input_data, data_copy_1) + self._mse_loss(input_data, data_copy_2)
if self._device_target == 'CPU':
split_op_1 = P.Split(2, data_shape[2])
split_op_2 = P.Split(3, data_shape[3])
data_split_1 = split_op_1(input_data)
data_split_2 = split_op_2(input_data)
loss_2 = 0
for i in range(1, data_shape[2]):
loss_2 += self._mse_loss(data_split_1[i], data_split_1[i - 1])
for j in range(1, data_shape[3]):
loss_2 += self._mse_loss(data_split_2[j], data_split_2[j - 1])
else:
data_copy_1 = self._zeros(input_data)
data_copy_2 = self._zeros(input_data)
data_copy_1[:, :, :(data_shape[2] - 1), :] = input_data[:, :, 1:, :]
data_copy_2[:, :, :, :(data_shape[2] - 1)] = input_data[:, :, :, 1:]
loss_2 = self._mse_loss(input_data, data_copy_1) + self._mse_loss(input_data, data_copy_2)

loss_3 = self._mse_loss(input_data, 0)

@@ -152,7 +165,6 @@ class ImageInversionAttack:
msg = "The shape of target_features ({}) is not in accordance with the shape" \
" of network output({})".format(target_features.shape, test_out.shape)
raise ValueError(msg)

loss_net = self._loss
loss_grad = GradWrapWithLoss(loss_net)

@@ -165,7 +177,7 @@ class ImageInversionAttack:
x_grad_sign = np.sign(x_grad)
inversion_image_n -= x_grad_sign*0.01
inversion_image_n = np.clip(inversion_image_n, self._input_bound[0], self._input_bound[1])
current_loss = self._loss(Tensor(inversion_image_n), Tensor(target_feature_n))
current_loss = loss_net(Tensor(inversion_image_n), Tensor(target_feature_n))
LOGGER.info(TAG, 'iteration step: {}, loss is {}'.format(s, current_loss))
inversion_images.append(inversion_image_n)
return np.concatenate(np.array(inversion_images))


+ 16
- 0
tests/ut/python/privacy/evaluation/test_inversion_attack.py View File

@@ -60,6 +60,22 @@ def test_inversion_attack_pynative():


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_card
@pytest.mark.component_mindarmour
def test_inversion_attack_cpu():
context.set_context(device_target='CPU')
net = Net()
original_images = np.random.random((2, 1, 32, 32)).astype(np.float32)
target_features = np.random.random((2, 10)).astype(np.float32)
inversion_attack = ImageInversionAttack(net, input_shape=(1, 32, 32), input_bound=(0, 1), loss_weights=[1, 0.2, 5])
inversion_images = inversion_attack.generate(target_features, iters=10)
avg_ssim = inversion_attack.evaluate(original_images, inversion_images)
assert 0 < avg_ssim[1] < 1
assert target_features.shape[0] == inversion_images.shape[0]


@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard


Loading…
Cancel
Save