Browse Source

Add API Doc examples for attack classes

tags/v1.6.0
shu-kun-zhang 3 years ago
parent
commit
7aa82a43de
13 changed files with 952 additions and 255 deletions
  1. +281
    -130
      mindarmour/adv_robustness/attacks/black/genetic_attack.py
  2. +27
    -6
      mindarmour/adv_robustness/attacks/black/hop_skip_jump_attack.py
  3. +34
    -12
      mindarmour/adv_robustness/attacks/black/natural_evolutionary_strategy.py
  4. +22
    -1
      mindarmour/adv_robustness/attacks/black/pointwise_attack.py
  5. +276
    -76
      mindarmour/adv_robustness/attacks/black/pso_attack.py
  6. +22
    -1
      mindarmour/adv_robustness/attacks/black/salt_and_pepper_attack.py
  7. +51
    -7
      mindarmour/adv_robustness/attacks/carlini_wagner.py
  8. +35
    -10
      mindarmour/adv_robustness/attacks/deep_fool.py
  9. +111
    -5
      mindarmour/adv_robustness/attacks/gradient_method.py
  10. +59
    -1
      mindarmour/adv_robustness/attacks/iterative_gradient_method.py
  11. +24
    -2
      mindarmour/adv_robustness/attacks/jsma.py
  12. +8
    -2
      mindarmour/adv_robustness/attacks/lbfgs.py
  13. +2
    -2
      tests/ut/python/adv_robustness/attacks/black/test_nes.py

+ 281
- 130
mindarmour/adv_robustness/attacks/black/genetic_attack.py View File

@@ -45,7 +45,7 @@ class GeneticAttack(Attack):
default: 'classification'.
targeted (bool): If True, turns on the targeted attack. If False,
turns on untargeted attack. It should be noted that only untargeted attack
is supproted for model_type='detection', Default: True.
is supported for model_type='detection', Default: True.
reserve_ratio (Union[int, float]): The percentage of objects that can be detected after attacks,
specifically for model_type='detection'. Reserve_ratio should be in the range of (0, 1). Default: 0.3.
pop_size (int): The number of particles, which should be greater than
@@ -69,7 +69,33 @@ class GeneticAttack(Attack):
c (Union[int, float]): Weight of perturbation loss. Default: 0.1.

Examples:
>>> attack = GeneticAttack(model)
>>> import numpy as np
>>> import mindspore.ops.operations as M
>>> from mindspore import Tensor
>>> from mindspore.nn import Cell
>>> from mindarmour import BlackModel
>>> from mindarmour.adv_robustness.attacks import GeneticAttack
>>>
>>> class ModelToBeAttacked(BlackModel):
>>> def __init__(self, network):
>>> super(ModelToBeAttacked, self).__init__()
>>> self._network = network
>>> def predict(self, inputs):
>>> result = self._network(Tensor(inputs.astype(np.float32)))
>>> return result.asnumpy()
>>>
>>> class Net(Cell):
>>> def __init__(self):
>>> super(Net, self).__init__()
>>> self._softmax = M.Softmax()
>>>
>>> def construct(self, inputs):
>>> out = self._softmax(inputs)
>>> return out
>>>
>>> net = Net()
>>> model = ModelToBeAttacked(net)
>>> attack = GeneticAttack(model, sparse=False)
"""
def __init__(self, model, model_type='classification', targeted=True, reserve_ratio=0.3, sparse=True,
pop_size=6, mutation_rate=0.005, per_bounds=0.15, max_steps=1000, step_size=0.20, temp=0.3,
@@ -135,18 +161,77 @@ class GeneticAttack(Attack):
np.random.random(cur_pop.shape) < prob) + cur_pop
return mutated_pop

def generate(self, inputs, labels):

def _compute_next_generation(self, cur_pop, fit_vals, x_ori):
"""
Generate adversarial examples based on input data and targeted
labels (or ground_truth labels).
Compute pop for next generation

Args:
inputs (Union[numpy.ndarray, tuple]): Input samples. The format of inputs should be numpy.ndarray if
model_type='classification'. The format of inputs can be (input1, input2, ...) or only one array if
model_type='detection'.
labels (Union[numpy.ndarray, tuple]): Targeted labels or ground-truth labels. The format of labels should
be numpy.ndarray if model_type='classification'. The format of labels should be (gt_boxes, gt_labels)
if model_type='detection'.
cur_pop (numpy.ndarray): Samples before mutation.
fit_vals (numpy.ndarray): fitness values
x_ori (numpy.ndarray): original input x

Returns:
numpy.ndarray, pop after generation

Examples:
>>> cur_pop, elite = self._compute_next_generation(cur_pop, fit_vals, x_ori)
"""
best_fit = max(fit_vals)

if best_fit > self._best_fit:
self._best_fit = best_fit
self._plateau_times = 0
else:
self._plateau_times += 1
adap_threshold = (lambda z: 100 if z > -0.4 else 300)(best_fit)
if self._plateau_times > adap_threshold:
self._adap_times += 1
self._plateau_times = 0
if self._adaptive:
step_noise = max(self._step_size, 0.4*(0.9**self._adap_times))
step_p = max(self._mutation_rate, 0.5*(0.9**self._adap_times))
else:
step_noise = self._step_size
step_p = self._mutation_rate
step_temp = self._temp
elite = cur_pop[np.argmax(fit_vals)]
select_probs = softmax(fit_vals/step_temp)
select_args = np.arange(self._pop_size)
parents_arg = np.random.choice(
a=select_args, size=2*(self._pop_size - 1),
replace=True, p=select_probs)
parent1 = cur_pop[parents_arg[:self._pop_size - 1]]
parent2 = cur_pop[parents_arg[self._pop_size - 1:]]
parent1_probs = select_probs[parents_arg[:self._pop_size - 1]]
parent2_probs = select_probs[parents_arg[self._pop_size - 1:]]
parent2_probs = parent2_probs / (parent1_probs + parent2_probs)
# duplicate the probabilities to all features of each particle.
dims = len(x_ori.shape)
for _ in range(dims):
parent2_probs = parent2_probs[:, np.newaxis]
parent2_probs = np.tile(parent2_probs, ((1,) + x_ori.shape))
cross_probs = (np.random.random(parent1.shape) >
parent2_probs).astype(np.int32)
children = parent1*cross_probs + parent2*(1 - cross_probs)
mutated_children = self._mutation(
children, step_noise=self._per_bounds*step_noise,
prob=step_p)
cur_pop = np.concatenate((mutated_children, elite[np.newaxis, :]))

return cur_pop, elite



def _generate_classification(self, inputs, labels):
"""
Generate adversarial examples based on input data and
targeted labels (or ground_truth labels) for classification model.

Args:
inputs (Union[numpy.ndarray, tuple]): Input samples. The format of inputs should be numpy.ndarray.
labels (Union[numpy.ndarray, tuple]): Targeted labels or ground-truth labels.
The format of labels should be numpy.ndarray.

Returns:
- numpy.ndarray, bool values for each attack result.
@@ -156,28 +241,27 @@ class GeneticAttack(Attack):
- numpy.ndarray, query times for each sample.

