@@ -41,6 +41,9 @@ class ModelToBeAttacked(BlackModel): | |||
def predict(self, inputs): | |||
"""predict""" | |||
# Adapt to the input shape requirements of the target network if inputs is only one image. | |||
if len(inputs.shape) == 3: | |||
inputs = np.expand_dims(inputs, axis=0) | |||
result = self._network(Tensor(inputs.astype(np.float32))) | |||
return result.asnumpy() | |||
@@ -41,6 +41,9 @@ class ModelToBeAttacked(BlackModel): | |||
def predict(self, inputs): | |||
"""predict""" | |||
# Adapt to the input shape requirements of the target network if inputs is only one image. | |||
if len(inputs.shape) == 3: | |||
inputs = np.expand_dims(inputs, axis=0) | |||
result = self._network(Tensor(inputs.astype(np.float32))) | |||
return result.asnumpy() | |||
@@ -0,0 +1,150 @@ | |||
# Copyright 2020 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
"""PSO attack for Faster R-CNN""" | |||
import os | |||
import numpy as np | |||
from mindspore import context | |||
from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
from mindspore.common import set_seed | |||
from mindspore import Tensor | |||
from mindarmour import BlackModel | |||
from mindarmour.adv_robustness.attacks.black.genetic_attack import GeneticAttack | |||
from mindarmour.utils.logger import LogUtil | |||
from src.FasterRcnn.faster_rcnn_r50 import Faster_Rcnn_Resnet50 | |||
from src.config import config | |||
from src.dataset import data_to_mindrecord_byte_image, create_fasterrcnn_dataset | |||
# pylint: disable=locally-disabled, unused-argument, redefined-outer-name | |||
LOGGER = LogUtil.get_instance() | |||
LOGGER.set_level('INFO') | |||
set_seed(1) | |||
context.set_context(mode=context.GRAPH_MODE, device_target='Ascend', device_id=1) | |||
class ModelToBeAttacked(BlackModel): | |||
"""model to be attack""" | |||
def __init__(self, network): | |||
super(ModelToBeAttacked, self).__init__() | |||
self._network = network | |||
def predict(self, images, img_metas, gt_boxes, gt_labels, gt_num): | |||
"""predict""" | |||
# Adapt to the input shape requirements of the target network if inputs is only one image. | |||
if len(images.shape) == 3: | |||
inputs_num = 1 | |||
images = np.expand_dims(images, axis=0) | |||
else: | |||
inputs_num = images.shape[0] | |||
box_and_confi = [] | |||
pred_labels = [] | |||
gt_number = np.sum(gt_num) | |||
for i in range(inputs_num): | |||
inputs_i = np.expand_dims(images[i], axis=0) | |||
output = self._network(Tensor(inputs_i.astype(np.float16)), Tensor(img_metas), | |||
Tensor(gt_boxes), Tensor(gt_labels), Tensor(gt_num)) | |||
all_bbox = output[0] | |||
all_labels = output[1] | |||
all_mask = output[2] | |||
all_bbox_squee = np.squeeze(all_bbox.asnumpy()) | |||
all_labels_squee = np.squeeze(all_labels.asnumpy()) | |||
all_mask_squee = np.squeeze(all_mask.asnumpy()) | |||
all_bboxes_tmp_mask = all_bbox_squee[all_mask_squee, :] | |||
all_labels_tmp_mask = all_labels_squee[all_mask_squee] | |||
if all_bboxes_tmp_mask.shape[0] > gt_number + 1: | |||
inds = np.argsort(-all_bboxes_tmp_mask[:, -1]) | |||
inds = inds[:gt_number+1] | |||
all_bboxes_tmp_mask = all_bboxes_tmp_mask[inds] | |||
all_labels_tmp_mask = all_labels_tmp_mask[inds] | |||
box_and_confi.append(all_bboxes_tmp_mask) | |||
pred_labels.append(all_labels_tmp_mask) | |||
return np.array(box_and_confi), np.array(pred_labels) | |||
if __name__ == '__main__': | |||
prefix = 'FasterRcnn_eval.mindrecord' | |||
mindrecord_dir = config.mindrecord_dir | |||
mindrecord_file = os.path.join(mindrecord_dir, prefix) | |||
pre_trained = '/ckpt_path' | |||
print("CHECKING MINDRECORD FILES ...") | |||
if not os.path.exists(mindrecord_file): | |||
if not os.path.isdir(mindrecord_dir): | |||
os.makedirs(mindrecord_dir) | |||
if os.path.isdir(config.coco_root): | |||
print("Create Mindrecord. It may take some time.") | |||
data_to_mindrecord_byte_image("coco", False, prefix, file_num=1) | |||
print("Create Mindrecord Done, at {}".format(mindrecord_dir)) | |||
else: | |||
print("coco_root not exits.") | |||
print('Start generate adversarial samples.') | |||
# build network and dataset | |||
ds = create_fasterrcnn_dataset(mindrecord_file, batch_size=config.test_batch_size, \ | |||
repeat_num=1, is_training=False) | |||
net = Faster_Rcnn_Resnet50(config) | |||
param_dict = load_checkpoint(pre_trained) | |||
load_param_into_net(net, param_dict) | |||
net = net.set_train(False) | |||
# build attacker | |||
model = ModelToBeAttacked(net) | |||
attack = GeneticAttack(model, model_type='detection', max_steps=50, reserve_ratio=0.3, mutation_rate=0.05, | |||
per_bounds=0.5, step_size=0.25, temp=0.1) | |||
# generate adversarial samples | |||
sample_num = 5 | |||
ori_imagess = [] | |||
adv_imgs = [] | |||
ori_meta = [] | |||
ori_box = [] | |||
ori_labels = [] | |||
ori_gt_num = [] | |||
idx = 0 | |||
for data in ds.create_dict_iterator(): | |||
if idx > sample_num: | |||
break | |||
img_data = data['image'] | |||
img_metas = data['image_shape'] | |||
gt_bboxes = data['box'] | |||
gt_labels = data['label'] | |||
gt_num = data['valid_num'] | |||
ori_imagess.append(img_data.asnumpy()) | |||
ori_meta.append(img_metas.asnumpy()) | |||
ori_box.append(gt_bboxes.asnumpy()) | |||
ori_labels.append(gt_labels.asnumpy()) | |||
ori_gt_num.append(gt_num.asnumpy()) | |||
all_inputs = (img_data.asnumpy(), img_metas.asnumpy(), gt_bboxes.asnumpy(), | |||
gt_labels.asnumpy(), gt_num.asnumpy()) | |||
pre_gt_boxes, pre_gt_label = model.predict(*all_inputs) | |||
success_flags, adv_img, query_times = attack.generate(all_inputs, (pre_gt_boxes, pre_gt_label)) | |||
adv_imgs.append(adv_img) | |||
idx += 1 | |||
np.save('ori_imagess.npy', ori_imagess) | |||
np.save('ori_meta.npy', ori_meta) | |||
np.save('ori_box.npy', ori_box) | |||
np.save('ori_labels.npy', ori_labels) | |||
np.save('ori_gt_num.npy', ori_gt_num) | |||
np.save('adv_imgs.npy', adv_imgs) | |||
print('Generate adversarial samples complete.') |
@@ -0,0 +1,149 @@ | |||
# Copyright 2020 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
"""PSO attack for Faster R-CNN""" | |||
import os | |||
import numpy as np | |||
from mindspore import context | |||
from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
from mindspore.common import set_seed | |||
from mindspore import Tensor | |||
from mindarmour import BlackModel | |||
from mindarmour.adv_robustness.attacks.black.pso_attack import PSOAttack | |||
from mindarmour.utils.logger import LogUtil | |||
from src.FasterRcnn.faster_rcnn_r50 import Faster_Rcnn_Resnet50 | |||
from src.config import config | |||
from src.dataset import data_to_mindrecord_byte_image, create_fasterrcnn_dataset | |||
# pylint: disable=locally-disabled, unused-argument, redefined-outer-name | |||
LOGGER = LogUtil.get_instance() | |||
LOGGER.set_level('INFO') | |||
set_seed(1) | |||
context.set_context(mode=context.GRAPH_MODE, device_target='Ascend', device_id=1) | |||
class ModelToBeAttacked(BlackModel): | |||
"""model to be attack""" | |||
def __init__(self, network): | |||
super(ModelToBeAttacked, self).__init__() | |||
self._network = network | |||
def predict(self, images, img_metas, gt_boxes, gt_labels, gt_num): | |||
"""predict""" | |||
# Adapt to the input shape requirements of the target network if inputs is only one image. | |||
if len(images.shape) == 3: | |||
inputs_num = 1 | |||
images = np.expand_dims(images, axis=0) | |||
else: | |||
inputs_num = images.shape[0] | |||
box_and_confi = [] | |||
pred_labels = [] | |||
gt_number = np.sum(gt_num) | |||
for i in range(inputs_num): | |||
inputs_i = np.expand_dims(images[i], axis=0) | |||
output = self._network(Tensor(inputs_i.astype(np.float16)), Tensor(img_metas), | |||
Tensor(gt_boxes), Tensor(gt_labels), Tensor(gt_num)) | |||
all_bbox = output[0] | |||
all_labels = output[1] | |||
all_mask = output[2] | |||
all_bbox_squee = np.squeeze(all_bbox.asnumpy()) | |||
all_labels_squee = np.squeeze(all_labels.asnumpy()) | |||
all_mask_squee = np.squeeze(all_mask.asnumpy()) | |||
all_bboxes_tmp_mask = all_bbox_squee[all_mask_squee, :] | |||
all_labels_tmp_mask = all_labels_squee[all_mask_squee] | |||
if all_bboxes_tmp_mask.shape[0] > gt_number + 1: | |||
inds = np.argsort(-all_bboxes_tmp_mask[:, -1]) | |||
inds = inds[:gt_number+1] | |||
all_bboxes_tmp_mask = all_bboxes_tmp_mask[inds] | |||
all_labels_tmp_mask = all_labels_tmp_mask[inds] | |||
box_and_confi.append(all_bboxes_tmp_mask) | |||
pred_labels.append(all_labels_tmp_mask) | |||
return np.array(box_and_confi), np.array(pred_labels) | |||
if __name__ == '__main__': | |||
prefix = 'FasterRcnn_eval.mindrecord' | |||
mindrecord_dir = config.mindrecord_dir | |||
mindrecord_file = os.path.join(mindrecord_dir, prefix) | |||
pre_trained = '/ckpt_path' | |||
print("CHECKING MINDRECORD FILES ...") | |||
if not os.path.exists(mindrecord_file): | |||
if not os.path.isdir(mindrecord_dir): | |||
os.makedirs(mindrecord_dir) | |||
if os.path.isdir(config.coco_root): | |||
print("Create Mindrecord. It may take some time.") | |||
data_to_mindrecord_byte_image("coco", False, prefix, file_num=1) | |||
print("Create Mindrecord Done, at {}".format(mindrecord_dir)) | |||
else: | |||
print("coco_root not exits.") | |||
print('Start generate adversarial samples.') | |||
# build network and dataset | |||
ds = create_fasterrcnn_dataset(mindrecord_file, batch_size=config.test_batch_size, \ | |||
repeat_num=1, is_training=False) | |||
net = Faster_Rcnn_Resnet50(config) | |||
param_dict = load_checkpoint(pre_trained) | |||
load_param_into_net(net, param_dict) | |||
net = net.set_train(False) | |||
# build attacker | |||
model = ModelToBeAttacked(net) | |||
attack = PSOAttack(model, c=0.2, t_max=50, pm=0.5, model_type='detection', reserve_ratio=0.3) | |||
# generate adversarial samples | |||
sample_num = 5 | |||
ori_imagess = [] | |||
adv_imgs = [] | |||
ori_meta = [] | |||
ori_box = [] | |||
ori_labels = [] | |||
ori_gt_num = [] | |||
idx = 0 | |||
for data in ds.create_dict_iterator(): | |||
if idx > sample_num: | |||
break | |||
img_data = data['image'] | |||
img_metas = data['image_shape'] | |||
gt_bboxes = data['box'] | |||
gt_labels = data['label'] | |||
gt_num = data['valid_num'] | |||
ori_imagess.append(img_data.asnumpy()) | |||
ori_meta.append(img_metas.asnumpy()) | |||
ori_box.append(gt_bboxes.asnumpy()) | |||
ori_labels.append(gt_labels.asnumpy()) | |||
ori_gt_num.append(gt_num.asnumpy()) | |||
all_inputs = (img_data.asnumpy(), img_metas.asnumpy(), gt_bboxes.asnumpy(), | |||
gt_labels.asnumpy(), gt_num.asnumpy()) | |||
pre_gt_boxes, pre_gt_label = model.predict(*all_inputs) | |||
success_flags, adv_img, query_times = attack.generate(all_inputs, (pre_gt_boxes, pre_gt_label)) | |||
adv_imgs.append(adv_img) | |||
idx += 1 | |||
np.save('ori_imagess.npy', ori_imagess) | |||
np.save('ori_meta.npy', ori_meta) | |||
np.save('ori_box.npy', ori_box) | |||
np.save('ori_labels.npy', ori_labels) | |||
np.save('ori_gt_num.npy', ori_gt_num) | |||
np.save('adv_imgs.npy', adv_imgs) | |||
print('Generate adversarial samples complete.') |
@@ -19,8 +19,10 @@ from abc import abstractmethod | |||
import numpy as np | |||
from mindarmour.utils._check_param import check_pair_numpy_param, \ | |||
check_int_positive | |||
check_int_positive, check_equal_shape, check_numpy_param, check_model | |||
from mindarmour.utils.util import calculate_iou | |||
from mindarmour.utils.logger import LogUtil | |||
from mindarmour.adv_robustness.attacks.black.black_model import BlackModel | |||
LOGGER = LogUtil.get_instance() | |||
TAG = 'Attack' | |||
@@ -87,7 +89,6 @@ class Attack: | |||
# Black-attack methods will return 3 values, just get the second. | |||
res.append(adv_x[1] if isinstance(adv_x, tuple) else adv_x) | |||
adv_x = np.concatenate(res, axis=0) | |||
return adv_x | |||
@@ -109,3 +110,125 @@ class Attack: | |||
'`Attack` and should be implemented in child class.' | |||
LOGGER.error(TAG, msg) | |||
raise NotImplementedError(msg) | |||
@staticmethod | |||
def _reduction(x_ori, q_times, label, best_position, model, targeted_attack): | |||
""" | |||
Decrease the differences between the original samples and adversarial samples. | |||
Args: | |||
x_ori (numpy.ndarray): Original samples. | |||
q_times (int): Query times. | |||
label (int): Target label ot ground-truth label. | |||
best_position (numpy.ndarray): Adversarial examples. | |||
model (BlackModel): Target model. | |||
targeted_attack (bool): If True, it means this is a targeted attack. If False, | |||
it means this is an untargeted attack. | |||
Returns: | |||
numpy.ndarray, adversarial examples after reduction. | |||
Examples: | |||
>>> adv_reduction = self._reduction(self, [0.1, 0.2, 0.3], 20, 1, | |||
>>> [0.12, 0.15, 0.25]) | |||
""" | |||
LOGGER.info(TAG, 'Reduction begins...') | |||
model = check_model('model', model, BlackModel) | |||
x_ori = check_numpy_param('x_ori', x_ori) | |||
best_position = check_numpy_param('best_position', best_position) | |||
x_ori, best_position = check_equal_shape('x_ori', x_ori, | |||
'best_position', best_position) | |||
x_ori_fla = x_ori.flatten() | |||
best_position_fla = best_position.flatten() | |||
pixel_deep = np.max(x_ori) - np.min(x_ori) | |||
nums_pixel = len(x_ori_fla) | |||
for i in range(nums_pixel): | |||
diff = x_ori_fla[i] - best_position_fla[i] | |||
if abs(diff) > pixel_deep*0.1: | |||
best_position_fla[i] += diff*0.5 | |||
cur_label = np.argmax( | |||
model.predict(best_position_fla.reshape(x_ori.shape))) | |||
q_times += 1 | |||
if targeted_attack: | |||
if cur_label != label: | |||
best_position_fla[i] -= diff * 0.5 | |||
else: | |||
if cur_label == label: | |||
best_position_fla -= diff*0.5 | |||
return best_position_fla.reshape(x_ori.shape), q_times | |||
def _fast_reduction(self, x_ori, best_position, q_times, auxiliary_inputs, gt_boxes, gt_labels, model): | |||
""" | |||
Decrease the differences between the original samples and adversarial samples in a fast way. | |||
Args: | |||
x_ori (numpy.ndarray): Original samples. | |||
best_position (numpy.ndarray): Adversarial examples. | |||
q_times (int): Query times. | |||
auxiliary_inputs (tuple): Auxiliary inputs mathced with x_ori. | |||
gt_boxes (numpy.ndarray): Ground-truth boxes of x_ori. | |||
gt_labels (numpy.ndarray): Ground-truth labels of x_ori. | |||
model (BlackModel): Target model. | |||
Returns: | |||
- numpy.ndarray, adversarial examples after reduction. | |||
- int, total query times after reduction. | |||
""" | |||
LOGGER.info(TAG, 'Reduction begins...') | |||
model = check_model('model', model, BlackModel) | |||
x_ori = check_numpy_param('x_ori', x_ori) | |||
best_position = check_numpy_param('best_position', best_position) | |||
x_ori, best_position = check_equal_shape('x_ori', x_ori, 'best_position', best_position) | |||
x_shape = best_position.shape | |||
reduction_iters = 10000 # recover 0.01% each step | |||
_, original_num = self.detection_scores((best_position,) + auxiliary_inputs, gt_boxes, gt_labels, model) | |||
for _ in range(reduction_iters): | |||
diff = x_ori - best_position | |||
res = 0.5*diff*(np.random.random(x_shape) < 0.0001) | |||
best_position += res | |||
_, correct_num = self.detection_scores((best_position,) + auxiliary_inputs, gt_boxes, gt_labels, model) | |||
q_times += 1 | |||
if correct_num > original_num: | |||
best_position -= res | |||
return best_position, q_times | |||
@staticmethod | |||
def _detection_scores(inputs, gt_boxes, gt_labels, model): | |||
""" | |||
Evaluate the detection result of inputs, specially for object detection models. | |||
Args: | |||
inputs (numpy.ndarray): Input samples. | |||
gt_boxes (numpy.ndarray): Ground-truth boxes of inputs. | |||
gt_labels (numpy.ndarray): Ground-truth labels of inputs. | |||
model (BlackModel): Target model. | |||
Returns: | |||
- numpy.ndarray, detection scores of inputs. | |||
- numpy.ndarray, the number of objects that are correctly detected. | |||
""" | |||
model = check_model('model', model, BlackModel) | |||
box_and_confi, pred_labels = model.predict(*inputs) | |||
det_scores = [] | |||
correct_labels_num = [] | |||
gt_boxes_num = gt_boxes.shape[0] | |||
iou_thres = 0.5 | |||
for boxes, labels in zip(box_and_confi, pred_labels): | |||
score = 0 | |||
box_num = boxes.shape[0] | |||
correct_label_flag = np.zeros(gt_labels.shape) | |||
for i in range(box_num): | |||
pred_box = boxes[i] | |||
max_iou_confi = 0 | |||
for j in range(gt_boxes_num): | |||
iou = calculate_iou(pred_box[:4], gt_boxes[j][:4]) | |||
if labels[i] == gt_labels[j] and iou > iou_thres: | |||
max_iou_confi = max(max_iou_confi, pred_box[-1] + iou) | |||
correct_label_flag[j] = 1 | |||
score += max_iou_confi | |||
det_scores.append(score) | |||
correct_labels_num.append(np.sum(correct_label_flag)) | |||
return np.array(det_scores), np.array(correct_labels_num) |
@@ -20,38 +20,14 @@ from scipy.special import softmax | |||
from mindarmour.utils.logger import LogUtil | |||
from mindarmour.utils._check_param import check_numpy_param, check_model, \ | |||
check_pair_numpy_param, check_param_type, check_value_positive, \ | |||
check_int_positive, check_param_multi_types | |||
from ..attack import Attack | |||
check_int_positive, check_detection_inputs, check_value_non_negative, check_param_multi_types | |||
from mindarmour.adv_robustness.attacks.attack import Attack | |||
from .black_model import BlackModel | |||
LOGGER = LogUtil.get_instance() | |||
TAG = 'GeneticAttack' | |||
def _mutation(cur_pop, step_noise=0.01, prob=0.005): | |||
""" | |||
Generate mutation samples in genetic_attack. | |||
Args: | |||
cur_pop (numpy.ndarray): Samples before mutation. | |||
step_noise (float): Noise range. Default: 0.01. | |||
prob (float): Mutation probability. Default: 0.005. | |||
Returns: | |||
numpy.ndarray, samples after mutation operation in genetic_attack. | |||
Examples: | |||
>>> mul_pop = self._mutation_op([0.2, 0.3, 0.4], step_noise=0.03, | |||
>>> prob=0.01) | |||
""" | |||
cur_pop = check_numpy_param('cur_pop', cur_pop) | |||
perturb_noise = np.clip(np.random.random(cur_pop.shape) - 0.5, | |||
-step_noise, step_noise) | |||
mutated_pop = perturb_noise*( | |||
np.random.random(cur_pop.shape) < prob) + cur_pop | |||
return mutated_pop | |||
class GeneticAttack(Attack): | |||
""" | |||
The Genetic Attack represents the black-box attack based on the genetic algorithm, | |||
@@ -65,6 +41,13 @@ class GeneticAttack(Attack): | |||
Args: | |||
model (BlackModel): Target model. | |||
model_type (str): The type of targeted model. 'classification' and 'detection' are supported now. | |||
default: 'classification'. | |||
targeted (bool): If True, turns on the targeted attack. If False, | |||
turns on untargeted attack. It should be noted that only untargeted attack | |||
is supproted for model_type='detection', Default: False. | |||
reserve_ratio (float): The percentage of objects that can be detected after attacks, specifically for | |||
model_type='detection'. Default: 0.3. | |||
pop_size (int): The number of particles, which should be greater than | |||
zero. Default: 6. | |||
mutation_rate (float): The probability of mutations. Default: 0.005. | |||
@@ -73,23 +56,33 @@ class GeneticAttack(Attack): | |||
example. Default: 1000. | |||
step_size (float): Attack step size. Default: 0.2. | |||
temp (float): Sampling temperature for selection. Default: 0.3. | |||
bounds (tuple): Upper and lower bounds of data. In form of (clip_min, | |||
clip_max). Default: (0, 1.0) | |||
The greater the temp, the greater the differences between individuals' | |||
selecting probabilities. | |||
bounds (Union[tuple, list, None]): Upper and lower bounds of data. In form | |||
of (clip_min, clip_max). Default: (0, 1.0). | |||
adaptive (bool): If True, turns on dynamic scaling of mutation | |||
parameters. If false, turns on static mutation parameters. | |||
Default: False. | |||
sparse (bool): If True, input labels are sparse-encoded. If False, | |||
input labels are one-hot-encoded. Default: True. | |||
c (float): Weight of perturbation loss. Default: 0.1. | |||
Examples: | |||
>>> attack = GeneticAttack(model) | |||
""" | |||
def __init__(self, model, pop_size=6, | |||
mutation_rate=0.005, per_bounds=0.15, max_steps=1000, | |||
step_size=0.20, temp=0.3, bounds=(0, 1.0), adaptive=False, | |||
sparse=True): | |||
def __init__(self, model, model_type='classification', targeted=True, reserve_ratio=0.3, sparse=True, | |||
pop_size=6, mutation_rate=0.005, per_bounds=0.15, max_steps=1000, step_size=0.20, temp=0.3, | |||
bounds=(0, 1.0), adaptive=False, c=0.1): | |||
super(GeneticAttack, self).__init__() | |||
self._model = check_model('model', model, BlackModel) | |||
self._model_type = check_param_type('model_type', model_type, str) | |||
self._targeted = check_param_type('targeted', targeted, bool) | |||
self._reserve_ratio = check_value_non_negative('reserve_ratio', reserve_ratio) | |||
if self._reserve_ratio > 1: | |||
msg = "reserve_ratio should be less than 1.0, but got {}.".format(self._reserve_ratio) | |||
LOGGER.error(TAG, msg) | |||
raise ValueError(msg) | |||
self._sparse = check_param_type('sparse', sparse, bool) | |||
self._per_bounds = check_value_positive('per_bounds', per_bounds) | |||
self._pop_size = check_int_positive('pop_size', pop_size) | |||
self._step_size = check_value_positive('step_size', step_size) | |||
@@ -98,16 +91,41 @@ class GeneticAttack(Attack): | |||
self._mutation_rate = check_value_positive('mutation_rate', | |||
mutation_rate) | |||
self._adaptive = check_param_type('adaptive', adaptive, bool) | |||
self._bounds = check_param_multi_types('bounds', bounds, [list, tuple]) | |||
for b in self._bounds: | |||
_ = check_param_multi_types('bound', b, [int, float]) | |||
# initial global optimum fitness value | |||
self._best_fit = -1 | |||
self._best_fit = -np.inf | |||
# count times of no progress | |||
self._plateau_times = 0 | |||
# count times of changing attack step | |||
# count times of changing attack step_size | |||
self._adap_times = 0 | |||
self._sparse = check_param_type('sparse', sparse, bool) | |||
self._bounds = bounds | |||
if self._bounds is not None: | |||
self._bounds = check_param_multi_types('bounds', bounds, [list, tuple]) | |||
for b in self._bounds: | |||
_ = check_param_multi_types('bound', b, [int, float]) | |||
self._c = check_value_positive('c', c) | |||
def _mutation(self, cur_pop, step_noise=0.01, prob=0.005): | |||
""" | |||
Generate mutation samples in genetic_attack. | |||
Args: | |||
cur_pop (numpy.ndarray): Samples before mutation. | |||
step_noise (float): Noise range. Default: 0.01. | |||
prob (float): Mutation probability. Default: 0.005. | |||
Returns: | |||
numpy.ndarray, samples after mutation operation in genetic_attack. | |||
Examples: | |||
>>> mul_pop = self._mutation_op([0.2, 0.3, 0.4], step_noise=0.03, | |||
>>> prob=0.01) | |||
""" | |||
cur_pop = check_numpy_param('cur_pop', cur_pop) | |||
perturb_noise = np.clip(np.random.random(cur_pop.shape) - 0.5, | |||
-step_noise, step_noise)*(self._bounds[1] - self._bounds[0]) | |||
mutated_pop = perturb_noise*( | |||
np.random.random(cur_pop.shape) < prob) + cur_pop | |||
return mutated_pop | |||
def generate(self, inputs, labels): | |||
""" | |||
@@ -115,8 +133,10 @@ class GeneticAttack(Attack): | |||
labels (or ground_truth labels). | |||
Args: | |||
inputs (numpy.ndarray): Input samples. | |||
labels (numpy.ndarray): Targeted labels. | |||
inputs (Union[numpy.ndarray, tuple]): Input samples. The format of inputs can be (input1, input2, ...) | |||
or only one array if model_type='detection'. | |||
labels (Union[numpy.ndarray, tuple]): Targeted labels or ground-truth labels. The format of labels | |||
should be (gt_boxes, gt_labels) if model_type='detection'. | |||
Returns: | |||
- numpy.ndarray, bool values for each attack result. | |||
@@ -130,47 +150,96 @@ class GeneticAttack(Attack): | |||
>>> [0.3, 0.3, 0.2]], | |||
>>> [1, 2]) | |||
""" | |||
inputs, labels = check_pair_numpy_param('inputs', inputs, | |||
'labels', labels) | |||
# if input is one-hot encoded, get sparse format value | |||
if not self._sparse: | |||
if labels.ndim != 2: | |||
raise ValueError('labels must be 2 dims, ' | |||
'but got {} dims.'.format(labels.ndim)) | |||
labels = np.argmax(labels, axis=1) | |||
if self._model_type == 'classification': | |||
inputs, labels = check_pair_numpy_param('inputs', inputs, | |||
'labels', labels) | |||
if not self._sparse: | |||
if labels.ndim != 2: | |||
raise ValueError('labels must be 2 dims, ' | |||
'but got {} dims.'.format(labels.ndim)) | |||
labels = np.argmax(labels, axis=1) | |||
images = inputs | |||
elif self._model_type == 'detection': | |||
images, auxiliary_inputs, gt_boxes, gt_labels = check_detection_inputs(inputs, labels) | |||
adv_list = [] | |||
success_list = [] | |||
query_times_list = [] | |||
for i in range(inputs.shape[0]): | |||
for i in range(images.shape[0]): | |||
is_success = False | |||
target_label = labels[i] | |||
iters = 0 | |||
x_ori = inputs[i] | |||
x_ori = images[i] | |||
if not self._bounds: | |||
self._bounds = [np.min(x_ori), np.max(x_ori)] | |||
pixel_deep = self._bounds[1] - self._bounds[0] | |||
if self._model_type == 'classification': | |||
label_i = labels[i] | |||
elif self._model_type == 'detection': | |||
auxiliary_input_i = tuple() | |||
for item in auxiliary_inputs: | |||
auxiliary_input_i += (np.expand_dims(item[i], axis=0),) | |||
gt_boxes_i, gt_labels_i = gt_boxes[i], gt_labels[i] | |||
inputs_i = (images[i],) + auxiliary_input_i | |||
confi_ori, gt_object_num = self._detection_scores(inputs_i, gt_boxes_i, gt_labels_i, model=self._model) | |||
LOGGER.info(TAG, 'The number of ground-truth objects is %s', gt_object_num[0]) | |||
# generate particles | |||
ori_copies = np.repeat( | |||
x_ori[np.newaxis, :], self._pop_size, axis=0) | |||
ori_copies = np.repeat(x_ori[np.newaxis, :], self._pop_size, axis=0) | |||
# initial perturbations | |||
cur_pert = np.clip(np.random.random(ori_copies.shape)*self._step_size, | |||
(0 - self._per_bounds), | |||
self._per_bounds) | |||
cur_pert = np.clip(np.random.random(ori_copies.shape)*self._step_size*pixel_deep, | |||
(0 - self._per_bounds)*pixel_deep, | |||
self._per_bounds*pixel_deep) | |||
query_times = 0 | |||
iters = 0 | |||
while iters < self._max_steps: | |||
iters += 1 | |||
cur_pop = np.clip( | |||
ori_copies + cur_pert, self._bounds[0], self._bounds[1]) | |||
pop_preds = self._model.predict(cur_pop) | |||
query_times += cur_pop.shape[0] | |||
all_preds = np.argmax(pop_preds, axis=1) | |||
success_pop = np.equal(target_label, all_preds).astype(np.int32) | |||
success = max(success_pop) | |||
if success == 1: | |||
is_success = True | |||
adv = cur_pop[np.argmax(success_pop)] | |||
if self._model_type == 'classification': | |||
pop_preds = self._model.predict(cur_pop) | |||
query_times += cur_pop.shape[0] | |||
all_preds = np.argmax(pop_preds, axis=1) | |||
if self._targeted: | |||
success_pop = np.equal(label_i, all_preds).astype(np.int32) | |||
else: | |||
success_pop = np.not_equal(label_i, all_preds).astype(np.int32) | |||
is_success = max(success_pop) | |||
best_idx = np.argmax(success_pop) | |||
target_preds = pop_preds[:, label_i] | |||
others_preds_sum = np.sum(pop_preds, axis=1) - target_preds | |||
if self._targeted: | |||
fit_vals = target_preds - others_preds_sum | |||
else: | |||
fit_vals = others_preds_sum - target_preds | |||
elif self._model_type == 'detection': | |||
confi_adv, correct_nums_adv = self._detection_scores( | |||
(cur_pop,) + auxiliary_input_i, gt_boxes_i, gt_labels_i, model=self._model) | |||
LOGGER.info(TAG, 'The number of correctly detected objects in adversarial image is %s', | |||
np.min(correct_nums_adv)) | |||
query_times += self._pop_size | |||
fit_vals = abs( | |||
confi_ori - confi_adv) - self._c / self._pop_size * np.linalg.norm( | |||
(cur_pop - x_ori).reshape(cur_pop.shape[0], -1), axis=1) | |||
if np.max(fit_vals) < 0: | |||
self._c /= 2 | |||
if np.min(correct_nums_adv) <= int(gt_object_num*self._reserve_ratio): | |||
is_success = True | |||
best_idx = np.argmin(correct_nums_adv) | |||
if is_success: | |||
LOGGER.debug(TAG, 'successfully find one adversarial sample ' | |||
'and start Reduction process.') | |||
final_adv = cur_pop[best_idx] | |||
if self._model_type == 'classification': | |||
final_adv, query_times = self._reduction(x_ori, query_times, label_i, final_adv, | |||
model=self._model, targeted_attack=self._targeted) | |||
break | |||
target_preds = pop_preds[:, target_label] | |||
others_preds_sum = np.sum(pop_preds, axis=1) - target_preds | |||
fit_vals = target_preds - others_preds_sum | |||
best_fit = max(target_preds - np.max(pop_preds)) | |||
best_fit = max(fit_vals) | |||
if best_fit > self._best_fit: | |||
self._best_fit = best_fit | |||
self._plateau_times = 0 | |||
@@ -206,17 +275,19 @@ class GeneticAttack(Attack): | |||
cross_probs = (np.random.random(parent1.shape) > | |||
parent2_probs).astype(np.int32) | |||
childs = parent1*cross_probs + parent2*(1 - cross_probs) | |||
mutated_childs = _mutation( | |||
mutated_childs = self._mutation( | |||
childs, step_noise=self._per_bounds*step_noise, | |||
prob=step_p) | |||
cur_pert = np.concatenate((mutated_childs, elite[np.newaxis, :])) | |||
if is_success: | |||
LOGGER.debug(TAG, 'successfully find one adversarial sample ' | |||
'and start Reduction process.') | |||
adv_list.append(adv) | |||
else: | |||
if not is_success: | |||
LOGGER.debug(TAG, 'fail to find adversarial sample.') | |||
adv_list.append(elite + x_ori) | |||
final_adv = elite + x_ori | |||
if self._model_type == 'detection': | |||
final_adv, query_times = self._fast_reduction( | |||
x_ori, final_adv, query_times, auxiliary_input_i, gt_boxes_i, gt_labels_i, model=self._model) | |||
adv_list.append(final_adv) | |||
LOGGER.debug(TAG, | |||
'iteration times is: %d and query times is: %d', | |||
iters, | |||
@@ -19,7 +19,8 @@ import numpy as np | |||
from mindarmour.utils.logger import LogUtil | |||
from mindarmour.utils._check_param import check_model, check_pair_numpy_param, \ | |||
check_numpy_param, check_value_positive, check_int_positive, \ | |||
check_param_type, check_equal_shape, check_param_multi_types | |||
check_param_type, check_param_multi_types,\ | |||
check_value_non_negative, check_detection_inputs | |||
from ..attack import Attack | |||
from .black_model import BlackModel | |||
@@ -53,18 +54,21 @@ class PSOAttack(Attack): | |||
bounds (tuple): Upper and lower bounds of data. In form of (clip_min, | |||
clip_max). Default: None. | |||
targeted (bool): If True, turns on the targeted attack. If False, | |||
turns on untargeted attack. Default: False. | |||
reduction_iters (int): Cycle times in reduction process. Default: 3. | |||
turns on untargeted attack. It should be noted that only untargeted attack | |||
is supproted for model_type='detection', Default: False. | |||
sparse (bool): If True, input labels are sparse-encoded. If False, | |||
input labels are one-hot-encoded. Default: True. | |||
model_type (str): The type of targeted model. 'classification' and 'detection' are supported now. | |||
default: 'classification'. | |||
reserve_ratio (float): The percentage of objects that can be detected after attacks, specifically for | |||
model_type='detection'. Default: 0.3. | |||
Examples: | |||
>>> attack = PSOAttack(model) | |||
""" | |||
def __init__(self, model, step_size=0.5, per_bounds=0.6, c1=2.0, c2=2.0, | |||
c=2.0, pop_size=6, t_max=1000, pm=0.5, bounds=None, | |||
targeted=False, reduction_iters=3, sparse=True): | |||
def __init__(self, model, model_type='classification', targeted=False, reserve_ratio=0.3, sparse=True, | |||
step_size=0.5, per_bounds=0.6, c1=2.0, c2=2.0, c=2.0, pop_size=6, t_max=1000, pm=0.5, bounds=None): | |||
super(PSOAttack, self).__init__() | |||
self._model = check_model('model', model, BlackModel) | |||
self._step_size = check_value_positive('step_size', step_size) | |||
@@ -74,14 +78,24 @@ class PSOAttack(Attack): | |||
self._c = check_value_positive('c', c) | |||
self._pop_size = check_int_positive('pop_size', pop_size) | |||
self._pm = check_value_positive('pm', pm) | |||
self._bounds = check_param_multi_types('bounds', bounds, [list, tuple]) | |||
for b in self._bounds: | |||
_ = check_param_multi_types('bound', b, [int, float]) | |||
self._bounds = bounds | |||
if self._bounds is not None: | |||
self._bounds = check_param_multi_types('bounds', bounds, [list, tuple]) | |||
for b in self._bounds: | |||
_ = check_param_multi_types('bound', b, [int, float]) | |||
self._targeted = check_param_type('targeted', targeted, bool) | |||
self._t_max = check_int_positive('t_max', t_max) | |||
self._reduce_iters = check_int_positive('reduction_iters', | |||
reduction_iters) | |||
self._sparse = check_param_type('sparse', sparse, bool) | |||
self._model_type = check_param_type('model_type', model_type, str) | |||
if self._model_type not in ('classification', 'detection'): | |||
msg = "Only 'classification' or 'detection' is supported now, but got {}.".format(self._model_type) | |||
LOGGER.error(TAG, msg) | |||
raise ValueError(msg) | |||
self._reserve_ratio = check_value_non_negative('reserve_ratio', reserve_ratio) | |||
if self._reserve_ratio > 1: | |||
msg = "reserve_ratio should be less than 1.0, but got {}.".format(self._reserve_ratio) | |||
LOGGER.error(TAG, msg) | |||
raise ValueError(msg) | |||
def _fitness(self, confi_ori, confi_adv, x_ori, x_adv): | |||
""" | |||
@@ -109,63 +123,50 @@ class PSOAttack(Attack): | |||
fit_value = abs( | |||
confi_ori - confi_adv) - self._c / self._pop_size*np.linalg.norm( | |||
(x_adv - x_ori).reshape(x_adv.shape[0], -1), axis=1) | |||
if np.max(fit_value) < 0: | |||
self._c /= 2 | |||
return fit_value | |||
def _mutation_op(self, cur_pop): | |||
def _confidence_cla(self, inputs, labels): | |||
""" | |||
Generate mutation samples. | |||
Calculate the prediction confidence of corresponding label or max confidence of inputs. | |||
Args: | |||
inputs (numpy.ndarray): Input samples. | |||
labels (Union[numpy.int, numpy.int16, numpy.int32, numpy.int64]): Target labels. | |||
Returns: | |||
float, the prediction confidences of inputs. | |||
""" | |||
cur_pop = check_numpy_param('cur_pop', cur_pop) | |||
perturb_noise = np.random.random(cur_pop.shape) - 0.5 | |||
mutated_pop = perturb_noise*(np.random.random(cur_pop.shape) | |||
< self._pm) + cur_pop | |||
mutated_pop = np.clip(mutated_pop, cur_pop*(1 - self._per_bounds), | |||
cur_pop*(1 + self._per_bounds)) | |||
return mutated_pop | |||
check_numpy_param('inputs', inputs) | |||
check_param_multi_types('labels', labels, (np.int, np.int16, np.int32, np.int64)) | |||
confidences = self._model.predict(inputs) | |||
if self._targeted: | |||
confi_choose = confidences[:, labels] | |||
else: | |||
confi_choose = np.max(confidences, axis=1) | |||
return confi_choose | |||
def _reduction(self, x_ori, q_times, label, best_position): | |||
def _mutation_op(self, cur_pop): | |||
""" | |||
Decrease the differences between the original samples and adversarial samples. | |||
Generate mutation samples. | |||
Args: | |||
x_ori (numpy.ndarray): Original samples. | |||
q_times (int): Query times. | |||
label (int): Target label ot ground-truth label. | |||
best_position (numpy.ndarray): Adversarial examples. | |||
cur_pop (numpy.ndarray): Inputs before mutation operation. | |||
Returns: | |||
numpy.ndarray, adversarial examples after reduction. | |||
Examples: | |||
>>> adv_reduction = self._reduction(self, [0.1, 0.2, 0.3], 20, 1, | |||
>>> [0.12, 0.15, 0.25]) | |||
numpy.ndarray, mutational inputs. | |||
""" | |||
x_ori = check_numpy_param('x_ori', x_ori) | |||
best_position = check_numpy_param('best_position', best_position) | |||
x_ori, best_position = check_equal_shape('x_ori', x_ori, | |||
'best_position', best_position) | |||
x_ori_fla = x_ori.flatten() | |||
best_position_fla = best_position.flatten() | |||
# LOGGER.info(TAG, 'Mutation happens...') | |||
pixel_deep = self._bounds[1] - self._bounds[0] | |||
nums_pixel = len(x_ori_fla) | |||
for i in range(nums_pixel): | |||
diff = x_ori_fla[i] - best_position_fla[i] | |||
if abs(diff) > pixel_deep*0.1: | |||
old_poi_fla = np.copy(best_position_fla) | |||
best_position_fla[i] = np.clip( | |||
best_position_fla[i] + diff*0.5, | |||
self._bounds[0], self._bounds[1]) | |||
cur_label = np.argmax( | |||
self._model.predict(np.expand_dims( | |||
best_position_fla.reshape(x_ori.shape), axis=0))[0]) | |||
q_times += 1 | |||
if self._targeted: | |||
if cur_label != label: | |||
best_position_fla = old_poi_fla | |||
else: | |||
if cur_label == label: | |||
best_position_fla = old_poi_fla | |||
return best_position_fla.reshape(x_ori.shape), q_times | |||
cur_pop = check_numpy_param('cur_pop', cur_pop) | |||
perturb_noise = (np.random.random(cur_pop.shape) - 0.5)*pixel_deep | |||
mutated_pop = perturb_noise + cur_pop | |||
if self._model_type == 'classification': | |||
mutated_pop = np.clip(np.clip(mutated_pop, cur_pop - self._per_bounds*np.abs(cur_pop), | |||
cur_pop + self._per_bounds*np.abs(cur_pop)), | |||
self._bounds[0], self._bounds[1]) | |||
return mutated_pop | |||
def generate(self, inputs, labels): | |||
""" | |||
@@ -173,8 +174,10 @@ class PSOAttack(Attack): | |||
labels (or ground_truth labels). | |||
Args: | |||
inputs (numpy.ndarray): Input samples. | |||
labels (numpy.ndarray): Targeted labels or ground_truth labels. | |||
inputs (Union[numpy.ndarray, tuple]): Input samples. The format of inputs can be (input1, input2, ...) | |||
or only one array if model_type='detection'. | |||
labels (Union[numpy.ndarray, tuple]): Targeted labels or ground_truth labels. The format of labels | |||
should be (gt_boxes, gt_labels) if model_type='detection'. | |||
Returns: | |||
- numpy.ndarray, bool values for each attack result. | |||
@@ -187,42 +190,51 @@ class PSOAttack(Attack): | |||
>>> advs = attack.generate([[0.2, 0.3, 0.4], [0.3, 0.3, 0.2]], | |||
>>> [1, 2]) | |||
""" | |||
inputs, labels = check_pair_numpy_param('inputs', inputs, | |||
'labels', labels) | |||
if not self._sparse: | |||
labels = np.argmax(labels, axis=1) | |||
# inputs check | |||
if self._model_type == 'classification': | |||
inputs, labels = check_pair_numpy_param('inputs', inputs, | |||
'labels', labels) | |||
if not self._sparse: | |||
labels = np.argmax(labels, axis=1) | |||
images = inputs | |||
elif self._model_type == 'detection': | |||
images, auxiliary_inputs, gt_boxes, gt_labels = check_detection_inputs(inputs, labels) | |||
# generate one adversarial each time | |||
if self._targeted: | |||
target_labels = labels | |||
adv_list = [] | |||
success_list = [] | |||
query_times_list = [] | |||
pixel_deep = self._bounds[1] - self._bounds[0] | |||
for i in range(inputs.shape[0]): | |||
for i in range(images.shape[0]): | |||
is_success = False | |||
q_times = 0 | |||
x_ori = inputs[i] | |||
confidences = self._model.predict(np.expand_dims(x_ori, axis=0))[0] | |||
x_ori = images[i] | |||
if not self._bounds: | |||
self._bounds = [np.min(x_ori), np.max(x_ori)] | |||
pixel_deep = self._bounds[1] - self._bounds[0] | |||
q_times += 1 | |||
true_label = labels[i] | |||
if self._targeted: | |||
t_label = target_labels[i] | |||
confi_ori = confidences[t_label] | |||
else: | |||
confi_ori = max(confidences) | |||
if self._model_type == 'classification': | |||
label_i = labels[i] | |||
confi_ori = self._confidence_cla(x_ori, label_i) | |||
elif self._model_type == 'detection': | |||
auxiliary_input_i = tuple() | |||
for item in auxiliary_inputs: | |||
auxiliary_input_i += (np.expand_dims(item[i], axis=0),) | |||
gt_boxes_i, gt_labels_i = gt_boxes[i], gt_labels[i] | |||
inputs_i = (images[i],) + auxiliary_input_i | |||
confi_ori, gt_object_num = self._detection_scores(inputs_i, gt_boxes_i, gt_labels_i, self._model) | |||
LOGGER.info(TAG, 'The number of ground-truth objects is %s', gt_object_num[0]) | |||
# step1, initializing | |||
# initial global optimum fitness value, cannot set to be 0 | |||
# initial global optimum fitness value, cannot set to be -inf | |||
best_fitness = -np.inf | |||
# initial global optimum position | |||
best_position = x_ori | |||
x_copies = np.repeat(x_ori[np.newaxis, :], self._pop_size, axis=0) | |||
cur_noise = np.clip((np.random.random(x_copies.shape) - 0.5) | |||
*self._step_size, | |||
(0 - self._per_bounds)*(x_copies + 0.1), | |||
self._per_bounds*(x_copies + 0.1)) | |||
par = np.clip(x_copies + cur_noise, | |||
x_copies*(1 - self._per_bounds), | |||
x_copies*(1 + self._per_bounds)) | |||
cur_noise = np.clip((np.random.random(x_copies.shape) - 0.5)*pixel_deep*self._step_size, | |||
(0 - self._per_bounds)*(np.abs(x_copies) + 0.1), | |||
self._per_bounds*(np.abs(x_copies) + 0.1)) | |||
par = np.clip(x_copies + cur_noise, self._bounds[0], self._bounds[1]) | |||
# initial advs | |||
par_ori = np.copy(par) | |||
# initial optimum positions for particles | |||
@@ -241,17 +253,17 @@ class PSOAttack(Attack): | |||
v_particles = self._step_size*( | |||
v_particles + self._c1*ran_1*(best_position - par)) \ | |||
+ self._c2*ran_2*(par_best_poi - par) | |||
par = np.clip(par + v_particles, | |||
(par_ori + 0.1*pixel_deep)*( | |||
1 - self._per_bounds), | |||
(par_ori + 0.1*pixel_deep)*( | |||
1 + self._per_bounds)) | |||
if iters > 30 and is_mutation: | |||
par = np.clip(np.clip(par + v_particles, | |||
par_ori - (np.abs(par_ori) + 0.1*pixel_deep)*self._per_bounds, | |||
par_ori + (np.abs(par_ori) + 0.1*pixel_deep)*self._per_bounds), | |||
self._bounds[0], self._bounds[1]) | |||
if iters > 20 and is_mutation: | |||
par = self._mutation_op(par) | |||
if self._targeted: | |||
confi_adv = self._model.predict(par)[:, t_label] | |||
else: | |||
confi_adv = np.max(self._model.predict(par), axis=1) | |||
if self._model_type == 'classification': | |||
confi_adv = self._confidence_cla(par, label_i) | |||
elif self._model_type == 'detection': | |||
confi_adv, _ = self._detection_scores( | |||
(par,) + auxiliary_input_i, gt_boxes_i, gt_labels_i, self._model) | |||
q_times += self._pop_size | |||
fit_value = self._fitness(confi_ori, confi_adv, x_ori, par) | |||
for k in range(self._pop_size): | |||
@@ -262,30 +274,36 @@ class PSOAttack(Attack): | |||
best_fitness = fit_value[k] | |||
best_position = par[k] | |||
iters += 1 | |||
cur_pre = self._model.predict(np.expand_dims(best_position, | |||
axis=0))[0] | |||
is_mutation = False | |||
if (best_fitness - last_best_fit) < last_best_fit*0.05: | |||
is_mutation = True | |||
cur_label = np.argmax(cur_pre) | |||
q_times += 1 | |||
if self._targeted: | |||
if cur_label == t_label: | |||
if self._model_type == 'classification': | |||
cur_pre = self._model.predict(best_position) | |||
cur_label = np.argmax(cur_pre) | |||
if (self._targeted and cur_label == label_i) or (not self._targeted and cur_label != label_i): | |||
is_success = True | |||
else: | |||
if cur_label != true_label: | |||
elif self._model_type == 'detection': | |||
_, correct_nums_adv = self._detection_scores( | |||
(best_position,) + auxiliary_input_i, gt_boxes_i, gt_labels_i, self._model) | |||
LOGGER.info(TAG, 'The number of correctly detected objects in adversarial image is %s', | |||
correct_nums_adv[0]) | |||
if correct_nums_adv <= int(gt_object_num*self._reserve_ratio): | |||
is_success = True | |||
if is_success: | |||
LOGGER.debug(TAG, 'successfully find one adversarial ' | |||
'sample and start Reduction process') | |||
# step3, reduction | |||
if self._targeted: | |||
best_position, q_times = self._reduction( | |||
x_ori, q_times, t_label, best_position) | |||
else: | |||
best_position, q_times = self._reduction( | |||
x_ori, q_times, true_label, best_position) | |||
if self._model_type == 'classification': | |||
best_position, q_times = self._reduction(x_ori, q_times, label_i, best_position, self._model, | |||
targeted_attack=self._targeted) | |||
break | |||
if self._model_type == 'detection': | |||
best_position, q_times = self._fast_reduction(x_ori, best_position, q_times, | |||
auxiliary_input_i, gt_boxes_i, gt_labels_i, self._model) | |||
if not is_success: | |||
LOGGER.debug(TAG, | |||
'fail to find adversarial sample, iteration ' | |||
@@ -267,3 +267,61 @@ def normalize_value(value, norm_level): | |||
LOGGER.error(TAG, msg) | |||
raise NotImplementedError(msg) | |||
return norm_value.reshape(ori_shape) | |||
def check_detection_inputs(inputs, labels): | |||
""" | |||
Check the inputs for detection model attacks. | |||
Args: | |||
inputs (Union[numpy.ndarray, tuple]): Images and other auxiliary inputs for detection model. | |||
labels (tuple): Ground-truth boxes and ground-truth labels of inputs. | |||
Returns: | |||
- numpy.ndarray, images data. | |||
- tuple, auxiliary inputs, such as image shape. | |||
- numpy.ndarray, ground-truth boxes. | |||
- numpy.ndarray, ground-truth labels. | |||
""" | |||
if isinstance(inputs, tuple): | |||
has_images = False | |||
auxiliary_inputs = tuple() | |||
for item in inputs: | |||
check_numpy_param('item', item) | |||
if len(item.shape) == 4: | |||
images = item | |||
has_images = True | |||
else: | |||
auxiliary_inputs += (item,) | |||
if not has_images: | |||
msg = 'Inputs should contain images whose dimension is 4.' | |||
LOGGER.error(TAG, msg) | |||
raise ValueError(msg) | |||
else: | |||
check_numpy_param('inputs', inputs) | |||
check_param_type('labels', labels, tuple) | |||
if len(labels) != 2: | |||
msg = 'Labels should contain two arrays (boxes-confidences array and ground-truth labels array), ' \ | |||
'but got {} arrays.'.format(len(labels)) | |||
LOGGER.error(TAG, msg) | |||
raise ValueError(msg) | |||
has_boxes = False | |||
has_labels = False | |||
for item in labels: | |||
check_numpy_param('item', item) | |||
if len(item.shape) == 3 and item.shape[2] == 5: | |||
gt_boxes = item | |||
has_boxes = True | |||
elif len(item.shape) == 2: | |||
gt_labels = item | |||
has_labels = True | |||
if (not has_boxes) or (not has_labels): | |||
msg = 'The shape of boxes array and ground-truth labels array should be (N, M, 5) and (N, M), respectively. ' \ | |||
'But got {} and {}.'.format(labels[0].shape, labels[1].shape) | |||
LOGGER.error(TAG, msg) | |||
raise ValueError(msg) | |||
return images, auxiliary_inputs, gt_boxes, gt_labels |
@@ -17,6 +17,8 @@ from mindspore import Tensor | |||
from mindspore.nn import Cell | |||
from mindspore.ops.composite import GradOperation | |||
from mindarmour.utils._check_param import check_numpy_param | |||
from .logger import LogUtil | |||
LOGGER = LogUtil.get_instance() | |||
@@ -164,3 +166,36 @@ class GradWrap(Cell): | |||
""" | |||
gout = self.grad(self.network)(inputs, weight) | |||
return gout | |||
def calculate_iou(box_i, box_j): | |||
""" | |||
Calculate the intersection over union (iou) of two boxes. | |||
Args: | |||
box_i (numpy.ndarray): Coordinates of the first box, with the format as (x1, y1, x2, y2). | |||
(x1, y1) and (x2, y2) are coordinates of the lower left corner and the upper right corner, | |||
respectively. | |||
box_j: (numpy.ndarray): Coordinates of the second box, with the format as (x1, y1, x2, y2). | |||
Returns: | |||
float, iou of two input boxes. | |||
""" | |||
check_numpy_param('box_i', box_i) | |||
check_numpy_param('box_j', box_j) | |||
if box_i.shape[-1] != 4 or box_j.shape[-1] != 4: | |||
msg = 'The length of both coordinate arrays should be 4, bug got {} and {}.'.format(box_i.shape, box_j.shape) | |||
LOGGER.error(TAG, msg) | |||
raise ValueError(msg) | |||
i_x1, i_y1, i_x2, i_y2 = box_i | |||
j_x1, j_y1, j_x2, j_y2 = box_j | |||
s_i = (i_x2 - i_x1)*(i_y2 - i_y1) | |||
s_j = (j_x2 - j_x1)*(j_y2 - j_y1) | |||
inner_left_line = max(i_x1, j_x1) | |||
inner_right_line = min(i_x2, j_x2) | |||
inner_top_line = min(i_y2, j_y2) | |||
inner_bottom_line = max(i_y1, j_y1) | |||
if inner_left_line >= inner_right_line or inner_top_line <= inner_bottom_line: | |||
return 0 | |||
inner_area = (inner_right_line - inner_left_line)*(inner_top_line - inner_bottom_line) | |||
return inner_area / (s_i + s_j - inner_area) |
@@ -36,6 +36,9 @@ class ModelToBeAttacked(BlackModel): | |||
def predict(self, inputs): | |||
"""predict""" | |||
# Adapt to the input shape requirements of the target network if inputs is only one image. | |||
if len(inputs.shape) == 1: | |||
inputs = np.expand_dims(inputs, axis=0) | |||
result = self._network(Tensor(inputs.astype(np.float32))) | |||
return result.asnumpy() | |||