diff --git a/mindarmour/adv_robustness/attacks/attack.py b/mindarmour/adv_robustness/attacks/attack.py index 6c8bc43..93f1dfc 100644 --- a/mindarmour/adv_robustness/attacks/attack.py +++ b/mindarmour/adv_robustness/attacks/attack.py @@ -182,11 +182,11 @@ class Attack: best_position = check_numpy_param('best_position', best_position) x_ori, best_position = check_equal_shape('x_ori', x_ori, 'best_position', best_position) x_shape = best_position.shape - reduction_iters = 10000 # recover 0.01% each step + reduction_iters = 1000 # recover 0.1% each step _, original_num = self._detection_scores((best_position,) + auxiliary_inputs, gt_boxes, gt_labels, model) for _ in range(reduction_iters): diff = x_ori - best_position - res = 0.5*diff*(np.random.random(x_shape) < 0.0001) + res = 0.5*diff*(np.random.random(x_shape) < 0.001) best_position += res _, correct_num = self._detection_scores((best_position,) + auxiliary_inputs, gt_boxes, gt_labels, model) q_times += 1 diff --git a/mindarmour/adv_robustness/attacks/black/genetic_attack.py b/mindarmour/adv_robustness/attacks/black/genetic_attack.py index 9aa8a8c..f744d65 100644 --- a/mindarmour/adv_robustness/attacks/black/genetic_attack.py +++ b/mindarmour/adv_robustness/attacks/black/genetic_attack.py @@ -46,16 +46,16 @@ class GeneticAttack(Attack): targeted (bool): If True, turns on the targeted attack. If False, turns on untargeted attack. It should be noted that only untargeted attack is supproted for model_type='detection', Default: False. - reserve_ratio (float): The percentage of objects that can be detected after attacks, specifically for - model_type='detection'. Default: 0.3. + reserve_ratio (Union[int, float]): The percentage of objects that can be detected after attacks, + specifically for model_type='detection'. Default: 0.3. pop_size (int): The number of particles, which should be greater than zero. Default: 6. - mutation_rate (float): The probability of mutations. Default: 0.005. - per_bounds (float): Maximum L_inf distance. + mutation_rate (Union[int, float]): The probability of mutations. Default: 0.005. + per_bounds (Union[int, float]): Maximum L_inf distance. max_steps (int): The maximum round of iteration for each adversarial example. Default: 1000. - step_size (float): Attack step size. Default: 0.2. - temp (float): Sampling temperature for selection. Default: 0.3. + step_size (Union[int, float]): Attack step size. Default: 0.2. + temp (Union[int, float]): Sampling temperature for selection. Default: 0.3. The greater the temp, the greater the differences between individuals' selecting probabilities. bounds (Union[tuple, list, None]): Upper and lower bounds of data. In form @@ -65,7 +65,7 @@ class GeneticAttack(Attack): Default: False. sparse (bool): If True, input labels are sparse-encoded. If False, input labels are one-hot-encoded. Default: True. - c (float): Weight of perturbation loss. Default: 0.1. + c (Union[int, float]): Weight of perturbation loss. Default: 0.1. Examples: >>> attack = GeneticAttack(model) @@ -76,6 +76,10 @@ class GeneticAttack(Attack): super(GeneticAttack, self).__init__() self._model = check_model('model', model, BlackModel) self._model_type = check_param_type('model_type', model_type, str) + if self._model_type not in ('classification', 'detection'): + msg = "Only 'classification' or 'detection' is supported now, but got {}.".format(self._model_type) + LOGGER.error(TAG, msg) + raise ValueError(msg) self._targeted = check_param_type('targeted', targeted, bool) self._reserve_ratio = check_value_non_negative('reserve_ratio', reserve_ratio) if self._reserve_ratio > 1: @@ -153,10 +157,14 @@ class GeneticAttack(Attack): if self._model_type == 'classification': inputs, labels = check_pair_numpy_param('inputs', inputs, 'labels', labels) - if not self._sparse: - if labels.ndim != 2: - raise ValueError('labels must be 2 dims, ' - 'but got {} dims.'.format(labels.ndim)) + if self._sparse: + label_squ = np.squeeze(labels) + if len(label_squ.shape) >= 2 or label_squ.shape[0] != inputs.shape[0]: + msg = "The parameter 'sparse' of GeneticAttack is True, but the input labels is not sparse style " \ + "and got its shape as {}.".format(labels.shape) + LOGGER.error(TAG, msg) + raise ValueError + else: labels = np.argmax(labels, axis=1) images = inputs elif self._model_type == 'detection': diff --git a/mindarmour/adv_robustness/attacks/black/pso_attack.py b/mindarmour/adv_robustness/attacks/black/pso_attack.py index 46f8d0f..46e9e40 100644 --- a/mindarmour/adv_robustness/attacks/black/pso_attack.py +++ b/mindarmour/adv_robustness/attacks/black/pso_attack.py @@ -41,17 +41,17 @@ class PSOAttack(Attack): Args: model (BlackModel): Target model. - step_size (float): Attack step size. Default: 0.5. - per_bounds (float): Relative variation range of perturbations. Default: 0.6. - c1 (float): Weight coefficient. Default: 2. - c2 (float): Weight coefficient. Default: 2. - c (float): Weight of perturbation loss. Default: 2. + step_size (Union[int, float]): Attack step size. Default: 0.5. + per_bounds (Union[int, float]): Relative variation range of perturbations. Default: 0.6. + c1 (Union[int, float]): Weight coefficient. Default: 2. + c2 (Union[int, float]): Weight coefficient. Default: 2. + c (Union[int, float]): Weight of perturbation loss. Default: 2. pop_size (int): The number of particles, which should be greater than zero. Default: 6. t_max (int): The maximum round of iteration for each adversarial example, which should be greater than zero. Default: 1000. - pm (float): The probability of mutations. Default: 0.5. - bounds (tuple): Upper and lower bounds of data. In form of (clip_min, + pm (Union[int, float]): The probability of mutations. Default: 0.5. + bounds (Union[list, tuple, None]): Upper and lower bounds of data. In form of (clip_min, clip_max). Default: None. targeted (bool): If True, turns on the targeted attack. If False, turns on untargeted attack. It should be noted that only untargeted attack @@ -60,8 +60,8 @@ class PSOAttack(Attack): input labels are one-hot-encoded. Default: True. model_type (str): The type of targeted model. 'classification' and 'detection' are supported now. default: 'classification'. - reserve_ratio (float): The percentage of objects that can be detected after attacks, specifically for - model_type='detection'. Default: 0.3. + reserve_ratio (Union[int, float]): The percentage of objects that can be detected after attacks, + specifically for model_type='detection'. Default: 0.3. Examples: >>> attack = PSOAttack(model) @@ -161,7 +161,7 @@ class PSOAttack(Attack): pixel_deep = self._bounds[1] - self._bounds[0] cur_pop = check_numpy_param('cur_pop', cur_pop) perturb_noise = (np.random.random(cur_pop.shape) - 0.5)*pixel_deep - mutated_pop = perturb_noise + cur_pop + mutated_pop = perturb_noise*(np.random.random(cur_pop.shape) < self._pm) + cur_pop if self._model_type == 'classification': mutated_pop = np.clip(np.clip(mutated_pop, cur_pop - self._per_bounds*np.abs(cur_pop), cur_pop + self._per_bounds*np.abs(cur_pop)), @@ -194,7 +194,14 @@ class PSOAttack(Attack): if self._model_type == 'classification': inputs, labels = check_pair_numpy_param('inputs', inputs, 'labels', labels) - if not self._sparse: + if self._sparse: + label_squ = np.squeeze(labels) + if len(label_squ.shape) >= 2 or label_squ.shape[0] != inputs.shape[0]: + msg = "The parameter 'sparse' of PSOAttack is True, but the input labels is not sparse style and " \ + "got its shape as {}.".format(labels.shape) + LOGGER.error(TAG, msg) + raise ValueError + else: labels = np.argmax(labels, axis=1) images = inputs elif self._model_type == 'detection': diff --git a/mindarmour/utils/_check_param.py b/mindarmour/utils/_check_param.py index 92698eb..1adf18d 100644 --- a/mindarmour/utils/_check_param.py +++ b/mindarmour/utils/_check_param.py @@ -302,6 +302,8 @@ def check_detection_inputs(inputs, labels): raise ValueError(msg) else: check_numpy_param('inputs', inputs) + images = inputs + auxiliary_inputs = () check_param_type('labels', labels, tuple) if len(labels) != 2: diff --git a/tests/ut/python/adv_robustness/attacks/black/test_genetic_attack.py b/tests/ut/python/adv_robustness/attacks/black/test_genetic_attack.py index 27a7201..76f860d 100644 --- a/tests/ut/python/adv_robustness/attacks/black/test_genetic_attack.py +++ b/tests/ut/python/adv_robustness/attacks/black/test_genetic_attack.py @@ -24,8 +24,6 @@ from mindspore.nn import Cell from mindarmour import BlackModel from mindarmour.adv_robustness.attacks import GeneticAttack -context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") - # for user class ModelToBeAttacked(BlackModel): @@ -41,6 +39,28 @@ class ModelToBeAttacked(BlackModel): return result.asnumpy() +class DetectionModel(BlackModel): + """model to be attack""" + + def predict(self, inputs): + """predict""" + # Adapt to the input shape requirements of the target network if inputs is only one image. + if len(inputs.shape) == 3: + inputs_num = 1 + else: + inputs_num = inputs.shape[0] + box_and_confi = [] + pred_labels = [] + gt_number = np.random.randint(1, 128) + + for _ in range(inputs_num): + boxes_i = np.random.random((gt_number, 5)) + labels_i = np.random.randint(0, 10, gt_number) + box_and_confi.append(boxes_i) + pred_labels.append(labels_i) + return np.array(box_and_confi), np.array(pred_labels) + + class SimpleNet(Cell): """ Construct the network of target model. @@ -76,6 +96,7 @@ def test_genetic_attack(): """ Genetic_Attack test """ + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") batch_size = 6 net = SimpleNet() @@ -98,6 +119,7 @@ def test_genetic_attack(): @pytest.mark.env_card @pytest.mark.component_mindarmour def test_supplement(): + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") batch_size = 6 net = SimpleNet() @@ -123,6 +145,7 @@ def test_supplement(): @pytest.mark.component_mindarmour def test_value_error(): """test that exception is raised for invalid labels""" + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") batch_size = 6 net = SimpleNet() @@ -140,3 +163,29 @@ def test_value_error(): # raise error with pytest.raises(ValueError): assert attack.generate(inputs, labels) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_card +@pytest.mark.component_mindarmour +def test_genetic_attack_detection_cpu(): + """ + Genetic_Attack test + """ + context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + batch_size = 2 + inputs = np.random.random((batch_size, 3, 28, 28)) + model = DetectionModel() + attack = GeneticAttack(model, model_type='detection', pop_size=6, mutation_rate=0.05, + per_bounds=0.1, step_size=0.25, temp=0.1, + sparse=False, max_steps=50) + + # generate adversarial samples + adv_imgs = [] + for i in range(batch_size): + img_data = np.expand_dims(inputs[i], axis=0) + pre_gt_boxes, pre_gt_labels = model.predict(inputs) + _, adv_img, _ = attack.generate(img_data, (pre_gt_boxes, pre_gt_labels)) + adv_imgs.append(adv_img) + assert np.any(inputs != np.array(adv_imgs)) diff --git a/tests/ut/python/adv_robustness/attacks/black/test_pso_attack.py b/tests/ut/python/adv_robustness/attacks/black/test_pso_attack.py index 141785d..d3da309 100644 --- a/tests/ut/python/adv_robustness/attacks/black/test_pso_attack.py +++ b/tests/ut/python/adv_robustness/attacks/black/test_pso_attack.py @@ -43,6 +43,28 @@ class ModelToBeAttacked(BlackModel): return result.asnumpy() +class DetectionModel(BlackModel): + """model to be attack""" + + def predict(self, inputs): + """predict""" + # Adapt to the input shape requirements of the target network if inputs is only one image. + if len(inputs.shape) == 3: + inputs_num = 1 + else: + inputs_num = inputs.shape[0] + box_and_confi = [] + pred_labels = [] + gt_number = np.random.randint(1, 128) + + for _ in range(inputs_num): + boxes_i = np.random.random((gt_number, 5)) + labels_i = np.random.randint(0, 10, gt_number) + box_and_confi.append(boxes_i) + pred_labels.append(labels_i) + return np.array(box_and_confi), np.array(pred_labels) + + class SimpleNet(Cell): """ Construct the network of target model. @@ -167,3 +189,27 @@ def test_pso_attack_cpu(): attack = PSOAttack(model, bounds=(0.0, 1.0), pm=0.5, sparse=False) _, adv_data, _ = attack.generate(inputs, labels) assert np.any(inputs != adv_data) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_card +@pytest.mark.component_mindarmour +def test_pso_attack_detection_cpu(): + """ + PSO_Attack test + """ + context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + batch_size = 2 + inputs = np.random.random((batch_size, 3, 28, 28)) + model = DetectionModel() + attack = PSOAttack(model, t_max=30, pm=0.5, model_type='detection', reserve_ratio=0.5) + + # generate adversarial samples + adv_imgs = [] + for i in range(batch_size): + img_data = np.expand_dims(inputs[i], axis=0) + pre_gt_boxes, pre_gt_labels = model.predict(inputs) + _, adv_img, _ = attack.generate(img_data, (pre_gt_boxes, pre_gt_labels)) + adv_imgs.append(adv_img) + assert np.any(inputs != np.array(adv_imgs))