Examples:
>>> advs = attack.generate([[0.2, 0.3, 0.4],
>>> [0.3, 0.3, 0.2]],
>>> [1, 2])
>>> batch_size = 6
>>> x_test = np.random.rand(batch_size, 10)
>>> y_test = np.random.randint(low=0, high=10, size=batch_size)
>>> y_test = np.eye(10)[y_test]
>>> y_test = y_test.astype(np.float32)
>>> _, adv_data, _ = attack._generate_classification(x_test, y_test)
"""
if self._model_type == 'classification':
inputs, labels = check_pair_numpy_param('inputs', inputs,
'labels', labels)
if self._sparse:
if labels.size > 1:
label_squ = np.squeeze(labels)
else:
label_squ = labels
if len(label_squ.shape) >= 2 or label_squ.shape[0] != inputs.shape[0]:
msg = "The parameter 'sparse' of GeneticAttack is True, but the input labels is not sparse style " \
"and got its shape as {}.".format(labels.shape)
LOGGER.error(TAG, msg)
raise ValueError(msg)
inputs, labels = check_pair_numpy_param('inputs', inputs, 'labels', labels)
if self._sparse:
if labels.size > 1:
label_squ = np.squeeze(labels)
else:
labels = np.argmax(labels, axis=1)
images = inputs
elif self._model_type == 'detection':
images, auxiliary_inputs, gt_boxes, gt_labels = check_detection_inputs(inputs, labels)
label_squ = labels
if len(label_squ.shape) >= 2 or label_squ.shape[0] != inputs.shape[0]:
msg = "The parameter 'sparse' of GeneticAttack is True, but the input labels is not sparse style " \
"and got its shape as {}.".format(labels.shape)
LOGGER.error(TAG, msg)
raise ValueError(msg)
else:
labels = np.argmax(labels, axis=1)
images = inputs

adv_list = []
success_list = []
@@ -188,17 +272,7 @@ class GeneticAttack(Attack):
if not self._bounds:
self._bounds = [np.min(x_ori), np.max(x_ori)]
pixel_deep = self._bounds[1] - self._bounds[0]

if self._model_type == 'classification':
label_i = labels[i]
elif self._model_type == 'detection':
auxiliary_input_i = tuple()
for item in auxiliary_inputs:
auxiliary_input_i += (np.expand_dims(item[i], axis=0),)
gt_boxes_i, gt_labels_i = np.expand_dims(gt_boxes[i], axis=0), np.expand_dims(gt_labels[i], axis=0)
inputs_i = (images[i],) + auxiliary_input_i
confi_ori, gt_object_num = self._detection_scores(inputs_i, gt_boxes_i, gt_labels_i, model=self._model)
LOGGER.info(TAG, 'The number of ground-truth objects is %s', gt_object_num[0])
label_i = labels[i]

# generate particles
ori_copies = np.repeat(x_ori[np.newaxis, :], self._pop_size, axis=0)
@@ -215,106 +289,148 @@ class GeneticAttack(Attack):
ori_copies + pixel_deep*self._per_bounds),
self._bounds[0], self._bounds[1])

if self._model_type == 'classification':
pop_preds = self._model.predict(cur_pop)
query_times += cur_pop.shape[0]
all_preds = np.argmax(pop_preds, axis=1)
if self._targeted:
success_pop = np.equal(label_i, all_preds).astype(np.int32)
else:
success_pop = np.not_equal(label_i, all_preds).astype(np.int32)
is_success = max(success_pop)
best_idx = np.argmax(success_pop)
target_preds = pop_preds[:, label_i]
others_preds_sum = np.sum(pop_preds, axis=1) - target_preds
if self._targeted:
fit_vals = target_preds - others_preds_sum
else:
fit_vals = others_preds_sum - target_preds

elif self._model_type == 'detection':
confi_adv, correct_nums_adv = self._detection_scores(
(cur_pop,) + auxiliary_input_i, gt_boxes_i, gt_labels_i, model=self._model)
LOGGER.info(TAG, 'The number of correctly detected objects in adversarial image is %s',
np.min(correct_nums_adv))
query_times += self._pop_size
fit_vals = abs(
confi_ori - confi_adv) - self._c / self._pop_size * np.linalg.norm(
(cur_pop - x_ori).reshape(cur_pop.shape[0], -1), axis=1)

if np.max(fit_vals) < 0:
self._c /= 2

if np.max(fit_vals) < -2:
LOGGER.debug(TAG,
'best fitness value is %s, which is too small. We recommend that you decrease '
'the value of the initialization parameter c.', np.max(fit_vals))
if iters < 3 and np.max(fit_vals) > 100:
LOGGER.debug(TAG,
'best fitness value is %s, which is too large. We recommend that you increase '
'the value of the initialization parameter c.', np.max(fit_vals))

if np.min(correct_nums_adv) <= int(gt_object_num*self._reserve_ratio):
is_success = True
best_idx = np.argmin(correct_nums_adv)
pop_preds = self._model.predict(cur_pop)
query_times += cur_pop.shape[0]
all_preds = np.argmax(pop_preds, axis=1)
if self._targeted:
success_pop = np.equal(label_i, all_preds).astype(np.int32)
else:
success_pop = np.not_equal(label_i, all_preds).astype(np.int32)
is_success = max(success_pop)
best_idx = np.argmax(success_pop)
target_preds = pop_preds[:, label_i]
others_preds_sum = np.sum(pop_preds, axis=1) - target_preds
if self._targeted:
fit_vals = target_preds - others_preds_sum
else:
fit_vals = others_preds_sum - target_preds

if is_success:
LOGGER.debug(TAG, 'successfully find one adversarial sample '
'and start Reduction process.')
final_adv = cur_pop[best_idx]
if self._model_type == 'classification':
final_adv, query_times = self._reduction(x_ori, query_times, label_i, final_adv,
model=self._model, targeted_attack=self._targeted)

final_adv, query_times = self._reduction(x_ori, query_times, label_i, final_adv,
model=self._model, targeted_attack=self._targeted)
break

best_fit = max(fit_vals)
cur_pop, elite = self._compute_next_generation(cur_pop, fit_vals, x_ori)

if best_fit > self._best_fit:
self._best_fit = best_fit
self._plateau_times = 0
else:
self._plateau_times += 1
adap_threshold = (lambda z: 100 if z > -0.4 else 300)(best_fit)
if self._plateau_times > adap_threshold:
self._adap_times += 1
self._plateau_times = 0
if self._adaptive:
step_noise = max(self._step_size, 0.4*(0.9**self._adap_times))
step_p = max(self._mutation_rate, 0.5*(0.9**self._adap_times))
else:
step_noise = self._step_size
step_p = self._mutation_rate
step_temp = self._temp
elite = cur_pop[np.argmax(fit_vals)]
select_probs = softmax(fit_vals/step_temp)
select_args = np.arange(self._pop_size)
parents_arg = np.random.choice(
a=select_args, size=2*(self._pop_size - 1),
replace=True, p=select_probs)
parent1 = cur_pop[parents_arg[:self._pop_size - 1]]
parent2 = cur_pop[parents_arg[self._pop_size - 1:]]
parent1_probs = select_probs[parents_arg[:self._pop_size - 1]]
parent2_probs = select_probs[parents_arg[self._pop_size - 1:]]
parent2_probs = parent2_probs / (parent1_probs + parent2_probs)
# duplicate the probabilities to all features of each particle.
dims = len(x_ori.shape)
for _ in range(dims):
parent2_probs = parent2_probs[:, np.newaxis]
parent2_probs = np.tile(parent2_probs, ((1,) + x_ori.shape))
cross_probs = (np.random.random(parent1.shape) >
parent2_probs).astype(np.int32)
childs = parent1*cross_probs + parent2*(1 - cross_probs)
mutated_childs = self._mutation(
childs, step_noise=self._per_bounds*step_noise,
prob=step_p)
cur_pop = np.concatenate((mutated_childs, elite[np.newaxis, :]))
if not is_success:
LOGGER.debug(TAG, 'fail to find adversarial sample.')
final_adv = elite
adv_list.append(final_adv)

LOGGER.debug(TAG,
'iteration times is: %d and query times is: %d',
iters,
query_times)
success_list.append(is_success)
query_times_list.append(query_times)
del ori_copies, cur_pert, cur_pop
return np.asarray(success_list), \
np.asarray(adv_list), \
np.asarray(query_times_list)



def _generate_detection(self, inputs, labels):
"""
Generate adversarial examples based on input data and
targeted labels (or ground_truth labels) for detection model.

Args:
inputs (Union[numpy.ndarray, tuple]): Input samples. The format of inputs should be only one array.
labels (Union[numpy.ndarray, tuple]): Targeted labels or ground-truth labels. The format of labels should
be (gt_boxes, gt_labels).

Returns:
- numpy.ndarray, bool values for each attack result.

- numpy.ndarray, generated adversarial examples.

- numpy.ndarray, query times for each sample.

