Browse Source

optimize imp logic of PGD and JSMA

pull/393/head
pkuliuliu 2 years ago
parent
commit
83451c020f
4 changed files with 25 additions and 30 deletions
  1. +0
    -1
      mindarmour/adv_robustness/attacks/gradient_method.py
  2. +23
    -27
      mindarmour/adv_robustness/attacks/iterative_gradient_method.py
  3. +1
    -2
      mindarmour/adv_robustness/attacks/jsma.py
  4. +1
    -0
      mindarmour/adv_robustness/defenses/adversarial_defense.py

+ 0
- 1
mindarmour/adv_robustness/attacks/gradient_method.py View File

@@ -69,7 +69,6 @@ class GradientMethod(Attack):
else: else:
with_loss_cell = WithLossCell(self._network, loss_fn) with_loss_cell = WithLossCell(self._network, loss_fn)
self._grad_all = GradWrapWithLoss(with_loss_cell) self._grad_all = GradWrapWithLoss(with_loss_cell)
self._grad_all.set_train()


def generate(self, inputs, labels): def generate(self, inputs, labels):
""" """


+ 23
- 27
mindarmour/adv_robustness/attacks/iterative_gradient_method.py View File

@@ -14,6 +14,7 @@
""" Iterative gradient method attack. """ """ Iterative gradient method attack. """
from abc import abstractmethod from abc import abstractmethod


import copy
import numpy as np import numpy as np
from PIL import Image, ImageOps from PIL import Image, ImageOps


@@ -68,13 +69,14 @@ def _reshape_l1_projection(values, eps=3):
return proj_x return proj_x




def _projection(values, eps, norm_level):
def _projection(values, eps, clip_diff, norm_level):
""" """
Implementation of values normalization within eps. Implementation of values normalization within eps.


Args: Args:
values (numpy.ndarray): Input data. values (numpy.ndarray): Input data.
eps (float): Project radius. eps (float): Project radius.
clip_diff (float): Difference range of clip bounds.
norm_level (Union[int, char, numpy.inf]): Order of the norm. Possible norm_level (Union[int, char, numpy.inf]): Order of the norm. Possible
values: np.inf, 1 or 2. values: np.inf, 1 or 2.


@@ -88,12 +90,12 @@ def _projection(values, eps, norm_level):
if norm_level in (1, '1'): if norm_level in (1, '1'):
sample_batch = values.shape[0] sample_batch = values.shape[0]
x_flat = values.reshape(sample_batch, -1) x_flat = values.reshape(sample_batch, -1)
proj_flat = _reshape_l1_projection(x_flat, eps)
proj_flat = _reshape_l1_projection(x_flat, eps*clip_diff)
return proj_flat.reshape(values.shape) return proj_flat.reshape(values.shape)
if norm_level in (2, '2'): if norm_level in (2, '2'):
return eps*normalize_value(values, norm_level) return eps*normalize_value(values, norm_level)
if norm_level in (np.inf, 'inf'): if norm_level in (np.inf, 'inf'):
return eps*np.sign(values)
return eps*clip_diff*np.sign(values)
msg = 'Values of `norm_level` different from 1, 2 and `np.inf` are ' \ msg = 'Values of `norm_level` different from 1, 2 and `np.inf` are ' \
'currently not supported.' 'currently not supported.'
LOGGER.error(TAG, msg) LOGGER.error(TAG, msg)
@@ -132,7 +134,6 @@ class IterativeGradientMethod(Attack):
self._loss_grad = network self._loss_grad = network
else: else:
self._loss_grad = GradWrapWithLoss(WithLossCell(self._network, loss_fn)) self._loss_grad = GradWrapWithLoss(WithLossCell(self._network, loss_fn))
self._loss_grad.set_train()


