Browse Source

Extend PSOAttack and GeneticAttack to object detection models.

tags/v1.2.1
lvzhangcheng 4 years ago
parent
commit
99581e3700
10 changed files with 796 additions and 183 deletions
  1. +3
    -0
      examples/model_security/model_attacks/black_box/mnist_attack_genetic.py
  2. +3
    -0
      examples/model_security/model_attacks/black_box/mnist_attack_pso.py
  3. +150
    -0
      examples/model_security/model_attacks/cv/faster_rcnn/coco_attack_genetic.py
  4. +149
    -0
      examples/model_security/model_attacks/cv/faster_rcnn/coco_attack_pso.py
  5. +125
    -2
      mindarmour/adv_robustness/attacks/attack.py
  6. +147
    -76
      mindarmour/adv_robustness/attacks/black/genetic_attack.py
  7. +123
    -105
      mindarmour/adv_robustness/attacks/black/pso_attack.py
  8. +58
    -0
      mindarmour/utils/_check_param.py
  9. +35
    -0
      mindarmour/utils/util.py
  10. +3
    -0
      tests/ut/python/adv_robustness/attacks/black/test_pso_attack.py

+ 3
- 0
examples/model_security/model_attacks/black_box/mnist_attack_genetic.py View File

@@ -41,6 +41,9 @@ class ModelToBeAttacked(BlackModel):

def predict(self, inputs):
"""predict"""
# Adapt to the input shape requirements of the target network if inputs is only one image.
if len(inputs.shape) == 3:
inputs = np.expand_dims(inputs, axis=0)
result = self._network(Tensor(inputs.astype(np.float32)))
return result.asnumpy()



+ 3
- 0
examples/model_security/model_attacks/black_box/mnist_attack_pso.py View File

@@ -41,6 +41,9 @@ class ModelToBeAttacked(BlackModel):

def predict(self, inputs):
"""predict"""
# Adapt to the input shape requirements of the target network if inputs is only one image.
if len(inputs.shape) == 3:
inputs = np.expand_dims(inputs, axis=0)
result = self._network(Tensor(inputs.astype(np.float32)))
return result.asnumpy()



+ 150
- 0
examples/model_security/model_attacks/cv/faster_rcnn/coco_attack_genetic.py View File