Examples:
>>> batch_size = 6
>>> x_test = np.random.rand(batch_size, 10)
>>> y_test = np.random.randint(low=0, high=10, size=batch_size)
>>> y_test = np.eye(10)[y_test]
>>> y_test = y_test.astype(np.float32)
>>> _, adv_data, _ = attack._generate_detection(x_test, y_test)
"""
images, auxiliary_inputs, gt_boxes, gt_labels = check_detection_inputs(inputs, labels)
adv_list = []
success_list = []
query_times_list = []
for i in range(images.shape[0]):
is_success = False
x_ori = images[i]
if not self._bounds:
self._bounds = [np.min(x_ori), np.max(x_ori)]
pixel_deep = self._bounds[1] - self._bounds[0]
auxiliary_input_i = tuple()
for item in auxiliary_inputs:
auxiliary_input_i += (np.expand_dims(item[i], axis=0),)
gt_boxes_i, gt_labels_i = np.expand_dims(gt_boxes[i], axis=0), np.expand_dims(gt_labels[i], axis=0)
inputs_i = (images[i],) + auxiliary_input_i
confi_ori, gt_object_num = self._detection_scores(inputs_i, gt_boxes_i, gt_labels_i, model=self._model)
LOGGER.info(TAG, 'The number of ground-truth objects is %s', gt_object_num[0])

# generate particles
ori_copies = np.repeat(x_ori[np.newaxis, :], self._pop_size, axis=0)
# initial perturbations
cur_pert = np.random.uniform(self._bounds[0], self._bounds[1], ori_copies.shape)
cur_pop = ori_copies + cur_pert
query_times = 0
iters = 0

while iters < self._max_steps:
iters += 1
cur_pop = np.clip(np.clip(cur_pop,
ori_copies - pixel_deep*self._per_bounds,
ori_copies + pixel_deep*self._per_bounds),
self._bounds[0], self._bounds[1])

confi_adv, correct_nums_adv = self._detection_scores(
(cur_pop,) + auxiliary_input_i, gt_boxes_i, gt_labels_i, model=self._model)
LOGGER.info(TAG, 'The number of correctly detected objects in adversarial image is %s',
np.min(correct_nums_adv))
query_times += self._pop_size
fit_vals = abs(
confi_ori - confi_adv) - self._c / self._pop_size * np.linalg.norm(
(cur_pop - x_ori).reshape(cur_pop.shape[0], -1), axis=1)

if np.max(fit_vals) < 0:
self._c /= 2

if np.max(fit_vals) < -2:
LOGGER.debug(TAG,
'best fitness value is %s, which is too small. We recommend that you decrease '
'the value of the initialization parameter c.', np.max(fit_vals))
if iters < 3 and np.max(fit_vals) > 100:
LOGGER.debug(TAG,
'best fitness value is %s, which is too large. We recommend that you increase '
'the value of the initialization parameter c.', np.max(fit_vals))

if np.min(correct_nums_adv) <= int(gt_object_num*self._reserve_ratio):
is_success = True
best_idx = np.argmin(correct_nums_adv)

if is_success:
LOGGER.debug(TAG, 'successfully find one adversarial sample '
'and start Reduction process.')
final_adv = cur_pop[best_idx]
break

cur_pop, elite = self._compute_next_generation(cur_pop, fit_vals, x_ori)

if not is_success:
LOGGER.debug(TAG, 'fail to find adversarial sample.')
final_adv = elite
if self._model_type == 'detection':
final_adv, query_times = self._fast_reduction(
x_ori, final_adv, query_times, auxiliary_input_i, gt_boxes_i, gt_labels_i, model=self._model)
final_adv, query_times = self._fast_reduction(
x_ori, final_adv, query_times, auxiliary_input_i, gt_boxes_i, gt_labels_i, model=self._model)
adv_list.append(final_adv)

LOGGER.debug(TAG,
@@ -327,3 +443,38 @@ class GeneticAttack(Attack):
return np.asarray(success_list), \
np.asarray(adv_list), \
np.asarray(query_times_list)

def generate(self, inputs, labels):
"""
Generate adversarial examples based on input data and targeted labels (or ground_truth labels).

Args:
inputs (Union[numpy.ndarray, tuple]): Input samples. The format of inputs should be numpy.ndarray if
model_type='classification'. The format of inputs can be (input1, input2, ...) or only one array if
model_type='detection'.
labels (Union[numpy.ndarray, tuple]): Targeted labels or ground-truth labels. The format of labels should
be numpy.ndarray if model_type='classification'. The format of labels should be (gt_boxes, gt_labels)
if model_type='detection'.

Returns:
- numpy.ndarray, bool values for each attack result.

- numpy.ndarray, generated adversarial examples.

- numpy.ndarray, query times for each sample.

Examples:
>>> batch_size = 6
>>> x_test = np.random.rand(batch_size, 10)
>>> y_test = np.random.randint(low=0, high=10, size=batch_size)
>>> y_test = np.eye(10)[y_test]
>>> y_test = y_test.astype(np.float32)
>>> _, adv_data, _ = attack.generate(x_test, y_test)
"""
if self._model_type == 'classification':
success_list, adv_data, query_time_list = self._generate_classification(inputs, labels)

elif self._model_type == 'detection':
success_list, adv_data, query_time_list = self._generate_detection(inputs, labels)

return success_list, adv_data, query_time_list

+ 27
- 6
mindarmour/adv_robustness/attacks/black/hop_skip_jump_attack.py View File

@@ -75,11 +75,26 @@ class HopSkipJumpAttack(Attack):
ValueError: If constraint not in ['l2', 'linf']

Examples:
>>> x_test = np.asarray(np.random.random((sample_num,
>>> sample_length)), np.float32)
>>> y_test = np.random.randint(0, class_num, size=sample_num)
>>> instance = HopSkipJumpAttack(user_model)
>>> adv_x = instance.generate(x_test, y_test)
>>> import numpy as np
>>> from mindspore import Tensor
>>> from mindarmour import BlackModel
>>> from mindarmour.adv_robustness.attacks import HopSkipJumpAttack
>>> from tests.ut.python.utils.mock_net import Net
>>>
>>> class ModelToBeAttacked(BlackModel):
>>> def __init__(self, network):
>>> super(ModelToBeAttacked, self).__init__()
>>> self._network = network
>>> def predict(self, inputs):
>>> if len(inputs.shape) == 3:
>>> inputs = inputs[np.newaxis, :]
>>> result = self._network(Tensor(inputs.astype(np.float32)))
>>> return result.asnumpy()
>>>
>>>
>>> net = Net()
>>> model = ModelToBeAttacked(net)
>>> attack = HopSkipJumpAttack(model)
"""

