@@ -182,11 +182,11 @@ class Attack: | |||
best_position = check_numpy_param('best_position', best_position) | |||
x_ori, best_position = check_equal_shape('x_ori', x_ori, 'best_position', best_position) | |||
x_shape = best_position.shape | |||
reduction_iters = 10000 # recover 0.01% each step | |||
reduction_iters = 1000 # recover 0.1% each step | |||
_, original_num = self._detection_scores((best_position,) + auxiliary_inputs, gt_boxes, gt_labels, model) | |||
for _ in range(reduction_iters): | |||
diff = x_ori - best_position | |||
res = 0.5*diff*(np.random.random(x_shape) < 0.0001) | |||
res = 0.5*diff*(np.random.random(x_shape) < 0.001) | |||
best_position += res | |||
_, correct_num = self._detection_scores((best_position,) + auxiliary_inputs, gt_boxes, gt_labels, model) | |||
q_times += 1 | |||
@@ -46,16 +46,16 @@ class GeneticAttack(Attack): | |||
targeted (bool): If True, turns on the targeted attack. If False, | |||
turns on untargeted attack. It should be noted that only untargeted attack | |||
is supproted for model_type='detection', Default: False. | |||
reserve_ratio (float): The percentage of objects that can be detected after attacks, specifically for | |||
model_type='detection'. Default: 0.3. | |||
reserve_ratio (Union[int, float]): The percentage of objects that can be detected after attacks, | |||
specifically for model_type='detection'. Default: 0.3. | |||
pop_size (int): The number of particles, which should be greater than | |||
zero. Default: 6. | |||
mutation_rate (float): The probability of mutations. Default: 0.005. | |||
per_bounds (float): Maximum L_inf distance. | |||
mutation_rate (Union[int, float]): The probability of mutations. Default: 0.005. | |||
per_bounds (Union[int, float]): Maximum L_inf distance. | |||
max_steps (int): The maximum round of iteration for each adversarial | |||
example. Default: 1000. | |||
step_size (float): Attack step size. Default: 0.2. | |||
temp (float): Sampling temperature for selection. Default: 0.3. | |||
step_size (Union[int, float]): Attack step size. Default: 0.2. | |||
temp (Union[int, float]): Sampling temperature for selection. Default: 0.3. | |||
The greater the temp, the greater the differences between individuals' | |||
selecting probabilities. | |||
bounds (Union[tuple, list, None]): Upper and lower bounds of data. In form | |||
@@ -65,7 +65,7 @@ class GeneticAttack(Attack): | |||
Default: False. | |||
sparse (bool): If True, input labels are sparse-encoded. If False, | |||
input labels are one-hot-encoded. Default: True. | |||
c (float): Weight of perturbation loss. Default: 0.1. | |||
c (Union[int, float]): Weight of perturbation loss. Default: 0.1. | |||
Examples: | |||
>>> attack = GeneticAttack(model) | |||
@@ -76,6 +76,10 @@ class GeneticAttack(Attack): | |||
super(GeneticAttack, self).__init__() | |||
self._model = check_model('model', model, BlackModel) | |||
self._model_type = check_param_type('model_type', model_type, str) | |||
if self._model_type not in ('classification', 'detection'): | |||
msg = "Only 'classification' or 'detection' is supported now, but got {}.".format(self._model_type) | |||
LOGGER.error(TAG, msg) | |||
raise ValueError(msg) | |||
self._targeted = check_param_type('targeted', targeted, bool) | |||
self._reserve_ratio = check_value_non_negative('reserve_ratio', reserve_ratio) | |||
if self._reserve_ratio > 1: | |||
@@ -153,10 +157,14 @@ class GeneticAttack(Attack): | |||
if self._model_type == 'classification': | |||
inputs, labels = check_pair_numpy_param('inputs', inputs, | |||
'labels', labels) | |||
if not self._sparse: | |||
if labels.ndim != 2: | |||
raise ValueError('labels must be 2 dims, ' | |||
'but got {} dims.'.format(labels.ndim)) | |||
if self._sparse: | |||
label_squ = np.squeeze(labels) | |||
if len(label_squ.shape) >= 2 or label_squ.shape[0] != inputs.shape[0]: | |||
msg = "The parameter 'sparse' of GeneticAttack is True, but the input labels is not sparse style " \ | |||
"and got its shape as {}.".format(labels.shape) | |||
LOGGER.error(TAG, msg) | |||
raise ValueError | |||
else: | |||
labels = np.argmax(labels, axis=1) | |||
images = inputs | |||
elif self._model_type == 'detection': | |||
@@ -41,17 +41,17 @@ class PSOAttack(Attack): | |||
Args: | |||
model (BlackModel): Target model. | |||
step_size (float): Attack step size. Default: 0.5. | |||
per_bounds (float): Relative variation range of perturbations. Default: 0.6. | |||
c1 (float): Weight coefficient. Default: 2. | |||
c2 (float): Weight coefficient. Default: 2. | |||
c (float): Weight of perturbation loss. Default: 2. | |||
step_size (Union[int, float]): Attack step size. Default: 0.5. | |||
per_bounds (Union[int, float]): Relative variation range of perturbations. Default: 0.6. | |||
c1 (Union[int, float]): Weight coefficient. Default: 2. | |||
c2 (Union[int, float]): Weight coefficient. Default: 2. | |||
c (Union[int, float]): Weight of perturbation loss. Default: 2. | |||
pop_size (int): The number of particles, which should be greater | |||
than zero. Default: 6. | |||
t_max (int): The maximum round of iteration for each adversarial example, | |||
which should be greater than zero. Default: 1000. | |||
pm (float): The probability of mutations. Default: 0.5. | |||
bounds (tuple): Upper and lower bounds of data. In form of (clip_min, | |||
pm (Union[int, float]): The probability of mutations. Default: 0.5. | |||
bounds (Union[list, tuple, None]): Upper and lower bounds of data. In form of (clip_min, | |||
clip_max). Default: None. | |||
targeted (bool): If True, turns on the targeted attack. If False, | |||
turns on untargeted attack. It should be noted that only untargeted attack | |||
@@ -60,8 +60,8 @@ class PSOAttack(Attack): | |||
input labels are one-hot-encoded. Default: True. | |||
model_type (str): The type of targeted model. 'classification' and 'detection' are supported now. | |||
default: 'classification'. | |||
reserve_ratio (float): The percentage of objects that can be detected after attacks, specifically for | |||
model_type='detection'. Default: 0.3. | |||
reserve_ratio (Union[int, float]): The percentage of objects that can be detected after attacks, | |||
specifically for model_type='detection'. Default: 0.3. | |||
Examples: | |||
>>> attack = PSOAttack(model) | |||
@@ -161,7 +161,7 @@ class PSOAttack(Attack): | |||
pixel_deep = self._bounds[1] - self._bounds[0] | |||
cur_pop = check_numpy_param('cur_pop', cur_pop) | |||
perturb_noise = (np.random.random(cur_pop.shape) - 0.5)*pixel_deep | |||
mutated_pop = perturb_noise + cur_pop | |||
mutated_pop = perturb_noise*(np.random.random(cur_pop.shape) < self._pm) + cur_pop | |||
if self._model_type == 'classification': | |||
mutated_pop = np.clip(np.clip(mutated_pop, cur_pop - self._per_bounds*np.abs(cur_pop), | |||
cur_pop + self._per_bounds*np.abs(cur_pop)), | |||
@@ -194,7 +194,14 @@ class PSOAttack(Attack): | |||
if self._model_type == 'classification': | |||
inputs, labels = check_pair_numpy_param('inputs', inputs, | |||
'labels', labels) | |||
if not self._sparse: | |||
if self._sparse: | |||
label_squ = np.squeeze(labels) | |||
if len(label_squ.shape) >= 2 or label_squ.shape[0] != inputs.shape[0]: | |||
msg = "The parameter 'sparse' of PSOAttack is True, but the input labels is not sparse style and " \ | |||
"got its shape as {}.".format(labels.shape) | |||
LOGGER.error(TAG, msg) | |||
raise ValueError | |||
else: | |||
labels = np.argmax(labels, axis=1) | |||
images = inputs | |||
elif self._model_type == 'detection': | |||
@@ -302,6 +302,8 @@ def check_detection_inputs(inputs, labels): | |||
raise ValueError(msg) | |||
else: | |||
check_numpy_param('inputs', inputs) | |||
images = inputs | |||
auxiliary_inputs = () | |||
check_param_type('labels', labels, tuple) | |||
if len(labels) != 2: | |||
@@ -24,8 +24,6 @@ from mindspore.nn import Cell | |||
from mindarmour import BlackModel | |||
from mindarmour.adv_robustness.attacks import GeneticAttack | |||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
# for user | |||
class ModelToBeAttacked(BlackModel): | |||
@@ -41,6 +39,28 @@ class ModelToBeAttacked(BlackModel): | |||
return result.asnumpy() | |||
class DetectionModel(BlackModel): | |||
"""model to be attack""" | |||
def predict(self, inputs): | |||
"""predict""" | |||
# Adapt to the input shape requirements of the target network if inputs is only one image. | |||
if len(inputs.shape) == 3: | |||
inputs_num = 1 | |||
else: | |||
inputs_num = inputs.shape[0] | |||
box_and_confi = [] | |||
pred_labels = [] | |||
gt_number = np.random.randint(1, 128) | |||
for _ in range(inputs_num): | |||
boxes_i = np.random.random((gt_number, 5)) | |||
labels_i = np.random.randint(0, 10, gt_number) | |||
box_and_confi.append(boxes_i) | |||
pred_labels.append(labels_i) | |||
return np.array(box_and_confi), np.array(pred_labels) | |||
class SimpleNet(Cell): | |||
""" | |||
Construct the network of target model. | |||
@@ -76,6 +96,7 @@ def test_genetic_attack(): | |||
""" | |||
Genetic_Attack test | |||
""" | |||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
batch_size = 6 | |||
net = SimpleNet() | |||
@@ -98,6 +119,7 @@ def test_genetic_attack(): | |||
@pytest.mark.env_card | |||
@pytest.mark.component_mindarmour | |||
def test_supplement(): | |||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
batch_size = 6 | |||
net = SimpleNet() | |||
@@ -123,6 +145,7 @@ def test_supplement(): | |||
@pytest.mark.component_mindarmour | |||
def test_value_error(): | |||
"""test that exception is raised for invalid labels""" | |||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
batch_size = 6 | |||
net = SimpleNet() | |||
@@ -140,3 +163,29 @@ def test_value_error(): | |||
# raise error | |||
with pytest.raises(ValueError): | |||
assert attack.generate(inputs, labels) | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_x86_cpu | |||
@pytest.mark.env_card | |||
@pytest.mark.component_mindarmour | |||
def test_genetic_attack_detection_cpu(): | |||
""" | |||
Genetic_Attack test | |||
""" | |||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | |||
batch_size = 2 | |||
inputs = np.random.random((batch_size, 3, 28, 28)) | |||
model = DetectionModel() | |||
attack = GeneticAttack(model, model_type='detection', pop_size=6, mutation_rate=0.05, | |||
per_bounds=0.1, step_size=0.25, temp=0.1, | |||
sparse=False, max_steps=50) | |||
# generate adversarial samples | |||
adv_imgs = [] | |||
for i in range(batch_size): | |||
img_data = np.expand_dims(inputs[i], axis=0) | |||
pre_gt_boxes, pre_gt_labels = model.predict(inputs) | |||
_, adv_img, _ = attack.generate(img_data, (pre_gt_boxes, pre_gt_labels)) | |||
adv_imgs.append(adv_img) | |||
assert np.any(inputs != np.array(adv_imgs)) |
@@ -43,6 +43,28 @@ class ModelToBeAttacked(BlackModel): | |||
return result.asnumpy() | |||
class DetectionModel(BlackModel): | |||
"""model to be attack""" | |||
def predict(self, inputs): | |||
"""predict""" | |||
# Adapt to the input shape requirements of the target network if inputs is only one image. | |||
if len(inputs.shape) == 3: | |||
inputs_num = 1 | |||
else: | |||
inputs_num = inputs.shape[0] | |||
box_and_confi = [] | |||
pred_labels = [] | |||
gt_number = np.random.randint(1, 128) | |||
for _ in range(inputs_num): | |||
boxes_i = np.random.random((gt_number, 5)) | |||
labels_i = np.random.randint(0, 10, gt_number) | |||
box_and_confi.append(boxes_i) | |||
pred_labels.append(labels_i) | |||
return np.array(box_and_confi), np.array(pred_labels) | |||
class SimpleNet(Cell): | |||
""" | |||
Construct the network of target model. | |||
@@ -167,3 +189,27 @@ def test_pso_attack_cpu(): | |||
attack = PSOAttack(model, bounds=(0.0, 1.0), pm=0.5, sparse=False) | |||
_, adv_data, _ = attack.generate(inputs, labels) | |||
assert np.any(inputs != adv_data) | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_x86_cpu | |||
@pytest.mark.env_card | |||
@pytest.mark.component_mindarmour | |||
def test_pso_attack_detection_cpu(): | |||
""" | |||
PSO_Attack test | |||
""" | |||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | |||
batch_size = 2 | |||
inputs = np.random.random((batch_size, 3, 28, 28)) | |||
model = DetectionModel() | |||
attack = PSOAttack(model, t_max=30, pm=0.5, model_type='detection', reserve_ratio=0.5) | |||
# generate adversarial samples | |||
adv_imgs = [] | |||
for i in range(batch_size): | |||
img_data = np.expand_dims(inputs[i], axis=0) | |||
pre_gt_boxes, pre_gt_labels = model.predict(inputs) | |||
_, adv_img, _ = attack.generate(img_data, (pre_gt_boxes, pre_gt_labels)) | |||
adv_imgs.append(adv_img) | |||
assert np.any(inputs != np.array(adv_imgs)) |