@@ -0,0 +1,150 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PSO attack for Faster R-CNN"""
import os
import numpy as np
from mindspore import context
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.common import set_seed
from mindspore import Tensor
from mindarmour import BlackModel
from mindarmour.adv_robustness.attacks.black.genetic_attack import GeneticAttack
from mindarmour.utils.logger import LogUtil
from src.FasterRcnn.faster_rcnn_r50 import Faster_Rcnn_Resnet50
from src.config import config
from src.dataset import data_to_mindrecord_byte_image, create_fasterrcnn_dataset
# pylint: disable=locally-disabled, unused-argument, redefined-outer-name
LOGGER = LogUtil.get_instance()
LOGGER.set_level('INFO')
set_seed(1)
context.set_context(mode=context.GRAPH_MODE, device_target='Ascend', device_id=1)
class ModelToBeAttacked(BlackModel):
"""model to be attack"""
def __init__(self, network):
super(ModelToBeAttacked, self).__init__()
self._network = network
def predict(self, images, img_metas, gt_boxes, gt_labels, gt_num):
"""predict"""
# Adapt to the input shape requirements of the target network if inputs is only one image.
if len(images.shape) == 3:
inputs_num = 1
images = np.expand_dims(images, axis=0)
else:
inputs_num = images.shape[0]
box_and_confi = []
pred_labels = []
gt_number = np.sum(gt_num)
for i in range(inputs_num):
inputs_i = np.expand_dims(images[i], axis=0)
output = self._network(Tensor(inputs_i.astype(np.float16)), Tensor(img_metas),
Tensor(gt_boxes), Tensor(gt_labels), Tensor(gt_num))
all_bbox = output[0]
all_labels = output[1]
all_mask = output[2]
all_bbox_squee = np.squeeze(all_bbox.asnumpy())
all_labels_squee = np.squeeze(all_labels.asnumpy())
all_mask_squee = np.squeeze(all_mask.asnumpy())
all_bboxes_tmp_mask = all_bbox_squee[all_mask_squee, :]
all_labels_tmp_mask = all_labels_squee[all_mask_squee]
if all_bboxes_tmp_mask.shape[0] > gt_number + 1:
inds = np.argsort(-all_bboxes_tmp_mask[:, -1])
inds = inds[:gt_number+1]
all_bboxes_tmp_mask = all_bboxes_tmp_mask[inds]
all_labels_tmp_mask = all_labels_tmp_mask[inds]
box_and_confi.append(all_bboxes_tmp_mask)
pred_labels.append(all_labels_tmp_mask)
return np.array(box_and_confi), np.array(pred_labels)
if __name__ == '__main__':
prefix = 'FasterRcnn_eval.mindrecord'
mindrecord_dir = config.mindrecord_dir
mindrecord_file = os.path.join(mindrecord_dir, prefix)
pre_trained = '/ckpt_path'
print("CHECKING MINDRECORD FILES ...")
if not os.path.exists(mindrecord_file):
if not os.path.isdir(mindrecord_dir):
os.makedirs(mindrecord_dir)
if os.path.isdir(config.coco_root):
print("Create Mindrecord. It may take some time.")
data_to_mindrecord_byte_image("coco", False, prefix, file_num=1)
print("Create Mindrecord Done, at {}".format(mindrecord_dir))
else:
print("coco_root not exits.")
print('Start generate adversarial samples.')
# build network and dataset
ds = create_fasterrcnn_dataset(mindrecord_file, batch_size=config.test_batch_size, \
repeat_num=1, is_training=False)
net = Faster_Rcnn_Resnet50(config)
param_dict = load_checkpoint(pre_trained)
load_param_into_net(net, param_dict)
net = net.set_train(False)
# build attacker
model = ModelToBeAttacked(net)
attack = GeneticAttack(model, model_type='detection', max_steps=50, reserve_ratio=0.3, mutation_rate=0.05,
per_bounds=0.5, step_size=0.25, temp=0.1)
# generate adversarial samples
sample_num = 5
ori_imagess = []
adv_imgs = []
ori_meta = []
ori_box = []
ori_labels = []
ori_gt_num = []
idx = 0
for data in ds.create_dict_iterator():
if idx > sample_num:
break
img_data = data['image']
img_metas = data['image_shape']
gt_bboxes = data['box']
gt_labels = data['label']
gt_num = data['valid_num']
ori_imagess.append(img_data.asnumpy())
ori_meta.append(img_metas.asnumpy())
ori_box.append(gt_bboxes.asnumpy())
ori_labels.append(gt_labels.asnumpy())
ori_gt_num.append(gt_num.asnumpy())
all_inputs = (img_data.asnumpy(), img_metas.asnumpy(), gt_bboxes.asnumpy(),
gt_labels.asnumpy(), gt_num.asnumpy())
pre_gt_boxes, pre_gt_label = model.predict(*all_inputs)
success_flags, adv_img, query_times = attack.generate(all_inputs, (pre_gt_boxes, pre_gt_label))
adv_imgs.append(adv_img)
idx += 1
np.save('ori_imagess.npy', ori_imagess)
np.save('ori_meta.npy', ori_meta)
np.save('ori_box.npy', ori_box)
np.save('ori_labels.npy', ori_labels)
np.save('ori_gt_num.npy', ori_gt_num)
np.save('adv_imgs.npy', adv_imgs)
print('Generate adversarial samples complete.')

+ 149
- 0
examples/model_security/model_attacks/cv/faster_rcnn/coco_attack_pso.py View File

@@ -0,0 +1,149 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PSO attack for Faster R-CNN"""
import os
import numpy as np
from mindspore import context
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.common import set_seed
from mindspore import Tensor
from mindarmour import BlackModel
from mindarmour.adv_robustness.attacks.black.pso_attack import PSOAttack
from mindarmour.utils.logger import LogUtil
from src.FasterRcnn.faster_rcnn_r50 import Faster_Rcnn_Resnet50
from src.config import config
from src.dataset import data_to_mindrecord_byte_image, create_fasterrcnn_dataset
# pylint: disable=locally-disabled, unused-argument, redefined-outer-name
LOGGER = LogUtil.get_instance()
LOGGER.set_level('INFO')
set_seed(1)
context.set_context(mode=context.GRAPH_MODE, device_target='Ascend', device_id=1)
class ModelToBeAttacked(BlackModel):
"""model to be attack"""
def __init__(self, network):
super(ModelToBeAttacked, self).__init__()
self._network = network
def predict(self, images, img_metas, gt_boxes, gt_labels, gt_num):
"""predict"""
# Adapt to the input shape requirements of the target network if inputs is only one image.
if len(images.shape) == 3:
inputs_num = 1
images = np.expand_dims(images, axis=0)
else:
inputs_num = images.shape[0]
box_and_confi = []
pred_labels = []
gt_number = np.sum(gt_num)
for i in range(inputs_num):
inputs_i = np.expand_dims(images[i], axis=0)
output = self._network(Tensor(inputs_i.astype(np.float16)), Tensor(img_metas),
Tensor(gt_boxes), Tensor(gt_labels), Tensor(gt_num))
all_bbox = output[0]
all_labels = output[1]
all_mask = output[2]
all_bbox_squee = np.squeeze(all_bbox.asnumpy())
all_labels_squee = np.squeeze(all_labels.asnumpy())
all_mask_squee = np.squeeze(all_mask.asnumpy())
all_bboxes_tmp_mask = all_bbox_squee[all_mask_squee, :]
all_labels_tmp_mask = all_labels_squee[all_mask_squee]
if all_bboxes_tmp_mask.shape[0] > gt_number + 1:
inds = np.argsort(-all_bboxes_tmp_mask[:, -1])
inds = inds[:gt_number+1]
all_bboxes_tmp_mask = all_bboxes_tmp_mask[inds]
all_labels_tmp_mask = all_labels_tmp_mask[inds]
box_and_confi.append(all_bboxes_tmp_mask)
pred_labels.append(all_labels_tmp_mask)
return np.array(box_and_confi), np.array(pred_labels)
if __name__ == '__main__':
prefix = 'FasterRcnn_eval.mindrecord'
mindrecord_dir = config.mindrecord_dir
mindrecord_file = os.path.join(mindrecord_dir, prefix)
pre_trained = '/ckpt_path'
print("CHECKING MINDRECORD FILES ...")
if not os.path.exists(mindrecord_file):
if not os.path.isdir(mindrecord_dir):
os.makedirs(mindrecord_dir)
if os.path.isdir(config.coco_root):
print("Create Mindrecord. It may take some time.")
data_to_mindrecord_byte_image("coco", False, prefix, file_num=1)
print("Create Mindrecord Done, at {}".format(mindrecord_dir))
else:
print("coco_root not exits.")
print('Start generate adversarial samples.')
# build network and dataset
ds = create_fasterrcnn_dataset(mindrecord_file, batch_size=config.test_batch_size, \
repeat_num=1, is_training=False)
net = Faster_Rcnn_Resnet50(config)
param_dict = load_checkpoint(pre_trained)
load_param_into_net(net, param_dict)
net = net.set_train(False)
# build attacker
model = ModelToBeAttacked(net)
attack = PSOAttack(model, c=0.2, t_max=50, pm=0.5, model_type='detection', reserve_ratio=0.3)
# generate adversarial samples
sample_num = 5
ori_imagess = []
adv_imgs = []
ori_meta = []
ori_box = []
ori_labels = []
ori_gt_num = []
idx = 0
for data in ds.create_dict_iterator():
if idx > sample_num:
break
img_data = data['image']
img_metas = data['image_shape']
gt_bboxes = data['box']
gt_labels = data['label']
gt_num = data['valid_num']
ori_imagess.append(img_data.asnumpy())
ori_meta.append(img_metas.asnumpy())
ori_box.append(gt_bboxes.asnumpy())
ori_labels.append(gt_labels.asnumpy())
ori_gt_num.append(gt_num.asnumpy())
all_inputs = (img_data.asnumpy(), img_metas.asnumpy(), gt_bboxes.asnumpy(),
gt_labels.asnumpy(), gt_num.asnumpy())
pre_gt_boxes, pre_gt_label = model.predict(*all_inputs)
success_flags, adv_img, query_times = attack.generate(all_inputs, (pre_gt_boxes, pre_gt_label))
adv_imgs.append(adv_img)
idx += 1
np.save('ori_imagess.npy', ori_imagess)
np.save('ori_meta.npy', ori_meta)
np.save('ori_box.npy', ori_box)
np.save('ori_labels.npy', ori_labels)
np.save('ori_gt_num.npy', ori_gt_num)
np.save('adv_imgs.npy', adv_imgs)
print('Generate adversarial samples complete.')