def __init__(self, model, init_num_evals=100, max_num_evals=1000,
@@ -173,7 +188,13 @@ class HopSkipJumpAttack(Attack):
- numpy.ndarray, query times for each sample.

Examples:
>>> generate([[0.1,0.2,0.2],[0.2,0.3,0.4]],[2,6])
>>> attack = HopSkipJumpAttack(model)
>>> n, c, h, w = 1, 1, 32, 32
>>> class_num = 3
>>> x_test = np.asarray(np.random.random((n,c,h,w)), np.float32)
>>> y_test = np.random.randint(0, class_num, size=n)
>>>
>>> _, adv_x, _= attack.generate(x_test, y_test)
"""
if labels is not None:
inputs, labels = check_pair_numpy_param('inputs', inputs,


+ 34
- 12
mindarmour/adv_robustness/attacks/black/natural_evolutionary_strategy.py View File

@@ -79,16 +79,27 @@ class NES(Attack):
input labels are one-hot-encoded. Default: True.

Examples:
>>> SCENE = 'Label_Only'
>>> TOP_K = 5
>>> num_class = 5
>>> nes_instance = NES(user_model, SCENE, top_k=TOP_K)
>>> initial_img = np.asarray(np.random.random((32, 32)), np.float32)
>>> target_image = np.asarray(np.random.random((32, 32)), np.float32)
>>> orig_class = 0
>>> target_class = 2
>>> nes_instance.set_target_images(target_image)
>>> tag, adv, queries = nes_instance.generate([initial_img], [target_class])
>>> import numpy as np
>>> from mindspore import Tensor
>>> from mindarmour import BlackModel
>>> from mindarmour.adv_robustness.attacks import NES
>>> from tests.ut.python.utils.mock_net import Net
>>>
>>> class ModelToBeAttacked(BlackModel):
>>> def __init__(self, network):
>>> super(ModelToBeAttacked, self).__init__()
>>> self._network = network
>>> def predict(self, inputs):
>>> if len(inputs.shape) == 3:
>>> inputs = inputs[np.newaxis, :]
>>> result = self._network(Tensor(inputs.astype(np.float32)))
>>> return result.asnumpy()
>>>
>>> net = Net()
>>> model = ModelToBeAttacked(net)
>>> SCENE = 'Query_Limit'
>>> TOP_K = -1
>>> attack= NES(model, SCENE, top_k=TOP_K)
"""

def __init__(self, model, scene, max_queries=10000, top_k=-1, num_class=10, batch_size=128, epsilon=0.3,
@@ -146,8 +157,19 @@ class NES(Attack):
ValueError: If scene is not in ['Label_Only', 'Partial_Info', 'Query_Limit']

Examples:
>>> advs = attack.generate([[0.2, 0.3, 0.4], [0.3, 0.3, 0.2]],
>>> [1, 2])
>>> net = Net()
>>> model = ModelToBeAttacked(net)
>>> SCENE = 'Query_Limit'
>>> TOP_K = -1
>>> attack= NES(model, SCENE, top_k=TOP_K)
>>>
>>> num_class = 5
>>> x_test = np.asarray(np.random.random((32, 32)), np.float32)
>>> target_image = np.asarray(np.random.random((32, 32)), np.float32)
>>> orig_class = 0
>>> target_class = 2
>>> attack.set_target_images(target_image)
>>> tag, adv, queries = attack.generate(np.array(x_test), np.array([target_class]))
"""
inputs, labels = check_pair_numpy_param('inputs', inputs, 'labels', labels)
if not self._sparse:


+ 22
- 1
mindarmour/adv_robustness/attacks/black/pointwise_attack.py View File

@@ -47,6 +47,22 @@ class PointWiseAttack(Attack):
Default: True.

Examples:
>>> import numpy as np
>>> from mindspore import Tensor
>>> from mindarmour import BlackModel
>>> from mindarmour.adv_robustness.attacks import PointWiseAttack
>>> from tests.ut.python.utils.mock_net import Net
>>>
>>> class ModelToBeAttacked(BlackModel):
>>> def __init__(self, network):
>>> super(ModelToBeAttacked, self).__init__()
>>> self._network = network
>>> def predict(self, inputs):
>>> result = self._network(Tensor(inputs.astype(np.float32)))
>>> return result.asnumpy()
>>>
>>> net = Net()
>>> model = ModelToBeAttacked(net)
>>> attack = PointWiseAttack(model)
"""

@@ -79,7 +95,12 @@ class PointWiseAttack(Attack):
- numpy.ndarray, query times for each sample.

Examples:
>>> is_adv_list, adv_list, query_times_each_adv = attack.generate([[0.1, 0.2, 0.6], [0.3, 0, 0.4]], [2, 3])
>>> net = Net()
>>> model = ModelToBeAttacked(net)
>>> attack = PointWiseAttack(model)
>>> x_test = np.asarray(np.random.random((1,1,32,32)), np.float32)
>>> y_test = np.random.randint(0, 3, size=1)
>>> is_adv_list, adv_list, query_times_each_adv = attack.generate(x_test, y_test)
"""
arr_x, arr_y = check_pair_numpy_param('inputs', inputs, 'labels', labels)
if not self._sparse:


+ 276
- 76
mindarmour/adv_robustness/attacks/black/pso_attack.py View File

@@ -55,7 +55,7 @@ class PSOAttack(Attack):
clip_max). Default: None.
targeted (bool): If True, turns on the targeted attack. If False,
turns on untargeted attack. It should be noted that only untargeted attack
is supproted for model_type='detection', Default: False.
is supported for model_type='detection', Default: False.
sparse (bool): If True, input labels are sparse-encoded. If False,
input labels are one-hot-encoded. Default: True.
model_type (str): The type of targeted model. 'classification' and 'detection' are supported now.
@@ -64,7 +64,35 @@ class PSOAttack(Attack):
specifically for model_type='detection'. Reserve_ratio should be in the range of (0, 1). Default: 0.3.

Examples:
>>> attack = PSOAttack(model)
>>> import numpy as np
>>> import mindspore.nn as nn
>>> from mindspore import Tensor
>>> from mindspore.nn import Cell
>>> from mindarmour import BlackModel
>>> from mindarmour.adv_robustness.attacks import PSOAttack
>>>
>>> class ModelToBeAttacked(BlackModel):
>>> def __init__(self, network):
>>> super(ModelToBeAttacked, self).__init__()
>>> self._network = network
>>> def predict(self, inputs):
>>> if len(inputs.shape) == 1:
>>> inputs = np.expand_dims(inputs, axis=0)
>>> result = self._network(Tensor(inputs.astype(np.float32)))
>>> return result.asnumpy()
>>>
>>> class Net(Cell):
>>> def __init__(self):
>>> super(Net, self).__init__()
>>> self._relu = nn.ReLU()
>>>
>>> def construct(self, inputs):
>>> out = self._relu(inputs)
>>> return out
>>>
>>> net = Net()
>>> model = ModelToBeAttacked(net)
>>> attack = PSOAttack(model, bounds=(0.0, 1.0), pm=0.5, sparse=False)
"""

def __init__(self, model, model_type='classification', targeted=False, reserve_ratio=0.3, sparse=True,
@@ -169,18 +197,33 @@ class PSOAttack(Attack):
self._bounds[1])
return mutated_pop

def generate(self, inputs, labels):
def _check_best_fitness(self, best_fitness, iters):
if best_fitness < -2:
LOGGER.debug(TAG, 'best fitness value is %s, which is too small. We recommend that you decrease '
'the value of the initialization parameter c.', best_fitness)
if iters < 3 and best_fitness > 100:
LOGGER.debug(TAG, 'best fitness value is %s, which is too large. We recommend that you increase '
'the value of the initialization parameter c.', best_fitness)

def _update_best_fit_position(self, fit_value, par_best_fit, par_best_poi, par, best_fitness, best_position):
for k in range(self._pop_size):
if fit_value[k] > par_best_fit[k]:
par_best_fit[k] = fit_value[k]
par_best_poi[k] = par[k]
if fit_value[k] > best_fitness:
best_fitness = fit_value[k]
best_position = par[k].copy()
return par_best_fit, par_best_poi, best_fitness, best_position

def _generate_classification(self, inputs, labels):
"""
Generate adversarial examples based on input data and targeted
labels (or ground_truth labels).
Generate adversarial examples based on input data and
targeted labels (or ground_truth labels) for classification model.

Args:
inputs (Union[numpy.ndarray, tuple]): Input samples. The format of inputs should be numpy.ndarray if
model_type='classification'. The format of inputs can be (input1, input2, ...) or only one array if
model_type='detection'.
inputs (Union[numpy.ndarray, tuple]): Input samples. The format of inputs should be numpy.ndarray.
labels (Union[numpy.ndarray, tuple]): Targeted labels or ground-truth labels. The format of labels should
be numpy.ndarray if model_type='classification'. The format of labels should be (gt_boxes, gt_labels)
if model_type='detection'.
be numpy.ndarray.

Returns:
- numpy.ndarray, bool values for each attack result.
@@ -190,28 +233,32 @@ class PSOAttack(Attack):
- numpy.ndarray, query times for each sample.

Examples:
>>> advs = attack.generate([[0.2, 0.3, 0.4], [0.3, 0.3, 0.2]],
>>> [1, 2])
>>> net = Net()
>>> model = ModelToBeAttacked(net)
>>> attack = PSOAttack(model, bounds=(0.0, 1.0), pm=0.5, sparse=False)
>>> batch_size = 6
>>> x_test = np.random.rand(batch_size, 10)
>>> y_test = np.random.randint(low=0, high=10, size=batch_size)
>>> y_test = np.eye(10)[y_test]
>>> y_test = y_test.astype(np.float32)
>>> _, adv_data, _ = attack.generate(x_test, y_test)
"""
# inputs check
if self._model_type == 'classification':
inputs, labels = check_pair_numpy_param('inputs', inputs,
'labels', labels)
if self._sparse:
if labels.size > 1:
label_squ = np.squeeze(labels)
else:
label_squ = labels
if len(label_squ.shape) >= 2 or label_squ.shape[0] != inputs.shape[0]:
msg = "The parameter 'sparse' of PSOAttack is True, but the input labels is not sparse style and " \
"got its shape as {}.".format(labels.shape)
LOGGER.error(TAG, msg)
raise ValueError(msg)
inputs, labels = check_pair_numpy_param('inputs', inputs,
'labels', labels)
if self._sparse:
if labels.size > 1:
label_squ = np.squeeze(labels)
else:
labels = np.argmax(labels, axis=1)
images = inputs
elif self._model_type == 'detection':
images, auxiliary_inputs, gt_boxes, gt_labels = check_detection_inputs(inputs, labels)
label_squ = labels
if len(label_squ.shape) >= 2 or label_squ.shape[0] != inputs.shape[0]:
msg = "The parameter 'sparse' of PSOAttack is True, but the input labels is not sparse style and " \
"got its shape as {}.".format(labels.shape)
LOGGER.error(TAG, msg)
raise ValueError(msg)
else:
labels = np.argmax(labels, axis=1)
images = inputs

# generate one adversarial each time
adv_list = []
@@ -226,17 +273,9 @@ class PSOAttack(Attack):
pixel_deep = self._bounds[1] - self._bounds[0]

q_times += 1
if self._model_type == 'classification':
label_i = labels[i]
confi_ori = self._confidence_cla(x_ori, label_i)
elif self._model_type == 'detection':
auxiliary_input_i = tuple()
for item in auxiliary_inputs:
auxiliary_input_i += (np.expand_dims(item[i], axis=0),)
gt_boxes_i, gt_labels_i = np.expand_dims(gt_boxes[i], axis=0), np.expand_dims(gt_labels[i], axis=0)
inputs_i = (images[i],) + auxiliary_input_i
confi_ori, gt_object_num = self._detection_scores(inputs_i, gt_boxes_i, gt_labels_i, self._model)
LOGGER.info(TAG, 'The number of ground-truth objects is %s', gt_object_num[0])

label_i = labels[i]
confi_ori = self._confidence_cla(x_ori, label_i)

# step1, initializing
# initial global optimum fitness value, cannot set to be -inf
@@ -277,57 +316,178 @@ class PSOAttack(Attack):
x_copies + (np.abs(x_copies) + 0.1*pixel_deep)*self._per_bounds),
self._bounds[0], self._bounds[1])

