From 83451c020f1cc70f54c14d160c8624677a1033b7 Mon Sep 17 00:00:00 2001 From: pkuliuliu Date: Mon, 8 Aug 2022 10:20:17 +0800 Subject: [PATCH] optimize imp logic of PGD and JSMA --- .../adv_robustness/attacks/gradient_method.py | 1 - .../attacks/iterative_gradient_method.py | 50 ++++++++++------------ mindarmour/adv_robustness/attacks/jsma.py | 3 +- .../adv_robustness/defenses/adversarial_defense.py | 1 + 4 files changed, 25 insertions(+), 30 deletions(-) diff --git a/mindarmour/adv_robustness/attacks/gradient_method.py b/mindarmour/adv_robustness/attacks/gradient_method.py index 9a154c3..1ca002c 100644 --- a/mindarmour/adv_robustness/attacks/gradient_method.py +++ b/mindarmour/adv_robustness/attacks/gradient_method.py @@ -69,7 +69,6 @@ class GradientMethod(Attack): else: with_loss_cell = WithLossCell(self._network, loss_fn) self._grad_all = GradWrapWithLoss(with_loss_cell) - self._grad_all.set_train() def generate(self, inputs, labels): """ diff --git a/mindarmour/adv_robustness/attacks/iterative_gradient_method.py b/mindarmour/adv_robustness/attacks/iterative_gradient_method.py index 4ec509c..99e45b8 100644 --- a/mindarmour/adv_robustness/attacks/iterative_gradient_method.py +++ b/mindarmour/adv_robustness/attacks/iterative_gradient_method.py @@ -14,6 +14,7 @@ """ Iterative gradient method attack. """ from abc import abstractmethod +import copy import numpy as np from PIL import Image, ImageOps @@ -68,13 +69,14 @@ def _reshape_l1_projection(values, eps=3): return proj_x -def _projection(values, eps, norm_level): +def _projection(values, eps, clip_diff, norm_level): """ Implementation of values normalization within eps. Args: values (numpy.ndarray): Input data. eps (float): Project radius. + clip_diff (float): Difference range of clip bounds. norm_level (Union[int, char, numpy.inf]): Order of the norm. Possible values: np.inf, 1 or 2. @@ -88,12 +90,12 @@ def _projection(values, eps, norm_level): if norm_level in (1, '1'): sample_batch = values.shape[0] x_flat = values.reshape(sample_batch, -1) - proj_flat = _reshape_l1_projection(x_flat, eps) + proj_flat = _reshape_l1_projection(x_flat, eps*clip_diff) return proj_flat.reshape(values.shape) if norm_level in (2, '2'): return eps*normalize_value(values, norm_level) if norm_level in (np.inf, 'inf'): - return eps*np.sign(values) + return eps*clip_diff*np.sign(values) msg = 'Values of `norm_level` different from 1, 2 and `np.inf` are ' \ 'currently not supported.' LOGGER.error(TAG, msg) @@ -132,7 +134,6 @@ class IterativeGradientMethod(Attack): self._loss_grad = network else: self._loss_grad = GradWrapWithLoss(WithLossCell(self._network, loss_fn)) - self._loss_grad.set_train() @abstractmethod def generate(self, inputs, labels): @@ -470,33 +471,28 @@ class ProjectedGradientDescent(BasicIterativeMethod): """ inputs_image, inputs, labels = check_inputs_labels(inputs, labels) arr_x = inputs_image + adv_x = copy.deepcopy(inputs_image) if self._bounds is not None: clip_min, clip_max = self._bounds clip_diff = clip_max - clip_min - for _ in range(self._nb_iter): - adv_x = self._attack.generate(inputs, labels) - perturs = _projection(adv_x - arr_x, - self._eps, - norm_level=self._norm_level) - perturs = np.clip(perturs, (0 - self._eps)*clip_diff, - self._eps*clip_diff) - adv_x = arr_x + perturs - if isinstance(inputs, tuple): - inputs = (adv_x,) + inputs[1:] - else: - inputs = adv_x else: - for _ in range(self._nb_iter): - adv_x = self._attack.generate(inputs, labels) - perturs = _projection(adv_x - arr_x, - self._eps, - norm_level=self._norm_level) - adv_x = arr_x + perturs - adv_x = np.clip(adv_x, arr_x - self._eps, arr_x + self._eps) - if isinstance(inputs, tuple): - inputs = (adv_x,) + inputs[1:] - else: - inputs = adv_x + clip_diff = 1 + + for _ in range(self._nb_iter): + inputs_tensor = to_tensor_tuple(inputs) + labels_tensor = to_tensor_tuple(labels) + out_grad = self._loss_grad(*inputs_tensor, *labels_tensor) + gradient = out_grad.asnumpy() + perturbs = _projection(gradient, self._eps_iter, clip_diff, norm_level=self._norm_level) + sum_perturbs = adv_x - arr_x + perturbs + sum_perturbs = np.clip(sum_perturbs, (0 - self._eps)*clip_diff, self._eps*clip_diff) + adv_x = arr_x + sum_perturbs + if self._bounds is not None: + adv_x = np.clip(adv_x, clip_min, clip_max) + if isinstance(inputs, tuple): + inputs = (adv_x,) + inputs[1:] + else: + inputs = adv_x return adv_x diff --git a/mindarmour/adv_robustness/attacks/jsma.py b/mindarmour/adv_robustness/attacks/jsma.py index 629d77c..5b50c3a 100644 --- a/mindarmour/adv_robustness/attacks/jsma.py +++ b/mindarmour/adv_robustness/attacks/jsma.py @@ -150,7 +150,6 @@ class JSMAAttack(Attack): ori_shape = data.shape temp = data.flatten() bit_map = np.ones_like(temp) - fake_res = np.zeros_like(data) counter = np.zeros_like(temp) perturbed = np.copy(temp) for _ in range(self._max_iter): @@ -183,7 +182,7 @@ class JSMAAttack(Attack): bit_map[p2_ind] = 0 perturbed = np.clip(perturbed, self._min, self._max) LOGGER.debug(TAG, 'fail to find adversarial sample.') - return fake_res + return perturbed.reshape(ori_shape) def generate(self, inputs, labels): """ diff --git a/mindarmour/adv_robustness/defenses/adversarial_defense.py b/mindarmour/adv_robustness/defenses/adversarial_defense.py index 520baeb..71c889c 100644 --- a/mindarmour/adv_robustness/defenses/adversarial_defense.py +++ b/mindarmour/adv_robustness/defenses/adversarial_defense.py @@ -162,6 +162,7 @@ class AdversarialDefenseWithAttacks(AdversarialDefense): replace_ratio, 0, 1) self._graph_initialized = False + self._train_net.set_train() def defense(self, inputs, labels): """