+ 125
- 2
mindarmour/adv_robustness/attacks/attack.py View File

@@ -19,8 +19,10 @@ from abc import abstractmethod
import numpy as np

from mindarmour.utils._check_param import check_pair_numpy_param, \
check_int_positive
check_int_positive, check_equal_shape, check_numpy_param, check_model
from mindarmour.utils.util import calculate_iou
from mindarmour.utils.logger import LogUtil
from mindarmour.adv_robustness.attacks.black.black_model import BlackModel

LOGGER = LogUtil.get_instance()
TAG = 'Attack'
@@ -87,7 +89,6 @@ class Attack:
# Black-attack methods will return 3 values, just get the second.
res.append(adv_x[1] if isinstance(adv_x, tuple) else adv_x)


adv_x = np.concatenate(res, axis=0)
return adv_x

@@ -109,3 +110,125 @@ class Attack:
'`Attack` and should be implemented in child class.'
LOGGER.error(TAG, msg)
raise NotImplementedError(msg)

@staticmethod
def _reduction(x_ori, q_times, label, best_position, model, targeted_attack):
"""
Decrease the differences between the original samples and adversarial samples.

Args:
x_ori (numpy.ndarray): Original samples.
q_times (int): Query times.
label (int): Target label ot ground-truth label.
best_position (numpy.ndarray): Adversarial examples.
model (BlackModel): Target model.
targeted_attack (bool): If True, it means this is a targeted attack. If False,
it means this is an untargeted attack.

Returns:
numpy.ndarray, adversarial examples after reduction.

Examples:
>>> adv_reduction = self._reduction(self, [0.1, 0.2, 0.3], 20, 1,
>>> [0.12, 0.15, 0.25])
"""
LOGGER.info(TAG, 'Reduction begins...')
model = check_model('model', model, BlackModel)
x_ori = check_numpy_param('x_ori', x_ori)
best_position = check_numpy_param('best_position', best_position)
x_ori, best_position = check_equal_shape('x_ori', x_ori,
'best_position', best_position)
x_ori_fla = x_ori.flatten()
best_position_fla = best_position.flatten()
pixel_deep = np.max(x_ori) - np.min(x_ori)
nums_pixel = len(x_ori_fla)
for i in range(nums_pixel):
diff = x_ori_fla[i] - best_position_fla[i]
if abs(diff) > pixel_deep*0.1:
best_position_fla[i] += diff*0.5
cur_label = np.argmax(
model.predict(best_position_fla.reshape(x_ori.shape)))
q_times += 1
if targeted_attack:
if cur_label != label:
best_position_fla[i] -= diff * 0.5

else:
if cur_label == label:
best_position_fla -= diff*0.5
return best_position_fla.reshape(x_ori.shape), q_times

def _fast_reduction(self, x_ori, best_position, q_times, auxiliary_inputs, gt_boxes, gt_labels, model):
"""
Decrease the differences between the original samples and adversarial samples in a fast way.

Args:
x_ori (numpy.ndarray): Original samples.
best_position (numpy.ndarray): Adversarial examples.
q_times (int): Query times.
auxiliary_inputs (tuple): Auxiliary inputs mathced with x_ori.
gt_boxes (numpy.ndarray): Ground-truth boxes of x_ori.
gt_labels (numpy.ndarray): Ground-truth labels of x_ori.
model (BlackModel): Target model.

Returns:
- numpy.ndarray, adversarial examples after reduction.

- int, total query times after reduction.
"""
LOGGER.info(TAG, 'Reduction begins...')
model = check_model('model', model, BlackModel)
x_ori = check_numpy_param('x_ori', x_ori)
best_position = check_numpy_param('best_position', best_position)
x_ori, best_position = check_equal_shape('x_ori', x_ori, 'best_position', best_position)
x_shape = best_position.shape
reduction_iters = 10000 # recover 0.01% each step
_, original_num = self.detection_scores((best_position,) + auxiliary_inputs, gt_boxes, gt_labels, model)
for _ in range(reduction_iters):
diff = x_ori - best_position
res = 0.5*diff*(np.random.random(x_shape) < 0.0001)
best_position += res
_, correct_num = self.detection_scores((best_position,) + auxiliary_inputs, gt_boxes, gt_labels, model)
q_times += 1
if correct_num > original_num:
best_position -= res
return best_position, q_times

@staticmethod
def _detection_scores(inputs, gt_boxes, gt_labels, model):
"""
Evaluate the detection result of inputs, specially for object detection models.

Args:
inputs (numpy.ndarray): Input samples.
gt_boxes (numpy.ndarray): Ground-truth boxes of inputs.
gt_labels (numpy.ndarray): Ground-truth labels of inputs.
model (BlackModel): Target model.

Returns:
- numpy.ndarray, detection scores of inputs.

- numpy.ndarray, the number of objects that are correctly detected.
"""
model = check_model('model', model, BlackModel)
box_and_confi, pred_labels = model.predict(*inputs)
det_scores = []
correct_labels_num = []
gt_boxes_num = gt_boxes.shape[0]
iou_thres = 0.5
for boxes, labels in zip(box_and_confi, pred_labels):
score = 0
box_num = boxes.shape[0]
correct_label_flag = np.zeros(gt_labels.shape)
for i in range(box_num):
pred_box = boxes[i]
max_iou_confi = 0
for j in range(gt_boxes_num):
iou = calculate_iou(pred_box[:4], gt_boxes[j][:4])
if labels[i] == gt_labels[j] and iou > iou_thres:
max_iou_confi = max(max_iou_confi, pred_box[-1] + iou)
correct_label_flag[j] = 1
score += max_iou_confi
det_scores.append(score)
correct_labels_num.append(np.sum(correct_label_flag))
return np.array(det_scores), np.array(correct_labels_num)

+ 147
- 76
mindarmour/adv_robustness/attacks/black/genetic_attack.py View File