if self._model_type == 'classification':
confi_adv = self._confidence_cla(par, label_i)
elif self._model_type == 'detection':
confi_adv, _ = self._detection_scores(
(par,) + auxiliary_input_i, gt_boxes_i, gt_labels_i, self._model)

confi_adv = self._confidence_cla(par, label_i)

q_times += self._pop_size
fit_value = self._fitness(confi_ori, confi_adv, x_ori, par)
for k in range(self._pop_size):
if fit_value[k] > par_best_fit[k]:
par_best_fit[k] = fit_value[k]
par_best_poi[k] = par[k]
if fit_value[k] > best_fitness:
best_fitness = fit_value[k]
best_position = par[k].copy()
par_best_fit, par_best_poi, best_fitness, best_position = self._update_best_fit_position(fit_value,
par_best_fit,
par_best_poi,
par,
best_fitness,
best_position)
iters += 1
if best_fitness < -2:
LOGGER.debug(TAG, 'best fitness value is %s, which is too small. We recommend that you decrease '
'the value of the initialization parameter c.', best_fitness)
if iters < 3 and best_fitness > 100:
LOGGER.debug(TAG, 'best fitness value is %s, which is too large. We recommend that you increase '
'the value of the initialization parameter c.', best_fitness)
self._check_best_fitness(best_fitness, iters)

is_mutation = False
if (best_fitness - last_best_fit) < last_best_fit*0.05:
is_mutation = True

q_times += 1
if self._model_type == 'classification':
cur_pre = self._model.predict(best_position)
cur_label = np.argmax(cur_pre)
if (self._targeted and cur_label == label_i) or (not self._targeted and cur_label != label_i):
is_success = True
elif self._model_type == 'detection':
_, correct_nums_adv = self._detection_scores(
(best_position,) + auxiliary_input_i, gt_boxes_i, gt_labels_i, self._model)
LOGGER.info(TAG, 'The number of correctly detected objects in adversarial image is %s',
correct_nums_adv[0])
if correct_nums_adv <= int(gt_object_num*self._reserve_ratio):
is_success = True

cur_pre = self._model.predict(best_position)
cur_label = np.argmax(cur_pre)
if (self._targeted and cur_label == label_i) or (not self._targeted and cur_label != label_i):
is_success = True

if is_success:
LOGGER.debug(TAG, 'successfully find one adversarial '
'sample and start Reduction process')
# step3, reduction
if self._model_type == 'classification':
best_position, q_times = self._reduction(x_ori, q_times, label_i, best_position, self._model,
targeted_attack=self._targeted)
best_position, q_times = self._reduction(x_ori, q_times, label_i, best_position, self._model,
targeted_attack=self._targeted)
break

if not is_success:
LOGGER.debug(TAG,
'fail to find adversarial sample, iteration '
'times is: %d and query times is: %d',
iters,
q_times)
adv_list.append(best_position)
success_list.append(is_success)
query_times_list.append(q_times)
del x_copies, cur_noise, par, par_best_poi
return np.asarray(success_list), \
np.asarray(adv_list), \
np.asarray(query_times_list)


def _generate_detection(self, inputs, labels):
"""
Generate adversarial examples based on input data and
targeted labels (or ground_truth labels) for detection model.

Args:
inputs (Union[numpy.ndarray, tuple]): Input samples. The format of inputs can be (input1, input2, ...)
or only one array.
labels (Union[numpy.ndarray, tuple]): Targeted labels or ground-truth labels.
The format of labels should be (gt_boxes, gt_labels).

Returns:
- numpy.ndarray, bool values for each attack result.

- numpy.ndarray, generated adversarial examples.

- numpy.ndarray, query times for each sample.

Examples:
>>> net = Net()
>>> model = ModelToBeAttacked(net)
>>> attack = PSOAttack(model, bounds=(0.0, 1.0), pm=0.5, sparse=False)
>>> batch_size = 6
>>> x_test = np.random.rand(batch_size, 10)
>>> y_test = np.random.randint(low=0, high=10, size=batch_size)
>>> y_test = np.eye(10)[y_test]
>>> y_test = y_test.astype(np.float32)
>>> _, adv_data, _ = attack.generate(x_test, y_test)
"""
# inputs check
images, auxiliary_inputs, gt_boxes, gt_labels = check_detection_inputs(inputs, labels)

# generate one adversarial each time
adv_list = []
success_list = []
query_times_list = []
for i in range(images.shape[0]):
is_success = False
q_times = 0
x_ori = images[i]
if not self._bounds:
self._bounds = [np.min(x_ori), np.max(x_ori)]
pixel_deep = self._bounds[1] - self._bounds[0]

q_times += 1
auxiliary_input_i = tuple()
for item in auxiliary_inputs:
auxiliary_input_i += (np.expand_dims(item[i], axis=0),)
gt_boxes_i, gt_labels_i = np.expand_dims(gt_boxes[i], axis=0), np.expand_dims(gt_labels[i], axis=0)
inputs_i = (images[i],) + auxiliary_input_i
confi_ori, gt_object_num = self._detection_scores(inputs_i, gt_boxes_i, gt_labels_i, self._model)
LOGGER.info(TAG, 'The number of ground-truth objects is %s', gt_object_num[0])

# step1, initializing
# initial global optimum fitness value, cannot set to be -inf
best_fitness = -np.inf
# initial global optimum position
best_position = x_ori
x_copies = np.repeat(x_ori[np.newaxis, :], self._pop_size, axis=0)
cur_noise = np.clip(np.random.random(x_copies.shape)*pixel_deep,
(0 - self._per_bounds)*(np.abs(x_copies) + 0.1),
self._per_bounds*(np.abs(x_copies) + 0.1))

# initial advs
par = np.clip(x_copies + cur_noise, self._bounds[0], self._bounds[1])
# initial optimum positions for particles
par_best_poi = np.copy(par)
# initial optimum fitness values
par_best_fit = -np.inf*np.ones(self._pop_size)
# step2, optimization
# initial velocities for particles
v_particles = np.zeros(par.shape)
is_mutation = False
iters = 0
while iters < self._t_max:
last_best_fit = best_fitness
ran_1 = np.random.random(par.shape)
ran_2 = np.random.random(par.shape)
v_particles = self._step_size*(
v_particles + self._c1*ran_1*(best_position - par)) \
+ self._c2*ran_2*(par_best_poi - par)

par += v_particles

if iters > 6 and is_mutation:
par = self._mutation_op(par)

par = np.clip(np.clip(par,
x_copies - (np.abs(x_copies) + 0.1*pixel_deep)*self._per_bounds,
x_copies + (np.abs(x_copies) + 0.1*pixel_deep)*self._per_bounds),
self._bounds[0], self._bounds[1])

confi_adv, _ = self._detection_scores(
(par,) + auxiliary_input_i, gt_boxes_i, gt_labels_i, self._model)
q_times += self._pop_size
fit_value = self._fitness(confi_ori, confi_adv, x_ori, par)
par_best_fit, par_best_poi, best_fitness, best_position = self._update_best_fit_position(fit_value,
par_best_fit,
par_best_poi,
par,
best_fitness,
best_position)
iters += 1
self._check_best_fitness(best_fitness, iters)

