From 264017835f9a7a5c2d3228aff1ba705c7f0e24c4 Mon Sep 17 00:00:00 2001 From: jin-xiulang Date: Thu, 8 Apr 2021 09:10:18 +0800 Subject: [PATCH] Fix a bug of inversion-attack --- mindarmour/privacy/evaluation/inversion_attack.py | 15 ++++++--------- .../privacy/evaluation/test_inversion_attack.py | 19 +++++++++++++++++-- 2 files changed, 23 insertions(+), 11 deletions(-) diff --git a/mindarmour/privacy/evaluation/inversion_attack.py b/mindarmour/privacy/evaluation/inversion_attack.py index 4fd959e..9d38f13 100644 --- a/mindarmour/privacy/evaluation/inversion_attack.py +++ b/mindarmour/privacy/evaluation/inversion_attack.py @@ -47,6 +47,7 @@ class InversionLoss(Cell): self._mse_loss = MSELoss() self._weights = check_param_multi_types('weights', weights, [list, tuple]) self._get_shape = P.Shape() + self._zeros = P.ZerosLike() def construct(self, input_data, target_features): """ @@ -65,15 +66,11 @@ class InversionLoss(Cell): loss_1 = self._mse_loss(output, target_features) / self._mse_loss(target_features, 0) data_shape = self._get_shape(input_data) - split_op_1 = P.Split(2, data_shape[2]) - split_op_2 = P.Split(3, data_shape[3]) - data_split_1 = split_op_1(input_data) - data_split_2 = split_op_2(input_data) - loss_2 = 0 - for i in range(1, data_shape[2]): - loss_2 += self._mse_loss(data_split_1[i], data_split_1[i-1]) - for j in range(1, data_shape[3]): - loss_2 += self._mse_loss(data_split_2[j], data_split_2[j-1]) + data_copy_1 = self._zeros(input_data) + data_copy_2 = self._zeros(input_data) + data_copy_1[:, :, :(data_shape[2] - 1), :] = input_data[:, :, 1:, :] + data_copy_2[:, :, :, :(data_shape[2] - 1)] = input_data[:, :, :, 1:] + loss_2 = self._mse_loss(input_data, data_copy_1) + self._mse_loss(input_data, data_copy_2) loss_3 = self._mse_loss(input_data, 0) diff --git a/tests/ut/python/privacy/evaluation/test_inversion_attack.py b/tests/ut/python/privacy/evaluation/test_inversion_attack.py index 23651f6..8727263 100644 --- a/tests/ut/python/privacy/evaluation/test_inversion_attack.py +++ b/tests/ut/python/privacy/evaluation/test_inversion_attack.py @@ -25,7 +25,21 @@ from mindarmour.privacy.evaluation.inversion_attack import ImageInversionAttack from tests.ut.python.utils.mock_net import Net -context.set_context(mode=context.GRAPH_MODE) +@pytest.mark.level0 +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.env_onecard +@pytest.mark.component_mindarmour +def test_inversion_attack_graph(): + context.set_context(mode=context.GRAPH_MODE) + net = Net() + original_images = np.random.random((2, 1, 32, 32)).astype(np.float32) + target_features = np.random.random((2, 10)).astype(np.float32) + inversion_attack = ImageInversionAttack(net, input_shape=(1, 32, 32), input_bound=(0, 1), loss_weights=[1, 0.2, 5]) + inversion_images = inversion_attack.generate(target_features, iters=10) + avg_ssim = inversion_attack.evaluate(original_images, inversion_images) + assert 0 < avg_ssim[1] < 1 + assert target_features.shape[0] == inversion_images.shape[0] @pytest.mark.level0 @@ -33,7 +47,8 @@ context.set_context(mode=context.GRAPH_MODE) @pytest.mark.platform_arm_ascend_training @pytest.mark.env_onecard @pytest.mark.component_mindarmour -def test_inversion_attack(): +def test_inversion_attack_pynative(): + context.set_context(mode=context.PYNATIVE_MODE) net = Net() original_images = np.random.random((2, 1, 32, 32)).astype(np.float32) target_features = np.random.random((2, 10)).astype(np.float32)