@@ -1,3 +1,42 @@ | |||
# MindArmour 1.2.1 | |||
## MindArmour 1.2.1 Release Notes | |||
### Major Features and Improvements | |||
### API Change | |||
#### Backwards Incompatible Change | |||
##### C++ API | |||
[Modify] ... | |||
[Add] ... | |||
[Delete] ... | |||
##### Java API | |||
[Add] ... | |||
#### Deprecations | |||
##### C++ API | |||
##### Java API | |||
### Bug fixes | |||
* [BUGFIX] Fix a bug of PGD method | |||
* [BUGFIX] Fix a bug of JSMA method | |||
### Contributors | |||
Thanks goes to these wonderful people: | |||
Liu Liu, Zhidan Liu, Luobin Liu and Xiulang Jin. | |||
Contributions of any kind are welcome! | |||
# MindArmour 1.2.0 | |||
## MindArmour 1.2.0 Release Notes | |||
@@ -75,7 +75,6 @@ class GradientMethod(Attack): | |||
else: | |||
with_loss_cell = WithLossCell(self._network, loss_fn) | |||
self._grad_all = GradWrapWithLoss(with_loss_cell) | |||
self._grad_all.set_train() | |||
def generate(self, inputs, labels): | |||
""" | |||
@@ -14,6 +14,7 @@ | |||
""" Iterative gradient method attack. """ | |||
from abc import abstractmethod | |||
import copy | |||
import numpy as np | |||
from PIL import Image, ImageOps | |||
@@ -68,13 +69,14 @@ def _reshape_l1_projection(values, eps=3): | |||
return proj_x | |||
def _projection(values, eps, norm_level): | |||
def _projection(values, eps, clip_diff, norm_level): | |||
""" | |||
Implementation of values normalization within eps. | |||
Args: | |||
values (numpy.ndarray): Input data. | |||
eps (float): Project radius. | |||
clip_diff (float): Difference range of clip bounds. | |||
norm_level (Union[int, char, numpy.inf]): Order of the norm. Possible | |||
values: np.inf, 1 or 2. | |||
@@ -88,12 +90,12 @@ def _projection(values, eps, norm_level): | |||
if norm_level in (1, '1'): | |||
sample_batch = values.shape[0] | |||
x_flat = values.reshape(sample_batch, -1) | |||
proj_flat = _reshape_l1_projection(x_flat, eps) | |||
proj_flat = _reshape_l1_projection(x_flat, eps*clip_diff) | |||
return proj_flat.reshape(values.shape) | |||
if norm_level in (2, '2'): | |||
return eps*normalize_value(values, norm_level) | |||
if norm_level in (np.inf, 'inf'): | |||
return eps*np.sign(values) | |||
return eps*clip_diff*np.sign(values) | |||
msg = 'Values of `norm_level` different from 1, 2 and `np.inf` are ' \ | |||
'currently not supported.' | |||
LOGGER.error(TAG, msg) | |||
@@ -132,7 +134,6 @@ class IterativeGradientMethod(Attack): | |||
self._loss_grad = network | |||
else: | |||
self._loss_grad = GradWrapWithLoss(WithLossCell(self._network, loss_fn)) | |||
self._loss_grad.set_train() | |||
@abstractmethod | |||
def generate(self, inputs, labels): | |||
@@ -407,10 +408,12 @@ class ProjectedGradientDescent(BasicIterativeMethod): | |||
np.inf, 1 or 2. Default: 'inf'. | |||
loss_fn (Loss): Loss function for optimization. If None, the input network \ | |||
is already equipped with loss function. Default: None. | |||
random_start (bool): If True, use random perturbs at the beginning. If False, | |||
start from original samples. | |||
""" | |||
def __init__(self, network, eps=0.3, eps_iter=0.1, bounds=(0.0, 1.0), | |||
is_targeted=False, nb_iter=5, norm_level='inf', loss_fn=None): | |||
is_targeted=False, nb_iter=5, norm_level='inf', loss_fn=None, random_start=False): | |||
super(ProjectedGradientDescent, self).__init__(network, | |||
eps=eps, | |||
eps_iter=eps_iter, | |||
@@ -419,6 +422,10 @@ class ProjectedGradientDescent(BasicIterativeMethod): | |||
nb_iter=nb_iter, | |||
loss_fn=loss_fn) | |||
self._norm_level = check_norm_level(norm_level) | |||
self._random_start = check_param_type('random_start', random_start, bool) | |||
def _get_random_start(self, inputs): | |||
return inputs + np.random.uniform(-self._eps, self._eps, size=inputs.shape).astype(np.float32) | |||
def generate(self, inputs, labels): | |||
""" | |||
@@ -442,33 +449,29 @@ class ProjectedGradientDescent(BasicIterativeMethod): | |||
""" | |||
inputs_image, inputs, labels = check_inputs_labels(inputs, labels) | |||
arr_x = inputs_image | |||
adv_x = copy.deepcopy(inputs_image) | |||
if self._bounds is not None: | |||
clip_min, clip_max = self._bounds | |||
clip_diff = clip_max - clip_min | |||
for _ in range(self._nb_iter): | |||
adv_x = self._attack.generate(inputs, labels) | |||
perturs = _projection(adv_x - arr_x, | |||
self._eps, | |||
norm_level=self._norm_level) | |||
perturs = np.clip(perturs, (0 - self._eps)*clip_diff, | |||
self._eps*clip_diff) | |||
adv_x = arr_x + perturs | |||
if isinstance(inputs, tuple): | |||
inputs = (adv_x,) + inputs[1:] | |||
else: | |||
inputs = adv_x | |||
else: | |||
for _ in range(self._nb_iter): | |||
adv_x = self._attack.generate(inputs, labels) | |||
perturs = _projection(adv_x - arr_x, | |||
self._eps, | |||
norm_level=self._norm_level) | |||
adv_x = arr_x + perturs | |||
adv_x = np.clip(adv_x, arr_x - self._eps, arr_x + self._eps) | |||
if isinstance(inputs, tuple): | |||
inputs = (adv_x,) + inputs[1:] | |||
else: | |||
inputs = adv_x | |||
clip_diff = 1 | |||
if self._random_start: | |||
inputs = self._get_random_start(inputs) | |||
for _ in range(self._nb_iter): | |||
inputs_tensor = to_tensor_tuple(inputs) | |||
labels_tensor = to_tensor_tuple(labels) | |||
out_grad = self._loss_grad(*inputs_tensor, *labels_tensor) | |||
gradient = out_grad.asnumpy() | |||
perturbs = _projection(gradient, self._eps_iter, clip_diff, norm_level=self._norm_level) | |||
sum_perturbs = adv_x - arr_x + perturbs | |||
sum_perturbs = np.clip(sum_perturbs, (0 - self._eps)*clip_diff, self._eps*clip_diff) | |||
adv_x = arr_x + sum_perturbs | |||
if self._bounds is not None: | |||
adv_x = np.clip(adv_x, clip_min, clip_max) | |||
if isinstance(inputs, tuple): | |||
inputs = (adv_x,) + inputs[1:] | |||
else: | |||
inputs = adv_x | |||
return adv_x | |||
@@ -134,7 +134,6 @@ class JSMAAttack(Attack): | |||
ori_shape = data.shape | |||
temp = data.flatten() | |||
bit_map = np.ones_like(temp) | |||
fake_res = np.zeros_like(data) | |||
counter = np.zeros_like(temp) | |||
perturbed = np.copy(temp) | |||
for _ in range(self._max_iter): | |||
@@ -167,7 +166,7 @@ class JSMAAttack(Attack): | |||
bit_map[p2_ind] = 0 | |||
perturbed = np.clip(perturbed, self._min, self._max) | |||
LOGGER.debug(TAG, 'fail to find adversarial sample.') | |||
return fake_res | |||
return perturbed.reshape(ori_shape) | |||
def generate(self, inputs, labels): | |||
""" | |||
@@ -136,6 +136,7 @@ class AdversarialDefenseWithAttacks(AdversarialDefense): | |||
replace_ratio, | |||
0, 1) | |||
self._graph_initialized = False | |||
self._train_net.set_train() | |||
def defense(self, inputs, labels): | |||
""" | |||
@@ -39,6 +39,8 @@ class ProjectedAdversarialDefense(AdversarialDefenseWithAttacks): | |||
nb_iter (int): PGD attack parameters, number of iteration. | |||
Default: 5. | |||
norm_level (str): Norm type. 'inf' or 'l2'. Default: 'inf'. | |||
random_start (bool): If True, use random perturbs at the beginning. If False, | |||
start from original samples. | |||
Examples: | |||
>>> net = Net() | |||
@@ -54,14 +56,16 @@ class ProjectedAdversarialDefense(AdversarialDefenseWithAttacks): | |||
eps=0.3, | |||
eps_iter=0.1, | |||
nb_iter=5, | |||
norm_level='inf'): | |||
norm_level='inf', | |||
random_start=True): | |||
attack = ProjectedGradientDescent(network, | |||
eps=eps, | |||
eps_iter=eps_iter, | |||
nb_iter=nb_iter, | |||
bounds=bounds, | |||
norm_level=norm_level, | |||
loss_fn=loss_fn) | |||
loss_fn=loss_fn, | |||
random_start=random_start) | |||
super(ProjectedAdversarialDefense, self).__init__( | |||
network, [attack], loss_fn=loss_fn, optimizer=optimizer, | |||
bounds=bounds, replace_ratio=replace_ratio) |
@@ -20,7 +20,7 @@ from setuptools import setup | |||
from setuptools.command.egg_info import egg_info | |||
from setuptools.command.build_py import build_py | |||
version = '1.2.0' | |||
version = '1.2.1' | |||
cur_dir = os.path.dirname(os.path.realpath(__file__)) | |||
pkg_dir = os.path.join(cur_dir, 'build') | |||
@@ -26,7 +26,6 @@ from mindarmour.utils.logger import LogUtil | |||
from tests.ut.python.utils.mock_net import Net | |||
context.set_context(mode=context.GRAPH_MODE) | |||
context.set_context(device_target="Ascend") | |||
LOGGER = LogUtil.get_instance() | |||
TAG = 'HopSkipJumpAttack' | |||
@@ -91,9 +90,9 @@ def test_hsja_mnist_attack(): | |||
""" | |||
hsja-Attack test | |||
""" | |||
context.set_context(device_target="Ascend") | |||
current_dir = os.path.dirname(os.path.abspath(__file__)) | |||
# get test data | |||
test_images_set = np.load(os.path.join(current_dir, | |||
'../../../dataset/test_images.npy')) | |||
@@ -159,6 +158,7 @@ def test_hsja_mnist_attack(): | |||
@pytest.mark.env_card | |||
@pytest.mark.component_mindarmour | |||
def test_value_error(): | |||
context.set_context(device_target="Ascend") | |||
model = get_model() | |||
norm = 'l2' | |||
with pytest.raises(ValueError): | |||
@@ -26,7 +26,6 @@ from mindarmour.utils.logger import LogUtil | |||
from tests.ut.python.utils.mock_net import Net | |||
context.set_context(mode=context.GRAPH_MODE) | |||
context.set_context(device_target="Ascend") | |||
LOGGER = LogUtil.get_instance() | |||
TAG = 'HopSkipJumpAttack' | |||
@@ -103,6 +102,7 @@ def nes_mnist_attack(scene, top_k): | |||
""" | |||
hsja-Attack test | |||
""" | |||
context.set_context(device_target="Ascend") | |||
current_dir = os.path.dirname(os.path.abspath(__file__)) | |||
test_images, test_labels = get_dataset(current_dir) | |||
model = get_model(current_dir) | |||
@@ -167,6 +167,7 @@ def nes_mnist_attack(scene, top_k): | |||
@pytest.mark.component_mindarmour | |||
def test_nes_query_limit(): | |||
# scene is in ['Query_Limit', 'Partial_Info', 'Label_Only'] | |||
context.set_context(device_target="Ascend") | |||
scene = 'Query_Limit' | |||
nes_mnist_attack(scene, top_k=-1) | |||
@@ -178,6 +179,7 @@ def test_nes_query_limit(): | |||
@pytest.mark.component_mindarmour | |||
def test_nes_partial_info(): | |||
# scene is in ['Query_Limit', 'Partial_Info', 'Label_Only'] | |||
context.set_context(device_target="Ascend") | |||
scene = 'Partial_Info' | |||
nes_mnist_attack(scene, top_k=5) | |||
@@ -189,6 +191,7 @@ def test_nes_partial_info(): | |||
@pytest.mark.component_mindarmour | |||
def test_nes_label_only(): | |||
# scene is in ['Query_Limit', 'Partial_Info', 'Label_Only'] | |||
context.set_context(device_target="Ascend") | |||
scene = 'Label_Only' | |||
nes_mnist_attack(scene, top_k=5) | |||
@@ -200,6 +203,7 @@ def test_nes_label_only(): | |||
@pytest.mark.component_mindarmour | |||
def test_value_error(): | |||
"""test that exception is raised for invalid labels""" | |||
context.set_context(device_target="Ascend") | |||
with pytest.raises(ValueError): | |||
assert nes_mnist_attack('Label_Only', -1) | |||
@@ -210,6 +214,7 @@ def test_value_error(): | |||
@pytest.mark.env_card | |||
@pytest.mark.component_mindarmour | |||
def test_none(): | |||
context.set_context(device_target="Ascend") | |||
current_dir = os.path.dirname(os.path.abspath(__file__)) | |||
model = get_model(current_dir) | |||
test_images, test_labels = get_dataset(current_dir) | |||
@@ -28,8 +28,6 @@ from mindarmour.utils.logger import LogUtil | |||
from tests.ut.python.utils.mock_net import Net | |||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
LOGGER = LogUtil.get_instance() | |||
TAG = 'Pointwise_Test' | |||
LOGGER.set_level('INFO') | |||
@@ -57,6 +55,7 @@ def test_pointwise_attack_method(): | |||
""" | |||
Pointwise attack method unit test. | |||
""" | |||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
np.random.seed(123) | |||
# upload trained network | |||
current_dir = os.path.dirname(os.path.abspath(__file__)) | |||
@@ -26,8 +26,6 @@ from mindarmour import BlackModel | |||
from mindarmour.adv_robustness.attacks import SaltAndPepperNoiseAttack | |||
context.set_context(mode=context.GRAPH_MODE) | |||
context.set_context(device_target="Ascend") | |||
# for user | |||
class ModelToBeAttacked(BlackModel): | |||
@@ -79,6 +77,7 @@ def test_salt_and_pepper_attack_method(): | |||
""" | |||
Salt and pepper attack method unit test. | |||
""" | |||
context.set_context(device_target="Ascend") | |||
batch_size = 6 | |||
np.random.seed(123) | |||
net = SimpleNet() | |||
@@ -105,6 +104,7 @@ def test_salt_and_pepper_attack_in_batch(): | |||
""" | |||
Salt and pepper attack method unit test in batch. | |||
""" | |||
context.set_context(device_target="Ascend") | |||
batch_size = 32 | |||
np.random.seed(123) | |||
net = SimpleNet() | |||
@@ -24,10 +24,6 @@ from mindspore.nn import Cell, SoftmaxCrossEntropyWithLogits | |||
from mindarmour.adv_robustness.attacks import FastGradientMethod | |||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
# for user | |||
class Net(Cell): | |||
""" | |||
@@ -118,6 +114,7 @@ def test_batch_generate_attack(): | |||
""" | |||
Attack with batch-generate. | |||
""" | |||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
input_np = np.random.random((128, 10)).astype(np.float32) | |||
label = np.random.randint(0, 10, 128).astype(np.int32) | |||
label = np.eye(10)[label].astype(np.float32) | |||
@@ -23,10 +23,6 @@ from mindspore import context | |||
from mindarmour.adv_robustness.attacks import CarliniWagnerL2Attack | |||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
# for user | |||
class Net(Cell): | |||
""" | |||
@@ -63,6 +59,7 @@ def test_cw_attack(): | |||
""" | |||
CW-Attack test | |||
""" | |||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
net = Net() | |||
input_np = np.array([[0.1, 0.2, 0.7, 0.5, 0.4]]).astype(np.float32) | |||
label_np = np.array([3]).astype(np.int64) | |||
@@ -81,6 +78,7 @@ def test_cw_attack_targeted(): | |||
""" | |||
CW-Attack test | |||
""" | |||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
net = Net() | |||
input_np = np.array([[0.1, 0.2, 0.7, 0.5, 0.4]]).astype(np.float32) | |||
target_np = np.array([1]).astype(np.int64) | |||
@@ -24,7 +24,6 @@ from mindspore import Tensor | |||
from mindarmour.adv_robustness.attacks import DeepFool | |||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
# for user | |||
@@ -80,6 +79,7 @@ def test_deepfool_attack(): | |||
""" | |||
Deepfool-Attack test | |||
""" | |||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
net = Net() | |||
input_shape = (1, 5) | |||
_, classes = input_shape | |||
@@ -105,6 +105,7 @@ def test_deepfool_attack_detection(): | |||
""" | |||
Deepfool-Attack test | |||
""" | |||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
net = Net2() | |||
inputs1_np = np.random.random((2, 10, 10)).astype(np.float32) | |||
inputs2_np = np.random.random((2, 10, 5)).astype(np.float32) | |||
@@ -128,6 +129,7 @@ def test_deepfool_attack_inf(): | |||
""" | |||
Deepfool-Attack test | |||
""" | |||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
net = Net() | |||
input_shape = (1, 5) | |||
_, classes = input_shape | |||
@@ -146,6 +148,7 @@ def test_deepfool_attack_inf(): | |||
@pytest.mark.env_card | |||
@pytest.mark.component_mindarmour | |||
def test_value_error(): | |||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
net = Net() | |||
input_shape = (1, 5) | |||
_, classes = input_shape | |||
@@ -29,9 +29,6 @@ from mindarmour.adv_robustness.attacks import IterativeGradientMethod | |||
from mindarmour.adv_robustness.attacks import DiverseInputIterativeMethod | |||
from mindarmour.adv_robustness.attacks import MomentumDiverseInputIterativeMethod | |||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
# for user | |||
class Net(Cell): | |||
""" | |||
@@ -65,6 +62,7 @@ def test_basic_iterative_method(): | |||
""" | |||
Basic iterative method unit test. | |||
""" | |||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
input_np = np.asarray([[0.1, 0.2, 0.7]], np.float32) | |||
label = np.asarray([2], np.int32) | |||
label = np.eye(3)[label].astype(np.float32) | |||
@@ -87,6 +85,7 @@ def test_momentum_iterative_method(): | |||
""" | |||
Momentum iterative method unit test. | |||
""" | |||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
input_np = np.asarray([[0.1, 0.2, 0.7]], np.float32) | |||
label = np.asarray([2], np.int32) | |||
label = np.eye(3)[label].astype(np.float32) | |||
@@ -108,6 +107,53 @@ def test_projected_gradient_descent_method(): | |||
""" | |||
Projected gradient descent method unit test. | |||
""" | |||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
input_np = np.asarray([[0.1, 0.2, 0.7]], np.float32) | |||
label = np.asarray([2], np.int32) | |||
label = np.eye(3)[label].astype(np.float32) | |||
for i in range(5): | |||
attack = ProjectedGradientDescent(Net(), nb_iter=i + 1, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False)) | |||
ms_adv_x = attack.generate(input_np, label) | |||
assert np.any( | |||
ms_adv_x != input_np), 'Projected gradient descent method: ' \ | |||
'generate value must not be equal to' \ | |||
' original value.' | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_x86_gpu_training | |||
@pytest.mark.env_card | |||
@pytest.mark.component_mindarmour | |||
def test_projected_gradient_descent_method_gpu(): | |||
""" | |||
Projected gradient descent method unit test. | |||
""" | |||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
input_np = np.asarray([[0.1, 0.2, 0.7]], np.float32) | |||
label = np.asarray([2], np.int32) | |||
label = np.eye(3)[label].astype(np.float32) | |||
for i in range(5): | |||
attack = ProjectedGradientDescent(Net(), nb_iter=i + 1, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False)) | |||
ms_adv_x = attack.generate(input_np, label) | |||
assert np.any( | |||
ms_adv_x != input_np), 'Projected gradient descent method: ' \ | |||
'generate value must not be equal to' \ | |||
' original value.' | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_x86_cpu | |||
@pytest.mark.env_card | |||
@pytest.mark.component_mindarmour | |||
def test_projected_gradient_descent_method_cpu(): | |||
""" | |||
Projected gradient descent method unit test. | |||
""" | |||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | |||
input_np = np.asarray([[0.1, 0.2, 0.7]], np.float32) | |||
label = np.asarray([2], np.int32) | |||
label = np.eye(3)[label].astype(np.float32) | |||
@@ -131,6 +177,7 @@ def test_diverse_input_iterative_method(): | |||
""" | |||
Diverse input iterative method unit test. | |||
""" | |||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
input_np = np.asarray([[0.1, 0.2, 0.7]], np.float32) | |||
label = np.asarray([2], np.int32) | |||
label = np.eye(3)[label].astype(np.float32) | |||
@@ -151,6 +198,7 @@ def test_momentum_diverse_input_iterative_method(): | |||
""" | |||
Momentum diverse input iterative method unit test. | |||
""" | |||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
input_np = np.asarray([[0.1, 0.2, 0.7]], np.float32) | |||
label = np.asarray([2], np.int32) | |||
label = np.eye(3)[label].astype(np.float32) | |||
@@ -168,6 +216,7 @@ def test_momentum_diverse_input_iterative_method(): | |||
@pytest.mark.env_card | |||
@pytest.mark.component_mindarmour | |||
def test_error(): | |||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
attack = IterativeGradientMethod(Net(), bounds=(0.0, 1.0), loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False)) | |||
with pytest.raises(NotImplementedError): | |||
input_np = np.asarray([[0.1, 0.2, 0.7]], np.float32) | |||
@@ -26,9 +26,6 @@ from mindarmour.utils.logger import LogUtil | |||
from tests.ut.python.utils.mock_net import Net | |||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
LOGGER = LogUtil.get_instance() | |||
TAG = 'LBFGS_Test' | |||
LOGGER.set_level('DEBUG') | |||
@@ -43,6 +40,7 @@ def test_lbfgs_attack(): | |||
""" | |||
LBFGS-Attack test | |||
""" | |||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
np.random.seed(123) | |||
# upload trained network | |||
current_dir = os.path.dirname(os.path.abspath(__file__)) | |||
@@ -24,8 +24,6 @@ from mindspore.ops.operations import Add | |||
from mindarmour.adv_robustness.detectors import SimilarityDetector | |||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
class EncoderNet(Cell): | |||
""" | |||
@@ -66,6 +64,7 @@ def test_similarity_detector(): | |||
""" | |||
Similarity detector unit test | |||
""" | |||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
# Prepare dataset | |||
np.random.seed(5) | |||
x_train = np.random.rand(1000, 32, 32, 3).astype(np.float32) | |||
@@ -26,8 +26,6 @@ from mindarmour.adv_robustness.detectors import ErrorBasedDetector | |||
from mindarmour.adv_robustness.detectors import RegionBasedDetector | |||
from mindarmour.adv_robustness.detectors import EnsembleDetector | |||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
class Net(Cell): | |||
""" | |||
@@ -74,6 +72,7 @@ def test_ensemble_detector(): | |||
""" | |||
Compute mindspore result. | |||
""" | |||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
np.random.seed(6) | |||
adv = np.random.rand(4, 4).astype(np.float32) | |||
model = Model(Net()) | |||
@@ -97,6 +96,7 @@ def test_ensemble_detector(): | |||
@pytest.mark.env_card | |||
@pytest.mark.component_mindarmour | |||
def test_error(): | |||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
np.random.seed(6) | |||
adv = np.random.rand(4, 4).astype(np.float32) | |||
model = Model(Net()) | |||
@@ -26,8 +26,6 @@ from mindspore import context | |||
from mindarmour.adv_robustness.detectors import ErrorBasedDetector | |||
from mindarmour.adv_robustness.detectors import DivergenceBasedDetector | |||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
class Net(Cell): | |||
""" | |||
@@ -79,6 +77,7 @@ def test_mag_net(): | |||
""" | |||
Compute mindspore result. | |||
""" | |||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
np.random.seed(5) | |||
ori = np.random.rand(4, 4, 4).astype(np.float32) | |||
np.random.seed(6) | |||
@@ -100,6 +99,7 @@ def test_mag_net_transform(): | |||
""" | |||
Compute mindspore result. | |||
""" | |||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
np.random.seed(6) | |||
adv = np.random.rand(4, 4, 4).astype(np.float32) | |||
model = Model(Net()) | |||
@@ -117,6 +117,7 @@ def test_mag_net_divergence(): | |||
""" | |||
Compute mindspore result. | |||
""" | |||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
np.random.seed(5) | |||
ori = np.random.rand(4, 4, 4).astype(np.float32) | |||
np.random.seed(6) | |||
@@ -140,6 +141,7 @@ def test_mag_net_divergence_transform(): | |||
""" | |||
Compute mindspore result. | |||
""" | |||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
np.random.seed(6) | |||
adv = np.random.rand(4, 4, 4).astype(np.float32) | |||
encoder = Model(Net()) | |||
@@ -155,6 +157,7 @@ def test_mag_net_divergence_transform(): | |||
@pytest.mark.env_card | |||
@pytest.mark.component_mindarmour | |||
def test_value_error(): | |||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
np.random.seed(6) | |||
adv = np.random.rand(4, 4, 4).astype(np.float32) | |||
encoder = Model(Net()) | |||
@@ -25,9 +25,6 @@ from mindspore.ops.operations import Add | |||
from mindarmour.adv_robustness.detectors import RegionBasedDetector | |||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
class Net(Cell): | |||
""" | |||
Construct the network of target model. | |||
@@ -55,6 +52,7 @@ def test_region_based_classification(): | |||
""" | |||
Compute mindspore result. | |||
""" | |||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
np.random.seed(5) | |||
ori = np.random.rand(4, 4).astype(np.float32) | |||
labels = np.array([[1, 0, 0, 0], [0, 0, 1, 0], [0, 0, 1, 0], | |||
@@ -76,6 +74,7 @@ def test_region_based_classification(): | |||
@pytest.mark.env_card | |||
@pytest.mark.component_mindarmour | |||
def test_value_error(): | |||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
np.random.seed(5) | |||
ori = np.random.rand(4, 4).astype(np.float32) | |||
labels = np.array([[1, 0, 0, 0], [0, 0, 1, 0], [0, 0, 1, 0], | |||
@@ -24,8 +24,6 @@ from mindspore import context | |||
from mindarmour.adv_robustness.detectors import SpatialSmoothing | |||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
# for use | |||
class Net(Cell): | |||
@@ -55,6 +53,7 @@ def test_spatial_smoothing(): | |||
""" | |||
Compute mindspore result. | |||
""" | |||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
input_shape = (50, 3) | |||
np.random.seed(1) | |||
@@ -84,6 +83,7 @@ def test_spatial_smoothing_diff(): | |||
""" | |||
Compute mindspore result. | |||
""" | |||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
input_shape = (50, 3) | |||
np.random.seed(1) | |||
input_np = np.random.randn(*input_shape).astype(np.float32) | |||