is_mutation = False
if (best_fitness - last_best_fit) < last_best_fit*0.05:
is_mutation = True

q_times += 1

_, correct_nums_adv = self._detection_scores(
(best_position,) + auxiliary_input_i, gt_boxes_i, gt_labels_i, self._model)
LOGGER.info(TAG, 'The number of correctly detected objects in adversarial image is %s',
correct_nums_adv[0])
if correct_nums_adv <= int(gt_object_num*self._reserve_ratio):
is_success = True

if is_success:
LOGGER.debug(TAG, 'successfully find one adversarial '
'sample and start Reduction process')
break
if self._model_type == 'detection':
best_position, q_times = self._fast_reduction(x_ori, best_position, q_times,
auxiliary_input_i, gt_boxes_i, gt_labels_i, self._model)
best_position, q_times = self._fast_reduction(x_ori, best_position, q_times,
auxiliary_input_i, gt_boxes_i, gt_labels_i, self._model)
if not is_success:
LOGGER.debug(TAG,
'fail to find adversarial sample, iteration '
@@ -341,3 +501,43 @@ class PSOAttack(Attack):
return np.asarray(success_list), \
np.asarray(adv_list), \
np.asarray(query_times_list)

def generate(self, inputs, labels):
"""
Generate adversarial examples based on input data and
targeted labels (or ground_truth labels).

Args:
inputs (Union[numpy.ndarray, tuple]): Input samples. The format of inputs should be numpy.ndarray if
model_type='classification'. The format of inputs can be (input1, input2, ...) or only one array if
model_type='detection'.
labels (Union[numpy.ndarray, tuple]): Targeted labels or ground-truth labels. The format of labels should
be numpy.ndarray if model_type='classification'. The format of labels should be (gt_boxes, gt_labels)
if model_type='detection'.

Returns:
- numpy.ndarray, bool values for each attack result.

- numpy.ndarray, generated adversarial examples.

- numpy.ndarray, query times for each sample.

Examples:
>>> net = Net()
>>> model = ModelToBeAttacked(net)
>>> attack = PSOAttack(model, bounds=(0.0, 1.0), pm=0.5, sparse=False)
>>> batch_size = 6
>>> x_test = np.random.rand(batch_size, 10)
>>> y_test = np.random.randint(low=0, high=10, size=batch_size)
>>> y_test = np.eye(10)[y_test]
>>> y_test = y_test.astype(np.float32)
>>> _, adv_data, _ = attack.generate(x_test, y_test)
"""
# inputs check
if self._model_type == 'classification':
success_list, adv_data, query_time_list = self._generate_classification(inputs, labels)

elif self._model_type == 'detection':
success_list, adv_data, query_time_list = self._generate_detection(inputs, labels)

return success_list, adv_data, query_time_list

+ 22
- 1
mindarmour/adv_robustness/attacks/black/salt_and_pepper_attack.py View File

@@ -40,6 +40,22 @@ class SaltAndPepperNoiseAttack(Attack):
Default: True.

Examples:
>>> import numpy as np
>>> from mindspore import Tensor
>>> from mindarmour import BlackModel
>>> from mindarmour.adv_robustness.attacks import SaltAndPepperNoiseAttack
>>> from tests.ut.python.utils.mock_net import Net
>>>
>>> class ModelToBeAttacked(BlackModel):
>>> def __init__(self, network):
>>> super(ModelToBeAttacked, self).__init__()
>>> self._network = network
>>> def predict(self, inputs):
>>> result = self._network(Tensor(inputs.astype(np.float32)))
>>> return result.asnumpy()
>>>
>>> net = Net()
>>> model = ModelToBeAttacked(net)
>>> attack = SaltAndPepperNoiseAttack(model)
"""

@@ -69,7 +85,12 @@ class SaltAndPepperNoiseAttack(Attack):
- numpy.ndarray, query times for each sample.

Examples:
>>> adv_list = attack.generate(([[0.1, 0.2, 0.6], [0.3, 0, 0.4]], [1, 2])
>>> net = Net()
>>> model = ModelToBeAttacked(net)
>>> attack = PointWiseAttack(model)
>>> x_test = np.asarray(np.random.random((1,1,32,32)), np.float32)
>>> y_test = np.random.randint(0, 3, size=1)
>>> _, adv_list, _ = attack.generate(x_test, y_test)
"""
arr_x, arr_y = check_pair_numpy_param('inputs', inputs, 'labels', labels)
if not self._sparse:


+ 51
- 7
mindarmour/adv_robustness/attacks/carlini_wagner.py View File

@@ -95,7 +95,24 @@ class CarliniWagnerL2Attack(Attack):
input labels are onehot-coded. Default: True.

Examples:
>>> attack = CarliniWagnerL2Attack(network)
>>> import numpy as np
>>> import mindspore.ops.operations as M
>>> from mindspore.nn import Cell
>>> from mindarmour.adv_robustness.attacks import CarliniWagnerL2Attack
>>>
>>> class Net(Cell):
>>> def __init__(self):
>>> super(Net, self).__init__()
>>> self._softmax = M.Softmax()
>>>
>>> def construct(self, inputs):
>>> out = self._softmax(inputs)
>>> return out
>>>
>>> input_np = np.array([[0.1, 0.2, 0.7, 0.5, 0.4]]).astype(np.float32)
>>> label_np = np.array([3]).astype(np.int64)
>>> num_classes = input_np.shape[1]
>>> attack = CarliniWagnerL2Attack(net, num_classes, targeted=False)
"""

def __init__(self, network, num_classes, box_min=0.0, box_max=1.0,
@@ -246,6 +263,14 @@ class CarliniWagnerL2Attack(Attack):
the_grad = the_grad*diff
return inputs, the_grad

def _check_success(self, logits, labels):
""" check if attack success (include all examples)"""
if self._targeted:
is_adv = (np.argmax(logits, axis=1) == labels)
else:
is_adv = (np.argmax(logits, axis=1) != labels)
return is_adv

def generate(self, inputs, labels):
"""
Generate adversarial examples based on input data and targeted labels.
@@ -259,7 +284,30 @@ class CarliniWagnerL2Attack(Attack):
numpy.ndarray, generated adversarial examples.

Examples:
>>> advs = attack.generate([[0.1, 0.2, 0.6], [0.3, 0, 0.4]], [1, 2]]
>>> import numpy as np
>>> import mindspore.ops.operations as M
>>> from mindspore.nn import Cell
>>> from mindarmour.adv_robustness.attacks import CarliniWagnerL2Attack
>>>
>>> class Net(Cell):
>>> def __init__(self):
>>> super(Net, self).__init__()
>>> self._softmax = M.Softmax()
>>>
>>> def construct(self, inputs):
>>> out = self._softmax(inputs)
>>> return out
>>>
>>> input_np = np.array([[0.1, 0.2, 0.7, 0.5, 0.4]]).astype(np.float32)
>>> num_classes = input_np.shape[1]
>>>
>>> label_np = np.array([3]).astype(np.int64)
>>> attack_nonTargeted = CarliniWagnerL2Attack(net, num_classes, targeted=False)
>>> advs_nonTargeted = attack_nonTargeted.generate(input_np, label_np)
>>>
>>> target_np = np.array([1]).astype(np.int64)
>>> attack_targeted = CarliniWagnerL2Attack(net, num_classes, targeted=False)
>>> advs_targeted = attack_targeted.generate(input_np, target_np)
"""

LOGGER.debug(TAG, "enter the func generate.")
@@ -302,11 +350,7 @@ class CarliniWagnerL2Attack(Attack):
logits, x_input, reconstructed_original,
labels, const, self._confidence)

# check if attack success (include all examples)
if self._targeted:
is_adv = (np.argmax(logits, axis=1) == labels)
else:
is_adv = (np.argmax(logits, axis=1) != labels)
is_adv = self._check_success(logits, labels)

for i in range(samples_num):
if is_adv[i]:


+ 35
- 10
mindarmour/adv_robustness/attacks/deep_fool.py View File

@@ -117,7 +117,23 @@ class DeepFool(Attack):
input labels are onehot-coded. Default: True.

Examples:
>>> attack = DeepFool(network)
>>> import numpy as np
>>> import mindspore.ops.operations as P
>>> from mindspore.nn import Cell
>>> from mindspore import Tensor
>>> from mindarmour.adv_robustness.attacks import DeepFool
>>>
>>> def __init__(self):
>>> super(Net, self).__init__()
>>> self._softmax = P.Softmax()
>>>
>>> def construct(self, inputs):
>>> out = self._softmax(inputs)
>>> return out
>>>
>>> net = Net()
>>> attack = DeepFool(net, classes, max_iters=10, norm_level=2,
bounds=(0.0, 1.0))
"""

def __init__(self, network, num_classes, model_type='classification',
@@ -165,14 +181,30 @@ class DeepFool(Attack):
NotImplementedError: If norm_level is not in [2, np.inf, '2', 'inf'].

Examples:
>>> advs = generate([[0.2, 0.3, 0.4], [0.3, 0.4, 0.5]], [1, 2])
>>> input_shape = (1, 5)
>>> _, classes = input_shape
>>> input_np = np.array([[0.1, 0.2, 0.7, 0.5, 0.4]]).astype(np.float32)
>>> input_me = Tensor(input_np)
>>> true_labels = np.argmax(net(input_me).asnumpy(), axis=1)
>>> attack = DeepFool(net, classes, max_iters=10, norm_level=2, bounds=(0.0, 1.0))
>>> advs = attack.generate(input_np, true_labels)
"""

if self._model_type == 'detection':
return self._generate_detection(inputs, labels)
if self._model_type == 'classification':
return self._generate_classification(inputs, labels)
return None

def _update_image(self, x_origin, r_tot):
"""update image based on bounds"""
if self._bounds is not None:
clip_min, clip_max = self._bounds
images = x_origin + (1 + self._overshoot)*r_tot*(clip_max-clip_min)
images = np.clip(images, clip_min, clip_max)
else:
images = x_origin + (1 + self._overshoot)*r_tot
return images

def _generate_detection(self, inputs, labels):
"""Generate adversarial examples in detection scenario"""
@@ -239,19 +271,12 @@ class DeepFool(Attack):
raise NotImplementedError(msg)
r_tot[idx, ...] = r_tot[idx, ...] + r_i

if self._bounds is not None:
clip_min, clip_max = self._bounds
images = x_origin + (1 + self._overshoot)*r_tot*(clip_max-clip_min)
images = np.clip(images, clip_min, clip_max)
else:
images = x_origin + (1 + self._overshoot)*r_tot
images = self._update_image(x_origin, r_tot)
iteration += 1
images = images.astype(images_dtype)
del preds_logits, grads
return images



def _generate_classification(self, inputs, labels):
"""Generate adversarial examples in classification scenario"""
inputs, labels = check_pair_numpy_param('inputs', inputs,


+ 111
- 5
mindarmour/adv_robustness/attacks/gradient_method.py View File

@@ -47,9 +47,25 @@ class GradientMethod(Attack):
is already equipped with loss function. Default: None.

Examples:
>>> import numpy as np
>>> import mindspore.nn as nn
>>> from mindspore.nn import Cell, SoftmaxCrossEntropyWithLogits
>>> from mindspore import Tensor
>>> from mindarmour.adv_robustness.attacksimport FastGradientMethod
>>>
>>> class Net(Cell):
>>> def __init__(self):
>>> super(Net, self).__init__()
>>> self._relu = nn.ReLU()
>>>
>>> def construct(self, inputs):
>>> out = self._relu(inputs)
>>> return out
>>>
>>> inputs = np.array([[0.1, 0.2, 0.6], [0.3, 0, 0.4]])
>>> labels = np.array([[0, 1, 0, 0, 0], [0, 0, 1, 0, 0]])
>>> attack = FastGradientMethod(network, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False))
>>> net = Net()
>>> attack = FastGradientMethod(net, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False))
>>> adv_x = attack.generate(inputs, labels)
"""

@@ -155,9 +171,24 @@ class FastGradientMethod(GradientMethod):
is already equipped with loss function. Default: None.

Examples:
>>> import numpy as np
>>> import mindspore.nn as nn
>>> from mindspore.nn import Cell, SoftmaxCrossEntropyWithLogits
>>> from mindarmour.adv_robustness.attacks import FastGradientMethod
>>>
>>> class Net(Cell):
>>> def __init__(self):
>>> super(Net, self).__init__()
>>> self._relu = nn.ReLU()
>>>
>>> def construct(self, inputs):
>>> out = self._relu(inputs)
>>> return out
>>>
>>> inputs = np.array([[0.1, 0.2, 0.6], [0.3, 0, 0.4]])
>>> labels = np.array([[0, 1, 0, 0, 0], [0, 0, 1, 0, 0]])
>>> attack = FastGradientMethod(network, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False))
>>> net = Net()
>>> attack = FastGradientMethod(net, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False))
>>> adv_x = attack.generate(inputs, labels)
"""

@@ -223,9 +254,24 @@ class RandomFastGradientMethod(FastGradientMethod):
ValueError: eps is smaller than alpha!

Examples:
>>> import numpy as np
>>> import mindspore.nn as nn
>>> from mindspore.nn import Cell, SoftmaxCrossEntropyWithLogits
>>> from mindarmour.adv_robustness.attacks import RandomFastGradientMethod
>>>
>>> class Net(Cell):
>>> def __init__(self):
>>> super(Net, self).__init__()
>>> self._relu = nn.ReLU()
>>>
>>> def construct(self, inputs):
>>> out = self._relu(inputs)
>>> return out
>>>
>>> net = Net()
>>> inputs = np.array([[0.1, 0.2, 0.6], [0.3, 0, 0.4]])
>>> labels = np.array([[0, 1, 0, 0, 0], [0, 0, 1, 0, 0]])
>>> attack = RandomFastGradientMethod(network, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False))
>>> attack = RandomFastGradientMethod(net, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False))
>>> adv_x = attack.generate(inputs, labels)
"""

