From d0d36a5085c7da2546906004bc895646ca08767d Mon Sep 17 00:00:00 2001 From: jin-xiulang Date: Wed, 23 Dec 2020 09:47:19 +0800 Subject: [PATCH] Fix an issue. --- .../model_attacks/black_box/mnist_attack_pointwise.py | 2 +- examples/model_security/model_defenses/mnist_evaluation.py | 1 + .../model_security/model_defenses/mnist_similarity_detector.py | 1 + mindarmour/adv_robustness/attacks/deep_fool.py | 10 ++++++---- 4 files changed, 9 insertions(+), 5 deletions(-) diff --git a/examples/model_security/model_attacks/black_box/mnist_attack_pointwise.py b/examples/model_security/model_attacks/black_box/mnist_attack_pointwise.py index f835e58..3aa41de 100644 --- a/examples/model_security/model_attacks/black_box/mnist_attack_pointwise.py +++ b/examples/model_security/model_attacks/black_box/mnist_attack_pointwise.py @@ -110,7 +110,7 @@ def test_pointwise_attack_on_mnist(): test_labels_onehot = np.eye(10)[true_labels] attack_evaluate = AttackEvaluate(np.concatenate(test_images), test_labels_onehot, adv_data, - adv_preds, targeted=is_target, + np.array(adv_preds), targeted=is_target, target_label=targeted_labels) LOGGER.info(TAG, 'mis-classification rate of adversaries is : %s', attack_evaluate.mis_classification_rate()) diff --git a/examples/model_security/model_defenses/mnist_evaluation.py b/examples/model_security/model_defenses/mnist_evaluation.py index 893e135..b6dc7ab 100644 --- a/examples/model_security/model_defenses/mnist_evaluation.py +++ b/examples/model_security/model_defenses/mnist_evaluation.py @@ -39,6 +39,7 @@ from examples.common.dataset.data_processing import generate_mnist_dataset from examples.common.networks.lenet5.lenet5_net import LeNet5 LOGGER = LogUtil.get_instance() +LOGGER.set_level('INFO') TAG = 'Defense_Evaluate_Example' diff --git a/examples/model_security/model_defenses/mnist_similarity_detector.py b/examples/model_security/model_defenses/mnist_similarity_detector.py index a3b9993..253a20e 100644 --- a/examples/model_security/model_defenses/mnist_similarity_detector.py +++ b/examples/model_security/model_defenses/mnist_similarity_detector.py @@ -30,6 +30,7 @@ from examples.common.dataset.data_processing import generate_mnist_dataset from examples.common.networks.lenet5.lenet5_net import LeNet5 LOGGER = LogUtil.get_instance() +LOGGER.set_level('INFO') TAG = 'Similarity Detector test' diff --git a/mindarmour/adv_robustness/attacks/deep_fool.py b/mindarmour/adv_robustness/attacks/deep_fool.py index d98034e..db8cca3 100644 --- a/mindarmour/adv_robustness/attacks/deep_fool.py +++ b/mindarmour/adv_robustness/attacks/deep_fool.py @@ -152,10 +152,12 @@ class DeepFool(Attack): Generate adversarial examples based on input samples and original labels. Args: - inputs (Union[numpy.ndarray, tuple]): Input samples. The format of inputs can be (inputs1, input2, ...) \ - or only one array if model_type='detection' - labels (Union[numpy.ndarray, tuple]): Original labels. The format of labels should be \ - (gt_boxes, gt_labels) if model_type='detection'. + inputs (Union[numpy.ndarray, tuple]): Input samples. The format of inputs should be numpy.ndarray if + model_type='classification'. The format of inputs can be (input1, input2, ...) or only one array if + model_type='detection'. + labels (Union[numpy.ndarray, tuple]): Targeted labels or ground-truth labels. The format of labels should + be numpy.ndarray if model_type='classification'. The format of labels should be (gt_boxes, gt_labels) + if model_type='detection'. Returns: numpy.ndarray, adversarial examples.