diff --git a/mindarmour/adv_robustness/attacks/black/genetic_attack.py b/mindarmour/adv_robustness/attacks/black/genetic_attack.py index 991a650..4263f0b 100644 --- a/mindarmour/adv_robustness/attacks/black/genetic_attack.py +++ b/mindarmour/adv_robustness/attacks/black/genetic_attack.py @@ -45,7 +45,7 @@ class GeneticAttack(Attack): default: 'classification'. targeted (bool): If True, turns on the targeted attack. If False, turns on untargeted attack. It should be noted that only untargeted attack - is supproted for model_type='detection', Default: True. + is supported for model_type='detection', Default: True. reserve_ratio (Union[int, float]): The percentage of objects that can be detected after attacks, specifically for model_type='detection'. Reserve_ratio should be in the range of (0, 1). Default: 0.3. pop_size (int): The number of particles, which should be greater than @@ -69,7 +69,33 @@ class GeneticAttack(Attack): c (Union[int, float]): Weight of perturbation loss. Default: 0.1. Examples: - >>> attack = GeneticAttack(model) + >>> import numpy as np + >>> import mindspore.ops.operations as M + >>> from mindspore import Tensor + >>> from mindspore.nn import Cell + >>> from mindarmour import BlackModel + >>> from mindarmour.adv_robustness.attacks import GeneticAttack + >>> + >>> class ModelToBeAttacked(BlackModel): + >>> def __init__(self, network): + >>> super(ModelToBeAttacked, self).__init__() + >>> self._network = network + >>> def predict(self, inputs): + >>> result = self._network(Tensor(inputs.astype(np.float32))) + >>> return result.asnumpy() + >>> + >>> class Net(Cell): + >>> def __init__(self): + >>> super(Net, self).__init__() + >>> self._softmax = M.Softmax() + >>> + >>> def construct(self, inputs): + >>> out = self._softmax(inputs) + >>> return out + >>> + >>> net = Net() + >>> model = ModelToBeAttacked(net) + >>> attack = GeneticAttack(model, sparse=False) """ def __init__(self, model, model_type='classification', targeted=True, reserve_ratio=0.3, sparse=True, pop_size=6, mutation_rate=0.005, per_bounds=0.15, max_steps=1000, step_size=0.20, temp=0.3, @@ -135,18 +161,77 @@ class GeneticAttack(Attack): np.random.random(cur_pop.shape) < prob) + cur_pop return mutated_pop - def generate(self, inputs, labels): + + def _compute_next_generation(self, cur_pop, fit_vals, x_ori): """ - Generate adversarial examples based on input data and targeted - labels (or ground_truth labels). + Compute pop for next generation Args: - inputs (Union[numpy.ndarray, tuple]): Input samples. The format of inputs should be numpy.ndarray if - model_type='classification'. The format of inputs can be (input1, input2, ...) or only one array if - model_type='detection'. - labels (Union[numpy.ndarray, tuple]): Targeted labels or ground-truth labels. The format of labels should - be numpy.ndarray if model_type='classification'. The format of labels should be (gt_boxes, gt_labels) - if model_type='detection'. + cur_pop (numpy.ndarray): Samples before mutation. + fit_vals (numpy.ndarray): fitness values + x_ori (numpy.ndarray): original input x + + Returns: + numpy.ndarray, pop after generation + + Examples: + >>> cur_pop, elite = self._compute_next_generation(cur_pop, fit_vals, x_ori) + """ + best_fit = max(fit_vals) + + if best_fit > self._best_fit: + self._best_fit = best_fit + self._plateau_times = 0 + else: + self._plateau_times += 1 + adap_threshold = (lambda z: 100 if z > -0.4 else 300)(best_fit) + if self._plateau_times > adap_threshold: + self._adap_times += 1 + self._plateau_times = 0 + if self._adaptive: + step_noise = max(self._step_size, 0.4*(0.9**self._adap_times)) + step_p = max(self._mutation_rate, 0.5*(0.9**self._adap_times)) + else: + step_noise = self._step_size + step_p = self._mutation_rate + step_temp = self._temp + elite = cur_pop[np.argmax(fit_vals)] + select_probs = softmax(fit_vals/step_temp) + select_args = np.arange(self._pop_size) + parents_arg = np.random.choice( + a=select_args, size=2*(self._pop_size - 1), + replace=True, p=select_probs) + parent1 = cur_pop[parents_arg[:self._pop_size - 1]] + parent2 = cur_pop[parents_arg[self._pop_size - 1:]] + parent1_probs = select_probs[parents_arg[:self._pop_size - 1]] + parent2_probs = select_probs[parents_arg[self._pop_size - 1:]] + parent2_probs = parent2_probs / (parent1_probs + parent2_probs) + # duplicate the probabilities to all features of each particle. + dims = len(x_ori.shape) + for _ in range(dims): + parent2_probs = parent2_probs[:, np.newaxis] + parent2_probs = np.tile(parent2_probs, ((1,) + x_ori.shape)) + cross_probs = (np.random.random(parent1.shape) > + parent2_probs).astype(np.int32) + children = parent1*cross_probs + parent2*(1 - cross_probs) + mutated_children = self._mutation( + children, step_noise=self._per_bounds*step_noise, + prob=step_p) + cur_pop = np.concatenate((mutated_children, elite[np.newaxis, :])) + + return cur_pop, elite + + + + def _generate_classification(self, inputs, labels): + """ + Generate adversarial examples based on input data and + targeted labels (or ground_truth labels) for classification model. + + Args: + inputs (Union[numpy.ndarray, tuple]): Input samples. The format of inputs should be numpy.ndarray. + labels (Union[numpy.ndarray, tuple]): Targeted labels or ground-truth labels. + The format of labels should be numpy.ndarray. Returns: - numpy.ndarray, bool values for each attack result. @@ -156,28 +241,27 @@ class GeneticAttack(Attack): - numpy.ndarray, query times for each sample. Examples: - >>> advs = attack.generate([[0.2, 0.3, 0.4], - >>> [0.3, 0.3, 0.2]], - >>> [1, 2]) + >>> batch_size = 6 + >>> x_test = np.random.rand(batch_size, 10) + >>> y_test = np.random.randint(low=0, high=10, size=batch_size) + >>> y_test = np.eye(10)[y_test] + >>> y_test = y_test.astype(np.float32) + >>> _, adv_data, _ = attack._generate_classification(x_test, y_test) """ - if self._model_type == 'classification': - inputs, labels = check_pair_numpy_param('inputs', inputs, - 'labels', labels) - if self._sparse: - if labels.size > 1: - label_squ = np.squeeze(labels) - else: - label_squ = labels - if len(label_squ.shape) >= 2 or label_squ.shape[0] != inputs.shape[0]: - msg = "The parameter 'sparse' of GeneticAttack is True, but the input labels is not sparse style " \ - "and got its shape as {}.".format(labels.shape) - LOGGER.error(TAG, msg) - raise ValueError(msg) + inputs, labels = check_pair_numpy_param('inputs', inputs, 'labels', labels) + if self._sparse: + if labels.size > 1: + label_squ = np.squeeze(labels) else: - labels = np.argmax(labels, axis=1) - images = inputs - elif self._model_type == 'detection': - images, auxiliary_inputs, gt_boxes, gt_labels = check_detection_inputs(inputs, labels) + label_squ = labels + if len(label_squ.shape) >= 2 or label_squ.shape[0] != inputs.shape[0]: + msg = "The parameter 'sparse' of GeneticAttack is True, but the input labels is not sparse style " \ + "and got its shape as {}.".format(labels.shape) + LOGGER.error(TAG, msg) + raise ValueError(msg) + else: + labels = np.argmax(labels, axis=1) + images = inputs adv_list = [] success_list = [] @@ -188,17 +272,7 @@ class GeneticAttack(Attack): if not self._bounds: self._bounds = [np.min(x_ori), np.max(x_ori)] pixel_deep = self._bounds[1] - self._bounds[0] - - if self._model_type == 'classification': - label_i = labels[i] - elif self._model_type == 'detection': - auxiliary_input_i = tuple() - for item in auxiliary_inputs: - auxiliary_input_i += (np.expand_dims(item[i], axis=0),) - gt_boxes_i, gt_labels_i = np.expand_dims(gt_boxes[i], axis=0), np.expand_dims(gt_labels[i], axis=0) - inputs_i = (images[i],) + auxiliary_input_i - confi_ori, gt_object_num = self._detection_scores(inputs_i, gt_boxes_i, gt_labels_i, model=self._model) - LOGGER.info(TAG, 'The number of ground-truth objects is %s', gt_object_num[0]) + label_i = labels[i] # generate particles ori_copies = np.repeat(x_ori[np.newaxis, :], self._pop_size, axis=0) @@ -215,106 +289,148 @@ class GeneticAttack(Attack): ori_copies + pixel_deep*self._per_bounds), self._bounds[0], self._bounds[1]) - if self._model_type == 'classification': - pop_preds = self._model.predict(cur_pop) - query_times += cur_pop.shape[0] - all_preds = np.argmax(pop_preds, axis=1) - if self._targeted: - success_pop = np.equal(label_i, all_preds).astype(np.int32) - else: - success_pop = np.not_equal(label_i, all_preds).astype(np.int32) - is_success = max(success_pop) - best_idx = np.argmax(success_pop) - target_preds = pop_preds[:, label_i] - others_preds_sum = np.sum(pop_preds, axis=1) - target_preds - if self._targeted: - fit_vals = target_preds - others_preds_sum - else: - fit_vals = others_preds_sum - target_preds - - elif self._model_type == 'detection': - confi_adv, correct_nums_adv = self._detection_scores( - (cur_pop,) + auxiliary_input_i, gt_boxes_i, gt_labels_i, model=self._model) - LOGGER.info(TAG, 'The number of correctly detected objects in adversarial image is %s', - np.min(correct_nums_adv)) - query_times += self._pop_size - fit_vals = abs( - confi_ori - confi_adv) - self._c / self._pop_size * np.linalg.norm( - (cur_pop - x_ori).reshape(cur_pop.shape[0], -1), axis=1) - - if np.max(fit_vals) < 0: - self._c /= 2 - - if np.max(fit_vals) < -2: - LOGGER.debug(TAG, - 'best fitness value is %s, which is too small. We recommend that you decrease ' - 'the value of the initialization parameter c.', np.max(fit_vals)) - if iters < 3 and np.max(fit_vals) > 100: - LOGGER.debug(TAG, - 'best fitness value is %s, which is too large. We recommend that you increase ' - 'the value of the initialization parameter c.', np.max(fit_vals)) - - if np.min(correct_nums_adv) <= int(gt_object_num*self._reserve_ratio): - is_success = True - best_idx = np.argmin(correct_nums_adv) + pop_preds = self._model.predict(cur_pop) + query_times += cur_pop.shape[0] + all_preds = np.argmax(pop_preds, axis=1) + if self._targeted: + success_pop = np.equal(label_i, all_preds).astype(np.int32) + else: + success_pop = np.not_equal(label_i, all_preds).astype(np.int32) + is_success = max(success_pop) + best_idx = np.argmax(success_pop) + target_preds = pop_preds[:, label_i] + others_preds_sum = np.sum(pop_preds, axis=1) - target_preds + if self._targeted: + fit_vals = target_preds - others_preds_sum + else: + fit_vals = others_preds_sum - target_preds if is_success: LOGGER.debug(TAG, 'successfully find one adversarial sample ' 'and start Reduction process.') final_adv = cur_pop[best_idx] - if self._model_type == 'classification': - final_adv, query_times = self._reduction(x_ori, query_times, label_i, final_adv, - model=self._model, targeted_attack=self._targeted) + + final_adv, query_times = self._reduction(x_ori, query_times, label_i, final_adv, + model=self._model, targeted_attack=self._targeted) break - best_fit = max(fit_vals) + cur_pop, elite = self._compute_next_generation(cur_pop, fit_vals, x_ori) - if best_fit > self._best_fit: - self._best_fit = best_fit - self._plateau_times = 0 - else: - self._plateau_times += 1 - adap_threshold = (lambda z: 100 if z > -0.4 else 300)(best_fit) - if self._plateau_times > adap_threshold: - self._adap_times += 1 - self._plateau_times = 0 - if self._adaptive: - step_noise = max(self._step_size, 0.4*(0.9**self._adap_times)) - step_p = max(self._mutation_rate, 0.5*(0.9**self._adap_times)) - else: - step_noise = self._step_size - step_p = self._mutation_rate - step_temp = self._temp - elite = cur_pop[np.argmax(fit_vals)] - select_probs = softmax(fit_vals/step_temp) - select_args = np.arange(self._pop_size) - parents_arg = np.random.choice( - a=select_args, size=2*(self._pop_size - 1), - replace=True, p=select_probs) - parent1 = cur_pop[parents_arg[:self._pop_size - 1]] - parent2 = cur_pop[parents_arg[self._pop_size - 1:]] - parent1_probs = select_probs[parents_arg[:self._pop_size - 1]] - parent2_probs = select_probs[parents_arg[self._pop_size - 1:]] - parent2_probs = parent2_probs / (parent1_probs + parent2_probs) - # duplicate the probabilities to all features of each particle. - dims = len(x_ori.shape) - for _ in range(dims): - parent2_probs = parent2_probs[:, np.newaxis] - parent2_probs = np.tile(parent2_probs, ((1,) + x_ori.shape)) - cross_probs = (np.random.random(parent1.shape) > - parent2_probs).astype(np.int32) - childs = parent1*cross_probs + parent2*(1 - cross_probs) - mutated_childs = self._mutation( - childs, step_noise=self._per_bounds*step_noise, - prob=step_p) - cur_pop = np.concatenate((mutated_childs, elite[np.newaxis, :])) + if not is_success: + LOGGER.debug(TAG, 'fail to find adversarial sample.') + final_adv = elite + adv_list.append(final_adv) + + LOGGER.debug(TAG, + 'iteration times is: %d and query times is: %d', + iters, + query_times) + success_list.append(is_success) + query_times_list.append(query_times) + del ori_copies, cur_pert, cur_pop + return np.asarray(success_list), \ + np.asarray(adv_list), \ + np.asarray(query_times_list) + + + + def _generate_detection(self, inputs, labels): + """ + Generate adversarial examples based on input data and + targeted labels (or ground_truth labels) for detection model. + + Args: + inputs (Union[numpy.ndarray, tuple]): Input samples. The format of inputs should be only one array. + labels (Union[numpy.ndarray, tuple]): Targeted labels or ground-truth labels. The format of labels should + be (gt_boxes, gt_labels). + + Returns: + - numpy.ndarray, bool values for each attack result. + + - numpy.ndarray, generated adversarial examples. + + - numpy.ndarray, query times for each sample. + + Examples: + >>> batch_size = 6 + >>> x_test = np.random.rand(batch_size, 10) + >>> y_test = np.random.randint(low=0, high=10, size=batch_size) + >>> y_test = np.eye(10)[y_test] + >>> y_test = y_test.astype(np.float32) + >>> _, adv_data, _ = attack._generate_detection(x_test, y_test) + """ + images, auxiliary_inputs, gt_boxes, gt_labels = check_detection_inputs(inputs, labels) + adv_list = [] + success_list = [] + query_times_list = [] + for i in range(images.shape[0]): + is_success = False + x_ori = images[i] + if not self._bounds: + self._bounds = [np.min(x_ori), np.max(x_ori)] + pixel_deep = self._bounds[1] - self._bounds[0] + auxiliary_input_i = tuple() + for item in auxiliary_inputs: + auxiliary_input_i += (np.expand_dims(item[i], axis=0),) + gt_boxes_i, gt_labels_i = np.expand_dims(gt_boxes[i], axis=0), np.expand_dims(gt_labels[i], axis=0) + inputs_i = (images[i],) + auxiliary_input_i + confi_ori, gt_object_num = self._detection_scores(inputs_i, gt_boxes_i, gt_labels_i, model=self._model) + LOGGER.info(TAG, 'The number of ground-truth objects is %s', gt_object_num[0]) + + # generate particles + ori_copies = np.repeat(x_ori[np.newaxis, :], self._pop_size, axis=0) + # initial perturbations + cur_pert = np.random.uniform(self._bounds[0], self._bounds[1], ori_copies.shape) + cur_pop = ori_copies + cur_pert + query_times = 0 + iters = 0 + + while iters < self._max_steps: + iters += 1 + cur_pop = np.clip(np.clip(cur_pop, + ori_copies - pixel_deep*self._per_bounds, + ori_copies + pixel_deep*self._per_bounds), + self._bounds[0], self._bounds[1]) + + confi_adv, correct_nums_adv = self._detection_scores( + (cur_pop,) + auxiliary_input_i, gt_boxes_i, gt_labels_i, model=self._model) + LOGGER.info(TAG, 'The number of correctly detected objects in adversarial image is %s', + np.min(correct_nums_adv)) + query_times += self._pop_size + fit_vals = abs( + confi_ori - confi_adv) - self._c / self._pop_size * np.linalg.norm( + (cur_pop - x_ori).reshape(cur_pop.shape[0], -1), axis=1) + + if np.max(fit_vals) < 0: + self._c /= 2 + + if np.max(fit_vals) < -2: + LOGGER.debug(TAG, + 'best fitness value is %s, which is too small. We recommend that you decrease ' + 'the value of the initialization parameter c.', np.max(fit_vals)) + if iters < 3 and np.max(fit_vals) > 100: + LOGGER.debug(TAG, + 'best fitness value is %s, which is too large. We recommend that you increase ' + 'the value of the initialization parameter c.', np.max(fit_vals)) + + if np.min(correct_nums_adv) <= int(gt_object_num*self._reserve_ratio): + is_success = True + best_idx = np.argmin(correct_nums_adv) + + if is_success: + LOGGER.debug(TAG, 'successfully find one adversarial sample ' + 'and start Reduction process.') + final_adv = cur_pop[best_idx] + break + + cur_pop, elite = self._compute_next_generation(cur_pop, fit_vals, x_ori) if not is_success: LOGGER.debug(TAG, 'fail to find adversarial sample.') final_adv = elite - if self._model_type == 'detection': - final_adv, query_times = self._fast_reduction( - x_ori, final_adv, query_times, auxiliary_input_i, gt_boxes_i, gt_labels_i, model=self._model) + + final_adv, query_times = self._fast_reduction( + x_ori, final_adv, query_times, auxiliary_input_i, gt_boxes_i, gt_labels_i, model=self._model) adv_list.append(final_adv) LOGGER.debug(TAG, @@ -327,3 +443,38 @@ class GeneticAttack(Attack): return np.asarray(success_list), \ np.asarray(adv_list), \ np.asarray(query_times_list) + + def generate(self, inputs, labels): + """ + Generate adversarial examples based on input data and targeted labels (or ground_truth labels). + + Args: + inputs (Union[numpy.ndarray, tuple]): Input samples. The format of inputs should be numpy.ndarray if + model_type='classification'. The format of inputs can be (input1, input2, ...) or only one array if + model_type='detection'. + labels (Union[numpy.ndarray, tuple]): Targeted labels or ground-truth labels. The format of labels should + be numpy.ndarray if model_type='classification'. The format of labels should be (gt_boxes, gt_labels) + if model_type='detection'. + + Returns: + - numpy.ndarray, bool values for each attack result. + + - numpy.ndarray, generated adversarial examples. + + - numpy.ndarray, query times for each sample. + + Examples: + >>> batch_size = 6 + >>> x_test = np.random.rand(batch_size, 10) + >>> y_test = np.random.randint(low=0, high=10, size=batch_size) + >>> y_test = np.eye(10)[y_test] + >>> y_test = y_test.astype(np.float32) + >>> _, adv_data, _ = attack.generate(x_test, y_test) + """ + if self._model_type == 'classification': + success_list, adv_data, query_time_list = self._generate_classification(inputs, labels) + + elif self._model_type == 'detection': + success_list, adv_data, query_time_list = self._generate_detection(inputs, labels) + + return success_list, adv_data, query_time_list diff --git a/mindarmour/adv_robustness/attacks/black/hop_skip_jump_attack.py b/mindarmour/adv_robustness/attacks/black/hop_skip_jump_attack.py index 2d6e0b1..98b8c34 100644 --- a/mindarmour/adv_robustness/attacks/black/hop_skip_jump_attack.py +++ b/mindarmour/adv_robustness/attacks/black/hop_skip_jump_attack.py @@ -75,11 +75,26 @@ class HopSkipJumpAttack(Attack): ValueError: If constraint not in ['l2', 'linf'] Examples: - >>> x_test = np.asarray(np.random.random((sample_num, - >>> sample_length)), np.float32) - >>> y_test = np.random.randint(0, class_num, size=sample_num) - >>> instance = HopSkipJumpAttack(user_model) - >>> adv_x = instance.generate(x_test, y_test) + >>> import numpy as np + >>> from mindspore import Tensor + >>> from mindarmour import BlackModel + >>> from mindarmour.adv_robustness.attacks import HopSkipJumpAttack + >>> from tests.ut.python.utils.mock_net import Net + >>> + >>> class ModelToBeAttacked(BlackModel): + >>> def __init__(self, network): + >>> super(ModelToBeAttacked, self).__init__() + >>> self._network = network + >>> def predict(self, inputs): + >>> if len(inputs.shape) == 3: + >>> inputs = inputs[np.newaxis, :] + >>> result = self._network(Tensor(inputs.astype(np.float32))) + >>> return result.asnumpy() + >>> + >>> + >>> net = Net() + >>> model = ModelToBeAttacked(net) + >>> attack = HopSkipJumpAttack(model) """ def __init__(self, model, init_num_evals=100, max_num_evals=1000, @@ -173,7 +188,13 @@ class HopSkipJumpAttack(Attack): - numpy.ndarray, query times for each sample. Examples: - >>> generate([[0.1,0.2,0.2],[0.2,0.3,0.4]],[2,6]) + >>> attack = HopSkipJumpAttack(model) + >>> n, c, h, w = 1, 1, 32, 32 + >>> class_num = 3 + >>> x_test = np.asarray(np.random.random((n,c,h,w)), np.float32) + >>> y_test = np.random.randint(0, class_num, size=n) + >>> + >>> _, adv_x, _= attack.generate(x_test, y_test) """ if labels is not None: inputs, labels = check_pair_numpy_param('inputs', inputs, diff --git a/mindarmour/adv_robustness/attacks/black/natural_evolutionary_strategy.py b/mindarmour/adv_robustness/attacks/black/natural_evolutionary_strategy.py index 85da72a..60ae91e 100644 --- a/mindarmour/adv_robustness/attacks/black/natural_evolutionary_strategy.py +++ b/mindarmour/adv_robustness/attacks/black/natural_evolutionary_strategy.py @@ -79,16 +79,27 @@ class NES(Attack): input labels are one-hot-encoded. Default: True. Examples: - >>> SCENE = 'Label_Only' - >>> TOP_K = 5 - >>> num_class = 5 - >>> nes_instance = NES(user_model, SCENE, top_k=TOP_K) - >>> initial_img = np.asarray(np.random.random((32, 32)), np.float32) - >>> target_image = np.asarray(np.random.random((32, 32)), np.float32) - >>> orig_class = 0 - >>> target_class = 2 - >>> nes_instance.set_target_images(target_image) - >>> tag, adv, queries = nes_instance.generate([initial_img], [target_class]) + >>> import numpy as np + >>> from mindspore import Tensor + >>> from mindarmour import BlackModel + >>> from mindarmour.adv_robustness.attacks import NES + >>> from tests.ut.python.utils.mock_net import Net + >>> + >>> class ModelToBeAttacked(BlackModel): + >>> def __init__(self, network): + >>> super(ModelToBeAttacked, self).__init__() + >>> self._network = network + >>> def predict(self, inputs): + >>> if len(inputs.shape) == 3: + >>> inputs = inputs[np.newaxis, :] + >>> result = self._network(Tensor(inputs.astype(np.float32))) + >>> return result.asnumpy() + >>> + >>> net = Net() + >>> model = ModelToBeAttacked(net) + >>> SCENE = 'Query_Limit' + >>> TOP_K = -1 + >>> attack= NES(model, SCENE, top_k=TOP_K) """ def __init__(self, model, scene, max_queries=10000, top_k=-1, num_class=10, batch_size=128, epsilon=0.3, @@ -146,8 +157,19 @@ class NES(Attack): ValueError: If scene is not in ['Label_Only', 'Partial_Info', 'Query_Limit'] Examples: - >>> advs = attack.generate([[0.2, 0.3, 0.4], [0.3, 0.3, 0.2]], - >>> [1, 2]) + >>> net = Net() + >>> model = ModelToBeAttacked(net) + >>> SCENE = 'Query_Limit' + >>> TOP_K = -1 + >>> attack= NES(model, SCENE, top_k=TOP_K) + >>> + >>> num_class = 5 + >>> x_test = np.asarray(np.random.random((32, 32)), np.float32) + >>> target_image = np.asarray(np.random.random((32, 32)), np.float32) + >>> orig_class = 0 + >>> target_class = 2 + >>> attack.set_target_images(target_image) + >>> tag, adv, queries = attack.generate(np.array(x_test), np.array([target_class])) """ inputs, labels = check_pair_numpy_param('inputs', inputs, 'labels', labels) if not self._sparse: diff --git a/mindarmour/adv_robustness/attacks/black/pointwise_attack.py b/mindarmour/adv_robustness/attacks/black/pointwise_attack.py index c62f0c0..c4d824a 100644 --- a/mindarmour/adv_robustness/attacks/black/pointwise_attack.py +++ b/mindarmour/adv_robustness/attacks/black/pointwise_attack.py @@ -47,6 +47,22 @@ class PointWiseAttack(Attack): Default: True. Examples: + >>> import numpy as np + >>> from mindspore import Tensor + >>> from mindarmour import BlackModel + >>> from mindarmour.adv_robustness.attacks import PointWiseAttack + >>> from tests.ut.python.utils.mock_net import Net + >>> + >>> class ModelToBeAttacked(BlackModel): + >>> def __init__(self, network): + >>> super(ModelToBeAttacked, self).__init__() + >>> self._network = network + >>> def predict(self, inputs): + >>> result = self._network(Tensor(inputs.astype(np.float32))) + >>> return result.asnumpy() + >>> + >>> net = Net() + >>> model = ModelToBeAttacked(net) >>> attack = PointWiseAttack(model) """ @@ -79,7 +95,12 @@ class PointWiseAttack(Attack): - numpy.ndarray, query times for each sample. Examples: - >>> is_adv_list, adv_list, query_times_each_adv = attack.generate([[0.1, 0.2, 0.6], [0.3, 0, 0.4]], [2, 3]) + >>> net = Net() + >>> model = ModelToBeAttacked(net) + >>> attack = PointWiseAttack(model) + >>> x_test = np.asarray(np.random.random((1,1,32,32)), np.float32) + >>> y_test = np.random.randint(0, 3, size=1) + >>> is_adv_list, adv_list, query_times_each_adv = attack.generate(x_test, y_test) """ arr_x, arr_y = check_pair_numpy_param('inputs', inputs, 'labels', labels) if not self._sparse: diff --git a/mindarmour/adv_robustness/attacks/black/pso_attack.py b/mindarmour/adv_robustness/attacks/black/pso_attack.py index d960693..0ec5ad5 100644 --- a/mindarmour/adv_robustness/attacks/black/pso_attack.py +++ b/mindarmour/adv_robustness/attacks/black/pso_attack.py @@ -55,7 +55,7 @@ class PSOAttack(Attack): clip_max). Default: None. targeted (bool): If True, turns on the targeted attack. If False, turns on untargeted attack. It should be noted that only untargeted attack - is supproted for model_type='detection', Default: False. + is supported for model_type='detection', Default: False. sparse (bool): If True, input labels are sparse-encoded. If False, input labels are one-hot-encoded. Default: True. model_type (str): The type of targeted model. 'classification' and 'detection' are supported now. @@ -64,7 +64,35 @@ class PSOAttack(Attack): specifically for model_type='detection'. Reserve_ratio should be in the range of (0, 1). Default: 0.3. Examples: - >>> attack = PSOAttack(model) + >>> import numpy as np + >>> import mindspore.nn as nn + >>> from mindspore import Tensor + >>> from mindspore.nn import Cell + >>> from mindarmour import BlackModel + >>> from mindarmour.adv_robustness.attacks import PSOAttack + >>> + >>> class ModelToBeAttacked(BlackModel): + >>> def __init__(self, network): + >>> super(ModelToBeAttacked, self).__init__() + >>> self._network = network + >>> def predict(self, inputs): + >>> if len(inputs.shape) == 1: + >>> inputs = np.expand_dims(inputs, axis=0) + >>> result = self._network(Tensor(inputs.astype(np.float32))) + >>> return result.asnumpy() + >>> + >>> class Net(Cell): + >>> def __init__(self): + >>> super(Net, self).__init__() + >>> self._relu = nn.ReLU() + >>> + >>> def construct(self, inputs): + >>> out = self._relu(inputs) + >>> return out + >>> + >>> net = Net() + >>> model = ModelToBeAttacked(net) + >>> attack = PSOAttack(model, bounds=(0.0, 1.0), pm=0.5, sparse=False) """ def __init__(self, model, model_type='classification', targeted=False, reserve_ratio=0.3, sparse=True, @@ -169,18 +197,33 @@ class PSOAttack(Attack): self._bounds[1]) return mutated_pop - def generate(self, inputs, labels): + def _check_best_fitness(self, best_fitness, iters): + if best_fitness < -2: + LOGGER.debug(TAG, 'best fitness value is %s, which is too small. We recommend that you decrease ' + 'the value of the initialization parameter c.', best_fitness) + if iters < 3 and best_fitness > 100: + LOGGER.debug(TAG, 'best fitness value is %s, which is too large. We recommend that you increase ' + 'the value of the initialization parameter c.', best_fitness) + + def _update_best_fit_position(self, fit_value, par_best_fit, par_best_poi, par, best_fitness, best_position): + for k in range(self._pop_size): + if fit_value[k] > par_best_fit[k]: + par_best_fit[k] = fit_value[k] + par_best_poi[k] = par[k] + if fit_value[k] > best_fitness: + best_fitness = fit_value[k] + best_position = par[k].copy() + return par_best_fit, par_best_poi, best_fitness, best_position + + def _generate_classification(self, inputs, labels): """ - Generate adversarial examples based on input data and targeted - labels (or ground_truth labels). + Generate adversarial examples based on input data and + targeted labels (or ground_truth labels) for classification model. Args: - inputs (Union[numpy.ndarray, tuple]): Input samples. The format of inputs should be numpy.ndarray if - model_type='classification'. The format of inputs can be (input1, input2, ...) or only one array if - model_type='detection'. + inputs (Union[numpy.ndarray, tuple]): Input samples. The format of inputs should be numpy.ndarray. labels (Union[numpy.ndarray, tuple]): Targeted labels or ground-truth labels. The format of labels should - be numpy.ndarray if model_type='classification'. The format of labels should be (gt_boxes, gt_labels) - if model_type='detection'. + be numpy.ndarray. Returns: - numpy.ndarray, bool values for each attack result. @@ -190,28 +233,32 @@ class PSOAttack(Attack): - numpy.ndarray, query times for each sample. Examples: - >>> advs = attack.generate([[0.2, 0.3, 0.4], [0.3, 0.3, 0.2]], - >>> [1, 2]) + >>> net = Net() + >>> model = ModelToBeAttacked(net) + >>> attack = PSOAttack(model, bounds=(0.0, 1.0), pm=0.5, sparse=False) + >>> batch_size = 6 + >>> x_test = np.random.rand(batch_size, 10) + >>> y_test = np.random.randint(low=0, high=10, size=batch_size) + >>> y_test = np.eye(10)[y_test] + >>> y_test = y_test.astype(np.float32) + >>> _, adv_data, _ = attack.generate(x_test, y_test) """ # inputs check - if self._model_type == 'classification': - inputs, labels = check_pair_numpy_param('inputs', inputs, - 'labels', labels) - if self._sparse: - if labels.size > 1: - label_squ = np.squeeze(labels) - else: - label_squ = labels - if len(label_squ.shape) >= 2 or label_squ.shape[0] != inputs.shape[0]: - msg = "The parameter 'sparse' of PSOAttack is True, but the input labels is not sparse style and " \ - "got its shape as {}.".format(labels.shape) - LOGGER.error(TAG, msg) - raise ValueError(msg) + inputs, labels = check_pair_numpy_param('inputs', inputs, + 'labels', labels) + if self._sparse: + if labels.size > 1: + label_squ = np.squeeze(labels) else: - labels = np.argmax(labels, axis=1) - images = inputs - elif self._model_type == 'detection': - images, auxiliary_inputs, gt_boxes, gt_labels = check_detection_inputs(inputs, labels) + label_squ = labels + if len(label_squ.shape) >= 2 or label_squ.shape[0] != inputs.shape[0]: + msg = "The parameter 'sparse' of PSOAttack is True, but the input labels is not sparse style and " \ + "got its shape as {}.".format(labels.shape) + LOGGER.error(TAG, msg) + raise ValueError(msg) + else: + labels = np.argmax(labels, axis=1) + images = inputs # generate one adversarial each time adv_list = [] @@ -226,17 +273,9 @@ class PSOAttack(Attack): pixel_deep = self._bounds[1] - self._bounds[0] q_times += 1 - if self._model_type == 'classification': - label_i = labels[i] - confi_ori = self._confidence_cla(x_ori, label_i) - elif self._model_type == 'detection': - auxiliary_input_i = tuple() - for item in auxiliary_inputs: - auxiliary_input_i += (np.expand_dims(item[i], axis=0),) - gt_boxes_i, gt_labels_i = np.expand_dims(gt_boxes[i], axis=0), np.expand_dims(gt_labels[i], axis=0) - inputs_i = (images[i],) + auxiliary_input_i - confi_ori, gt_object_num = self._detection_scores(inputs_i, gt_boxes_i, gt_labels_i, self._model) - LOGGER.info(TAG, 'The number of ground-truth objects is %s', gt_object_num[0]) + + label_i = labels[i] + confi_ori = self._confidence_cla(x_ori, label_i) # step1, initializing # initial global optimum fitness value, cannot set to be -inf @@ -277,57 +316,178 @@ class PSOAttack(Attack): x_copies + (np.abs(x_copies) + 0.1*pixel_deep)*self._per_bounds), self._bounds[0], self._bounds[1]) - if self._model_type == 'classification': - confi_adv = self._confidence_cla(par, label_i) - elif self._model_type == 'detection': - confi_adv, _ = self._detection_scores( - (par,) + auxiliary_input_i, gt_boxes_i, gt_labels_i, self._model) + + confi_adv = self._confidence_cla(par, label_i) + q_times += self._pop_size fit_value = self._fitness(confi_ori, confi_adv, x_ori, par) - for k in range(self._pop_size): - if fit_value[k] > par_best_fit[k]: - par_best_fit[k] = fit_value[k] - par_best_poi[k] = par[k] - if fit_value[k] > best_fitness: - best_fitness = fit_value[k] - best_position = par[k].copy() + par_best_fit, par_best_poi, best_fitness, best_position = self._update_best_fit_position(fit_value, + par_best_fit, + par_best_poi, + par, + best_fitness, + best_position) iters += 1 - if best_fitness < -2: - LOGGER.debug(TAG, 'best fitness value is %s, which is too small. We recommend that you decrease ' - 'the value of the initialization parameter c.', best_fitness) - if iters < 3 and best_fitness > 100: - LOGGER.debug(TAG, 'best fitness value is %s, which is too large. We recommend that you increase ' - 'the value of the initialization parameter c.', best_fitness) + self._check_best_fitness(best_fitness, iters) + is_mutation = False if (best_fitness - last_best_fit) < last_best_fit*0.05: is_mutation = True q_times += 1 - if self._model_type == 'classification': - cur_pre = self._model.predict(best_position) - cur_label = np.argmax(cur_pre) - if (self._targeted and cur_label == label_i) or (not self._targeted and cur_label != label_i): - is_success = True - elif self._model_type == 'detection': - _, correct_nums_adv = self._detection_scores( - (best_position,) + auxiliary_input_i, gt_boxes_i, gt_labels_i, self._model) - LOGGER.info(TAG, 'The number of correctly detected objects in adversarial image is %s', - correct_nums_adv[0]) - if correct_nums_adv <= int(gt_object_num*self._reserve_ratio): - is_success = True + + cur_pre = self._model.predict(best_position) + cur_label = np.argmax(cur_pre) + if (self._targeted and cur_label == label_i) or (not self._targeted and cur_label != label_i): + is_success = True if is_success: LOGGER.debug(TAG, 'successfully find one adversarial ' 'sample and start Reduction process') # step3, reduction - if self._model_type == 'classification': - best_position, q_times = self._reduction(x_ori, q_times, label_i, best_position, self._model, - targeted_attack=self._targeted) + best_position, q_times = self._reduction(x_ori, q_times, label_i, best_position, self._model, + targeted_attack=self._targeted) + break + + if not is_success: + LOGGER.debug(TAG, + 'fail to find adversarial sample, iteration ' + 'times is: %d and query times is: %d', + iters, + q_times) + adv_list.append(best_position) + success_list.append(is_success) + query_times_list.append(q_times) + del x_copies, cur_noise, par, par_best_poi + return np.asarray(success_list), \ + np.asarray(adv_list), \ + np.asarray(query_times_list) + + + def _generate_detection(self, inputs, labels): + """ + Generate adversarial examples based on input data and + targeted labels (or ground_truth labels) for detection model. + + Args: + inputs (Union[numpy.ndarray, tuple]): Input samples. The format of inputs can be (input1, input2, ...) + or only one array. + labels (Union[numpy.ndarray, tuple]): Targeted labels or ground-truth labels. + The format of labels should be (gt_boxes, gt_labels). + + Returns: + - numpy.ndarray, bool values for each attack result. + + - numpy.ndarray, generated adversarial examples. + + - numpy.ndarray, query times for each sample. + + Examples: + >>> net = Net() + >>> model = ModelToBeAttacked(net) + >>> attack = PSOAttack(model, bounds=(0.0, 1.0), pm=0.5, sparse=False) + >>> batch_size = 6 + >>> x_test = np.random.rand(batch_size, 10) + >>> y_test = np.random.randint(low=0, high=10, size=batch_size) + >>> y_test = np.eye(10)[y_test] + >>> y_test = y_test.astype(np.float32) + >>> _, adv_data, _ = attack.generate(x_test, y_test) + """ + # inputs check + images, auxiliary_inputs, gt_boxes, gt_labels = check_detection_inputs(inputs, labels) + + # generate one adversarial each time + adv_list = [] + success_list = [] + query_times_list = [] + for i in range(images.shape[0]): + is_success = False + q_times = 0 + x_ori = images[i] + if not self._bounds: + self._bounds = [np.min(x_ori), np.max(x_ori)] + pixel_deep = self._bounds[1] - self._bounds[0] + + q_times += 1 + auxiliary_input_i = tuple() + for item in auxiliary_inputs: + auxiliary_input_i += (np.expand_dims(item[i], axis=0),) + gt_boxes_i, gt_labels_i = np.expand_dims(gt_boxes[i], axis=0), np.expand_dims(gt_labels[i], axis=0) + inputs_i = (images[i],) + auxiliary_input_i + confi_ori, gt_object_num = self._detection_scores(inputs_i, gt_boxes_i, gt_labels_i, self._model) + LOGGER.info(TAG, 'The number of ground-truth objects is %s', gt_object_num[0]) + + # step1, initializing + # initial global optimum fitness value, cannot set to be -inf + best_fitness = -np.inf + # initial global optimum position + best_position = x_ori + x_copies = np.repeat(x_ori[np.newaxis, :], self._pop_size, axis=0) + cur_noise = np.clip(np.random.random(x_copies.shape)*pixel_deep, + (0 - self._per_bounds)*(np.abs(x_copies) + 0.1), + self._per_bounds*(np.abs(x_copies) + 0.1)) + # initial advs + par = np.clip(x_copies + cur_noise, self._bounds[0], self._bounds[1]) + # initial optimum positions for particles + par_best_poi = np.copy(par) + # initial optimum fitness values + par_best_fit = -np.inf*np.ones(self._pop_size) + # step2, optimization + # initial velocities for particles + v_particles = np.zeros(par.shape) + is_mutation = False + iters = 0 + while iters < self._t_max: + last_best_fit = best_fitness + ran_1 = np.random.random(par.shape) + ran_2 = np.random.random(par.shape) + v_particles = self._step_size*( + v_particles + self._c1*ran_1*(best_position - par)) \ + + self._c2*ran_2*(par_best_poi - par) + + par += v_particles + + if iters > 6 and is_mutation: + par = self._mutation_op(par) + + par = np.clip(np.clip(par, + x_copies - (np.abs(x_copies) + 0.1*pixel_deep)*self._per_bounds, + x_copies + (np.abs(x_copies) + 0.1*pixel_deep)*self._per_bounds), + self._bounds[0], self._bounds[1]) + + confi_adv, _ = self._detection_scores( + (par,) + auxiliary_input_i, gt_boxes_i, gt_labels_i, self._model) + q_times += self._pop_size + fit_value = self._fitness(confi_ori, confi_adv, x_ori, par) + par_best_fit, par_best_poi, best_fitness, best_position = self._update_best_fit_position(fit_value, + par_best_fit, + par_best_poi, + par, + best_fitness, + best_position) + iters += 1 + self._check_best_fitness(best_fitness, iters) + + is_mutation = False + if (best_fitness - last_best_fit) < last_best_fit*0.05: + is_mutation = True + + q_times += 1 + + _, correct_nums_adv = self._detection_scores( + (best_position,) + auxiliary_input_i, gt_boxes_i, gt_labels_i, self._model) + LOGGER.info(TAG, 'The number of correctly detected objects in adversarial image is %s', + correct_nums_adv[0]) + if correct_nums_adv <= int(gt_object_num*self._reserve_ratio): + is_success = True + + if is_success: + LOGGER.debug(TAG, 'successfully find one adversarial ' + 'sample and start Reduction process') break - if self._model_type == 'detection': - best_position, q_times = self._fast_reduction(x_ori, best_position, q_times, - auxiliary_input_i, gt_boxes_i, gt_labels_i, self._model) + best_position, q_times = self._fast_reduction(x_ori, best_position, q_times, + auxiliary_input_i, gt_boxes_i, gt_labels_i, self._model) if not is_success: LOGGER.debug(TAG, 'fail to find adversarial sample, iteration ' @@ -341,3 +501,43 @@ class PSOAttack(Attack): return np.asarray(success_list), \ np.asarray(adv_list), \ np.asarray(query_times_list) + + def generate(self, inputs, labels): + """ + Generate adversarial examples based on input data and + targeted labels (or ground_truth labels). + + Args: + inputs (Union[numpy.ndarray, tuple]): Input samples. The format of inputs should be numpy.ndarray if + model_type='classification'. The format of inputs can be (input1, input2, ...) or only one array if + model_type='detection'. + labels (Union[numpy.ndarray, tuple]): Targeted labels or ground-truth labels. The format of labels should + be numpy.ndarray if model_type='classification'. The format of labels should be (gt_boxes, gt_labels) + if model_type='detection'. + + Returns: + - numpy.ndarray, bool values for each attack result. + + - numpy.ndarray, generated adversarial examples. + + - numpy.ndarray, query times for each sample. + + Examples: + >>> net = Net() + >>> model = ModelToBeAttacked(net) + >>> attack = PSOAttack(model, bounds=(0.0, 1.0), pm=0.5, sparse=False) + >>> batch_size = 6 + >>> x_test = np.random.rand(batch_size, 10) + >>> y_test = np.random.randint(low=0, high=10, size=batch_size) + >>> y_test = np.eye(10)[y_test] + >>> y_test = y_test.astype(np.float32) + >>> _, adv_data, _ = attack.generate(x_test, y_test) + """ + # inputs check + if self._model_type == 'classification': + success_list, adv_data, query_time_list = self._generate_classification(inputs, labels) + + elif self._model_type == 'detection': + success_list, adv_data, query_time_list = self._generate_detection(inputs, labels) + + return success_list, adv_data, query_time_list diff --git a/mindarmour/adv_robustness/attacks/black/salt_and_pepper_attack.py b/mindarmour/adv_robustness/attacks/black/salt_and_pepper_attack.py index 054add8..786ab69 100644 --- a/mindarmour/adv_robustness/attacks/black/salt_and_pepper_attack.py +++ b/mindarmour/adv_robustness/attacks/black/salt_and_pepper_attack.py @@ -40,6 +40,22 @@ class SaltAndPepperNoiseAttack(Attack): Default: True. Examples: + >>> import numpy as np + >>> from mindspore import Tensor + >>> from mindarmour import BlackModel + >>> from mindarmour.adv_robustness.attacks import SaltAndPepperNoiseAttack + >>> from tests.ut.python.utils.mock_net import Net + >>> + >>> class ModelToBeAttacked(BlackModel): + >>> def __init__(self, network): + >>> super(ModelToBeAttacked, self).__init__() + >>> self._network = network + >>> def predict(self, inputs): + >>> result = self._network(Tensor(inputs.astype(np.float32))) + >>> return result.asnumpy() + >>> + >>> net = Net() + >>> model = ModelToBeAttacked(net) >>> attack = SaltAndPepperNoiseAttack(model) """ @@ -69,7 +85,12 @@ class SaltAndPepperNoiseAttack(Attack): - numpy.ndarray, query times for each sample. Examples: - >>> adv_list = attack.generate(([[0.1, 0.2, 0.6], [0.3, 0, 0.4]], [1, 2]) + >>> net = Net() + >>> model = ModelToBeAttacked(net) + >>> attack = PointWiseAttack(model) + >>> x_test = np.asarray(np.random.random((1,1,32,32)), np.float32) + >>> y_test = np.random.randint(0, 3, size=1) + >>> _, adv_list, _ = attack.generate(x_test, y_test) """ arr_x, arr_y = check_pair_numpy_param('inputs', inputs, 'labels', labels) if not self._sparse: diff --git a/mindarmour/adv_robustness/attacks/carlini_wagner.py b/mindarmour/adv_robustness/attacks/carlini_wagner.py index 09c4046..a65532d 100644 --- a/mindarmour/adv_robustness/attacks/carlini_wagner.py +++ b/mindarmour/adv_robustness/attacks/carlini_wagner.py @@ -95,7 +95,24 @@ class CarliniWagnerL2Attack(Attack): input labels are onehot-coded. Default: True. Examples: - >>> attack = CarliniWagnerL2Attack(network) + >>> import numpy as np + >>> import mindspore.ops.operations as M + >>> from mindspore.nn import Cell + >>> from mindarmour.adv_robustness.attacks import CarliniWagnerL2Attack + >>> + >>> class Net(Cell): + >>> def __init__(self): + >>> super(Net, self).__init__() + >>> self._softmax = M.Softmax() + >>> + >>> def construct(self, inputs): + >>> out = self._softmax(inputs) + >>> return out + >>> + >>> input_np = np.array([[0.1, 0.2, 0.7, 0.5, 0.4]]).astype(np.float32) + >>> label_np = np.array([3]).astype(np.int64) + >>> num_classes = input_np.shape[1] + >>> attack = CarliniWagnerL2Attack(net, num_classes, targeted=False) """ def __init__(self, network, num_classes, box_min=0.0, box_max=1.0, @@ -246,6 +263,14 @@ class CarliniWagnerL2Attack(Attack): the_grad = the_grad*diff return inputs, the_grad + def _check_success(self, logits, labels): + """ check if attack success (include all examples)""" + if self._targeted: + is_adv = (np.argmax(logits, axis=1) == labels) + else: + is_adv = (np.argmax(logits, axis=1) != labels) + return is_adv + def generate(self, inputs, labels): """ Generate adversarial examples based on input data and targeted labels. @@ -259,7 +284,30 @@ class CarliniWagnerL2Attack(Attack): numpy.ndarray, generated adversarial examples. Examples: - >>> advs = attack.generate([[0.1, 0.2, 0.6], [0.3, 0, 0.4]], [1, 2]] + >>> import numpy as np + >>> import mindspore.ops.operations as M + >>> from mindspore.nn import Cell + >>> from mindarmour.adv_robustness.attacks import CarliniWagnerL2Attack + >>> + >>> class Net(Cell): + >>> def __init__(self): + >>> super(Net, self).__init__() + >>> self._softmax = M.Softmax() + >>> + >>> def construct(self, inputs): + >>> out = self._softmax(inputs) + >>> return out + >>> + >>> input_np = np.array([[0.1, 0.2, 0.7, 0.5, 0.4]]).astype(np.float32) + >>> num_classes = input_np.shape[1] + >>> + >>> label_np = np.array([3]).astype(np.int64) + >>> attack_nonTargeted = CarliniWagnerL2Attack(net, num_classes, targeted=False) + >>> advs_nonTargeted = attack_nonTargeted.generate(input_np, label_np) + >>> + >>> target_np = np.array([1]).astype(np.int64) + >>> attack_targeted = CarliniWagnerL2Attack(net, num_classes, targeted=False) + >>> advs_targeted = attack_targeted.generate(input_np, target_np) """ LOGGER.debug(TAG, "enter the func generate.") @@ -302,11 +350,7 @@ class CarliniWagnerL2Attack(Attack): logits, x_input, reconstructed_original, labels, const, self._confidence) - # check if attack success (include all examples) - if self._targeted: - is_adv = (np.argmax(logits, axis=1) == labels) - else: - is_adv = (np.argmax(logits, axis=1) != labels) + is_adv = self._check_success(logits, labels) for i in range(samples_num): if is_adv[i]: diff --git a/mindarmour/adv_robustness/attacks/deep_fool.py b/mindarmour/adv_robustness/attacks/deep_fool.py index a303b17..2d20327 100644 --- a/mindarmour/adv_robustness/attacks/deep_fool.py +++ b/mindarmour/adv_robustness/attacks/deep_fool.py @@ -117,7 +117,23 @@ class DeepFool(Attack): input labels are onehot-coded. Default: True. Examples: - >>> attack = DeepFool(network) + >>> import numpy as np + >>> import mindspore.ops.operations as P + >>> from mindspore.nn import Cell + >>> from mindspore import Tensor + >>> from mindarmour.adv_robustness.attacks import DeepFool + >>> + >>> def __init__(self): + >>> super(Net, self).__init__() + >>> self._softmax = P.Softmax() + >>> + >>> def construct(self, inputs): + >>> out = self._softmax(inputs) + >>> return out + >>> + >>> net = Net() + >>> attack = DeepFool(net, classes, max_iters=10, norm_level=2, + bounds=(0.0, 1.0)) """ def __init__(self, network, num_classes, model_type='classification', @@ -165,14 +181,30 @@ class DeepFool(Attack): NotImplementedError: If norm_level is not in [2, np.inf, '2', 'inf']. Examples: - >>> advs = generate([[0.2, 0.3, 0.4], [0.3, 0.4, 0.5]], [1, 2]) + >>> input_shape = (1, 5) + >>> _, classes = input_shape + >>> input_np = np.array([[0.1, 0.2, 0.7, 0.5, 0.4]]).astype(np.float32) + >>> input_me = Tensor(input_np) + >>> true_labels = np.argmax(net(input_me).asnumpy(), axis=1) + >>> attack = DeepFool(net, classes, max_iters=10, norm_level=2, bounds=(0.0, 1.0)) + >>> advs = attack.generate(input_np, true_labels) """ + if self._model_type == 'detection': return self._generate_detection(inputs, labels) if self._model_type == 'classification': return self._generate_classification(inputs, labels) return None + def _update_image(self, x_origin, r_tot): + """update image based on bounds""" + if self._bounds is not None: + clip_min, clip_max = self._bounds + images = x_origin + (1 + self._overshoot)*r_tot*(clip_max-clip_min) + images = np.clip(images, clip_min, clip_max) + else: + images = x_origin + (1 + self._overshoot)*r_tot + return images def _generate_detection(self, inputs, labels): """Generate adversarial examples in detection scenario""" @@ -239,19 +271,12 @@ class DeepFool(Attack): raise NotImplementedError(msg) r_tot[idx, ...] = r_tot[idx, ...] + r_i - if self._bounds is not None: - clip_min, clip_max = self._bounds - images = x_origin + (1 + self._overshoot)*r_tot*(clip_max-clip_min) - images = np.clip(images, clip_min, clip_max) - else: - images = x_origin + (1 + self._overshoot)*r_tot + images = self._update_image(x_origin, r_tot) iteration += 1 images = images.astype(images_dtype) del preds_logits, grads return images - - def _generate_classification(self, inputs, labels): """Generate adversarial examples in classification scenario""" inputs, labels = check_pair_numpy_param('inputs', inputs, diff --git a/mindarmour/adv_robustness/attacks/gradient_method.py b/mindarmour/adv_robustness/attacks/gradient_method.py index e230a72..c0c14dd 100644 --- a/mindarmour/adv_robustness/attacks/gradient_method.py +++ b/mindarmour/adv_robustness/attacks/gradient_method.py @@ -47,9 +47,25 @@ class GradientMethod(Attack): is already equipped with loss function. Default: None. Examples: + >>> import numpy as np + >>> import mindspore.nn as nn + >>> from mindspore.nn import Cell, SoftmaxCrossEntropyWithLogits + >>> from mindspore import Tensor + >>> from mindarmour.adv_robustness.attacksimport FastGradientMethod + >>> + >>> class Net(Cell): + >>> def __init__(self): + >>> super(Net, self).__init__() + >>> self._relu = nn.ReLU() + >>> + >>> def construct(self, inputs): + >>> out = self._relu(inputs) + >>> return out + >>> >>> inputs = np.array([[0.1, 0.2, 0.6], [0.3, 0, 0.4]]) >>> labels = np.array([[0, 1, 0, 0, 0], [0, 0, 1, 0, 0]]) - >>> attack = FastGradientMethod(network, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False)) + >>> net = Net() + >>> attack = FastGradientMethod(net, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False)) >>> adv_x = attack.generate(inputs, labels) """ @@ -155,9 +171,24 @@ class FastGradientMethod(GradientMethod): is already equipped with loss function. Default: None. Examples: + >>> import numpy as np + >>> import mindspore.nn as nn + >>> from mindspore.nn import Cell, SoftmaxCrossEntropyWithLogits + >>> from mindarmour.adv_robustness.attacks import FastGradientMethod + >>> + >>> class Net(Cell): + >>> def __init__(self): + >>> super(Net, self).__init__() + >>> self._relu = nn.ReLU() + >>> + >>> def construct(self, inputs): + >>> out = self._relu(inputs) + >>> return out + >>> >>> inputs = np.array([[0.1, 0.2, 0.6], [0.3, 0, 0.4]]) >>> labels = np.array([[0, 1, 0, 0, 0], [0, 0, 1, 0, 0]]) - >>> attack = FastGradientMethod(network, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False)) + >>> net = Net() + >>> attack = FastGradientMethod(net, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False)) >>> adv_x = attack.generate(inputs, labels) """ @@ -223,9 +254,24 @@ class RandomFastGradientMethod(FastGradientMethod): ValueError: eps is smaller than alpha! Examples: + >>> import numpy as np + >>> import mindspore.nn as nn + >>> from mindspore.nn import Cell, SoftmaxCrossEntropyWithLogits + >>> from mindarmour.adv_robustness.attacks import RandomFastGradientMethod + >>> + >>> class Net(Cell): + >>> def __init__(self): + >>> super(Net, self).__init__() + >>> self._relu = nn.ReLU() + >>> + >>> def construct(self, inputs): + >>> out = self._relu(inputs) + >>> return out + >>> + >>> net = Net() >>> inputs = np.array([[0.1, 0.2, 0.6], [0.3, 0, 0.4]]) >>> labels = np.array([[0, 1, 0, 0, 0], [0, 0, 1, 0, 0]]) - >>> attack = RandomFastGradientMethod(network, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False)) + >>> attack = RandomFastGradientMethod(net, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False)) >>> adv_x = attack.generate(inputs, labels) """ @@ -265,9 +311,24 @@ class FastGradientSignMethod(GradientMethod): is already equipped with loss function. Default: None. Examples: + >>> import numpy as np + >>> import mindspore.nn as nn + >>> from mindspore.nn import Cell, SoftmaxCrossEntropyWithLogits + >>> from mindarmour.adv_robustness.attacks import FastGradientSignMethod + >>> + >>> class Net(Cell): + >>> def __init__(self): + >>> super(Net, self).__init__() + >>> self._relu = nn.ReLU() + >>> + >>> def construct(self, inputs): + >>> out = self._relu(inputs) + >>> return out + >>> + >>> net = Net() >>> inputs = np.array([[0.1, 0.2, 0.6], [0.3, 0, 0.4]]) >>> labels = np.array([[0, 1, 0, 0, 0], [0, 0, 1, 0, 0]]) - >>> attack = FastGradientSignMethod(network, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False)) + >>> attack = FastGradientSignMethod(net, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False)) >>> adv_x = attack.generate(inputs, labels) """ @@ -329,9 +390,24 @@ class RandomFastGradientSignMethod(FastGradientSignMethod): ValueError: eps is smaller than alpha! Examples: + >>> import numpy as np + >>> import mindspore.nn as nn + >>> from mindspore.nn import Cell, SoftmaxCrossEntropyWithLogits + >>> from mindarmour.adv_robustness.attacks import RandomFastGradientSignMethod + >>> + >>> class Net(Cell): + >>> def __init__(self): + >>> super(Net, self).__init__() + >>> self._relu = nn.ReLU() + >>> + >>> def construct(self, inputs): + >>> out = self._relu(inputs) + >>> return out + >>> + >>> net = Net() >>> inputs = np.array([[0.1, 0.2, 0.6], [0.3, 0, 0.4]]) >>> labels = np.array([[0, 1, 0, 0, 0], [0, 0, 1, 0, 0]]) - >>> attack = RandomFastGradientSignMethod(network, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False)) + >>> attack = RandomFastGradientSignMethod(net, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False)) >>> adv_x = attack.generate(inputs, labels) """ @@ -366,6 +442,21 @@ class LeastLikelyClassMethod(FastGradientSignMethod): is already equipped with loss function. Default: None. Examples: + >>> import numpy as np + >>> import mindspore.nn as nn + >>> from mindspore.nn import Cell, SoftmaxCrossEntropyWithLogits + >>> from mindarmour.adv_robustness.attacks import LeastLikelyClassMethod + >>> + >>> class Net(Cell): + >>> def __init__(self): + >>> super(Net, self).__init__() + >>> self._relu = nn.ReLU() + >>> + >>> def construct(self, inputs): + >>> out = self._relu(inputs) + >>> return out + >>> + >>> net = Net() >>> inputs = np.array([[0.1, 0.2, 0.6], [0.3, 0, 0.4]]) >>> labels = np.array([[0, 1, 0, 0, 0], [0, 0, 1, 0, 0]]) >>> attack = LeastLikelyClassMethod(network, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False)) @@ -404,6 +495,21 @@ class RandomLeastLikelyClassMethod(FastGradientSignMethod): ValueError: eps is smaller than alpha! Examples: + >>> import numpy as np + >>> import mindspore.nn as nn + >>> from mindspore.nn import Cell, SoftmaxCrossEntropyWithLogits + >>> from mindarmour.adv_robustness.attacks import RandomLeastLikelyClassMethod + >>> + >>> class Net(Cell): + >>> def __init__(self): + >>> super(Net, self).__init__() + >>> self._relu = nn.ReLU() + >>> + >>> def construct(self, inputs): + >>> out = self._relu(inputs) + >>> return out + >>> + >>> net = Net() >>> inputs = np.array([[0.1, 0.2, 0.6], [0.3, 0, 0.4]]) >>> labels = np.array([[0, 1, 0, 0, 0], [0, 0, 1, 0, 0]]) >>> attack = RandomLeastLikelyClassMethod(network, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False)) diff --git a/mindarmour/adv_robustness/attacks/iterative_gradient_method.py b/mindarmour/adv_robustness/attacks/iterative_gradient_method.py index 2577eac..0c13135 100644 --- a/mindarmour/adv_robustness/attacks/iterative_gradient_method.py +++ b/mindarmour/adv_robustness/attacks/iterative_gradient_method.py @@ -184,7 +184,22 @@ class BasicIterativeMethod(IterativeGradientMethod): is already equipped with loss function. Default: None. Examples: - >>> attack = BasicIterativeMethod(network, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False)) + >>> import numpy as np + >>> import mindspore.nn as nn + >>> from mindspore.nn import Cell, SoftmaxCrossEntropyWithLogits + >>> from mindarmour.adv_robustness.attacks import BasicIterativeMethod + >>> + >>> class Net(Cell): + >>> def __init__(self): + >>> super(Net, self).__init__() + >>> self._relu = nn.ReLU() + >>> + >>> def construct(self, inputs): + >>> out = self._relu(inputs) + >>> return out + >>> + >>> net = Net() + >>> attack = BasicIterativeMethod(netw, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False)) """ def __init__(self, network, eps=0.3, eps_iter=0.1, bounds=(0.0, 1.0), is_targeted=False, nb_iter=5, loss_fn=None): @@ -215,6 +230,17 @@ class BasicIterativeMethod(IterativeGradientMethod): numpy.ndarray, generated adversarial examples. Examples: + >>> class Net(Cell): + >>> def __init__(self): + >>> super(Net, self).__init__() + >>> self._relu = nn.ReLU() + >>> + >>> def construct(self, inputs): + >>> out = self._relu(inputs) + >>> return out + >>> + >>> net = Net() + >>> attack = BasicIterativeMethod(net, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False)) >>> adv_x = attack.generate([[0.3, 0.2, 0.6], >>> [0.3, 0.2, 0.4]], >>> [[0, 0, 1, 0, 0, 0, 0, 0, 0, 0], @@ -303,6 +329,22 @@ class MomentumIterativeMethod(IterativeGradientMethod): numpy.ndarray, generated adversarial examples. Examples: + >>> import numpy as np + >>> import mindspore.nn as nn + >>> from mindspore.nn import Cell, SoftmaxCrossEntropyWithLogits + >>> from mindarmour.adv_robustness.attacks import MomentumIterativeMethod + >>> + >>> class Net(Cell): + >>> def __init__(self): + >>> super(Net, self).__init__() + >>> self._relu = nn.ReLU() + >>> + >>> def construct(self, inputs): + >>> out = self._relu(inputs) + >>> return out + >>> + >>> net = Net() + >>> attack = MomentumIterativeMethod(net, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False)) >>> adv_x = attack.generate([[0.5, 0.2, 0.6], >>> [0.3, 0, 0.2]], >>> [[0, 0, 0, 0, 0, 0, 0, 0, 1, 0], @@ -433,6 +475,22 @@ class ProjectedGradientDescent(BasicIterativeMethod): numpy.ndarray, generated adversarial examples. Examples: + >>> import numpy as np + >>> import mindspore.nn as nn + >>> from mindspore.nn import Cell, SoftmaxCrossEntropyWithLogits + >>> from mindarmour.adv_robustness.attacks import ProjectedGradientDescent + >>> + >>> class Net(Cell): + >>> def __init__(self): + >>> super(Net, self).__init__() + >>> self._relu = nn.ReLU() + >>> + >>> def construct(self, inputs): + >>> out = self._relu(inputs) + >>> return out + >>> + >>> net = Net() + >>> attack = ProjectedGradientDescent(net, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False)) >>> adv_x = attack.generate([[0.6, 0.2, 0.6], >>> [0.3, 0.3, 0.4]], >>> [[0, 0, 0, 0, 0, 0, 0, 0, 0, 1], diff --git a/mindarmour/adv_robustness/attacks/jsma.py b/mindarmour/adv_robustness/attacks/jsma.py index ccd9e7d..560a5a6 100644 --- a/mindarmour/adv_robustness/attacks/jsma.py +++ b/mindarmour/adv_robustness/attacks/jsma.py @@ -54,7 +54,23 @@ class JSMAAttack(Attack): input labels are onehot-coded. Default: True. Examples: - >>> attack = JSMAAttack(network) + >>> import numpy as np + >>> import mindspore.nn as nn + >>> from mindspore.nn import Cell + >>> from mindarmour.adv_robustness.attacks import JSMAAttack + >>> class Net(Cell): + >>> def __init__(self): + >>> super(Net, self).__init__() + >>> self._relu = nn.ReLU() + >>> + >>> def construct(self, inputs): + >>> out = self._relu(inputs) + >>> return out + >>> + >>> net = Net() + >>> input_shape = (1, 5) + >>> batch_size, classes = input_shape + >>> attack = JSMAAttack(net, classes, max_iteration=5) """ def __init__(self, network, num_classes, box_min=0.0, box_max=1.0, @@ -181,7 +197,13 @@ class JSMAAttack(Attack): numpy.ndarray, adversarial samples. Examples: - >>> advs = generate([[0.2, 0.3, 0.4], [0.3, 0.4, 0.5]], [1, 2]) + >>> input_shape = (1, 5) + >>> input_np = np.random.random(input_shape).astype(np.float32) + >>> label_np = np.random.randint(classes, size=batch_size) + >>> batch_size, classes = input_shape + >>> + >>> attack = JSMAAttack(net, classes, max_iteration=5) + >>> advs = attack.generate(input_np, label_np) """ inputs, labels = check_pair_numpy_param('inputs', inputs, 'labels', labels) diff --git a/mindarmour/adv_robustness/attacks/lbfgs.py b/mindarmour/adv_robustness/attacks/lbfgs.py index 86f4f33..dd7912c 100644 --- a/mindarmour/adv_robustness/attacks/lbfgs.py +++ b/mindarmour/adv_robustness/attacks/lbfgs.py @@ -54,7 +54,12 @@ class LBFGS(Attack): input labels are onehot-coded. Default: False. Examples: - >>> attack = LBFGS(network) + >>> import numpy as np + >>> from mindarmour.adv_robustness.attacks import LBFGS + >>> from tests.ut.python.utils.mock_net import Net + >>> + >>> net = Net() + >>> attack = LBFGS(net, is_targeted=True) """ def __init__(self, network, eps=1e-5, bounds=(0.0, 1.0), is_targeted=True, nb_iter=150, search_iters=30, loss_fn=None, sparse=False): @@ -94,6 +99,7 @@ class LBFGS(Attack): numpy.ndarray, generated adversarial examples. Examples: + >>> attack = LBFGS(net, is_targeted=True) >>> adv = attack.generate([[0.1, 0.2, 0.6], [0.3, 0, 0.4]], [2, 2]) """ LOGGER.debug(TAG, 'start to generate adv image.') @@ -191,7 +197,7 @@ class LBFGS(Attack): def _optimize(self, start_input, labels, epsilon): """ - Given loss fuction and gradient, use l_bfgs_b algorithm to update input + Given loss function and gradient, use l_bfgs_b algorithm to update input sample. The epsilon will be doubled until an adversarial example is found. Args: diff --git a/tests/ut/python/adv_robustness/attacks/black/test_nes.py b/tests/ut/python/adv_robustness/attacks/black/test_nes.py index f6e9c76..4530b9b 100644 --- a/tests/ut/python/adv_robustness/attacks/black/test_nes.py +++ b/tests/ut/python/adv_robustness/attacks/black/test_nes.py @@ -28,7 +28,7 @@ from tests.ut.python.utils.mock_net import Net context.set_context(mode=context.GRAPH_MODE) LOGGER = LogUtil.get_instance() -TAG = 'HopSkipJumpAttack' +TAG = 'NaturalEvolutionaryStrategy' class ModelToBeAttacked(BlackModel): @@ -100,7 +100,7 @@ def get_dataset(current_dir): def nes_mnist_attack(scene, top_k): """ - hsja-Attack test + nes-Attack test """ current_dir = os.path.dirname(os.path.abspath(__file__)) test_images, test_labels = get_dataset(current_dir)