@abstractmethod @abstractmethod
def generate(self, inputs, labels): def generate(self, inputs, labels):
@@ -470,33 +471,28 @@ class ProjectedGradientDescent(BasicIterativeMethod):
""" """
inputs_image, inputs, labels = check_inputs_labels(inputs, labels) inputs_image, inputs, labels = check_inputs_labels(inputs, labels)
arr_x = inputs_image arr_x = inputs_image
adv_x = copy.deepcopy(inputs_image)
if self._bounds is not None: if self._bounds is not None:
clip_min, clip_max = self._bounds clip_min, clip_max = self._bounds
clip_diff = clip_max - clip_min clip_diff = clip_max - clip_min
for _ in range(self._nb_iter):
adv_x = self._attack.generate(inputs, labels)
perturs = _projection(adv_x - arr_x,
self._eps,
norm_level=self._norm_level)
perturs = np.clip(perturs, (0 - self._eps)*clip_diff,
self._eps*clip_diff)
adv_x = arr_x + perturs
if isinstance(inputs, tuple):
inputs = (adv_x,) + inputs[1:]
else:
inputs = adv_x
else: else:
for _ in range(self._nb_iter):
adv_x = self._attack.generate(inputs, labels)
perturs = _projection(adv_x - arr_x,
self._eps,
norm_level=self._norm_level)
adv_x = arr_x + perturs
adv_x = np.clip(adv_x, arr_x - self._eps, arr_x + self._eps)
if isinstance(inputs, tuple):
inputs = (adv_x,) + inputs[1:]
else:
inputs = adv_x
clip_diff = 1

for _ in range(self._nb_iter):
inputs_tensor = to_tensor_tuple(inputs)
labels_tensor = to_tensor_tuple(labels)
out_grad = self._loss_grad(*inputs_tensor, *labels_tensor)
gradient = out_grad.asnumpy()
perturbs = _projection(gradient, self._eps_iter, clip_diff, norm_level=self._norm_level)
sum_perturbs = adv_x - arr_x + perturbs
sum_perturbs = np.clip(sum_perturbs, (0 - self._eps)*clip_diff, self._eps*clip_diff)
adv_x = arr_x + sum_perturbs
if self._bounds is not None:
adv_x = np.clip(adv_x, clip_min, clip_max)
if isinstance(inputs, tuple):
inputs = (adv_x,) + inputs[1:]
else:
inputs = adv_x
return adv_x return adv_x






+ 1
- 2
mindarmour/adv_robustness/attacks/jsma.py View File

@@ -150,7 +150,6 @@ class JSMAAttack(Attack):
ori_shape = data.shape ori_shape = data.shape
temp = data.flatten() temp = data.flatten()
bit_map = np.ones_like(temp) bit_map = np.ones_like(temp)
fake_res = np.zeros_like(data)
counter = np.zeros_like(temp) counter = np.zeros_like(temp)
perturbed = np.copy(temp) perturbed = np.copy(temp)
for _ in range(self._max_iter): for _ in range(self._max_iter):
@@ -183,7 +182,7 @@ class JSMAAttack(Attack):
bit_map[p2_ind] = 0 bit_map[p2_ind] = 0
perturbed = np.clip(perturbed, self._min, self._max) perturbed = np.clip(perturbed, self._min, self._max)
LOGGER.debug(TAG, 'fail to find adversarial sample.') LOGGER.debug(TAG, 'fail to find adversarial sample.')
return fake_res
return perturbed.reshape(ori_shape)


def generate(self, inputs, labels): def generate(self, inputs, labels):
""" """


+ 1
- 0
mindarmour/adv_robustness/defenses/adversarial_defense.py View File

@@ -162,6 +162,7 @@ class AdversarialDefenseWithAttacks(AdversarialDefense):
replace_ratio, replace_ratio,
0, 1) 0, 1)
self._graph_initialized = False self._graph_initialized = False
self._train_net.set_train()


def defense(self, inputs, labels): def defense(self, inputs, labels):
""" """


Loading…
Cancel
Save