Merge pull request !14 from zheng-huanhuan/2_mastertags/v1.2.1
@@ -0,0 +1,118 @@ | |||||
# Copyright 2019 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
import sys | |||||
import time | |||||
import numpy as np | |||||
import pytest | |||||
from scipy.special import softmax | |||||
from mindspore import Model | |||||
from mindspore import Tensor | |||||
from mindspore import context | |||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||||
from mindarmour.attacks.iterative_gradient_method import MomentumDiverseInputIterativeMethod | |||||
from mindarmour.utils.logger import LogUtil | |||||
from mindarmour.evaluations.attack_evaluation import AttackEvaluate | |||||
from lenet5_net import LeNet5 | |||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||||
sys.path.append("..") | |||||
from data_processing import generate_mnist_dataset | |||||
LOGGER = LogUtil.get_instance() | |||||
TAG = 'M_DI2_FGSM_Test' | |||||
LOGGER.set_level('INFO') | |||||
@pytest.mark.level1 | |||||
@pytest.mark.platform_arm_ascend_training | |||||
@pytest.mark.platform_x86_ascend_training | |||||
@pytest.mark.env_card | |||||
@pytest.mark.component_mindarmour | |||||
def test_momentum_diverse_input_iterative_method(): | |||||
""" | |||||
M-DI2-FGSM Attack Test | |||||
""" | |||||
# upload trained network | |||||
ckpt_name = './trained_ckpt_file/checkpoint_lenet-10_1875.ckpt' | |||||
net = LeNet5() | |||||
load_dict = load_checkpoint(ckpt_name) | |||||
load_param_into_net(net, load_dict) | |||||
# get test data | |||||
data_list = "./MNIST_unzip/test" | |||||
batch_size = 32 | |||||
ds = generate_mnist_dataset(data_list, batch_size, sparse=False) | |||||
# prediction accuracy before attack | |||||
model = Model(net) | |||||
batch_num = 32 # the number of batches of attacking samples | |||||
test_images = [] | |||||
test_labels = [] | |||||
predict_labels = [] | |||||
i = 0 | |||||
for data in ds.create_tuple_iterator(): | |||||
i += 1 | |||||
images = data[0].astype(np.float32) | |||||
labels = data[1] | |||||
test_images.append(images) | |||||
test_labels.append(labels) | |||||
pred_labels = np.argmax(model.predict(Tensor(images)).asnumpy(), | |||||
axis=1) | |||||
predict_labels.append(pred_labels) | |||||
if i >= batch_num: | |||||
break | |||||
predict_labels = np.concatenate(predict_labels) | |||||
true_labels = np.argmax(np.concatenate(test_labels), axis=1) | |||||
accuracy = np.mean(np.equal(predict_labels, true_labels)) | |||||
LOGGER.info(TAG, "prediction accuracy before attacking is : %s", accuracy) | |||||
# attacking | |||||
attack = MomentumDiverseInputIterativeMethod(net) | |||||
start_time = time.clock() | |||||
adv_data = attack.batch_generate(np.concatenate(test_images), | |||||
np.concatenate(test_labels), batch_size=32) | |||||
stop_time = time.clock() | |||||
pred_logits_adv = model.predict(Tensor(adv_data)).asnumpy() | |||||
# rescale predict confidences into (0, 1). | |||||
pred_logits_adv = softmax(pred_logits_adv, axis=1) | |||||
pred_labels_adv = np.argmax(pred_logits_adv, axis=1) | |||||
accuracy_adv = np.mean(np.equal(pred_labels_adv, true_labels)) | |||||
LOGGER.info(TAG, "prediction accuracy after attacking is : %s", accuracy_adv) | |||||
attack_evaluate = AttackEvaluate(np.concatenate(test_images).transpose(0, 2, 3, 1), | |||||
np.concatenate(test_labels), | |||||
adv_data.transpose(0, 2, 3, 1), | |||||
pred_logits_adv) | |||||
LOGGER.info(TAG, 'mis-classification rate of adversaries is : %s', | |||||
attack_evaluate.mis_classification_rate()) | |||||
LOGGER.info(TAG, 'The average confidence of adversarial class is : %s', | |||||
attack_evaluate.avg_conf_adv_class()) | |||||
LOGGER.info(TAG, 'The average confidence of true class is : %s', | |||||
attack_evaluate.avg_conf_true_class()) | |||||
LOGGER.info(TAG, 'The average distance (l0, l2, linf) between original ' | |||||
'samples and adversarial samples are: %s', | |||||
attack_evaluate.avg_lp_distance()) | |||||
LOGGER.info(TAG, 'The average structural similarity between original ' | |||||
'samples and adversarial samples are: %s', | |||||
attack_evaluate.avg_ssim()) | |||||
LOGGER.info(TAG, 'The average costing time is %s', | |||||
(stop_time - start_time)/(batch_num*batch_size)) | |||||
if __name__ == '__main__': | |||||
test_momentum_diverse_input_iterative_method() |
@@ -26,6 +26,8 @@ __all__ = ['FastGradientMethod', | |||||
'BasicIterativeMethod', | 'BasicIterativeMethod', | ||||
'MomentumIterativeMethod', | 'MomentumIterativeMethod', | ||||
'ProjectedGradientDescent', | 'ProjectedGradientDescent', | ||||
'DiverseInputIterativeMethod', | |||||
'MomentumDiverseInputIterativeMethod', | |||||
'DeepFool', | 'DeepFool', | ||||
'CarliniWagnerL2Attack', | 'CarliniWagnerL2Attack', | ||||
'JSMAAttack', | 'JSMAAttack', | ||||
@@ -46,7 +46,7 @@ class GradientMethod(Attack): | |||||
Default: None. | Default: None. | ||||
bounds (tuple): Upper and lower bounds of data, indicating the data range. | bounds (tuple): Upper and lower bounds of data, indicating the data range. | ||||
In form of (clip_min, clip_max). Default: None. | In form of (clip_min, clip_max). Default: None. | ||||
loss_fn (Loss): Loss function for optimization. | |||||
loss_fn (Loss): Loss function for optimization. Default: None. | |||||
""" | """ | ||||
def __init__(self, network, eps=0.07, alpha=None, bounds=None, | def __init__(self, network, eps=0.07, alpha=None, bounds=None, | ||||
@@ -151,7 +151,7 @@ class FastGradientMethod(GradientMethod): | |||||
Possible values: np.inf, 1 or 2. Default: 2. | Possible values: np.inf, 1 or 2. Default: 2. | ||||
is_targeted (bool): If True, targeted attack. If False, untargeted | is_targeted (bool): If True, targeted attack. If False, untargeted | ||||
attack. Default: False. | attack. Default: False. | ||||
loss_fn (Loss): Loss function for optimization. | |||||
loss_fn (Loss): Loss function for optimization. Default: None. | |||||
Examples: | Examples: | ||||
>>> attack = FastGradientMethod(network) | >>> attack = FastGradientMethod(network) | ||||
@@ -214,7 +214,7 @@ class RandomFastGradientMethod(FastGradientMethod): | |||||
Possible values: np.inf, 1 or 2. Default: 2. | Possible values: np.inf, 1 or 2. Default: 2. | ||||
is_targeted (bool): If True, targeted attack. If False, untargeted | is_targeted (bool): If True, targeted attack. If False, untargeted | ||||
attack. Default: False. | attack. Default: False. | ||||
loss_fn (Loss): Loss function for optimization. | |||||
loss_fn (Loss): Loss function for optimization. Default: None. | |||||
Raises: | Raises: | ||||
ValueError: eps is smaller than alpha! | ValueError: eps is smaller than alpha! | ||||
@@ -255,7 +255,7 @@ class FastGradientSignMethod(GradientMethod): | |||||
In form of (clip_min, clip_max). Default: (0.0, 1.0). | In form of (clip_min, clip_max). Default: (0.0, 1.0). | ||||
is_targeted (bool): If True, targeted attack. If False, untargeted | is_targeted (bool): If True, targeted attack. If False, untargeted | ||||
attack. Default: False. | attack. Default: False. | ||||
loss_fn (Loss): Loss function for optimization. | |||||
loss_fn (Loss): Loss function for optimization. Default: None. | |||||
Examples: | Examples: | ||||
>>> attack = FastGradientSignMethod(network) | >>> attack = FastGradientSignMethod(network) | ||||
@@ -314,7 +314,7 @@ class RandomFastGradientSignMethod(FastGradientSignMethod): | |||||
In form of (clip_min, clip_max). Default: (0.0, 1.0). | In form of (clip_min, clip_max). Default: (0.0, 1.0). | ||||
is_targeted (bool): True: targeted attack. False: untargeted attack. | is_targeted (bool): True: targeted attack. False: untargeted attack. | ||||
Default: False. | Default: False. | ||||
loss_fn (Loss): Loss function for optimization. | |||||
loss_fn (Loss): Loss function for optimization. Default: None. | |||||
Raises: | Raises: | ||||
ValueError: eps is smaller than alpha! | ValueError: eps is smaller than alpha! | ||||
@@ -350,7 +350,7 @@ class LeastLikelyClassMethod(FastGradientSignMethod): | |||||
Default: None. | Default: None. | ||||
bounds (tuple): Upper and lower bounds of data, indicating the data range. | bounds (tuple): Upper and lower bounds of data, indicating the data range. | ||||
In form of (clip_min, clip_max). Default: (0.0, 1.0). | In form of (clip_min, clip_max). Default: (0.0, 1.0). | ||||
loss_fn (Loss): Loss function for optimization. | |||||
loss_fn (Loss): Loss function for optimization. Default: None. | |||||
Examples: | Examples: | ||||
>>> attack = LeastLikelyClassMethod(network) | >>> attack = LeastLikelyClassMethod(network) | ||||
@@ -15,6 +15,7 @@ | |||||
from abc import abstractmethod | from abc import abstractmethod | ||||
import numpy as np | import numpy as np | ||||
from PIL import Image, ImageOps | |||||
from mindspore.nn import SoftmaxCrossEntropyWithLogits | from mindspore.nn import SoftmaxCrossEntropyWithLogits | ||||
from mindspore import Tensor | from mindspore import Tensor | ||||
@@ -115,7 +116,7 @@ class IterativeGradientMethod(Attack): | |||||
bounds (tuple): Upper and lower bounds of data, indicating the data range. | bounds (tuple): Upper and lower bounds of data, indicating the data range. | ||||
In form of (clip_min, clip_max). Default: (0.0, 1.0). | In form of (clip_min, clip_max). Default: (0.0, 1.0). | ||||
nb_iter (int): Number of iteration. Default: 5. | nb_iter (int): Number of iteration. Default: 5. | ||||
loss_fn (Loss): Loss function for optimization. | |||||
loss_fn (Loss): Loss function for optimization. Default: None. | |||||
""" | """ | ||||
def __init__(self, network, eps=0.3, eps_iter=0.1, bounds=(0.0, 1.0), nb_iter=5, | def __init__(self, network, eps=0.3, eps_iter=0.1, bounds=(0.0, 1.0), nb_iter=5, | ||||
loss_fn=None): | loss_fn=None): | ||||
@@ -178,14 +179,13 @@ class BasicIterativeMethod(IterativeGradientMethod): | |||||
is_targeted (bool): If True, targeted attack. If False, untargeted | is_targeted (bool): If True, targeted attack. If False, untargeted | ||||
attack. Default: False. | attack. Default: False. | ||||
nb_iter (int): Number of iteration. Default: 5. | nb_iter (int): Number of iteration. Default: 5. | ||||
loss_fn (Loss): Loss function for optimization. | |||||
loss_fn (Loss): Loss function for optimization. Default: None. | |||||
attack (class): The single step gradient method of each iteration. In | attack (class): The single step gradient method of each iteration. In | ||||
this class, FGSM is used. | this class, FGSM is used. | ||||
Examples: | Examples: | ||||
>>> attack = BasicIterativeMethod(network) | >>> attack = BasicIterativeMethod(network) | ||||
""" | """ | ||||
def __init__(self, network, eps=0.3, eps_iter=0.1, bounds=(0.0, 1.0), | def __init__(self, network, eps=0.3, eps_iter=0.1, bounds=(0.0, 1.0), | ||||
is_targeted=False, nb_iter=5, loss_fn=None): | is_targeted=False, nb_iter=5, loss_fn=None): | ||||
super(BasicIterativeMethod, self).__init__(network, | super(BasicIterativeMethod, self).__init__(network, | ||||
@@ -227,14 +227,22 @@ class BasicIterativeMethod(IterativeGradientMethod): | |||||
clip_min, clip_max = self._bounds | clip_min, clip_max = self._bounds | ||||
clip_diff = clip_max - clip_min | clip_diff = clip_max - clip_min | ||||
for _ in range(self._nb_iter): | for _ in range(self._nb_iter): | ||||
adv_x = self._attack.generate(inputs, labels) | |||||
if 'self.prob' in globals(): | |||||
d_inputs = _transform_inputs(inputs, self.prob) | |||||
else: | |||||
d_inputs = inputs | |||||
adv_x = self._attack.generate(d_inputs, labels) | |||||
perturs = np.clip(adv_x - arr_x, (0 - self._eps)*clip_diff, | perturs = np.clip(adv_x - arr_x, (0 - self._eps)*clip_diff, | ||||
self._eps*clip_diff) | self._eps*clip_diff) | ||||
adv_x = arr_x + perturs | adv_x = arr_x + perturs | ||||
inputs = adv_x | inputs = adv_x | ||||
else: | else: | ||||
for _ in range(self._nb_iter): | for _ in range(self._nb_iter): | ||||
adv_x = self._attack.generate(inputs, labels) | |||||
if 'self.prob' in globals(): | |||||
d_inputs = _transform_inputs(inputs, self.prob) | |||||
else: | |||||
d_inputs = inputs | |||||
adv_x = self._attack.generate(d_inputs, labels) | |||||
adv_x = np.clip(adv_x, arr_x - self._eps, arr_x + self._eps) | adv_x = np.clip(adv_x, arr_x - self._eps, arr_x + self._eps) | ||||
inputs = adv_x | inputs = adv_x | ||||
return adv_x | return adv_x | ||||
@@ -261,7 +269,7 @@ class MomentumIterativeMethod(IterativeGradientMethod): | |||||
decay_factor (float): Decay factor in iterations. Default: 1.0. | decay_factor (float): Decay factor in iterations. Default: 1.0. | ||||
norm_level (Union[int, numpy.inf]): Order of the norm. Possible values: | norm_level (Union[int, numpy.inf]): Order of the norm. Possible values: | ||||
np.inf, 1 or 2. Default: 'inf'. | np.inf, 1 or 2. Default: 'inf'. | ||||
loss_fn (Loss): Loss function for optimization. | |||||
loss_fn (Loss): Loss function for optimization. Default: None. | |||||
""" | """ | ||||
def __init__(self, network, eps=0.3, eps_iter=0.1, bounds=(0.0, 1.0), | def __init__(self, network, eps=0.3, eps_iter=0.1, bounds=(0.0, 1.0), | ||||
@@ -303,9 +311,13 @@ class MomentumIterativeMethod(IterativeGradientMethod): | |||||
clip_min, clip_max = self._bounds | clip_min, clip_max = self._bounds | ||||
clip_diff = clip_max - clip_min | clip_diff = clip_max - clip_min | ||||
for _ in range(self._nb_iter): | for _ in range(self._nb_iter): | ||||
gradient = self._gradient(inputs, labels) | |||||
if 'self.prob' in globals(): | |||||
d_inputs = _transform_inputs(inputs, self.prob) | |||||
else: | |||||
d_inputs = inputs | |||||
gradient = self._gradient(d_inputs, labels) | |||||
momentum = self._decay_factor*momentum + gradient | momentum = self._decay_factor*momentum + gradient | ||||
adv_x = inputs + self._eps_iter*np.sign(momentum) | |||||
adv_x = d_inputs + self._eps_iter*np.sign(momentum) | |||||
perturs = np.clip(adv_x - arr_x, (0 - self._eps)*clip_diff, | perturs = np.clip(adv_x - arr_x, (0 - self._eps)*clip_diff, | ||||
self._eps*clip_diff) | self._eps*clip_diff) | ||||
adv_x = arr_x + perturs | adv_x = arr_x + perturs | ||||
@@ -313,12 +325,15 @@ class MomentumIterativeMethod(IterativeGradientMethod): | |||||
inputs = adv_x | inputs = adv_x | ||||
else: | else: | ||||
for _ in range(self._nb_iter): | for _ in range(self._nb_iter): | ||||
gradient = self._gradient(inputs, labels) | |||||
if 'self.prob' in globals(): | |||||
d_inputs = _transform_inputs(inputs, self.prob) | |||||
else: | |||||
d_inputs = inputs | |||||
gradient = self._gradient(d_inputs, labels) | |||||
momentum = self._decay_factor*momentum + gradient | momentum = self._decay_factor*momentum + gradient | ||||
adv_x = inputs + self._eps_iter*np.sign(momentum) | |||||
adv_x = d_inputs + self._eps_iter*np.sign(momentum) | |||||
adv_x = np.clip(adv_x, arr_x - self._eps, arr_x + self._eps) | adv_x = np.clip(adv_x, arr_x - self._eps, arr_x + self._eps) | ||||
inputs = adv_x | inputs = adv_x | ||||
return adv_x | return adv_x | ||||
def _gradient(self, inputs, labels): | def _gradient(self, inputs, labels): | ||||
@@ -372,7 +387,7 @@ class ProjectedGradientDescent(BasicIterativeMethod): | |||||
nb_iter (int): Number of iteration. Default: 5. | nb_iter (int): Number of iteration. Default: 5. | ||||
norm_level (Union[int, numpy.inf]): Order of the norm. Possible values: | norm_level (Union[int, numpy.inf]): Order of the norm. Possible values: | ||||
np.inf, 1 or 2. Default: 'inf'. | np.inf, 1 or 2. Default: 'inf'. | ||||
loss_fn (Loss): Loss function for optimization. | |||||
loss_fn (Loss): Loss function for optimization. Default: None. | |||||
""" | """ | ||||
def __init__(self, network, eps=0.3, eps_iter=0.1, bounds=(0.0, 1.0), | def __init__(self, network, eps=0.3, eps_iter=0.1, bounds=(0.0, 1.0), | ||||
@@ -430,3 +445,114 @@ class ProjectedGradientDescent(BasicIterativeMethod): | |||||
adv_x = np.clip(adv_x, arr_x - self._eps, arr_x + self._eps) | adv_x = np.clip(adv_x, arr_x - self._eps, arr_x + self._eps) | ||||
inputs = adv_x | inputs = adv_x | ||||
return adv_x | return adv_x | ||||
class DiverseInputIterativeMethod(BasicIterativeMethod): | |||||
""" | |||||
The Diverse Input Iterative Method attack. | |||||
References: `Xie, Cihang and Zhang, et al., "Improving Transferability of | |||||
Adversarial Examples With Input Diversity," in CVPR, 2019 <https://arxiv.org/abs/1803.06978>`_ | |||||
Args: | |||||
network (Cell): Target model. | |||||
eps (float): Proportion of adversarial perturbation generated by the | |||||
attack to data range. Default: 0.3. | |||||
bounds (tuple): Upper and lower bounds of data, indicating the data range. | |||||
In form of (clip_min, clip_max). Default: (0.0, 1.0). | |||||
is_targeted (bool): If True, targeted attack. If False, untargeted | |||||
attack. Default: False. | |||||
prob (float): Transformation probability. Default: 0.5. | |||||
loss_fn (Loss): Loss function for optimization. Default: None. | |||||
""" | |||||
def __init__(self, network, eps=0.3, bounds=(0.0, 1.0), | |||||
is_targeted=False, prob=0.5, loss_fn=None): | |||||
# reference to paper hyper parameters setting. | |||||
eps_iter = 16*2/255 | |||||
nb_iter = int(min(eps*255 + 4, 1.25*255*eps)) | |||||
super(DiverseInputIterativeMethod, self).__init__(network, | |||||
eps=eps, | |||||
eps_iter=eps_iter, | |||||
bounds=bounds, | |||||
is_targeted=is_targeted, | |||||
nb_iter=nb_iter, | |||||
loss_fn=loss_fn) | |||||
# FGSM default alpha is None equal alpha=1 | |||||
self.prob = check_param_type('prob', prob, float) | |||||
class MomentumDiverseInputIterativeMethod(MomentumIterativeMethod): | |||||
""" | |||||
The Momentum Diverse Input Iterative Method attack. | |||||
References: `Xie, Cihang and Zhang, et al., "Improving Transferability of | |||||
Adversarial Examples With Input Diversity," in CVPR, 2019 <https://arxiv.org/abs/1803.06978>`_ | |||||
Args: | |||||
network (Cell): Target model. | |||||
eps (float): Proportion of adversarial perturbation generated by the | |||||
attack to data range. Default: 0.3. | |||||
bounds (tuple): Upper and lower bounds of data, indicating the data range. | |||||
In form of (clip_min, clip_max). Default: (0.0, 1.0). | |||||
is_targeted (bool): If True, targeted attack. If False, untargeted | |||||
attack. Default: False. | |||||
norm_level (Union[int, numpy.inf]): Order of the norm. Possible values: | |||||
np.inf, 1 or 2. Default: 'l1'. | |||||
prob (float): Transformation probability. Default: 0.5. | |||||
loss_fn (Loss): Loss function for optimization. Default: None. | |||||
""" | |||||
def __init__(self, network, eps=0.3, bounds=(0.0, 1.0), | |||||
is_targeted=False, norm_level='l1', prob=0.5, loss_fn=None): | |||||
eps_iter = 16*2 / 255 | |||||
nb_iter = int(min(eps*255 + 4, 1.25*255*eps)) | |||||
super(MomentumDiverseInputIterativeMethod, self).__init__(network=network, | |||||
eps=eps, | |||||
eps_iter=eps_iter, | |||||
bounds=bounds, | |||||
nb_iter=nb_iter, | |||||
is_targeted=is_targeted, | |||||
norm_level=norm_level, | |||||
loss_fn=loss_fn) | |||||
self.prob = check_param_type('prob', prob, float) | |||||
def _transform_inputs(inputs, prob, low=29, high=33, full_aug=False): | |||||
""" | |||||
Inputs data augmentation. | |||||
Args: | |||||
inputs (Union[np.int8, np.float]): Inputs. | |||||
prob (float): The probability of augmentation. | |||||
low (int): Lower bound of resize image width. Default: 29. | |||||
high (int): Upper bound of resize image height. Default: 33. | |||||
full_aug (bool): type of augmentation method, use interpolation and padding | |||||
as default. Default: False. | |||||
Returns: | |||||
numpy.ndarray, the augmentation data. | |||||
""" | |||||
raw_shape = inputs[0].shape | |||||
tran_mask = np.random.uniform(0, 1, size=inputs.shape[0]) < prob | |||||
tran_inputs = inputs[tran_mask] | |||||
raw_inputs = inputs[tran_mask == 0] | |||||
tran_outputs = [] | |||||
for sample in tran_inputs: | |||||
width = np.random.choice(np.arange(low, high)) | |||||
# resize | |||||
sample = (sample*255).astype(np.uint8) | |||||
d_image = Image.fromarray(sample, mode='L').resize((width, width), Image.NEAREST) | |||||
# pad | |||||
left_pad = (raw_shape[0] - width) // 2 | |||||
right_pad = raw_shape[0] - width - left_pad | |||||
top_pad = (raw_shape[1] - width) // 2 | |||||
bottom_pad = raw_shape[1] - width - top_pad | |||||
p_sample = ImageOps.expand(d_image, | |||||
border=(left_pad, top_pad, right_pad, bottom_pad)) | |||||
tran_outputs.append(np.array(p_sample).astype(np.float) / 255) | |||||
if full_aug: | |||||
# gaussian noise | |||||
tran_outputs = np.random.normal(tran_outputs.shape) + tran_outputs | |||||
tran_outputs.extend(raw_inputs) | |||||
if not np.any(tran_outputs-raw_inputs): | |||||
LOGGER.error(TAG, 'the transform function does not take effect.') | |||||
return tran_outputs |
@@ -242,7 +242,7 @@ def normalize_value(value, norm_level): | |||||
Raises: | Raises: | ||||
NotImplementedError: If norm_level is not in [1, 2 , np.inf, '1', '2', | NotImplementedError: If norm_level is not in [1, 2 , np.inf, '1', '2', | ||||
'inf] | |||||
'inf', 'l1', 'l2'] | |||||
""" | """ | ||||
norm_level = check_norm_level(norm_level) | norm_level = check_norm_level(norm_level) | ||||
ori_shape = value.shape | ori_shape = value.shape | ||||
@@ -1,6 +1,7 @@ | |||||
numpy >= 1.17.0 | numpy >= 1.17.0 | ||||
scipy >= 1.3.3 | scipy >= 1.3.3 | ||||
matplotlib >= 3.1.3 | matplotlib >= 3.1.3 | ||||
Pillow >= 2.0.0 | |||||
pytest >= 4.3.1 | pytest >= 4.3.1 | ||||
wheel >= 0.32.0 | wheel >= 0.32.0 | ||||
setuptools >= 40.8.0 | setuptools >= 40.8.0 |
@@ -95,7 +95,8 @@ setup( | |||||
install_requires=[ | install_requires=[ | ||||
'scipy >= 1.3.3', | 'scipy >= 1.3.3', | ||||
'numpy >= 1.17.0', | 'numpy >= 1.17.0', | ||||
'matplotlib >= 3.1.3' | |||||
'matplotlib >= 3.1.3', | |||||
'Pillow >= 2.0.0' | |||||
], | ], | ||||
) | ) | ||||
print(find_packages()) | print(find_packages()) |
@@ -25,6 +25,8 @@ from mindarmour.attacks import BasicIterativeMethod | |||||
from mindarmour.attacks import MomentumIterativeMethod | from mindarmour.attacks import MomentumIterativeMethod | ||||
from mindarmour.attacks import ProjectedGradientDescent | from mindarmour.attacks import ProjectedGradientDescent | ||||
from mindarmour.attacks import IterativeGradientMethod | from mindarmour.attacks import IterativeGradientMethod | ||||
from mindarmour.attacks import DiverseInputIterativeMethod | |||||
from mindarmour.attacks import MomentumDiverseInputIterativeMethod | |||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | ||||
@@ -91,7 +93,7 @@ def test_momentum_iterative_method(): | |||||
for i in range(5): | for i in range(5): | ||||
attack = MomentumIterativeMethod(Net(), nb_iter=i + 1) | attack = MomentumIterativeMethod(Net(), nb_iter=i + 1) | ||||
ms_adv_x = attack.generate(input_np, label) | ms_adv_x = attack.generate(input_np, label) | ||||
assert np.any(ms_adv_x != input_np), 'Basic iterative method: generate' \ | |||||
assert np.any(ms_adv_x != input_np), 'Momentum iterative method: generate' \ | |||||
' value must not be equal to' \ | ' value must not be equal to' \ | ||||
' original value.' | ' original value.' | ||||
@@ -124,6 +126,48 @@ def test_projected_gradient_descent_method(): | |||||
@pytest.mark.platform_x86_ascend_training | @pytest.mark.platform_x86_ascend_training | ||||
@pytest.mark.env_card | @pytest.mark.env_card | ||||
@pytest.mark.component_mindarmour | @pytest.mark.component_mindarmour | ||||
def test_diverse_input_iterative_method(): | |||||
""" | |||||
Diverse input iterative method unit test. | |||||
""" | |||||
input_np = np.asarray([[0.1, 0.2, 0.7]], np.float32) | |||||
label = np.asarray([2], np.int32) | |||||
label = np.eye(3)[label].astype(np.float32) | |||||
for i in range(5): | |||||
attack = DiverseInputIterativeMethod(Net()) | |||||
ms_adv_x = attack.generate(input_np, label) | |||||
assert np.any(ms_adv_x != input_np), 'Diverse input iterative method: generate' \ | |||||
' value must not be equal to' \ | |||||
' original value.' | |||||
@pytest.mark.level0 | |||||
@pytest.mark.platform_arm_ascend_training | |||||
@pytest.mark.platform_x86_ascend_training | |||||
@pytest.mark.env_card | |||||
@pytest.mark.component_mindarmour | |||||
def test_momentum_diverse_input_iterative_method(): | |||||
""" | |||||
Momentum diverse input iterative method unit test. | |||||
""" | |||||
input_np = np.asarray([[0.1, 0.2, 0.7]], np.float32) | |||||
label = np.asarray([2], np.int32) | |||||
label = np.eye(3)[label].astype(np.float32) | |||||
for i in range(5): | |||||
attack = MomentumDiverseInputIterativeMethod(Net()) | |||||
ms_adv_x = attack.generate(input_np, label) | |||||
assert np.any(ms_adv_x != input_np), 'Momentum diverse input iterative method: ' \ | |||||
'generate value must not be equal to' \ | |||||
' original value.' | |||||
@pytest.mark.level0 | |||||
@pytest.mark.platform_arm_ascend_training | |||||
@pytest.mark.platform_x86_ascend_training | |||||
@pytest.mark.env_card | |||||
@pytest.mark.component_mindarmour | |||||
def test_error(): | def test_error(): | ||||
with pytest.raises(ValueError): | with pytest.raises(ValueError): | ||||
# check_param_multi_types | # check_param_multi_types | ||||