diff --git a/examples/model_security/model_attacks/black_box/mnist_attack_genetic.py b/examples/model_security/model_attacks/black_box/mnist_attack_genetic.py index bce2974..4b947bc 100644 --- a/examples/model_security/model_attacks/black_box/mnist_attack_genetic.py +++ b/examples/model_security/model_attacks/black_box/mnist_attack_genetic.py @@ -41,6 +41,9 @@ class ModelToBeAttacked(BlackModel): def predict(self, inputs): """predict""" + # Adapt to the input shape requirements of the target network if inputs is only one image. + if len(inputs.shape) == 3: + inputs = np.expand_dims(inputs, axis=0) result = self._network(Tensor(inputs.astype(np.float32))) return result.asnumpy() diff --git a/examples/model_security/model_attacks/black_box/mnist_attack_pso.py b/examples/model_security/model_attacks/black_box/mnist_attack_pso.py index f98bfb2..d7b199e 100644 --- a/examples/model_security/model_attacks/black_box/mnist_attack_pso.py +++ b/examples/model_security/model_attacks/black_box/mnist_attack_pso.py @@ -41,6 +41,9 @@ class ModelToBeAttacked(BlackModel): def predict(self, inputs): """predict""" + # Adapt to the input shape requirements of the target network if inputs is only one image. + if len(inputs.shape) == 3: + inputs = np.expand_dims(inputs, axis=0) result = self._network(Tensor(inputs.astype(np.float32))) return result.asnumpy() diff --git a/examples/model_security/model_attacks/cv/faster_rcnn/coco_attack_genetic.py b/examples/model_security/model_attacks/cv/faster_rcnn/coco_attack_genetic.py new file mode 100644 index 0000000..05c557b --- /dev/null +++ b/examples/model_security/model_attacks/cv/faster_rcnn/coco_attack_genetic.py @@ -0,0 +1,150 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PSO attack for Faster R-CNN""" +import os +import numpy as np + +from mindspore import context +from mindspore.train.serialization import load_checkpoint, load_param_into_net +from mindspore.common import set_seed +from mindspore import Tensor + +from mindarmour import BlackModel +from mindarmour.adv_robustness.attacks.black.genetic_attack import GeneticAttack +from mindarmour.utils.logger import LogUtil + +from src.FasterRcnn.faster_rcnn_r50 import Faster_Rcnn_Resnet50 +from src.config import config +from src.dataset import data_to_mindrecord_byte_image, create_fasterrcnn_dataset + +# pylint: disable=locally-disabled, unused-argument, redefined-outer-name +LOGGER = LogUtil.get_instance() +LOGGER.set_level('INFO') + +set_seed(1) + +context.set_context(mode=context.GRAPH_MODE, device_target='Ascend', device_id=1) + + +class ModelToBeAttacked(BlackModel): + """model to be attack""" + + def __init__(self, network): + super(ModelToBeAttacked, self).__init__() + self._network = network + + def predict(self, images, img_metas, gt_boxes, gt_labels, gt_num): + """predict""" + # Adapt to the input shape requirements of the target network if inputs is only one image. + if len(images.shape) == 3: + inputs_num = 1 + images = np.expand_dims(images, axis=0) + else: + inputs_num = images.shape[0] + box_and_confi = [] + pred_labels = [] + gt_number = np.sum(gt_num) + + for i in range(inputs_num): + inputs_i = np.expand_dims(images[i], axis=0) + output = self._network(Tensor(inputs_i.astype(np.float16)), Tensor(img_metas), + Tensor(gt_boxes), Tensor(gt_labels), Tensor(gt_num)) + all_bbox = output[0] + all_labels = output[1] + all_mask = output[2] + all_bbox_squee = np.squeeze(all_bbox.asnumpy()) + all_labels_squee = np.squeeze(all_labels.asnumpy()) + all_mask_squee = np.squeeze(all_mask.asnumpy()) + all_bboxes_tmp_mask = all_bbox_squee[all_mask_squee, :] + all_labels_tmp_mask = all_labels_squee[all_mask_squee] + + if all_bboxes_tmp_mask.shape[0] > gt_number + 1: + inds = np.argsort(-all_bboxes_tmp_mask[:, -1]) + inds = inds[:gt_number+1] + all_bboxes_tmp_mask = all_bboxes_tmp_mask[inds] + all_labels_tmp_mask = all_labels_tmp_mask[inds] + box_and_confi.append(all_bboxes_tmp_mask) + pred_labels.append(all_labels_tmp_mask) + return np.array(box_and_confi), np.array(pred_labels) + + +if __name__ == '__main__': + prefix = 'FasterRcnn_eval.mindrecord' + mindrecord_dir = config.mindrecord_dir + mindrecord_file = os.path.join(mindrecord_dir, prefix) + pre_trained = '/ckpt_path' + print("CHECKING MINDRECORD FILES ...") + if not os.path.exists(mindrecord_file): + if not os.path.isdir(mindrecord_dir): + os.makedirs(mindrecord_dir) + if os.path.isdir(config.coco_root): + print("Create Mindrecord. It may take some time.") + data_to_mindrecord_byte_image("coco", False, prefix, file_num=1) + print("Create Mindrecord Done, at {}".format(mindrecord_dir)) + else: + print("coco_root not exits.") + + print('Start generate adversarial samples.') + + # build network and dataset + ds = create_fasterrcnn_dataset(mindrecord_file, batch_size=config.test_batch_size, \ + repeat_num=1, is_training=False) + net = Faster_Rcnn_Resnet50(config) + param_dict = load_checkpoint(pre_trained) + load_param_into_net(net, param_dict) + net = net.set_train(False) + + # build attacker + model = ModelToBeAttacked(net) + attack = GeneticAttack(model, model_type='detection', max_steps=50, reserve_ratio=0.3, mutation_rate=0.05, + per_bounds=0.5, step_size=0.25, temp=0.1) + + # generate adversarial samples + sample_num = 5 + ori_imagess = [] + adv_imgs = [] + ori_meta = [] + ori_box = [] + ori_labels = [] + ori_gt_num = [] + idx = 0 + for data in ds.create_dict_iterator(): + if idx > sample_num: + break + img_data = data['image'] + img_metas = data['image_shape'] + gt_bboxes = data['box'] + gt_labels = data['label'] + gt_num = data['valid_num'] + + ori_imagess.append(img_data.asnumpy()) + ori_meta.append(img_metas.asnumpy()) + ori_box.append(gt_bboxes.asnumpy()) + ori_labels.append(gt_labels.asnumpy()) + ori_gt_num.append(gt_num.asnumpy()) + + all_inputs = (img_data.asnumpy(), img_metas.asnumpy(), gt_bboxes.asnumpy(), + gt_labels.asnumpy(), gt_num.asnumpy()) + + pre_gt_boxes, pre_gt_label = model.predict(*all_inputs) + success_flags, adv_img, query_times = attack.generate(all_inputs, (pre_gt_boxes, pre_gt_label)) + adv_imgs.append(adv_img) + idx += 1 + np.save('ori_imagess.npy', ori_imagess) + np.save('ori_meta.npy', ori_meta) + np.save('ori_box.npy', ori_box) + np.save('ori_labels.npy', ori_labels) + np.save('ori_gt_num.npy', ori_gt_num) + np.save('adv_imgs.npy', adv_imgs) + print('Generate adversarial samples complete.') diff --git a/examples/model_security/model_attacks/cv/faster_rcnn/coco_attack_pso.py b/examples/model_security/model_attacks/cv/faster_rcnn/coco_attack_pso.py new file mode 100644 index 0000000..60546f7 --- /dev/null +++ b/examples/model_security/model_attacks/cv/faster_rcnn/coco_attack_pso.py @@ -0,0 +1,149 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PSO attack for Faster R-CNN""" +import os +import numpy as np + +from mindspore import context +from mindspore.train.serialization import load_checkpoint, load_param_into_net +from mindspore.common import set_seed +from mindspore import Tensor + +from mindarmour import BlackModel +from mindarmour.adv_robustness.attacks.black.pso_attack import PSOAttack +from mindarmour.utils.logger import LogUtil + +from src.FasterRcnn.faster_rcnn_r50 import Faster_Rcnn_Resnet50 +from src.config import config +from src.dataset import data_to_mindrecord_byte_image, create_fasterrcnn_dataset + +# pylint: disable=locally-disabled, unused-argument, redefined-outer-name +LOGGER = LogUtil.get_instance() +LOGGER.set_level('INFO') + +set_seed(1) + +context.set_context(mode=context.GRAPH_MODE, device_target='Ascend', device_id=1) + + +class ModelToBeAttacked(BlackModel): + """model to be attack""" + + def __init__(self, network): + super(ModelToBeAttacked, self).__init__() + self._network = network + + def predict(self, images, img_metas, gt_boxes, gt_labels, gt_num): + """predict""" + # Adapt to the input shape requirements of the target network if inputs is only one image. + if len(images.shape) == 3: + inputs_num = 1 + images = np.expand_dims(images, axis=0) + else: + inputs_num = images.shape[0] + box_and_confi = [] + pred_labels = [] + gt_number = np.sum(gt_num) + + for i in range(inputs_num): + inputs_i = np.expand_dims(images[i], axis=0) + output = self._network(Tensor(inputs_i.astype(np.float16)), Tensor(img_metas), + Tensor(gt_boxes), Tensor(gt_labels), Tensor(gt_num)) + all_bbox = output[0] + all_labels = output[1] + all_mask = output[2] + all_bbox_squee = np.squeeze(all_bbox.asnumpy()) + all_labels_squee = np.squeeze(all_labels.asnumpy()) + all_mask_squee = np.squeeze(all_mask.asnumpy()) + all_bboxes_tmp_mask = all_bbox_squee[all_mask_squee, :] + all_labels_tmp_mask = all_labels_squee[all_mask_squee] + + if all_bboxes_tmp_mask.shape[0] > gt_number + 1: + inds = np.argsort(-all_bboxes_tmp_mask[:, -1]) + inds = inds[:gt_number+1] + all_bboxes_tmp_mask = all_bboxes_tmp_mask[inds] + all_labels_tmp_mask = all_labels_tmp_mask[inds] + box_and_confi.append(all_bboxes_tmp_mask) + pred_labels.append(all_labels_tmp_mask) + return np.array(box_and_confi), np.array(pred_labels) + + +if __name__ == '__main__': + prefix = 'FasterRcnn_eval.mindrecord' + mindrecord_dir = config.mindrecord_dir + mindrecord_file = os.path.join(mindrecord_dir, prefix) + pre_trained = '/ckpt_path' + print("CHECKING MINDRECORD FILES ...") + if not os.path.exists(mindrecord_file): + if not os.path.isdir(mindrecord_dir): + os.makedirs(mindrecord_dir) + if os.path.isdir(config.coco_root): + print("Create Mindrecord. It may take some time.") + data_to_mindrecord_byte_image("coco", False, prefix, file_num=1) + print("Create Mindrecord Done, at {}".format(mindrecord_dir)) + else: + print("coco_root not exits.") + + print('Start generate adversarial samples.') + + # build network and dataset + ds = create_fasterrcnn_dataset(mindrecord_file, batch_size=config.test_batch_size, \ + repeat_num=1, is_training=False) + net = Faster_Rcnn_Resnet50(config) + param_dict = load_checkpoint(pre_trained) + load_param_into_net(net, param_dict) + net = net.set_train(False) + + # build attacker + model = ModelToBeAttacked(net) + attack = PSOAttack(model, c=0.2, t_max=50, pm=0.5, model_type='detection', reserve_ratio=0.3) + + # generate adversarial samples + sample_num = 5 + ori_imagess = [] + adv_imgs = [] + ori_meta = [] + ori_box = [] + ori_labels = [] + ori_gt_num = [] + idx = 0 + for data in ds.create_dict_iterator(): + if idx > sample_num: + break + img_data = data['image'] + img_metas = data['image_shape'] + gt_bboxes = data['box'] + gt_labels = data['label'] + gt_num = data['valid_num'] + + ori_imagess.append(img_data.asnumpy()) + ori_meta.append(img_metas.asnumpy()) + ori_box.append(gt_bboxes.asnumpy()) + ori_labels.append(gt_labels.asnumpy()) + ori_gt_num.append(gt_num.asnumpy()) + + all_inputs = (img_data.asnumpy(), img_metas.asnumpy(), gt_bboxes.asnumpy(), + gt_labels.asnumpy(), gt_num.asnumpy()) + + pre_gt_boxes, pre_gt_label = model.predict(*all_inputs) + success_flags, adv_img, query_times = attack.generate(all_inputs, (pre_gt_boxes, pre_gt_label)) + adv_imgs.append(adv_img) + idx += 1 + np.save('ori_imagess.npy', ori_imagess) + np.save('ori_meta.npy', ori_meta) + np.save('ori_box.npy', ori_box) + np.save('ori_labels.npy', ori_labels) + np.save('ori_gt_num.npy', ori_gt_num) + np.save('adv_imgs.npy', adv_imgs) + print('Generate adversarial samples complete.') diff --git a/mindarmour/adv_robustness/attacks/attack.py b/mindarmour/adv_robustness/attacks/attack.py index ca17956..f97b662 100644 --- a/mindarmour/adv_robustness/attacks/attack.py +++ b/mindarmour/adv_robustness/attacks/attack.py @@ -19,8 +19,10 @@ from abc import abstractmethod import numpy as np from mindarmour.utils._check_param import check_pair_numpy_param, \ - check_int_positive + check_int_positive, check_equal_shape, check_numpy_param, check_model +from mindarmour.utils.util import calculate_iou from mindarmour.utils.logger import LogUtil +from mindarmour.adv_robustness.attacks.black.black_model import BlackModel LOGGER = LogUtil.get_instance() TAG = 'Attack' @@ -87,7 +89,6 @@ class Attack: # Black-attack methods will return 3 values, just get the second. res.append(adv_x[1] if isinstance(adv_x, tuple) else adv_x) - adv_x = np.concatenate(res, axis=0) return adv_x @@ -109,3 +110,125 @@ class Attack: '`Attack` and should be implemented in child class.' LOGGER.error(TAG, msg) raise NotImplementedError(msg) + + @staticmethod + def _reduction(x_ori, q_times, label, best_position, model, targeted_attack): + """ + Decrease the differences between the original samples and adversarial samples. + + Args: + x_ori (numpy.ndarray): Original samples. + q_times (int): Query times. + label (int): Target label ot ground-truth label. + best_position (numpy.ndarray): Adversarial examples. + model (BlackModel): Target model. + targeted_attack (bool): If True, it means this is a targeted attack. If False, + it means this is an untargeted attack. + + Returns: + numpy.ndarray, adversarial examples after reduction. + + Examples: + >>> adv_reduction = self._reduction(self, [0.1, 0.2, 0.3], 20, 1, + >>> [0.12, 0.15, 0.25]) + """ + LOGGER.info(TAG, 'Reduction begins...') + model = check_model('model', model, BlackModel) + x_ori = check_numpy_param('x_ori', x_ori) + best_position = check_numpy_param('best_position', best_position) + x_ori, best_position = check_equal_shape('x_ori', x_ori, + 'best_position', best_position) + x_ori_fla = x_ori.flatten() + best_position_fla = best_position.flatten() + pixel_deep = np.max(x_ori) - np.min(x_ori) + nums_pixel = len(x_ori_fla) + for i in range(nums_pixel): + diff = x_ori_fla[i] - best_position_fla[i] + if abs(diff) > pixel_deep*0.1: + best_position_fla[i] += diff*0.5 + cur_label = np.argmax( + model.predict(best_position_fla.reshape(x_ori.shape))) + q_times += 1 + if targeted_attack: + if cur_label != label: + best_position_fla[i] -= diff * 0.5 + + else: + if cur_label == label: + best_position_fla -= diff*0.5 + return best_position_fla.reshape(x_ori.shape), q_times + + def _fast_reduction(self, x_ori, best_position, q_times, auxiliary_inputs, gt_boxes, gt_labels, model): + """ + Decrease the differences between the original samples and adversarial samples in a fast way. + + Args: + x_ori (numpy.ndarray): Original samples. + best_position (numpy.ndarray): Adversarial examples. + q_times (int): Query times. + auxiliary_inputs (tuple): Auxiliary inputs mathced with x_ori. + gt_boxes (numpy.ndarray): Ground-truth boxes of x_ori. + gt_labels (numpy.ndarray): Ground-truth labels of x_ori. + model (BlackModel): Target model. + + Returns: + - numpy.ndarray, adversarial examples after reduction. + + - int, total query times after reduction. + """ + LOGGER.info(TAG, 'Reduction begins...') + model = check_model('model', model, BlackModel) + x_ori = check_numpy_param('x_ori', x_ori) + best_position = check_numpy_param('best_position', best_position) + x_ori, best_position = check_equal_shape('x_ori', x_ori, 'best_position', best_position) + x_shape = best_position.shape + reduction_iters = 10000 # recover 0.01% each step + _, original_num = self.detection_scores((best_position,) + auxiliary_inputs, gt_boxes, gt_labels, model) + for _ in range(reduction_iters): + diff = x_ori - best_position + res = 0.5*diff*(np.random.random(x_shape) < 0.0001) + best_position += res + _, correct_num = self.detection_scores((best_position,) + auxiliary_inputs, gt_boxes, gt_labels, model) + q_times += 1 + if correct_num > original_num: + best_position -= res + return best_position, q_times + + @staticmethod + def _detection_scores(inputs, gt_boxes, gt_labels, model): + """ + Evaluate the detection result of inputs, specially for object detection models. + + Args: + inputs (numpy.ndarray): Input samples. + gt_boxes (numpy.ndarray): Ground-truth boxes of inputs. + gt_labels (numpy.ndarray): Ground-truth labels of inputs. + model (BlackModel): Target model. + + Returns: + - numpy.ndarray, detection scores of inputs. + + - numpy.ndarray, the number of objects that are correctly detected. + """ + model = check_model('model', model, BlackModel) + box_and_confi, pred_labels = model.predict(*inputs) + det_scores = [] + correct_labels_num = [] + gt_boxes_num = gt_boxes.shape[0] + iou_thres = 0.5 + for boxes, labels in zip(box_and_confi, pred_labels): + score = 0 + box_num = boxes.shape[0] + correct_label_flag = np.zeros(gt_labels.shape) + for i in range(box_num): + pred_box = boxes[i] + max_iou_confi = 0 + for j in range(gt_boxes_num): + iou = calculate_iou(pred_box[:4], gt_boxes[j][:4]) + if labels[i] == gt_labels[j] and iou > iou_thres: + max_iou_confi = max(max_iou_confi, pred_box[-1] + iou) + correct_label_flag[j] = 1 + score += max_iou_confi + det_scores.append(score) + correct_labels_num.append(np.sum(correct_label_flag)) + return np.array(det_scores), np.array(correct_labels_num) diff --git a/mindarmour/adv_robustness/attacks/black/genetic_attack.py b/mindarmour/adv_robustness/attacks/black/genetic_attack.py index 6d80aa2..9aa8a8c 100644 --- a/mindarmour/adv_robustness/attacks/black/genetic_attack.py +++ b/mindarmour/adv_robustness/attacks/black/genetic_attack.py @@ -20,38 +20,14 @@ from scipy.special import softmax from mindarmour.utils.logger import LogUtil from mindarmour.utils._check_param import check_numpy_param, check_model, \ check_pair_numpy_param, check_param_type, check_value_positive, \ - check_int_positive, check_param_multi_types -from ..attack import Attack + check_int_positive, check_detection_inputs, check_value_non_negative, check_param_multi_types +from mindarmour.adv_robustness.attacks.attack import Attack from .black_model import BlackModel LOGGER = LogUtil.get_instance() TAG = 'GeneticAttack' -def _mutation(cur_pop, step_noise=0.01, prob=0.005): - """ - Generate mutation samples in genetic_attack. - - Args: - cur_pop (numpy.ndarray): Samples before mutation. - step_noise (float): Noise range. Default: 0.01. - prob (float): Mutation probability. Default: 0.005. - - Returns: - numpy.ndarray, samples after mutation operation in genetic_attack. - - Examples: - >>> mul_pop = self._mutation_op([0.2, 0.3, 0.4], step_noise=0.03, - >>> prob=0.01) - """ - cur_pop = check_numpy_param('cur_pop', cur_pop) - perturb_noise = np.clip(np.random.random(cur_pop.shape) - 0.5, - -step_noise, step_noise) - mutated_pop = perturb_noise*( - np.random.random(cur_pop.shape) < prob) + cur_pop - return mutated_pop - - class GeneticAttack(Attack): """ The Genetic Attack represents the black-box attack based on the genetic algorithm, @@ -65,6 +41,13 @@ class GeneticAttack(Attack): Args: model (BlackModel): Target model. + model_type (str): The type of targeted model. 'classification' and 'detection' are supported now. + default: 'classification'. + targeted (bool): If True, turns on the targeted attack. If False, + turns on untargeted attack. It should be noted that only untargeted attack + is supproted for model_type='detection', Default: False. + reserve_ratio (float): The percentage of objects that can be detected after attacks, specifically for + model_type='detection'. Default: 0.3. pop_size (int): The number of particles, which should be greater than zero. Default: 6. mutation_rate (float): The probability of mutations. Default: 0.005. @@ -73,23 +56,33 @@ class GeneticAttack(Attack): example. Default: 1000. step_size (float): Attack step size. Default: 0.2. temp (float): Sampling temperature for selection. Default: 0.3. - bounds (tuple): Upper and lower bounds of data. In form of (clip_min, - clip_max). Default: (0, 1.0) + The greater the temp, the greater the differences between individuals' + selecting probabilities. + bounds (Union[tuple, list, None]): Upper and lower bounds of data. In form + of (clip_min, clip_max). Default: (0, 1.0). adaptive (bool): If True, turns on dynamic scaling of mutation parameters. If false, turns on static mutation parameters. Default: False. sparse (bool): If True, input labels are sparse-encoded. If False, input labels are one-hot-encoded. Default: True. + c (float): Weight of perturbation loss. Default: 0.1. Examples: >>> attack = GeneticAttack(model) """ - def __init__(self, model, pop_size=6, - mutation_rate=0.005, per_bounds=0.15, max_steps=1000, - step_size=0.20, temp=0.3, bounds=(0, 1.0), adaptive=False, - sparse=True): + def __init__(self, model, model_type='classification', targeted=True, reserve_ratio=0.3, sparse=True, + pop_size=6, mutation_rate=0.005, per_bounds=0.15, max_steps=1000, step_size=0.20, temp=0.3, + bounds=(0, 1.0), adaptive=False, c=0.1): super(GeneticAttack, self).__init__() self._model = check_model('model', model, BlackModel) + self._model_type = check_param_type('model_type', model_type, str) + self._targeted = check_param_type('targeted', targeted, bool) + self._reserve_ratio = check_value_non_negative('reserve_ratio', reserve_ratio) + if self._reserve_ratio > 1: + msg = "reserve_ratio should be less than 1.0, but got {}.".format(self._reserve_ratio) + LOGGER.error(TAG, msg) + raise ValueError(msg) + self._sparse = check_param_type('sparse', sparse, bool) self._per_bounds = check_value_positive('per_bounds', per_bounds) self._pop_size = check_int_positive('pop_size', pop_size) self._step_size = check_value_positive('step_size', step_size) @@ -98,16 +91,41 @@ class GeneticAttack(Attack): self._mutation_rate = check_value_positive('mutation_rate', mutation_rate) self._adaptive = check_param_type('adaptive', adaptive, bool) - self._bounds = check_param_multi_types('bounds', bounds, [list, tuple]) - for b in self._bounds: - _ = check_param_multi_types('bound', b, [int, float]) # initial global optimum fitness value - self._best_fit = -1 + self._best_fit = -np.inf # count times of no progress self._plateau_times = 0 - # count times of changing attack step + # count times of changing attack step_size self._adap_times = 0 - self._sparse = check_param_type('sparse', sparse, bool) + self._bounds = bounds + if self._bounds is not None: + self._bounds = check_param_multi_types('bounds', bounds, [list, tuple]) + for b in self._bounds: + _ = check_param_multi_types('bound', b, [int, float]) + self._c = check_value_positive('c', c) + + def _mutation(self, cur_pop, step_noise=0.01, prob=0.005): + """ + Generate mutation samples in genetic_attack. + + Args: + cur_pop (numpy.ndarray): Samples before mutation. + step_noise (float): Noise range. Default: 0.01. + prob (float): Mutation probability. Default: 0.005. + + Returns: + numpy.ndarray, samples after mutation operation in genetic_attack. + + Examples: + >>> mul_pop = self._mutation_op([0.2, 0.3, 0.4], step_noise=0.03, + >>> prob=0.01) + """ + cur_pop = check_numpy_param('cur_pop', cur_pop) + perturb_noise = np.clip(np.random.random(cur_pop.shape) - 0.5, + -step_noise, step_noise)*(self._bounds[1] - self._bounds[0]) + mutated_pop = perturb_noise*( + np.random.random(cur_pop.shape) < prob) + cur_pop + return mutated_pop def generate(self, inputs, labels): """ @@ -115,8 +133,10 @@ class GeneticAttack(Attack): labels (or ground_truth labels). Args: - inputs (numpy.ndarray): Input samples. - labels (numpy.ndarray): Targeted labels. + inputs (Union[numpy.ndarray, tuple]): Input samples. The format of inputs can be (input1, input2, ...) + or only one array if model_type='detection'. + labels (Union[numpy.ndarray, tuple]): Targeted labels or ground-truth labels. The format of labels + should be (gt_boxes, gt_labels) if model_type='detection'. Returns: - numpy.ndarray, bool values for each attack result. @@ -130,47 +150,96 @@ class GeneticAttack(Attack): >>> [0.3, 0.3, 0.2]], >>> [1, 2]) """ - inputs, labels = check_pair_numpy_param('inputs', inputs, - 'labels', labels) - # if input is one-hot encoded, get sparse format value - if not self._sparse: - if labels.ndim != 2: - raise ValueError('labels must be 2 dims, ' - 'but got {} dims.'.format(labels.ndim)) - labels = np.argmax(labels, axis=1) + if self._model_type == 'classification': + inputs, labels = check_pair_numpy_param('inputs', inputs, + 'labels', labels) + if not self._sparse: + if labels.ndim != 2: + raise ValueError('labels must be 2 dims, ' + 'but got {} dims.'.format(labels.ndim)) + labels = np.argmax(labels, axis=1) + images = inputs + elif self._model_type == 'detection': + images, auxiliary_inputs, gt_boxes, gt_labels = check_detection_inputs(inputs, labels) + adv_list = [] success_list = [] query_times_list = [] - for i in range(inputs.shape[0]): + for i in range(images.shape[0]): is_success = False - target_label = labels[i] - iters = 0 - x_ori = inputs[i] + x_ori = images[i] + if not self._bounds: + self._bounds = [np.min(x_ori), np.max(x_ori)] + pixel_deep = self._bounds[1] - self._bounds[0] + + if self._model_type == 'classification': + label_i = labels[i] + elif self._model_type == 'detection': + auxiliary_input_i = tuple() + for item in auxiliary_inputs: + auxiliary_input_i += (np.expand_dims(item[i], axis=0),) + gt_boxes_i, gt_labels_i = gt_boxes[i], gt_labels[i] + inputs_i = (images[i],) + auxiliary_input_i + confi_ori, gt_object_num = self._detection_scores(inputs_i, gt_boxes_i, gt_labels_i, model=self._model) + LOGGER.info(TAG, 'The number of ground-truth objects is %s', gt_object_num[0]) + # generate particles - ori_copies = np.repeat( - x_ori[np.newaxis, :], self._pop_size, axis=0) + ori_copies = np.repeat(x_ori[np.newaxis, :], self._pop_size, axis=0) # initial perturbations - cur_pert = np.clip(np.random.random(ori_copies.shape)*self._step_size, - (0 - self._per_bounds), - self._per_bounds) + cur_pert = np.clip(np.random.random(ori_copies.shape)*self._step_size*pixel_deep, + (0 - self._per_bounds)*pixel_deep, + self._per_bounds*pixel_deep) + query_times = 0 + iters = 0 while iters < self._max_steps: iters += 1 cur_pop = np.clip( ori_copies + cur_pert, self._bounds[0], self._bounds[1]) - pop_preds = self._model.predict(cur_pop) - query_times += cur_pop.shape[0] - all_preds = np.argmax(pop_preds, axis=1) - success_pop = np.equal(target_label, all_preds).astype(np.int32) - success = max(success_pop) - if success == 1: - is_success = True - adv = cur_pop[np.argmax(success_pop)] + + if self._model_type == 'classification': + pop_preds = self._model.predict(cur_pop) + query_times += cur_pop.shape[0] + all_preds = np.argmax(pop_preds, axis=1) + if self._targeted: + success_pop = np.equal(label_i, all_preds).astype(np.int32) + else: + success_pop = np.not_equal(label_i, all_preds).astype(np.int32) + is_success = max(success_pop) + best_idx = np.argmax(success_pop) + target_preds = pop_preds[:, label_i] + others_preds_sum = np.sum(pop_preds, axis=1) - target_preds + if self._targeted: + fit_vals = target_preds - others_preds_sum + else: + fit_vals = others_preds_sum - target_preds + + elif self._model_type == 'detection': + confi_adv, correct_nums_adv = self._detection_scores( + (cur_pop,) + auxiliary_input_i, gt_boxes_i, gt_labels_i, model=self._model) + LOGGER.info(TAG, 'The number of correctly detected objects in adversarial image is %s', + np.min(correct_nums_adv)) + query_times += self._pop_size + fit_vals = abs( + confi_ori - confi_adv) - self._c / self._pop_size * np.linalg.norm( + (cur_pop - x_ori).reshape(cur_pop.shape[0], -1), axis=1) + if np.max(fit_vals) < 0: + self._c /= 2 + + if np.min(correct_nums_adv) <= int(gt_object_num*self._reserve_ratio): + is_success = True + best_idx = np.argmin(correct_nums_adv) + + if is_success: + LOGGER.debug(TAG, 'successfully find one adversarial sample ' + 'and start Reduction process.') + final_adv = cur_pop[best_idx] + if self._model_type == 'classification': + final_adv, query_times = self._reduction(x_ori, query_times, label_i, final_adv, + model=self._model, targeted_attack=self._targeted) break - target_preds = pop_preds[:, target_label] - others_preds_sum = np.sum(pop_preds, axis=1) - target_preds - fit_vals = target_preds - others_preds_sum - best_fit = max(target_preds - np.max(pop_preds)) + + best_fit = max(fit_vals) if best_fit > self._best_fit: self._best_fit = best_fit self._plateau_times = 0 @@ -206,17 +275,19 @@ class GeneticAttack(Attack): cross_probs = (np.random.random(parent1.shape) > parent2_probs).astype(np.int32) childs = parent1*cross_probs + parent2*(1 - cross_probs) - mutated_childs = _mutation( + mutated_childs = self._mutation( childs, step_noise=self._per_bounds*step_noise, prob=step_p) cur_pert = np.concatenate((mutated_childs, elite[np.newaxis, :])) - if is_success: - LOGGER.debug(TAG, 'successfully find one adversarial sample ' - 'and start Reduction process.') - adv_list.append(adv) - else: + + if not is_success: LOGGER.debug(TAG, 'fail to find adversarial sample.') - adv_list.append(elite + x_ori) + final_adv = elite + x_ori + if self._model_type == 'detection': + final_adv, query_times = self._fast_reduction( + x_ori, final_adv, query_times, auxiliary_input_i, gt_boxes_i, gt_labels_i, model=self._model) + adv_list.append(final_adv) + LOGGER.debug(TAG, 'iteration times is: %d and query times is: %d', iters, diff --git a/mindarmour/adv_robustness/attacks/black/pso_attack.py b/mindarmour/adv_robustness/attacks/black/pso_attack.py index 3627ab4..46f8d0f 100644 --- a/mindarmour/adv_robustness/attacks/black/pso_attack.py +++ b/mindarmour/adv_robustness/attacks/black/pso_attack.py @@ -19,7 +19,8 @@ import numpy as np from mindarmour.utils.logger import LogUtil from mindarmour.utils._check_param import check_model, check_pair_numpy_param, \ check_numpy_param, check_value_positive, check_int_positive, \ - check_param_type, check_equal_shape, check_param_multi_types + check_param_type, check_param_multi_types,\ + check_value_non_negative, check_detection_inputs from ..attack import Attack from .black_model import BlackModel @@ -53,18 +54,21 @@ class PSOAttack(Attack): bounds (tuple): Upper and lower bounds of data. In form of (clip_min, clip_max). Default: None. targeted (bool): If True, turns on the targeted attack. If False, - turns on untargeted attack. Default: False. - reduction_iters (int): Cycle times in reduction process. Default: 3. + turns on untargeted attack. It should be noted that only untargeted attack + is supproted for model_type='detection', Default: False. sparse (bool): If True, input labels are sparse-encoded. If False, input labels are one-hot-encoded. Default: True. + model_type (str): The type of targeted model. 'classification' and 'detection' are supported now. + default: 'classification'. + reserve_ratio (float): The percentage of objects that can be detected after attacks, specifically for + model_type='detection'. Default: 0.3. Examples: >>> attack = PSOAttack(model) """ - def __init__(self, model, step_size=0.5, per_bounds=0.6, c1=2.0, c2=2.0, - c=2.0, pop_size=6, t_max=1000, pm=0.5, bounds=None, - targeted=False, reduction_iters=3, sparse=True): + def __init__(self, model, model_type='classification', targeted=False, reserve_ratio=0.3, sparse=True, + step_size=0.5, per_bounds=0.6, c1=2.0, c2=2.0, c=2.0, pop_size=6, t_max=1000, pm=0.5, bounds=None): super(PSOAttack, self).__init__() self._model = check_model('model', model, BlackModel) self._step_size = check_value_positive('step_size', step_size) @@ -74,14 +78,24 @@ class PSOAttack(Attack): self._c = check_value_positive('c', c) self._pop_size = check_int_positive('pop_size', pop_size) self._pm = check_value_positive('pm', pm) - self._bounds = check_param_multi_types('bounds', bounds, [list, tuple]) - for b in self._bounds: - _ = check_param_multi_types('bound', b, [int, float]) + self._bounds = bounds + if self._bounds is not None: + self._bounds = check_param_multi_types('bounds', bounds, [list, tuple]) + for b in self._bounds: + _ = check_param_multi_types('bound', b, [int, float]) self._targeted = check_param_type('targeted', targeted, bool) self._t_max = check_int_positive('t_max', t_max) - self._reduce_iters = check_int_positive('reduction_iters', - reduction_iters) self._sparse = check_param_type('sparse', sparse, bool) + self._model_type = check_param_type('model_type', model_type, str) + if self._model_type not in ('classification', 'detection'): + msg = "Only 'classification' or 'detection' is supported now, but got {}.".format(self._model_type) + LOGGER.error(TAG, msg) + raise ValueError(msg) + self._reserve_ratio = check_value_non_negative('reserve_ratio', reserve_ratio) + if self._reserve_ratio > 1: + msg = "reserve_ratio should be less than 1.0, but got {}.".format(self._reserve_ratio) + LOGGER.error(TAG, msg) + raise ValueError(msg) def _fitness(self, confi_ori, confi_adv, x_ori, x_adv): """ @@ -109,63 +123,50 @@ class PSOAttack(Attack): fit_value = abs( confi_ori - confi_adv) - self._c / self._pop_size*np.linalg.norm( (x_adv - x_ori).reshape(x_adv.shape[0], -1), axis=1) + if np.max(fit_value) < 0: + self._c /= 2 return fit_value - def _mutation_op(self, cur_pop): + def _confidence_cla(self, inputs, labels): """ - Generate mutation samples. + Calculate the prediction confidence of corresponding label or max confidence of inputs. + + Args: + inputs (numpy.ndarray): Input samples. + labels (Union[numpy.int, numpy.int16, numpy.int32, numpy.int64]): Target labels. + + Returns: + float, the prediction confidences of inputs. """ - cur_pop = check_numpy_param('cur_pop', cur_pop) - perturb_noise = np.random.random(cur_pop.shape) - 0.5 - mutated_pop = perturb_noise*(np.random.random(cur_pop.shape) - < self._pm) + cur_pop - mutated_pop = np.clip(mutated_pop, cur_pop*(1 - self._per_bounds), - cur_pop*(1 + self._per_bounds)) - return mutated_pop + check_numpy_param('inputs', inputs) + check_param_multi_types('labels', labels, (np.int, np.int16, np.int32, np.int64)) + confidences = self._model.predict(inputs) + if self._targeted: + confi_choose = confidences[:, labels] + else: + confi_choose = np.max(confidences, axis=1) + return confi_choose - def _reduction(self, x_ori, q_times, label, best_position): + def _mutation_op(self, cur_pop): """ - Decrease the differences between the original samples and adversarial samples. + Generate mutation samples. Args: - x_ori (numpy.ndarray): Original samples. - q_times (int): Query times. - label (int): Target label ot ground-truth label. - best_position (numpy.ndarray): Adversarial examples. + cur_pop (numpy.ndarray): Inputs before mutation operation. Returns: - numpy.ndarray, adversarial examples after reduction. - - Examples: - >>> adv_reduction = self._reduction(self, [0.1, 0.2, 0.3], 20, 1, - >>> [0.12, 0.15, 0.25]) + numpy.ndarray, mutational inputs. """ - x_ori = check_numpy_param('x_ori', x_ori) - best_position = check_numpy_param('best_position', best_position) - x_ori, best_position = check_equal_shape('x_ori', x_ori, - 'best_position', best_position) - x_ori_fla = x_ori.flatten() - best_position_fla = best_position.flatten() + # LOGGER.info(TAG, 'Mutation happens...') pixel_deep = self._bounds[1] - self._bounds[0] - nums_pixel = len(x_ori_fla) - for i in range(nums_pixel): - diff = x_ori_fla[i] - best_position_fla[i] - if abs(diff) > pixel_deep*0.1: - old_poi_fla = np.copy(best_position_fla) - best_position_fla[i] = np.clip( - best_position_fla[i] + diff*0.5, - self._bounds[0], self._bounds[1]) - cur_label = np.argmax( - self._model.predict(np.expand_dims( - best_position_fla.reshape(x_ori.shape), axis=0))[0]) - q_times += 1 - if self._targeted: - if cur_label != label: - best_position_fla = old_poi_fla - else: - if cur_label == label: - best_position_fla = old_poi_fla - return best_position_fla.reshape(x_ori.shape), q_times + cur_pop = check_numpy_param('cur_pop', cur_pop) + perturb_noise = (np.random.random(cur_pop.shape) - 0.5)*pixel_deep + mutated_pop = perturb_noise + cur_pop + if self._model_type == 'classification': + mutated_pop = np.clip(np.clip(mutated_pop, cur_pop - self._per_bounds*np.abs(cur_pop), + cur_pop + self._per_bounds*np.abs(cur_pop)), + self._bounds[0], self._bounds[1]) + return mutated_pop def generate(self, inputs, labels): """ @@ -173,8 +174,10 @@ class PSOAttack(Attack): labels (or ground_truth labels). Args: - inputs (numpy.ndarray): Input samples. - labels (numpy.ndarray): Targeted labels or ground_truth labels. + inputs (Union[numpy.ndarray, tuple]): Input samples. The format of inputs can be (input1, input2, ...) + or only one array if model_type='detection'. + labels (Union[numpy.ndarray, tuple]): Targeted labels or ground_truth labels. The format of labels + should be (gt_boxes, gt_labels) if model_type='detection'. Returns: - numpy.ndarray, bool values for each attack result. @@ -187,42 +190,51 @@ class PSOAttack(Attack): >>> advs = attack.generate([[0.2, 0.3, 0.4], [0.3, 0.3, 0.2]], >>> [1, 2]) """ - inputs, labels = check_pair_numpy_param('inputs', inputs, - 'labels', labels) - if not self._sparse: - labels = np.argmax(labels, axis=1) + # inputs check + if self._model_type == 'classification': + inputs, labels = check_pair_numpy_param('inputs', inputs, + 'labels', labels) + if not self._sparse: + labels = np.argmax(labels, axis=1) + images = inputs + elif self._model_type == 'detection': + images, auxiliary_inputs, gt_boxes, gt_labels = check_detection_inputs(inputs, labels) + # generate one adversarial each time - if self._targeted: - target_labels = labels adv_list = [] success_list = [] query_times_list = [] - pixel_deep = self._bounds[1] - self._bounds[0] - for i in range(inputs.shape[0]): + for i in range(images.shape[0]): is_success = False q_times = 0 - x_ori = inputs[i] - confidences = self._model.predict(np.expand_dims(x_ori, axis=0))[0] + x_ori = images[i] + if not self._bounds: + self._bounds = [np.min(x_ori), np.max(x_ori)] + pixel_deep = self._bounds[1] - self._bounds[0] + q_times += 1 - true_label = labels[i] - if self._targeted: - t_label = target_labels[i] - confi_ori = confidences[t_label] - else: - confi_ori = max(confidences) + if self._model_type == 'classification': + label_i = labels[i] + confi_ori = self._confidence_cla(x_ori, label_i) + elif self._model_type == 'detection': + auxiliary_input_i = tuple() + for item in auxiliary_inputs: + auxiliary_input_i += (np.expand_dims(item[i], axis=0),) + gt_boxes_i, gt_labels_i = gt_boxes[i], gt_labels[i] + inputs_i = (images[i],) + auxiliary_input_i + confi_ori, gt_object_num = self._detection_scores(inputs_i, gt_boxes_i, gt_labels_i, self._model) + LOGGER.info(TAG, 'The number of ground-truth objects is %s', gt_object_num[0]) + # step1, initializing - # initial global optimum fitness value, cannot set to be 0 + # initial global optimum fitness value, cannot set to be -inf best_fitness = -np.inf # initial global optimum position best_position = x_ori x_copies = np.repeat(x_ori[np.newaxis, :], self._pop_size, axis=0) - cur_noise = np.clip((np.random.random(x_copies.shape) - 0.5) - *self._step_size, - (0 - self._per_bounds)*(x_copies + 0.1), - self._per_bounds*(x_copies + 0.1)) - par = np.clip(x_copies + cur_noise, - x_copies*(1 - self._per_bounds), - x_copies*(1 + self._per_bounds)) + cur_noise = np.clip((np.random.random(x_copies.shape) - 0.5)*pixel_deep*self._step_size, + (0 - self._per_bounds)*(np.abs(x_copies) + 0.1), + self._per_bounds*(np.abs(x_copies) + 0.1)) + par = np.clip(x_copies + cur_noise, self._bounds[0], self._bounds[1]) # initial advs par_ori = np.copy(par) # initial optimum positions for particles @@ -241,17 +253,17 @@ class PSOAttack(Attack): v_particles = self._step_size*( v_particles + self._c1*ran_1*(best_position - par)) \ + self._c2*ran_2*(par_best_poi - par) - par = np.clip(par + v_particles, - (par_ori + 0.1*pixel_deep)*( - 1 - self._per_bounds), - (par_ori + 0.1*pixel_deep)*( - 1 + self._per_bounds)) - if iters > 30 and is_mutation: + par = np.clip(np.clip(par + v_particles, + par_ori - (np.abs(par_ori) + 0.1*pixel_deep)*self._per_bounds, + par_ori + (np.abs(par_ori) + 0.1*pixel_deep)*self._per_bounds), + self._bounds[0], self._bounds[1]) + if iters > 20 and is_mutation: par = self._mutation_op(par) - if self._targeted: - confi_adv = self._model.predict(par)[:, t_label] - else: - confi_adv = np.max(self._model.predict(par), axis=1) + if self._model_type == 'classification': + confi_adv = self._confidence_cla(par, label_i) + elif self._model_type == 'detection': + confi_adv, _ = self._detection_scores( + (par,) + auxiliary_input_i, gt_boxes_i, gt_labels_i, self._model) q_times += self._pop_size fit_value = self._fitness(confi_ori, confi_adv, x_ori, par) for k in range(self._pop_size): @@ -262,30 +274,36 @@ class PSOAttack(Attack): best_fitness = fit_value[k] best_position = par[k] iters += 1 - cur_pre = self._model.predict(np.expand_dims(best_position, - axis=0))[0] is_mutation = False if (best_fitness - last_best_fit) < last_best_fit*0.05: is_mutation = True - cur_label = np.argmax(cur_pre) + q_times += 1 - if self._targeted: - if cur_label == t_label: + if self._model_type == 'classification': + cur_pre = self._model.predict(best_position) + cur_label = np.argmax(cur_pre) + if (self._targeted and cur_label == label_i) or (not self._targeted and cur_label != label_i): is_success = True - else: - if cur_label != true_label: + elif self._model_type == 'detection': + _, correct_nums_adv = self._detection_scores( + (best_position,) + auxiliary_input_i, gt_boxes_i, gt_labels_i, self._model) + LOGGER.info(TAG, 'The number of correctly detected objects in adversarial image is %s', + correct_nums_adv[0]) + if correct_nums_adv <= int(gt_object_num*self._reserve_ratio): is_success = True + if is_success: LOGGER.debug(TAG, 'successfully find one adversarial ' 'sample and start Reduction process') # step3, reduction - if self._targeted: - best_position, q_times = self._reduction( - x_ori, q_times, t_label, best_position) - else: - best_position, q_times = self._reduction( - x_ori, q_times, true_label, best_position) + if self._model_type == 'classification': + best_position, q_times = self._reduction(x_ori, q_times, label_i, best_position, self._model, + targeted_attack=self._targeted) + break + if self._model_type == 'detection': + best_position, q_times = self._fast_reduction(x_ori, best_position, q_times, + auxiliary_input_i, gt_boxes_i, gt_labels_i, self._model) if not is_success: LOGGER.debug(TAG, 'fail to find adversarial sample, iteration ' diff --git a/mindarmour/utils/_check_param.py b/mindarmour/utils/_check_param.py index 8984b47..949fa51 100644 --- a/mindarmour/utils/_check_param.py +++ b/mindarmour/utils/_check_param.py @@ -267,3 +267,61 @@ def normalize_value(value, norm_level): LOGGER.error(TAG, msg) raise NotImplementedError(msg) return norm_value.reshape(ori_shape) + + +def check_detection_inputs(inputs, labels): + """ + Check the inputs for detection model attacks. + + Args: + inputs (Union[numpy.ndarray, tuple]): Images and other auxiliary inputs for detection model. + labels (tuple): Ground-truth boxes and ground-truth labels of inputs. + + Returns: + - numpy.ndarray, images data. + + - tuple, auxiliary inputs, such as image shape. + + - numpy.ndarray, ground-truth boxes. + + - numpy.ndarray, ground-truth labels. + """ + if isinstance(inputs, tuple): + has_images = False + auxiliary_inputs = tuple() + for item in inputs: + check_numpy_param('item', item) + if len(item.shape) == 4: + images = item + has_images = True + else: + auxiliary_inputs += (item,) + if not has_images: + msg = 'Inputs should contain images whose dimension is 4.' + LOGGER.error(TAG, msg) + raise ValueError(msg) + else: + check_numpy_param('inputs', inputs) + + check_param_type('labels', labels, tuple) + if len(labels) != 2: + msg = 'Labels should contain two arrays (boxes-confidences array and ground-truth labels array), ' \ + 'but got {} arrays.'.format(len(labels)) + LOGGER.error(TAG, msg) + raise ValueError(msg) + has_boxes = False + has_labels = False + for item in labels: + check_numpy_param('item', item) + if len(item.shape) == 3 and item.shape[2] == 5: + gt_boxes = item + has_boxes = True + elif len(item.shape) == 2: + gt_labels = item + has_labels = True + if (not has_boxes) or (not has_labels): + msg = 'The shape of boxes array and ground-truth labels array should be (N, M, 5) and (N, M), respectively. ' \ + 'But got {} and {}.'.format(labels[0].shape, labels[1].shape) + LOGGER.error(TAG, msg) + raise ValueError(msg) + return images, auxiliary_inputs, gt_boxes, gt_labels diff --git a/mindarmour/utils/util.py b/mindarmour/utils/util.py index b5b4b97..64c7f18 100644 --- a/mindarmour/utils/util.py +++ b/mindarmour/utils/util.py @@ -17,6 +17,8 @@ from mindspore import Tensor from mindspore.nn import Cell from mindspore.ops.composite import GradOperation +from mindarmour.utils._check_param import check_numpy_param + from .logger import LogUtil LOGGER = LogUtil.get_instance() @@ -164,3 +166,36 @@ class GradWrap(Cell): """ gout = self.grad(self.network)(inputs, weight) return gout + + +def calculate_iou(box_i, box_j): + """ + Calculate the intersection over union (iou) of two boxes. + + Args: + box_i (numpy.ndarray): Coordinates of the first box, with the format as (x1, y1, x2, y2). + (x1, y1) and (x2, y2) are coordinates of the lower left corner and the upper right corner, + respectively. + box_j: (numpy.ndarray): Coordinates of the second box, with the format as (x1, y1, x2, y2). + + Returns: + float, iou of two input boxes. + """ + check_numpy_param('box_i', box_i) + check_numpy_param('box_j', box_j) + if box_i.shape[-1] != 4 or box_j.shape[-1] != 4: + msg = 'The length of both coordinate arrays should be 4, bug got {} and {}.'.format(box_i.shape, box_j.shape) + LOGGER.error(TAG, msg) + raise ValueError(msg) + i_x1, i_y1, i_x2, i_y2 = box_i + j_x1, j_y1, j_x2, j_y2 = box_j + s_i = (i_x2 - i_x1)*(i_y2 - i_y1) + s_j = (j_x2 - j_x1)*(j_y2 - j_y1) + inner_left_line = max(i_x1, j_x1) + inner_right_line = min(i_x2, j_x2) + inner_top_line = min(i_y2, j_y2) + inner_bottom_line = max(i_y1, j_y1) + if inner_left_line >= inner_right_line or inner_top_line <= inner_bottom_line: + return 0 + inner_area = (inner_right_line - inner_left_line)*(inner_top_line - inner_bottom_line) + return inner_area / (s_i + s_j - inner_area) diff --git a/tests/ut/python/adv_robustness/attacks/black/test_pso_attack.py b/tests/ut/python/adv_robustness/attacks/black/test_pso_attack.py index aa27cb5..141785d 100644 --- a/tests/ut/python/adv_robustness/attacks/black/test_pso_attack.py +++ b/tests/ut/python/adv_robustness/attacks/black/test_pso_attack.py @@ -36,6 +36,9 @@ class ModelToBeAttacked(BlackModel): def predict(self, inputs): """predict""" + # Adapt to the input shape requirements of the target network if inputs is only one image. + if len(inputs.shape) == 1: + inputs = np.expand_dims(inputs, axis=0) result = self._network(Tensor(inputs.astype(np.float32))) return result.asnumpy()