@@ -265,9 +311,24 @@ class FastGradientSignMethod(GradientMethod):
is already equipped with loss function. Default: None.

Examples:
>>> import numpy as np
>>> import mindspore.nn as nn
>>> from mindspore.nn import Cell, SoftmaxCrossEntropyWithLogits
>>> from mindarmour.adv_robustness.attacks import FastGradientSignMethod
>>>
>>> class Net(Cell):
>>> def __init__(self):
>>> super(Net, self).__init__()
>>> self._relu = nn.ReLU()
>>>
>>> def construct(self, inputs):
>>> out = self._relu(inputs)
>>> return out
>>>
>>> net = Net()
>>> inputs = np.array([[0.1, 0.2, 0.6], [0.3, 0, 0.4]])
>>> labels = np.array([[0, 1, 0, 0, 0], [0, 0, 1, 0, 0]])
>>> attack = FastGradientSignMethod(network, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False))
>>> attack = FastGradientSignMethod(net, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False))
>>> adv_x = attack.generate(inputs, labels)
"""

@@ -329,9 +390,24 @@ class RandomFastGradientSignMethod(FastGradientSignMethod):
ValueError: eps is smaller than alpha!

Examples:
>>> import numpy as np
>>> import mindspore.nn as nn
>>> from mindspore.nn import Cell, SoftmaxCrossEntropyWithLogits
>>> from mindarmour.adv_robustness.attacks import RandomFastGradientSignMethod
>>>
>>> class Net(Cell):
>>> def __init__(self):
>>> super(Net, self).__init__()
>>> self._relu = nn.ReLU()
>>>
>>> def construct(self, inputs):
>>> out = self._relu(inputs)
>>> return out
>>>
>>> net = Net()
>>> inputs = np.array([[0.1, 0.2, 0.6], [0.3, 0, 0.4]])
>>> labels = np.array([[0, 1, 0, 0, 0], [0, 0, 1, 0, 0]])
>>> attack = RandomFastGradientSignMethod(network, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False))
>>> attack = RandomFastGradientSignMethod(net, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False))
>>> adv_x = attack.generate(inputs, labels)
"""

@@ -366,6 +442,21 @@ class LeastLikelyClassMethod(FastGradientSignMethod):
is already equipped with loss function. Default: None.

Examples:
>>> import numpy as np
>>> import mindspore.nn as nn
>>> from mindspore.nn import Cell, SoftmaxCrossEntropyWithLogits
>>> from mindarmour.adv_robustness.attacks import LeastLikelyClassMethod
>>>
>>> class Net(Cell):
>>> def __init__(self):
>>> super(Net, self).__init__()
>>> self._relu = nn.ReLU()
>>>
>>> def construct(self, inputs):
>>> out = self._relu(inputs)
>>> return out
>>>
>>> net = Net()
>>> inputs = np.array([[0.1, 0.2, 0.6], [0.3, 0, 0.4]])
>>> labels = np.array([[0, 1, 0, 0, 0], [0, 0, 1, 0, 0]])
>>> attack = LeastLikelyClassMethod(network, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False))
@@ -404,6 +495,21 @@ class RandomLeastLikelyClassMethod(FastGradientSignMethod):
ValueError: eps is smaller than alpha!