@@ -20,38 +20,14 @@ from scipy.special import softmax
from mindarmour.utils.logger import LogUtil
from mindarmour.utils._check_param import check_numpy_param, check_model, \
check_pair_numpy_param, check_param_type, check_value_positive, \
check_int_positive, check_param_multi_types
from ..attack import Attack
check_int_positive, check_detection_inputs, check_value_non_negative, check_param_multi_types
from mindarmour.adv_robustness.attacks.attack import Attack
from .black_model import BlackModel

LOGGER = LogUtil.get_instance()
TAG = 'GeneticAttack'


def _mutation(cur_pop, step_noise=0.01, prob=0.005):
"""
Generate mutation samples in genetic_attack.

Args:
cur_pop (numpy.ndarray): Samples before mutation.
step_noise (float): Noise range. Default: 0.01.
prob (float): Mutation probability. Default: 0.005.

Returns:
numpy.ndarray, samples after mutation operation in genetic_attack.

Examples:
>>> mul_pop = self._mutation_op([0.2, 0.3, 0.4], step_noise=0.03,
>>> prob=0.01)
"""
cur_pop = check_numpy_param('cur_pop', cur_pop)
perturb_noise = np.clip(np.random.random(cur_pop.shape) - 0.5,
-step_noise, step_noise)
mutated_pop = perturb_noise*(
np.random.random(cur_pop.shape) < prob) + cur_pop
return mutated_pop


class GeneticAttack(Attack):
"""
The Genetic Attack represents the black-box attack based on the genetic algorithm,
@@ -65,6 +41,13 @@ class GeneticAttack(Attack):

Args:
model (BlackModel): Target model.
model_type (str): The type of targeted model. 'classification' and 'detection' are supported now.
default: 'classification'.
targeted (bool): If True, turns on the targeted attack. If False,
turns on untargeted attack. It should be noted that only untargeted attack
is supproted for model_type='detection', Default: False.
reserve_ratio (float): The percentage of objects that can be detected after attacks, specifically for
model_type='detection'. Default: 0.3.
pop_size (int): The number of particles, which should be greater than
zero. Default: 6.
mutation_rate (float): The probability of mutations. Default: 0.005.
@@ -73,23 +56,33 @@ class GeneticAttack(Attack):
example. Default: 1000.
step_size (float): Attack step size. Default: 0.2.
temp (float): Sampling temperature for selection. Default: 0.3.
bounds (tuple): Upper and lower bounds of data. In form of (clip_min,
clip_max). Default: (0, 1.0)
The greater the temp, the greater the differences between individuals'
selecting probabilities.
bounds (Union[tuple, list, None]): Upper and lower bounds of data. In form
of (clip_min, clip_max). Default: (0, 1.0).
adaptive (bool): If True, turns on dynamic scaling of mutation
parameters. If false, turns on static mutation parameters.
Default: False.
sparse (bool): If True, input labels are sparse-encoded. If False,
input labels are one-hot-encoded. Default: True.
c (float): Weight of perturbation loss. Default: 0.1.

Examples:
>>> attack = GeneticAttack(model)
"""
def __init__(self, model, pop_size=6,
mutation_rate=0.005, per_bounds=0.15, max_steps=1000,
step_size=0.20, temp=0.3, bounds=(0, 1.0), adaptive=False,
sparse=True):
def __init__(self, model, model_type='classification', targeted=True, reserve_ratio=0.3, sparse=True,
pop_size=6, mutation_rate=0.005, per_bounds=0.15, max_steps=1000, step_size=0.20, temp=0.3,
bounds=(0, 1.0), adaptive=False, c=0.1):
super(GeneticAttack, self).__init__()
self._model = check_model('model', model, BlackModel)
self._model_type = check_param_type('model_type', model_type, str)
self._targeted = check_param_type('targeted', targeted, bool)
self._reserve_ratio = check_value_non_negative('reserve_ratio', reserve_ratio)
if self._reserve_ratio > 1:
msg = "reserve_ratio should be less than 1.0, but got {}.".format(self._reserve_ratio)
LOGGER.error(TAG, msg)
raise ValueError(msg)
self._sparse = check_param_type('sparse', sparse, bool)
self._per_bounds = check_value_positive('per_bounds', per_bounds)
self._pop_size = check_int_positive('pop_size', pop_size)
self._step_size = check_value_positive('step_size', step_size)
@@ -98,16 +91,41 @@ class GeneticAttack(Attack):
self._mutation_rate = check_value_positive('mutation_rate',
mutation_rate)
self._adaptive = check_param_type('adaptive', adaptive, bool)
self._bounds = check_param_multi_types('bounds', bounds, [list, tuple])
for b in self._bounds:
_ = check_param_multi_types('bound', b, [int, float])
# initial global optimum fitness value
self._best_fit = -1
self._best_fit = -np.inf
# count times of no progress
self._plateau_times = 0
# count times of changing attack step
# count times of changing attack step_size
self._adap_times = 0
self._sparse = check_param_type('sparse', sparse, bool)
self._bounds = bounds
if self._bounds is not None:
self._bounds = check_param_multi_types('bounds', bounds, [list, tuple])
for b in self._bounds:
_ = check_param_multi_types('bound', b, [int, float])
self._c = check_value_positive('c', c)

def _mutation(self, cur_pop, step_noise=0.01, prob=0.005):
"""
Generate mutation samples in genetic_attack.

Args:
cur_pop (numpy.ndarray): Samples before mutation.
step_noise (float): Noise range. Default: 0.01.
prob (float): Mutation probability. Default: 0.005.

Returns:
numpy.ndarray, samples after mutation operation in genetic_attack.

Examples:
>>> mul_pop = self._mutation_op([0.2, 0.3, 0.4], step_noise=0.03,
>>> prob=0.01)
"""
cur_pop = check_numpy_param('cur_pop', cur_pop)
perturb_noise = np.clip(np.random.random(cur_pop.shape) - 0.5,
-step_noise, step_noise)*(self._bounds[1] - self._bounds[0])
mutated_pop = perturb_noise*(
np.random.random(cur_pop.shape) < prob) + cur_pop
return mutated_pop

