Browse Source

!393 Optimize imp logic of PGD and JSMA

Merge pull request !393 from pkuliuliu/master
pull/395/MERGE
i-robot Gitee 2 years ago
parent
commit
62c61d70b6
No known key found for this signature in database GPG Key ID: 173E9B9CA92EEF8F
4 changed files with 25 additions and 30 deletions
  1. +0
    -1
      mindarmour/adv_robustness/attacks/gradient_method.py
  2. +23
    -27
      mindarmour/adv_robustness/attacks/iterative_gradient_method.py
  3. +1
    -2
      mindarmour/adv_robustness/attacks/jsma.py
  4. +1
    -0
      mindarmour/adv_robustness/defenses/adversarial_defense.py

+ 0
- 1
mindarmour/adv_robustness/attacks/gradient_method.py View File

@@ -69,7 +69,6 @@ class GradientMethod(Attack):
else:
with_loss_cell = WithLossCell(self._network, loss_fn)
self._grad_all = GradWrapWithLoss(with_loss_cell)
self._grad_all.set_train()

def generate(self, inputs, labels):
"""


+ 23
- 27
mindarmour/adv_robustness/attacks/iterative_gradient_method.py View File

@@ -14,6 +14,7 @@
""" Iterative gradient method attack. """
from abc import abstractmethod

import copy
import numpy as np
from PIL import Image, ImageOps

@@ -68,13 +69,14 @@ def _reshape_l1_projection(values, eps=3):
return proj_x


def _projection(values, eps, norm_level):
def _projection(values, eps, clip_diff, norm_level):
"""
Implementation of values normalization within eps.

Args:
values (numpy.ndarray): Input data.
eps (float): Project radius.
clip_diff (float): Difference range of clip bounds.
norm_level (Union[int, char, numpy.inf]): Order of the norm. Possible
values: np.inf, 1 or 2.

@@ -88,12 +90,12 @@ def _projection(values, eps, norm_level):
if norm_level in (1, '1'):
sample_batch = values.shape[0]
x_flat = values.reshape(sample_batch, -1)
proj_flat = _reshape_l1_projection(x_flat, eps)
proj_flat = _reshape_l1_projection(x_flat, eps*clip_diff)
return proj_flat.reshape(values.shape)
if norm_level in (2, '2'):
return eps*normalize_value(values, norm_level)
if norm_level in (np.inf, 'inf'):
return eps*np.sign(values)
return eps*clip_diff*np.sign(values)
msg = 'Values of `norm_level` different from 1, 2 and `np.inf` are ' \
'currently not supported.'
LOGGER.error(TAG, msg)
@@ -132,7 +134,6 @@ class IterativeGradientMethod(Attack):
self._loss_grad = network
else:
self._loss_grad = GradWrapWithLoss(WithLossCell(self._network, loss_fn))
self._loss_grad.set_train()

@abstractmethod
def generate(self, inputs, labels):
@@ -470,33 +471,28 @@ class ProjectedGradientDescent(BasicIterativeMethod):
"""
inputs_image, inputs, labels = check_inputs_labels(inputs, labels)
arr_x = inputs_image
adv_x = copy.deepcopy(inputs_image)
if self._bounds is not None:
clip_min, clip_max = self._bounds
clip_diff = clip_max - clip_min
for _ in range(self._nb_iter):
adv_x = self._attack.generate(inputs, labels)
perturs = _projection(adv_x - arr_x,
self._eps,
norm_level=self._norm_level)
perturs = np.clip(perturs, (0 - self._eps)*clip_diff,
self._eps*clip_diff)
adv_x = arr_x + perturs
if isinstance(inputs, tuple):
inputs = (adv_x,) + inputs[1:]
else:
inputs = adv_x
else:
for _ in range(self._nb_iter):
adv_x = self._attack.generate(inputs, labels)
perturs = _projection(adv_x - arr_x,
self._eps,
norm_level=self._norm_level)
adv_x = arr_x + perturs
adv_x = np.clip(adv_x, arr_x - self._eps, arr_x + self._eps)
if isinstance(inputs, tuple):
inputs = (adv_x,) + inputs[1:]
else:
inputs = adv_x
clip_diff = 1

for _ in range(self._nb_iter):
inputs_tensor = to_tensor_tuple(inputs)
labels_tensor = to_tensor_tuple(labels)
out_grad = self._loss_grad(*inputs_tensor, *labels_tensor)
gradient = out_grad.asnumpy()
perturbs = _projection(gradient, self._eps_iter, clip_diff, norm_level=self._norm_level)
sum_perturbs = adv_x - arr_x + perturbs
sum_perturbs = np.clip(sum_perturbs, (0 - self._eps)*clip_diff, self._eps*clip_diff)
adv_x = arr_x + sum_perturbs
if self._bounds is not None:
adv_x = np.clip(adv_x, clip_min, clip_max)
if isinstance(inputs, tuple):
inputs = (adv_x,) + inputs[1:]
else:
inputs = adv_x
return adv_x




+ 1
- 2
mindarmour/adv_robustness/attacks/jsma.py View File

@@ -150,7 +150,6 @@ class JSMAAttack(Attack):
ori_shape = data.shape
temp = data.flatten()
bit_map = np.ones_like(temp)
fake_res = np.zeros_like(data)
counter = np.zeros_like(temp)
perturbed = np.copy(temp)
for _ in range(self._max_iter):
@@ -183,7 +182,7 @@ class JSMAAttack(Attack):
bit_map[p2_ind] = 0
perturbed = np.clip(perturbed, self._min, self._max)
LOGGER.debug(TAG, 'fail to find adversarial sample.')
return fake_res
return perturbed.reshape(ori_shape)

def generate(self, inputs, labels):
"""


+ 1
- 0
mindarmour/adv_robustness/defenses/adversarial_defense.py View File

@@ -162,6 +162,7 @@ class AdversarialDefenseWithAttacks(AdversarialDefense):
replace_ratio,
0, 1)
self._graph_initialized = False
self._train_net.set_train()

def defense(self, inputs, labels):
"""


Loading…
Cancel
Save