Examples:
>>> import numpy as np
>>> import mindspore.nn as nn
>>> from mindspore.nn import Cell, SoftmaxCrossEntropyWithLogits
>>> from mindarmour.adv_robustness.attacks import RandomLeastLikelyClassMethod
>>>
>>> class Net(Cell):
>>> def __init__(self):
>>> super(Net, self).__init__()
>>> self._relu = nn.ReLU()
>>>
>>> def construct(self, inputs):
>>> out = self._relu(inputs)
>>> return out
>>>
>>> net = Net()
>>> inputs = np.array([[0.1, 0.2, 0.6], [0.3, 0, 0.4]])
>>> labels = np.array([[0, 1, 0, 0, 0], [0, 0, 1, 0, 0]])
>>> attack = RandomLeastLikelyClassMethod(network, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False))


+ 59
- 1
mindarmour/adv_robustness/attacks/iterative_gradient_method.py View File

@@ -184,7 +184,22 @@ class BasicIterativeMethod(IterativeGradientMethod):
is already equipped with loss function. Default: None.

Examples:
>>> attack = BasicIterativeMethod(network, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False))
>>> import numpy as np
>>> import mindspore.nn as nn
>>> from mindspore.nn import Cell, SoftmaxCrossEntropyWithLogits
>>> from mindarmour.adv_robustness.attacks import BasicIterativeMethod
>>>
>>> class Net(Cell):
>>> def __init__(self):
>>> super(Net, self).__init__()
>>> self._relu = nn.ReLU()
>>>
>>> def construct(self, inputs):
>>> out = self._relu(inputs)
>>> return out
>>>
>>> net = Net()
>>> attack = BasicIterativeMethod(netw, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False))
"""
def __init__(self, network, eps=0.3, eps_iter=0.1, bounds=(0.0, 1.0),
is_targeted=False, nb_iter=5, loss_fn=None):
@@ -215,6 +230,17 @@ class BasicIterativeMethod(IterativeGradientMethod):
numpy.ndarray, generated adversarial examples.

Examples:
>>> class Net(Cell):
>>> def __init__(self):
>>> super(Net, self).__init__()
>>> self._relu = nn.ReLU()
>>>
>>> def construct(self, inputs):
>>> out = self._relu(inputs)
>>> return out
>>>
>>> net = Net()
>>> attack = BasicIterativeMethod(net, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False))
>>> adv_x = attack.generate([[0.3, 0.2, 0.6],
>>> [0.3, 0.2, 0.4]],
>>> [[0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
@@ -303,6 +329,22 @@ class MomentumIterativeMethod(IterativeGradientMethod):
numpy.ndarray, generated adversarial examples.

Examples:
>>> import numpy as np
>>> import mindspore.nn as nn
>>> from mindspore.nn import Cell, SoftmaxCrossEntropyWithLogits
>>> from mindarmour.adv_robustness.attacks import MomentumIterativeMethod
>>>
>>> class Net(Cell):
>>> def __init__(self):
>>> super(Net, self).__init__()
>>> self._relu = nn.ReLU()
>>>
>>> def construct(self, inputs):
>>> out = self._relu(inputs)
>>> return out
>>>
>>> net = Net()
>>> attack = MomentumIterativeMethod(net, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False))
>>> adv_x = attack.generate([[0.5, 0.2, 0.6],
>>> [0.3, 0, 0.2]],
>>> [[0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
@@ -433,6 +475,22 @@ class ProjectedGradientDescent(BasicIterativeMethod):
numpy.ndarray, generated adversarial examples.

Examples:
>>> import numpy as np
>>> import mindspore.nn as nn
>>> from mindspore.nn import Cell, SoftmaxCrossEntropyWithLogits
>>> from mindarmour.adv_robustness.attacks import ProjectedGradientDescent
>>>
>>> class Net(Cell):
>>> def __init__(self):
>>> super(Net, self).__init__()
>>> self._relu = nn.ReLU()
>>>
>>> def construct(self, inputs):
>>> out = self._relu(inputs)
>>> return out
>>>
>>> net = Net()
>>> attack = ProjectedGradientDescent(net, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False))
>>> adv_x = attack.generate([[0.6, 0.2, 0.6],
>>> [0.3, 0.3, 0.4]],
>>> [[0, 0, 0, 0, 0, 0, 0, 0, 0, 1],


+ 24
- 2
mindarmour/adv_robustness/attacks/jsma.py View File

@@ -54,7 +54,23 @@ class JSMAAttack(Attack):
input labels are onehot-coded. Default: True.

Examples:
>>> attack = JSMAAttack(network)
>>> import numpy as np
>>> import mindspore.nn as nn
>>> from mindspore.nn import Cell
>>> from mindarmour.adv_robustness.attacks import JSMAAttack
>>> class Net(Cell):
>>> def __init__(self):
>>> super(Net, self).__init__()
>>> self._relu = nn.ReLU()
>>>
>>> def construct(self, inputs):
>>> out = self._relu(inputs)
>>> return out
>>>
>>> net = Net()
>>> input_shape = (1, 5)
>>> batch_size, classes = input_shape
>>> attack = JSMAAttack(net, classes, max_iteration=5)
"""

def __init__(self, network, num_classes, box_min=0.0, box_max=1.0,
@@ -181,7 +197,13 @@ class JSMAAttack(Attack):
numpy.ndarray, adversarial samples.

Examples:
>>> advs = generate([[0.2, 0.3, 0.4], [0.3, 0.4, 0.5]], [1, 2])
>>> input_shape = (1, 5)
>>> input_np = np.random.random(input_shape).astype(np.float32)
>>> label_np = np.random.randint(classes, size=batch_size)
>>> batch_size, classes = input_shape
>>>
>>> attack = JSMAAttack(net, classes, max_iteration=5)
>>> advs = attack.generate(input_np, label_np)
"""
inputs, labels = check_pair_numpy_param('inputs', inputs,
'labels', labels)


+ 8
- 2
mindarmour/adv_robustness/attacks/lbfgs.py View File

@@ -54,7 +54,12 @@ class LBFGS(Attack):
input labels are onehot-coded. Default: False.

Examples:
>>> attack = LBFGS(network)
>>> import numpy as np
>>> from mindarmour.adv_robustness.attacks import LBFGS
>>> from tests.ut.python.utils.mock_net import Net
>>>
>>> net = Net()
>>> attack = LBFGS(net, is_targeted=True)
"""
def __init__(self, network, eps=1e-5, bounds=(0.0, 1.0), is_targeted=True,
nb_iter=150, search_iters=30, loss_fn=None, sparse=False):
@@ -94,6 +99,7 @@ class LBFGS(Attack):
numpy.ndarray, generated adversarial examples.

Examples:
>>> attack = LBFGS(net, is_targeted=True)
>>> adv = attack.generate([[0.1, 0.2, 0.6], [0.3, 0, 0.4]], [2, 2])
"""
LOGGER.debug(TAG, 'start to generate adv image.')
@@ -191,7 +197,7 @@ class LBFGS(Attack):

def _optimize(self, start_input, labels, epsilon):
"""
Given loss fuction and gradient, use l_bfgs_b algorithm to update input
Given loss function and gradient, use l_bfgs_b algorithm to update input
sample. The epsilon will be doubled until an adversarial example is found.

Args:


+ 2
- 2
tests/ut/python/adv_robustness/attacks/black/test_nes.py View File

@@ -28,7 +28,7 @@ from tests.ut.python.utils.mock_net import Net
context.set_context(mode=context.GRAPH_MODE)

LOGGER = LogUtil.get_instance()
TAG = 'HopSkipJumpAttack'
TAG = 'NaturalEvolutionaryStrategy'


class ModelToBeAttacked(BlackModel):
@@ -100,7 +100,7 @@ def get_dataset(current_dir):

def nes_mnist_attack(scene, top_k):
"""
hsja-Attack test
nes-Attack test
"""
current_dir = os.path.dirname(os.path.abspath(__file__))
test_images, test_labels = get_dataset(current_dir)


Loading…
Cancel
Save