def generate(self, inputs, labels):
"""
@@ -115,8 +133,10 @@ class GeneticAttack(Attack):
labels (or ground_truth labels).

Args:
inputs (numpy.ndarray): Input samples.
labels (numpy.ndarray): Targeted labels.
inputs (Union[numpy.ndarray, tuple]): Input samples. The format of inputs can be (input1, input2, ...)
or only one array if model_type='detection'.
labels (Union[numpy.ndarray, tuple]): Targeted labels or ground-truth labels. The format of labels
should be (gt_boxes, gt_labels) if model_type='detection'.

Returns:
- numpy.ndarray, bool values for each attack result.
@@ -130,47 +150,96 @@ class GeneticAttack(Attack):
>>> [0.3, 0.3, 0.2]],
>>> [1, 2])
"""
inputs, labels = check_pair_numpy_param('inputs', inputs,
'labels', labels)
# if input is one-hot encoded, get sparse format value
if not self._sparse:
if labels.ndim != 2:
raise ValueError('labels must be 2 dims, '
'but got {} dims.'.format(labels.ndim))
labels = np.argmax(labels, axis=1)
if self._model_type == 'classification':
inputs, labels = check_pair_numpy_param('inputs', inputs,
'labels', labels)
if not self._sparse:
if labels.ndim != 2:
raise ValueError('labels must be 2 dims, '
'but got {} dims.'.format(labels.ndim))
labels = np.argmax(labels, axis=1)
images = inputs
elif self._model_type == 'detection':
images, auxiliary_inputs, gt_boxes, gt_labels = check_detection_inputs(inputs, labels)

adv_list = []
success_list = []
query_times_list = []
for i in range(inputs.shape[0]):
for i in range(images.shape[0]):
is_success = False
target_label = labels[i]
iters = 0
x_ori = inputs[i]
x_ori = images[i]
if not self._bounds:
self._bounds = [np.min(x_ori), np.max(x_ori)]
pixel_deep = self._bounds[1] - self._bounds[0]

if self._model_type == 'classification':
label_i = labels[i]
elif self._model_type == 'detection':
auxiliary_input_i = tuple()
for item in auxiliary_inputs:
auxiliary_input_i += (np.expand_dims(item[i], axis=0),)
gt_boxes_i, gt_labels_i = gt_boxes[i], gt_labels[i]
inputs_i = (images[i],) + auxiliary_input_i
confi_ori, gt_object_num = self._detection_scores(inputs_i, gt_boxes_i, gt_labels_i, model=self._model)
LOGGER.info(TAG, 'The number of ground-truth objects is %s', gt_object_num[0])

# generate particles
ori_copies = np.repeat(
x_ori[np.newaxis, :], self._pop_size, axis=0)
ori_copies = np.repeat(x_ori[np.newaxis, :], self._pop_size, axis=0)
# initial perturbations
cur_pert = np.clip(np.random.random(ori_copies.shape)*self._step_size,
(0 - self._per_bounds),
self._per_bounds)
cur_pert = np.clip(np.random.random(ori_copies.shape)*self._step_size*pixel_deep,
(0 - self._per_bounds)*pixel_deep,
self._per_bounds*pixel_deep)

query_times = 0
iters = 0
while iters < self._max_steps:
iters += 1
cur_pop = np.clip(
ori_copies + cur_pert, self._bounds[0], self._bounds[1])
pop_preds = self._model.predict(cur_pop)
query_times += cur_pop.shape[0]
all_preds = np.argmax(pop_preds, axis=1)
success_pop = np.equal(target_label, all_preds).astype(np.int32)
success = max(success_pop)
if success == 1:
is_success = True
adv = cur_pop[np.argmax(success_pop)]

if self._model_type == 'classification':
pop_preds = self._model.predict(cur_pop)
query_times += cur_pop.shape[0]
all_preds = np.argmax(pop_preds, axis=1)
if self._targeted:
success_pop = np.equal(label_i, all_preds).astype(np.int32)
else:
success_pop = np.not_equal(label_i, all_preds).astype(np.int32)
is_success = max(success_pop)
best_idx = np.argmax(success_pop)
target_preds = pop_preds[:, label_i]
others_preds_sum = np.sum(pop_preds, axis=1) - target_preds
if self._targeted:
fit_vals = target_preds - others_preds_sum
else:
fit_vals = others_preds_sum - target_preds

elif self._model_type == 'detection':
confi_adv, correct_nums_adv = self._detection_scores(
(cur_pop,) + auxiliary_input_i, gt_boxes_i, gt_labels_i, model=self._model)
LOGGER.info(TAG, 'The number of correctly detected objects in adversarial image is %s',
np.min(correct_nums_adv))
query_times += self._pop_size
fit_vals = abs(
confi_ori - confi_adv) - self._c / self._pop_size * np.linalg.norm(
(cur_pop - x_ori).reshape(cur_pop.shape[0], -1), axis=1)
if np.max(fit_vals) < 0:
self._c /= 2

if np.min(correct_nums_adv) <= int(gt_object_num*self._reserve_ratio):
is_success = True
best_idx = np.argmin(correct_nums_adv)

if is_success:
LOGGER.debug(TAG, 'successfully find one adversarial sample '
'and start Reduction process.')
final_adv = cur_pop[best_idx]
if self._model_type == 'classification':
final_adv, query_times = self._reduction(x_ori, query_times, label_i, final_adv,
model=self._model, targeted_attack=self._targeted)
break
target_preds = pop_preds[:, target_label]
others_preds_sum = np.sum(pop_preds, axis=1) - target_preds
fit_vals = target_preds - others_preds_sum
best_fit = max(target_preds - np.max(pop_preds))

