Browse Source

Fix a bug of inversion-attack

tags/v1.6.0
jin-xiulang 4 years ago
parent
commit
264017835f
2 changed files with 23 additions and 11 deletions
  1. +6
    -9
      mindarmour/privacy/evaluation/inversion_attack.py
  2. +17
    -2
      tests/ut/python/privacy/evaluation/test_inversion_attack.py

+ 6
- 9
mindarmour/privacy/evaluation/inversion_attack.py View File

@@ -47,6 +47,7 @@ class InversionLoss(Cell):
self._mse_loss = MSELoss()
self._weights = check_param_multi_types('weights', weights, [list, tuple])
self._get_shape = P.Shape()
self._zeros = P.ZerosLike()

def construct(self, input_data, target_features):
"""
@@ -65,15 +66,11 @@ class InversionLoss(Cell):
loss_1 = self._mse_loss(output, target_features) / self._mse_loss(target_features, 0)

data_shape = self._get_shape(input_data)
split_op_1 = P.Split(2, data_shape[2])
split_op_2 = P.Split(3, data_shape[3])
data_split_1 = split_op_1(input_data)
data_split_2 = split_op_2(input_data)
loss_2 = 0
for i in range(1, data_shape[2]):
loss_2 += self._mse_loss(data_split_1[i], data_split_1[i-1])
for j in range(1, data_shape[3]):
loss_2 += self._mse_loss(data_split_2[j], data_split_2[j-1])
data_copy_1 = self._zeros(input_data)
data_copy_2 = self._zeros(input_data)
data_copy_1[:, :, :(data_shape[2] - 1), :] = input_data[:, :, 1:, :]
data_copy_2[:, :, :, :(data_shape[2] - 1)] = input_data[:, :, :, 1:]
loss_2 = self._mse_loss(input_data, data_copy_1) + self._mse_loss(input_data, data_copy_2)

loss_3 = self._mse_loss(input_data, 0)



+ 17
- 2
tests/ut/python/privacy/evaluation/test_inversion_attack.py View File

@@ -25,7 +25,21 @@ from mindarmour.privacy.evaluation.inversion_attack import ImageInversionAttack
from tests.ut.python.utils.mock_net import Net


context.set_context(mode=context.GRAPH_MODE)
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
@pytest.mark.component_mindarmour
def test_inversion_attack_graph():
context.set_context(mode=context.GRAPH_MODE)
net = Net()
original_images = np.random.random((2, 1, 32, 32)).astype(np.float32)
target_features = np.random.random((2, 10)).astype(np.float32)
inversion_attack = ImageInversionAttack(net, input_shape=(1, 32, 32), input_bound=(0, 1), loss_weights=[1, 0.2, 5])
inversion_images = inversion_attack.generate(target_features, iters=10)
avg_ssim = inversion_attack.evaluate(original_images, inversion_images)
assert 0 < avg_ssim[1] < 1
assert target_features.shape[0] == inversion_images.shape[0]


@pytest.mark.level0
@@ -33,7 +47,8 @@ context.set_context(mode=context.GRAPH_MODE)
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
@pytest.mark.component_mindarmour
def test_inversion_attack():
def test_inversion_attack_pynative():
context.set_context(mode=context.PYNATIVE_MODE)
net = Net()
original_images = np.random.random((2, 1, 32, 32)).astype(np.float32)
target_features = np.random.random((2, 10)).astype(np.float32)


Loading…
Cancel
Save