diff --git a/mindarmour/privacy/evaluation/inversion_attack.py b/mindarmour/privacy/evaluation/inversion_attack.py index 9d38f13..fe06abb 100644 --- a/mindarmour/privacy/evaluation/inversion_attack.py +++ b/mindarmour/privacy/evaluation/inversion_attack.py @@ -20,6 +20,7 @@ from scipy.special import softmax from mindspore.nn import Cell, MSELoss from mindspore import Tensor from mindspore.ops import operations as P +from mindspore import context from mindarmour.utils.util import GradWrapWithLoss from mindarmour.utils._check_param import check_param_type, check_param_multi_types, \ @@ -48,6 +49,7 @@ class InversionLoss(Cell): self._weights = check_param_multi_types('weights', weights, [list, tuple]) self._get_shape = P.Shape() self._zeros = P.ZerosLike() + self._device_target = context.get_context("device_target") def construct(self, input_data, target_features): """ @@ -66,11 +68,22 @@ class InversionLoss(Cell): loss_1 = self._mse_loss(output, target_features) / self._mse_loss(target_features, 0) data_shape = self._get_shape(input_data) - data_copy_1 = self._zeros(input_data) - data_copy_2 = self._zeros(input_data) - data_copy_1[:, :, :(data_shape[2] - 1), :] = input_data[:, :, 1:, :] - data_copy_2[:, :, :, :(data_shape[2] - 1)] = input_data[:, :, :, 1:] - loss_2 = self._mse_loss(input_data, data_copy_1) + self._mse_loss(input_data, data_copy_2) + if self._device_target == 'CPU': + split_op_1 = P.Split(2, data_shape[2]) + split_op_2 = P.Split(3, data_shape[3]) + data_split_1 = split_op_1(input_data) + data_split_2 = split_op_2(input_data) + loss_2 = 0 + for i in range(1, data_shape[2]): + loss_2 += self._mse_loss(data_split_1[i], data_split_1[i - 1]) + for j in range(1, data_shape[3]): + loss_2 += self._mse_loss(data_split_2[j], data_split_2[j - 1]) + else: + data_copy_1 = self._zeros(input_data) + data_copy_2 = self._zeros(input_data) + data_copy_1[:, :, :(data_shape[2] - 1), :] = input_data[:, :, 1:, :] + data_copy_2[:, :, :, :(data_shape[2] - 1)] = input_data[:, :, :, 1:] + loss_2 = self._mse_loss(input_data, data_copy_1) + self._mse_loss(input_data, data_copy_2) loss_3 = self._mse_loss(input_data, 0) @@ -152,7 +165,6 @@ class ImageInversionAttack: msg = "The shape of target_features ({}) is not in accordance with the shape" \ " of network output({})".format(target_features.shape, test_out.shape) raise ValueError(msg) - loss_net = self._loss loss_grad = GradWrapWithLoss(loss_net) @@ -165,7 +177,7 @@ class ImageInversionAttack: x_grad_sign = np.sign(x_grad) inversion_image_n -= x_grad_sign*0.01 inversion_image_n = np.clip(inversion_image_n, self._input_bound[0], self._input_bound[1]) - current_loss = self._loss(Tensor(inversion_image_n), Tensor(target_feature_n)) + current_loss = loss_net(Tensor(inversion_image_n), Tensor(target_feature_n)) LOGGER.info(TAG, 'iteration step: {}, loss is {}'.format(s, current_loss)) inversion_images.append(inversion_image_n) return np.concatenate(np.array(inversion_images)) diff --git a/tests/ut/python/privacy/evaluation/test_inversion_attack.py b/tests/ut/python/privacy/evaluation/test_inversion_attack.py index 8727263..a1f7cf7 100644 --- a/tests/ut/python/privacy/evaluation/test_inversion_attack.py +++ b/tests/ut/python/privacy/evaluation/test_inversion_attack.py @@ -60,6 +60,22 @@ def test_inversion_attack_pynative(): @pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_card +@pytest.mark.component_mindarmour +def test_inversion_attack_cpu(): + context.set_context(device_target='CPU') + net = Net() + original_images = np.random.random((2, 1, 32, 32)).astype(np.float32) + target_features = np.random.random((2, 10)).astype(np.float32) + inversion_attack = ImageInversionAttack(net, input_shape=(1, 32, 32), input_bound=(0, 1), loss_weights=[1, 0.2, 5]) + inversion_images = inversion_attack.generate(target_features, iters=10) + avg_ssim = inversion_attack.evaluate(original_images, inversion_images) + assert 0 < avg_ssim[1] < 1 + assert target_features.shape[0] == inversion_images.shape[0] + + +@pytest.mark.level0 @pytest.mark.platform_x86_ascend_training @pytest.mark.platform_arm_ascend_training @pytest.mark.env_onecard