best_fit = max(fit_vals)
if best_fit > self._best_fit:
self._best_fit = best_fit
self._plateau_times = 0
@@ -206,17 +275,19 @@ class GeneticAttack(Attack):
cross_probs = (np.random.random(parent1.shape) >
parent2_probs).astype(np.int32)
childs = parent1*cross_probs + parent2*(1 - cross_probs)
mutated_childs = _mutation(
mutated_childs = self._mutation(
childs, step_noise=self._per_bounds*step_noise,
prob=step_p)
cur_pert = np.concatenate((mutated_childs, elite[np.newaxis, :]))
if is_success:
LOGGER.debug(TAG, 'successfully find one adversarial sample '
'and start Reduction process.')
adv_list.append(adv)
else:

if not is_success:
LOGGER.debug(TAG, 'fail to find adversarial sample.')
adv_list.append(elite + x_ori)
final_adv = elite + x_ori
if self._model_type == 'detection':
final_adv, query_times = self._fast_reduction(
x_ori, final_adv, query_times, auxiliary_input_i, gt_boxes_i, gt_labels_i, model=self._model)
adv_list.append(final_adv)

LOGGER.debug(TAG,
'iteration times is: %d and query times is: %d',
iters,


+ 123
- 105
mindarmour/adv_robustness/attacks/black/pso_attack.py View File

@@ -19,7 +19,8 @@ import numpy as np
from mindarmour.utils.logger import LogUtil
from mindarmour.utils._check_param import check_model, check_pair_numpy_param, \
check_numpy_param, check_value_positive, check_int_positive, \
check_param_type, check_equal_shape, check_param_multi_types
check_param_type, check_param_multi_types,\
check_value_non_negative, check_detection_inputs
from ..attack import Attack
from .black_model import BlackModel

@@ -53,18 +54,21 @@ class PSOAttack(Attack):
bounds (tuple): Upper and lower bounds of data. In form of (clip_min,
clip_max). Default: None.
targeted (bool): If True, turns on the targeted attack. If False,
turns on untargeted attack. Default: False.
reduction_iters (int): Cycle times in reduction process. Default: 3.
turns on untargeted attack. It should be noted that only untargeted attack
is supproted for model_type='detection', Default: False.
sparse (bool): If True, input labels are sparse-encoded. If False,
input labels are one-hot-encoded. Default: True.
model_type (str): The type of targeted model. 'classification' and 'detection' are supported now.
default: 'classification'.
reserve_ratio (float): The percentage of objects that can be detected after attacks, specifically for
model_type='detection'. Default: 0.3.

Examples:
>>> attack = PSOAttack(model)
"""

def __init__(self, model, step_size=0.5, per_bounds=0.6, c1=2.0, c2=2.0,
c=2.0, pop_size=6, t_max=1000, pm=0.5, bounds=None,
targeted=False, reduction_iters=3, sparse=True):
def __init__(self, model, model_type='classification', targeted=False, reserve_ratio=0.3, sparse=True,
step_size=0.5, per_bounds=0.6, c1=2.0, c2=2.0, c=2.0, pop_size=6, t_max=1000, pm=0.5, bounds=None):
super(PSOAttack, self).__init__()
self._model = check_model('model', model, BlackModel)
self._step_size = check_value_positive('step_size', step_size)
@@ -74,14 +78,24 @@ class PSOAttack(Attack):
self._c = check_value_positive('c', c)
self._pop_size = check_int_positive('pop_size', pop_size)
self._pm = check_value_positive('pm', pm)
self._bounds = check_param_multi_types('bounds', bounds, [list, tuple])
for b in self._bounds:
_ = check_param_multi_types('bound', b, [int, float])
self._bounds = bounds
if self._bounds is not None:
self._bounds = check_param_multi_types('bounds', bounds, [list, tuple])
for b in self._bounds:
_ = check_param_multi_types('bound', b, [int, float])
self._targeted = check_param_type('targeted', targeted, bool)
self._t_max = check_int_positive('t_max', t_max)
self._reduce_iters = check_int_positive('reduction_iters',
reduction_iters)
self._sparse = check_param_type('sparse', sparse, bool)
self._model_type = check_param_type('model_type', model_type, str)
if self._model_type not in ('classification', 'detection'):
msg = "Only 'classification' or 'detection' is supported now, but got {}.".format(self._model_type)
LOGGER.error(TAG, msg)
raise ValueError(msg)
self._reserve_ratio = check_value_non_negative('reserve_ratio', reserve_ratio)
if self._reserve_ratio > 1:
msg = "reserve_ratio should be less than 1.0, but got {}.".format(self._reserve_ratio)
LOGGER.error(TAG, msg)
raise ValueError(msg)

def _fitness(self, confi_ori, confi_adv, x_ori, x_adv):
"""
@@ -109,63 +123,50 @@ class PSOAttack(Attack):
fit_value = abs(
confi_ori - confi_adv) - self._c / self._pop_size*np.linalg.norm(
(x_adv - x_ori).reshape(x_adv.shape[0], -1), axis=1)
if np.max(fit_value) < 0:
self._c /= 2
return fit_value

def _mutation_op(self, cur_pop):
def _confidence_cla(self, inputs, labels):
"""
Generate mutation samples.
Calculate the prediction confidence of corresponding label or max confidence of inputs.

Args:
inputs (numpy.ndarray): Input samples.
labels (Union[numpy.int, numpy.int16, numpy.int32, numpy.int64]): Target labels.

Returns:
float, the prediction confidences of inputs.
"""
cur_pop = check_numpy_param('cur_pop', cur_pop)
perturb_noise = np.random.random(cur_pop.shape) - 0.5
mutated_pop = perturb_noise*(np.random.random(cur_pop.shape)
< self._pm) + cur_pop
mutated_pop = np.clip(mutated_pop, cur_pop*(1 - self._per_bounds),
cur_pop*(1 + self._per_bounds))
return mutated_pop
check_numpy_param('inputs', inputs)
check_param_multi_types('labels', labels, (np.int, np.int16, np.int32, np.int64))
confidences = self._model.predict(inputs)
if self._targeted:
confi_choose = confidences[:, labels]
else:
confi_choose = np.max(confidences, axis=1)
return confi_choose

def _reduction(self, x_ori, q_times, label, best_position):
def _mutation_op(self, cur_pop):
"""
Decrease the differences between the original samples and adversarial samples.
Generate mutation samples.

Args:
x_ori (numpy.ndarray): Original samples.
q_times (int): Query times.
label (int): Target label ot ground-truth label.
best_position (numpy.ndarray): Adversarial examples.
cur_pop (numpy.ndarray): Inputs before mutation operation.

Returns:
numpy.ndarray, adversarial examples after reduction.

Examples:
>>> adv_reduction = self._reduction(self, [0.1, 0.2, 0.3], 20, 1,
>>> [0.12, 0.15, 0.25])
numpy.ndarray, mutational inputs.
"""
x_ori = check_numpy_param('x_ori', x_ori)
best_position = check_numpy_param('best_position', best_position)
x_ori, best_position = check_equal_shape('x_ori', x_ori,
'best_position', best_position)
x_ori_fla = x_ori.flatten()
best_position_fla = best_position.flatten()
# LOGGER.info(TAG, 'Mutation happens...')
pixel_deep = self._bounds[1] - self._bounds[0]
nums_pixel = len(x_ori_fla)
for i in range(nums_pixel):
diff = x_ori_fla[i] - best_position_fla[i]
if abs(diff) > pixel_deep*0.1:
old_poi_fla = np.copy(best_position_fla)
best_position_fla[i] = np.clip(
best_position_fla[i] + diff*0.5,
self._bounds[0], self._bounds[1])
cur_label = np.argmax(
self._model.predict(np.expand_dims(
best_position_fla.reshape(x_ori.shape), axis=0))[0])
q_times += 1
if self._targeted:
if cur_label != label:
best_position_fla = old_poi_fla
else:
if cur_label == label:
best_position_fla = old_poi_fla
return best_position_fla.reshape(x_ori.shape), q_times
cur_pop = check_numpy_param('cur_pop', cur_pop)
perturb_noise = (np.random.random(cur_pop.shape) - 0.5)*pixel_deep
mutated_pop = perturb_noise + cur_pop
if self._model_type == 'classification':
mutated_pop = np.clip(np.clip(mutated_pop, cur_pop - self._per_bounds*np.abs(cur_pop),
cur_pop + self._per_bounds*np.abs(cur_pop)),
self._bounds[0], self._bounds[1])
return mutated_pop

def generate(self, inputs, labels):
"""
@@ -173,8 +174,10 @@ class PSOAttack(Attack):
labels (or ground_truth labels).

Args:
inputs (numpy.ndarray): Input samples.
labels (numpy.ndarray): Targeted labels or ground_truth labels.
inputs (Union[numpy.ndarray, tuple]): Input samples. The format of inputs can be (input1, input2, ...)
or only one array if model_type='detection'.
labels (Union[numpy.ndarray, tuple]): Targeted labels or ground_truth labels. The format of labels
should be (gt_boxes, gt_labels) if model_type='detection'.

Returns:
- numpy.ndarray, bool values for each attack result.
@@ -187,42 +190,51 @@ class PSOAttack(Attack):
>>> advs = attack.generate([[0.2, 0.3, 0.4], [0.3, 0.3, 0.2]],
>>> [1, 2])
"""
inputs, labels = check_pair_numpy_param('inputs', inputs,
'labels', labels)
if not self._sparse:
labels = np.argmax(labels, axis=1)
# inputs check
if self._model_type == 'classification':
inputs, labels = check_pair_numpy_param('inputs', inputs,
'labels', labels)
if not self._sparse:
labels = np.argmax(labels, axis=1)
images = inputs
elif self._model_type == 'detection':
images, auxiliary_inputs, gt_boxes, gt_labels = check_detection_inputs(inputs, labels)

# generate one adversarial each time
if self._targeted:
target_labels = labels
adv_list = []
success_list = []
query_times_list = []
pixel_deep = self._bounds[1] - self._bounds[0]
for i in range(inputs.shape[0]):
for i in range(images.shape[0]):
is_success = False
q_times = 0
x_ori = inputs[i]
confidences = self._model.predict(np.expand_dims(x_ori, axis=0))[0]
x_ori = images[i]
if not self._bounds:
self._bounds = [np.min(x_ori), np.max(x_ori)]
pixel_deep = self._bounds[1] - self._bounds[0]

q_times += 1
true_label = labels[i]
if self._targeted:
t_label = target_labels[i]
confi_ori = confidences[t_label]
else:
confi_ori = max(confidences)
if self._model_type == 'classification':
label_i = labels[i]
confi_ori = self._confidence_cla(x_ori, label_i)
elif self._model_type == 'detection':
auxiliary_input_i = tuple()
for item in auxiliary_inputs:
auxiliary_input_i += (np.expand_dims(item[i], axis=0),)
gt_boxes_i, gt_labels_i = gt_boxes[i], gt_labels[i]
inputs_i = (images[i],) + auxiliary_input_i
confi_ori, gt_object_num = self._detection_scores(inputs_i, gt_boxes_i, gt_labels_i, self._model)
LOGGER.info(TAG, 'The number of ground-truth objects is %s', gt_object_num[0])

# step1, initializing
# initial global optimum fitness value, cannot set to be 0
# initial global optimum fitness value, cannot set to be -inf
best_fitness = -np.inf
# initial global optimum position
best_position = x_ori
x_copies = np.repeat(x_ori[np.newaxis, :], self._pop_size, axis=0)
cur_noise = np.clip((np.random.random(x_copies.shape) - 0.5)
*self._step_size,
(0 - self._per_bounds)*(x_copies + 0.1),
self._per_bounds*(x_copies + 0.1))
par = np.clip(x_copies + cur_noise,
x_copies*(1 - self._per_bounds),
x_copies*(1 + self._per_bounds))
cur_noise = np.clip((np.random.random(x_copies.shape) - 0.5)*pixel_deep*self._step_size,
(0 - self._per_bounds)*(np.abs(x_copies) + 0.1),
self._per_bounds*(np.abs(x_copies) + 0.1))
par = np.clip(x_copies + cur_noise, self._bounds[0], self._bounds[1])
# initial advs
par_ori = np.copy(par)
# initial optimum positions for particles
@@ -241,17 +253,17 @@ class PSOAttack(Attack):
v_particles = self._step_size*(
v_particles + self._c1*ran_1*(best_position - par)) \
+ self._c2*ran_2*(par_best_poi - par)
par = np.clip(par + v_particles,
(par_ori + 0.1*pixel_deep)*(
1 - self._per_bounds),
(par_ori + 0.1*pixel_deep)*(
1 + self._per_bounds))
if iters > 30 and is_mutation:
par = np.clip(np.clip(par + v_particles,
par_ori - (np.abs(par_ori) + 0.1*pixel_deep)*self._per_bounds,
par_ori + (np.abs(par_ori) + 0.1*pixel_deep)*self._per_bounds),
self._bounds[0], self._bounds[1])
if iters > 20 and is_mutation:
par = self._mutation_op(par)
if self._targeted:
confi_adv = self._model.predict(par)[:, t_label]
else:
confi_adv = np.max(self._model.predict(par), axis=1)
if self._model_type == 'classification':
confi_adv = self._confidence_cla(par, label_i)
elif self._model_type == 'detection':
confi_adv, _ = self._detection_scores(
(par,) + auxiliary_input_i, gt_boxes_i, gt_labels_i, self._model)
q_times += self._pop_size
fit_value = self._fitness(confi_ori, confi_adv, x_ori, par)
for k in range(self._pop_size):
@@ -262,30 +274,36 @@ class PSOAttack(Attack):
best_fitness = fit_value[k]
best_position = par[k]
iters += 1
cur_pre = self._model.predict(np.expand_dims(best_position,
axis=0))[0]
is_mutation = False
if (best_fitness - last_best_fit) < last_best_fit*0.05:
is_mutation = True
cur_label = np.argmax(cur_pre)
q_times += 1
if self._targeted:
if cur_label == t_label:
if self._model_type == 'classification':
cur_pre = self._model.predict(best_position)
cur_label = np.argmax(cur_pre)
if (self._targeted and cur_label == label_i) or (not self._targeted and cur_label != label_i):
is_success = True
else:
if cur_label != true_label:
elif self._model_type == 'detection':
_, correct_nums_adv = self._detection_scores(
(best_position,) + auxiliary_input_i, gt_boxes_i, gt_labels_i, self._model)
LOGGER.info(TAG, 'The number of correctly detected objects in adversarial image is %s',
correct_nums_adv[0])
if correct_nums_adv <= int(gt_object_num*self._reserve_ratio):
is_success = True

if is_success:
LOGGER.debug(TAG, 'successfully find one adversarial '
'sample and start Reduction process')
# step3, reduction
if self._targeted:
best_position, q_times = self._reduction(
x_ori, q_times, t_label, best_position)
else:
best_position, q_times = self._reduction(
x_ori, q_times, true_label, best_position)
if self._model_type == 'classification':
best_position, q_times = self._reduction(x_ori, q_times, label_i, best_position, self._model,
targeted_attack=self._targeted)

break
if self._model_type == 'detection':
best_position, q_times = self._fast_reduction(x_ori, best_position, q_times,
auxiliary_input_i, gt_boxes_i, gt_labels_i, self._model)
if not is_success:
LOGGER.debug(TAG,
'fail to find adversarial sample, iteration '


+ 58
- 0
mindarmour/utils/_check_param.py View File

@@ -267,3 +267,61 @@ def normalize_value(value, norm_level):
LOGGER.error(TAG, msg)
raise NotImplementedError(msg)
return norm_value.reshape(ori_shape)


def check_detection_inputs(inputs, labels):
"""
Check the inputs for detection model attacks.

Args:
inputs (Union[numpy.ndarray, tuple]): Images and other auxiliary inputs for detection model.
labels (tuple): Ground-truth boxes and ground-truth labels of inputs.

Returns:
- numpy.ndarray, images data.

- tuple, auxiliary inputs, such as image shape.

- numpy.ndarray, ground-truth boxes.

- numpy.ndarray, ground-truth labels.
"""
if isinstance(inputs, tuple):
has_images = False
auxiliary_inputs = tuple()
for item in inputs:
check_numpy_param('item', item)
if len(item.shape) == 4:
images = item
has_images = True
else:
auxiliary_inputs += (item,)
if not has_images:
msg = 'Inputs should contain images whose dimension is 4.'
LOGGER.error(TAG, msg)
raise ValueError(msg)
else:
check_numpy_param('inputs', inputs)

check_param_type('labels', labels, tuple)
if len(labels) != 2:
msg = 'Labels should contain two arrays (boxes-confidences array and ground-truth labels array), ' \
'but got {} arrays.'.format(len(labels))
LOGGER.error(TAG, msg)
raise ValueError(msg)
has_boxes = False
has_labels = False
for item in labels:
check_numpy_param('item', item)
if len(item.shape) == 3 and item.shape[2] == 5:
gt_boxes = item
has_boxes = True
elif len(item.shape) == 2:
gt_labels = item
has_labels = True
if (not has_boxes) or (not has_labels):
msg = 'The shape of boxes array and ground-truth labels array should be (N, M, 5) and (N, M), respectively. ' \
'But got {} and {}.'.format(labels[0].shape, labels[1].shape)
LOGGER.error(TAG, msg)
raise ValueError(msg)
return images, auxiliary_inputs, gt_boxes, gt_labels

+ 35
- 0
mindarmour/utils/util.py View File

@@ -17,6 +17,8 @@ from mindspore import Tensor
from mindspore.nn import Cell
from mindspore.ops.composite import GradOperation

from mindarmour.utils._check_param import check_numpy_param

from .logger import LogUtil

LOGGER = LogUtil.get_instance()
@@ -164,3 +166,36 @@ class GradWrap(Cell):
"""
gout = self.grad(self.network)(inputs, weight)
return gout


def calculate_iou(box_i, box_j):
"""
Calculate the intersection over union (iou) of two boxes.

Args:
box_i (numpy.ndarray): Coordinates of the first box, with the format as (x1, y1, x2, y2).
(x1, y1) and (x2, y2) are coordinates of the lower left corner and the upper right corner,
respectively.
box_j: (numpy.ndarray): Coordinates of the second box, with the format as (x1, y1, x2, y2).

Returns:
float, iou of two input boxes.
"""
check_numpy_param('box_i', box_i)
check_numpy_param('box_j', box_j)
if box_i.shape[-1] != 4 or box_j.shape[-1] != 4:
msg = 'The length of both coordinate arrays should be 4, bug got {} and {}.'.format(box_i.shape, box_j.shape)
LOGGER.error(TAG, msg)
raise ValueError(msg)
i_x1, i_y1, i_x2, i_y2 = box_i
j_x1, j_y1, j_x2, j_y2 = box_j
s_i = (i_x2 - i_x1)*(i_y2 - i_y1)
s_j = (j_x2 - j_x1)*(j_y2 - j_y1)
inner_left_line = max(i_x1, j_x1)
inner_right_line = min(i_x2, j_x2)
inner_top_line = min(i_y2, j_y2)
inner_bottom_line = max(i_y1, j_y1)
if inner_left_line >= inner_right_line or inner_top_line <= inner_bottom_line:
return 0
inner_area = (inner_right_line - inner_left_line)*(inner_top_line - inner_bottom_line)
return inner_area / (s_i + s_j - inner_area)

+ 3
- 0
tests/ut/python/adv_robustness/attacks/black/test_pso_attack.py View File

@@ -36,6 +36,9 @@ class ModelToBeAttacked(BlackModel):

def predict(self, inputs):
"""predict"""
# Adapt to the input shape requirements of the target network if inputs is only one image.
if len(inputs.shape) == 1:
inputs = np.expand_dims(inputs, axis=0)
result = self._network(Tensor(inputs.astype(np.float32)))
return result.asnumpy()



Loading…
Cancel
Save