From dbea7e6a54de44fbf9f74529d0b0aa894d64a47d Mon Sep 17 00:00:00 2001 From: jin-xiulang Date: Tue, 28 Dec 2021 16:40:13 +0800 Subject: [PATCH] patch for version 1.2 --- RELEASE.md | 39 ++++++++++++++ .../adv_robustness/attacks/gradient_method.py | 1 - .../attacks/iterative_gradient_method.py | 59 ++++++++++++---------- mindarmour/adv_robustness/attacks/jsma.py | 3 +- .../adv_robustness/defenses/adversarial_defense.py | 1 + .../defenses/projected_adversarial_defense.py | 8 ++- setup.py | 2 +- .../adv_robustness/attacks/black/test_hsja.py | 4 +- .../adv_robustness/attacks/black/test_nes.py | 7 ++- .../attacks/black/test_pointwise_attack.py | 3 +- .../attacks/black/test_salt_and_pepper_attack.py | 4 +- .../attacks/test_batch_generate_attack.py | 5 +- tests/ut/python/adv_robustness/attacks/test_cw.py | 6 +-- .../adv_robustness/attacks/test_deep_fool.py | 5 +- .../attacks/test_iterative_gradient_method.py | 55 ++++++++++++++++++-- .../ut/python/adv_robustness/attacks/test_lbfgs.py | 4 +- .../detectors/black/test_similarity_detector.py | 3 +- .../detectors/test_ensemble_detector.py | 4 +- .../adv_robustness/detectors/test_mag_net.py | 7 ++- .../detectors/test_region_based_detector.py | 5 +- .../detectors/test_spatial_smoothing.py | 4 +- 21 files changed, 162 insertions(+), 67 deletions(-) diff --git a/RELEASE.md b/RELEASE.md index bc04a67..ec532a8 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -1,3 +1,42 @@ +# MindArmour 1.2.1 + +## MindArmour 1.2.1 Release Notes + +### Major Features and Improvements + +### API Change + +#### Backwards Incompatible Change + +##### C++ API + +[Modify] ... +[Add] ... +[Delete] ... + +##### Java API + +[Add] ... + +#### Deprecations + +##### C++ API + +##### Java API + +### Bug fixes + +* [BUGFIX] Fix a bug of PGD method +* [BUGFIX] Fix a bug of JSMA method + +### Contributors + +Thanks goes to these wonderful people: + +Liu Liu, Zhidan Liu, Luobin Liu and Xiulang Jin. + +Contributions of any kind are welcome! + # MindArmour 1.2.0 ## MindArmour 1.2.0 Release Notes diff --git a/mindarmour/adv_robustness/attacks/gradient_method.py b/mindarmour/adv_robustness/attacks/gradient_method.py index cd7e2e1..06e3979 100644 --- a/mindarmour/adv_robustness/attacks/gradient_method.py +++ b/mindarmour/adv_robustness/attacks/gradient_method.py @@ -75,7 +75,6 @@ class GradientMethod(Attack): else: with_loss_cell = WithLossCell(self._network, loss_fn) self._grad_all = GradWrapWithLoss(with_loss_cell) - self._grad_all.set_train() def generate(self, inputs, labels): """ diff --git a/mindarmour/adv_robustness/attacks/iterative_gradient_method.py b/mindarmour/adv_robustness/attacks/iterative_gradient_method.py index 94acf25..c6e80ed 100644 --- a/mindarmour/adv_robustness/attacks/iterative_gradient_method.py +++ b/mindarmour/adv_robustness/attacks/iterative_gradient_method.py @@ -14,6 +14,7 @@ """ Iterative gradient method attack. """ from abc import abstractmethod +import copy import numpy as np from PIL import Image, ImageOps @@ -68,13 +69,14 @@ def _reshape_l1_projection(values, eps=3): return proj_x -def _projection(values, eps, norm_level): +def _projection(values, eps, clip_diff, norm_level): """ Implementation of values normalization within eps. Args: values (numpy.ndarray): Input data. eps (float): Project radius. + clip_diff (float): Difference range of clip bounds. norm_level (Union[int, char, numpy.inf]): Order of the norm. Possible values: np.inf, 1 or 2. @@ -88,12 +90,12 @@ def _projection(values, eps, norm_level): if norm_level in (1, '1'): sample_batch = values.shape[0] x_flat = values.reshape(sample_batch, -1) - proj_flat = _reshape_l1_projection(x_flat, eps) + proj_flat = _reshape_l1_projection(x_flat, eps*clip_diff) return proj_flat.reshape(values.shape) if norm_level in (2, '2'): return eps*normalize_value(values, norm_level) if norm_level in (np.inf, 'inf'): - return eps*np.sign(values) + return eps*clip_diff*np.sign(values) msg = 'Values of `norm_level` different from 1, 2 and `np.inf` are ' \ 'currently not supported.' LOGGER.error(TAG, msg) @@ -132,7 +134,6 @@ class IterativeGradientMethod(Attack): self._loss_grad = network else: self._loss_grad = GradWrapWithLoss(WithLossCell(self._network, loss_fn)) - self._loss_grad.set_train() @abstractmethod def generate(self, inputs, labels): @@ -407,10 +408,12 @@ class ProjectedGradientDescent(BasicIterativeMethod): np.inf, 1 or 2. Default: 'inf'. loss_fn (Loss): Loss function for optimization. If None, the input network \ is already equipped with loss function. Default: None. + random_start (bool): If True, use random perturbs at the beginning. If False, + start from original samples. """ def __init__(self, network, eps=0.3, eps_iter=0.1, bounds=(0.0, 1.0), - is_targeted=False, nb_iter=5, norm_level='inf', loss_fn=None): + is_targeted=False, nb_iter=5, norm_level='inf', loss_fn=None, random_start=False): super(ProjectedGradientDescent, self).__init__(network, eps=eps, eps_iter=eps_iter, @@ -419,6 +422,10 @@ class ProjectedGradientDescent(BasicIterativeMethod): nb_iter=nb_iter, loss_fn=loss_fn) self._norm_level = check_norm_level(norm_level) + self._random_start = check_param_type('random_start', random_start, bool) + + def _get_random_start(self, inputs): + return inputs + np.random.uniform(-self._eps, self._eps, size=inputs.shape).astype(np.float32) def generate(self, inputs, labels): """ @@ -442,33 +449,29 @@ class ProjectedGradientDescent(BasicIterativeMethod): """ inputs_image, inputs, labels = check_inputs_labels(inputs, labels) arr_x = inputs_image + adv_x = copy.deepcopy(inputs_image) if self._bounds is not None: clip_min, clip_max = self._bounds clip_diff = clip_max - clip_min - for _ in range(self._nb_iter): - adv_x = self._attack.generate(inputs, labels) - perturs = _projection(adv_x - arr_x, - self._eps, - norm_level=self._norm_level) - perturs = np.clip(perturs, (0 - self._eps)*clip_diff, - self._eps*clip_diff) - adv_x = arr_x + perturs - if isinstance(inputs, tuple): - inputs = (adv_x,) + inputs[1:] - else: - inputs = adv_x else: - for _ in range(self._nb_iter): - adv_x = self._attack.generate(inputs, labels) - perturs = _projection(adv_x - arr_x, - self._eps, - norm_level=self._norm_level) - adv_x = arr_x + perturs - adv_x = np.clip(adv_x, arr_x - self._eps, arr_x + self._eps) - if isinstance(inputs, tuple): - inputs = (adv_x,) + inputs[1:] - else: - inputs = adv_x + clip_diff = 1 + if self._random_start: + inputs = self._get_random_start(inputs) + for _ in range(self._nb_iter): + inputs_tensor = to_tensor_tuple(inputs) + labels_tensor = to_tensor_tuple(labels) + out_grad = self._loss_grad(*inputs_tensor, *labels_tensor) + gradient = out_grad.asnumpy() + perturbs = _projection(gradient, self._eps_iter, clip_diff, norm_level=self._norm_level) + sum_perturbs = adv_x - arr_x + perturbs + sum_perturbs = np.clip(sum_perturbs, (0 - self._eps)*clip_diff, self._eps*clip_diff) + adv_x = arr_x + sum_perturbs + if self._bounds is not None: + adv_x = np.clip(adv_x, clip_min, clip_max) + if isinstance(inputs, tuple): + inputs = (adv_x,) + inputs[1:] + else: + inputs = adv_x return adv_x diff --git a/mindarmour/adv_robustness/attacks/jsma.py b/mindarmour/adv_robustness/attacks/jsma.py index 996f031..7be4721 100644 --- a/mindarmour/adv_robustness/attacks/jsma.py +++ b/mindarmour/adv_robustness/attacks/jsma.py @@ -134,7 +134,6 @@ class JSMAAttack(Attack): ori_shape = data.shape temp = data.flatten() bit_map = np.ones_like(temp) - fake_res = np.zeros_like(data) counter = np.zeros_like(temp) perturbed = np.copy(temp) for _ in range(self._max_iter): @@ -167,7 +166,7 @@ class JSMAAttack(Attack): bit_map[p2_ind] = 0 perturbed = np.clip(perturbed, self._min, self._max) LOGGER.debug(TAG, 'fail to find adversarial sample.') - return fake_res + return perturbed.reshape(ori_shape) def generate(self, inputs, labels): """ diff --git a/mindarmour/adv_robustness/defenses/adversarial_defense.py b/mindarmour/adv_robustness/defenses/adversarial_defense.py index 6ef6648..039552f 100644 --- a/mindarmour/adv_robustness/defenses/adversarial_defense.py +++ b/mindarmour/adv_robustness/defenses/adversarial_defense.py @@ -136,6 +136,7 @@ class AdversarialDefenseWithAttacks(AdversarialDefense): replace_ratio, 0, 1) self._graph_initialized = False + self._train_net.set_train() def defense(self, inputs, labels): """ diff --git a/mindarmour/adv_robustness/defenses/projected_adversarial_defense.py b/mindarmour/adv_robustness/defenses/projected_adversarial_defense.py index 0b44869..3d4b986 100644 --- a/mindarmour/adv_robustness/defenses/projected_adversarial_defense.py +++ b/mindarmour/adv_robustness/defenses/projected_adversarial_defense.py @@ -39,6 +39,8 @@ class ProjectedAdversarialDefense(AdversarialDefenseWithAttacks): nb_iter (int): PGD attack parameters, number of iteration. Default: 5. norm_level (str): Norm type. 'inf' or 'l2'. Default: 'inf'. + random_start (bool): If True, use random perturbs at the beginning. If False, + start from original samples. Examples: >>> net = Net() @@ -54,14 +56,16 @@ class ProjectedAdversarialDefense(AdversarialDefenseWithAttacks): eps=0.3, eps_iter=0.1, nb_iter=5, - norm_level='inf'): + norm_level='inf', + random_start=True): attack = ProjectedGradientDescent(network, eps=eps, eps_iter=eps_iter, nb_iter=nb_iter, bounds=bounds, norm_level=norm_level, - loss_fn=loss_fn) + loss_fn=loss_fn, + random_start=random_start) super(ProjectedAdversarialDefense, self).__init__( network, [attack], loss_fn=loss_fn, optimizer=optimizer, bounds=bounds, replace_ratio=replace_ratio) diff --git a/setup.py b/setup.py index d92df61..77d8082 100644 --- a/setup.py +++ b/setup.py @@ -20,7 +20,7 @@ from setuptools import setup from setuptools.command.egg_info import egg_info from setuptools.command.build_py import build_py -version = '1.2.0' +version = '1.2.1' cur_dir = os.path.dirname(os.path.realpath(__file__)) pkg_dir = os.path.join(cur_dir, 'build') diff --git a/tests/ut/python/adv_robustness/attacks/black/test_hsja.py b/tests/ut/python/adv_robustness/attacks/black/test_hsja.py index c9e6d24..8c368c9 100644 --- a/tests/ut/python/adv_robustness/attacks/black/test_hsja.py +++ b/tests/ut/python/adv_robustness/attacks/black/test_hsja.py @@ -26,7 +26,6 @@ from mindarmour.utils.logger import LogUtil from tests.ut.python.utils.mock_net import Net context.set_context(mode=context.GRAPH_MODE) -context.set_context(device_target="Ascend") LOGGER = LogUtil.get_instance() TAG = 'HopSkipJumpAttack' @@ -91,9 +90,9 @@ def test_hsja_mnist_attack(): """ hsja-Attack test """ + context.set_context(device_target="Ascend") current_dir = os.path.dirname(os.path.abspath(__file__)) - # get test data test_images_set = np.load(os.path.join(current_dir, '../../../dataset/test_images.npy')) @@ -159,6 +158,7 @@ def test_hsja_mnist_attack(): @pytest.mark.env_card @pytest.mark.component_mindarmour def test_value_error(): + context.set_context(device_target="Ascend") model = get_model() norm = 'l2' with pytest.raises(ValueError): diff --git a/tests/ut/python/adv_robustness/attacks/black/test_nes.py b/tests/ut/python/adv_robustness/attacks/black/test_nes.py index 6f99a01..25fc5e5 100644 --- a/tests/ut/python/adv_robustness/attacks/black/test_nes.py +++ b/tests/ut/python/adv_robustness/attacks/black/test_nes.py @@ -26,7 +26,6 @@ from mindarmour.utils.logger import LogUtil from tests.ut.python.utils.mock_net import Net context.set_context(mode=context.GRAPH_MODE) -context.set_context(device_target="Ascend") LOGGER = LogUtil.get_instance() TAG = 'HopSkipJumpAttack' @@ -103,6 +102,7 @@ def nes_mnist_attack(scene, top_k): """ hsja-Attack test """ + context.set_context(device_target="Ascend") current_dir = os.path.dirname(os.path.abspath(__file__)) test_images, test_labels = get_dataset(current_dir) model = get_model(current_dir) @@ -167,6 +167,7 @@ def nes_mnist_attack(scene, top_k): @pytest.mark.component_mindarmour def test_nes_query_limit(): # scene is in ['Query_Limit', 'Partial_Info', 'Label_Only'] + context.set_context(device_target="Ascend") scene = 'Query_Limit' nes_mnist_attack(scene, top_k=-1) @@ -178,6 +179,7 @@ def test_nes_query_limit(): @pytest.mark.component_mindarmour def test_nes_partial_info(): # scene is in ['Query_Limit', 'Partial_Info', 'Label_Only'] + context.set_context(device_target="Ascend") scene = 'Partial_Info' nes_mnist_attack(scene, top_k=5) @@ -189,6 +191,7 @@ def test_nes_partial_info(): @pytest.mark.component_mindarmour def test_nes_label_only(): # scene is in ['Query_Limit', 'Partial_Info', 'Label_Only'] + context.set_context(device_target="Ascend") scene = 'Label_Only' nes_mnist_attack(scene, top_k=5) @@ -200,6 +203,7 @@ def test_nes_label_only(): @pytest.mark.component_mindarmour def test_value_error(): """test that exception is raised for invalid labels""" + context.set_context(device_target="Ascend") with pytest.raises(ValueError): assert nes_mnist_attack('Label_Only', -1) @@ -210,6 +214,7 @@ def test_value_error(): @pytest.mark.env_card @pytest.mark.component_mindarmour def test_none(): + context.set_context(device_target="Ascend") current_dir = os.path.dirname(os.path.abspath(__file__)) model = get_model(current_dir) test_images, test_labels = get_dataset(current_dir) diff --git a/tests/ut/python/adv_robustness/attacks/black/test_pointwise_attack.py b/tests/ut/python/adv_robustness/attacks/black/test_pointwise_attack.py index 542566b..f464521 100644 --- a/tests/ut/python/adv_robustness/attacks/black/test_pointwise_attack.py +++ b/tests/ut/python/adv_robustness/attacks/black/test_pointwise_attack.py @@ -28,8 +28,6 @@ from mindarmour.utils.logger import LogUtil from tests.ut.python.utils.mock_net import Net -context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") - LOGGER = LogUtil.get_instance() TAG = 'Pointwise_Test' LOGGER.set_level('INFO') @@ -57,6 +55,7 @@ def test_pointwise_attack_method(): """ Pointwise attack method unit test. """ + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") np.random.seed(123) # upload trained network current_dir = os.path.dirname(os.path.abspath(__file__)) diff --git a/tests/ut/python/adv_robustness/attacks/black/test_salt_and_pepper_attack.py b/tests/ut/python/adv_robustness/attacks/black/test_salt_and_pepper_attack.py index 28b4743..af304ee 100644 --- a/tests/ut/python/adv_robustness/attacks/black/test_salt_and_pepper_attack.py +++ b/tests/ut/python/adv_robustness/attacks/black/test_salt_and_pepper_attack.py @@ -26,8 +26,6 @@ from mindarmour import BlackModel from mindarmour.adv_robustness.attacks import SaltAndPepperNoiseAttack context.set_context(mode=context.GRAPH_MODE) -context.set_context(device_target="Ascend") - # for user class ModelToBeAttacked(BlackModel): @@ -79,6 +77,7 @@ def test_salt_and_pepper_attack_method(): """ Salt and pepper attack method unit test. """ + context.set_context(device_target="Ascend") batch_size = 6 np.random.seed(123) net = SimpleNet() @@ -105,6 +104,7 @@ def test_salt_and_pepper_attack_in_batch(): """ Salt and pepper attack method unit test in batch. """ + context.set_context(device_target="Ascend") batch_size = 32 np.random.seed(123) net = SimpleNet() diff --git a/tests/ut/python/adv_robustness/attacks/test_batch_generate_attack.py b/tests/ut/python/adv_robustness/attacks/test_batch_generate_attack.py index dcb4fca..bcac962 100644 --- a/tests/ut/python/adv_robustness/attacks/test_batch_generate_attack.py +++ b/tests/ut/python/adv_robustness/attacks/test_batch_generate_attack.py @@ -24,10 +24,6 @@ from mindspore.nn import Cell, SoftmaxCrossEntropyWithLogits from mindarmour.adv_robustness.attacks import FastGradientMethod - -context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") - - # for user class Net(Cell): """ @@ -118,6 +114,7 @@ def test_batch_generate_attack(): """ Attack with batch-generate. """ + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") input_np = np.random.random((128, 10)).astype(np.float32) label = np.random.randint(0, 10, 128).astype(np.int32) label = np.eye(10)[label].astype(np.float32) diff --git a/tests/ut/python/adv_robustness/attacks/test_cw.py b/tests/ut/python/adv_robustness/attacks/test_cw.py index da0eece..5ae1eb3 100644 --- a/tests/ut/python/adv_robustness/attacks/test_cw.py +++ b/tests/ut/python/adv_robustness/attacks/test_cw.py @@ -23,10 +23,6 @@ from mindspore import context from mindarmour.adv_robustness.attacks import CarliniWagnerL2Attack - -context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") - - # for user class Net(Cell): """ @@ -63,6 +59,7 @@ def test_cw_attack(): """ CW-Attack test """ + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") net = Net() input_np = np.array([[0.1, 0.2, 0.7, 0.5, 0.4]]).astype(np.float32) label_np = np.array([3]).astype(np.int64) @@ -81,6 +78,7 @@ def test_cw_attack_targeted(): """ CW-Attack test """ + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") net = Net() input_np = np.array([[0.1, 0.2, 0.7, 0.5, 0.4]]).astype(np.float32) target_np = np.array([1]).astype(np.int64) diff --git a/tests/ut/python/adv_robustness/attacks/test_deep_fool.py b/tests/ut/python/adv_robustness/attacks/test_deep_fool.py index 005e8ea..5713265 100644 --- a/tests/ut/python/adv_robustness/attacks/test_deep_fool.py +++ b/tests/ut/python/adv_robustness/attacks/test_deep_fool.py @@ -24,7 +24,6 @@ from mindspore import Tensor from mindarmour.adv_robustness.attacks import DeepFool -context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") # for user @@ -80,6 +79,7 @@ def test_deepfool_attack(): """ Deepfool-Attack test """ + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") net = Net() input_shape = (1, 5) _, classes = input_shape @@ -105,6 +105,7 @@ def test_deepfool_attack_detection(): """ Deepfool-Attack test """ + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") net = Net2() inputs1_np = np.random.random((2, 10, 10)).astype(np.float32) inputs2_np = np.random.random((2, 10, 5)).astype(np.float32) @@ -128,6 +129,7 @@ def test_deepfool_attack_inf(): """ Deepfool-Attack test """ + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") net = Net() input_shape = (1, 5) _, classes = input_shape @@ -146,6 +148,7 @@ def test_deepfool_attack_inf(): @pytest.mark.env_card @pytest.mark.component_mindarmour def test_value_error(): + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") net = Net() input_shape = (1, 5) _, classes = input_shape diff --git a/tests/ut/python/adv_robustness/attacks/test_iterative_gradient_method.py b/tests/ut/python/adv_robustness/attacks/test_iterative_gradient_method.py index da330bc..74b7532 100644 --- a/tests/ut/python/adv_robustness/attacks/test_iterative_gradient_method.py +++ b/tests/ut/python/adv_robustness/attacks/test_iterative_gradient_method.py @@ -29,9 +29,6 @@ from mindarmour.adv_robustness.attacks import IterativeGradientMethod from mindarmour.adv_robustness.attacks import DiverseInputIterativeMethod from mindarmour.adv_robustness.attacks import MomentumDiverseInputIterativeMethod -context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") - - # for user class Net(Cell): """ @@ -65,6 +62,7 @@ def test_basic_iterative_method(): """ Basic iterative method unit test. """ + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") input_np = np.asarray([[0.1, 0.2, 0.7]], np.float32) label = np.asarray([2], np.int32) label = np.eye(3)[label].astype(np.float32) @@ -87,6 +85,7 @@ def test_momentum_iterative_method(): """ Momentum iterative method unit test. """ + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") input_np = np.asarray([[0.1, 0.2, 0.7]], np.float32) label = np.asarray([2], np.int32) label = np.eye(3)[label].astype(np.float32) @@ -108,6 +107,53 @@ def test_projected_gradient_descent_method(): """ Projected gradient descent method unit test. """ + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + input_np = np.asarray([[0.1, 0.2, 0.7]], np.float32) + label = np.asarray([2], np.int32) + label = np.eye(3)[label].astype(np.float32) + + for i in range(5): + attack = ProjectedGradientDescent(Net(), nb_iter=i + 1, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False)) + ms_adv_x = attack.generate(input_np, label) + + assert np.any( + ms_adv_x != input_np), 'Projected gradient descent method: ' \ + 'generate value must not be equal to' \ + ' original value.' + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_card +@pytest.mark.component_mindarmour +def test_projected_gradient_descent_method_gpu(): + """ + Projected gradient descent method unit test. + """ + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + input_np = np.asarray([[0.1, 0.2, 0.7]], np.float32) + label = np.asarray([2], np.int32) + label = np.eye(3)[label].astype(np.float32) + + for i in range(5): + attack = ProjectedGradientDescent(Net(), nb_iter=i + 1, loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False)) + ms_adv_x = attack.generate(input_np, label) + + assert np.any( + ms_adv_x != input_np), 'Projected gradient descent method: ' \ + 'generate value must not be equal to' \ + ' original value.' + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_card +@pytest.mark.component_mindarmour +def test_projected_gradient_descent_method_cpu(): + """ + Projected gradient descent method unit test. + """ + context.set_context(mode=context.GRAPH_MODE, device_target="CPU") input_np = np.asarray([[0.1, 0.2, 0.7]], np.float32) label = np.asarray([2], np.int32) label = np.eye(3)[label].astype(np.float32) @@ -131,6 +177,7 @@ def test_diverse_input_iterative_method(): """ Diverse input iterative method unit test. """ + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") input_np = np.asarray([[0.1, 0.2, 0.7]], np.float32) label = np.asarray([2], np.int32) label = np.eye(3)[label].astype(np.float32) @@ -151,6 +198,7 @@ def test_momentum_diverse_input_iterative_method(): """ Momentum diverse input iterative method unit test. """ + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") input_np = np.asarray([[0.1, 0.2, 0.7]], np.float32) label = np.asarray([2], np.int32) label = np.eye(3)[label].astype(np.float32) @@ -168,6 +216,7 @@ def test_momentum_diverse_input_iterative_method(): @pytest.mark.env_card @pytest.mark.component_mindarmour def test_error(): + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") attack = IterativeGradientMethod(Net(), bounds=(0.0, 1.0), loss_fn=SoftmaxCrossEntropyWithLogits(sparse=False)) with pytest.raises(NotImplementedError): input_np = np.asarray([[0.1, 0.2, 0.7]], np.float32) diff --git a/tests/ut/python/adv_robustness/attacks/test_lbfgs.py b/tests/ut/python/adv_robustness/attacks/test_lbfgs.py index ce303d7..9e24aa7 100644 --- a/tests/ut/python/adv_robustness/attacks/test_lbfgs.py +++ b/tests/ut/python/adv_robustness/attacks/test_lbfgs.py @@ -26,9 +26,6 @@ from mindarmour.utils.logger import LogUtil from tests.ut.python.utils.mock_net import Net -context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") - - LOGGER = LogUtil.get_instance() TAG = 'LBFGS_Test' LOGGER.set_level('DEBUG') @@ -43,6 +40,7 @@ def test_lbfgs_attack(): """ LBFGS-Attack test """ + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") np.random.seed(123) # upload trained network current_dir = os.path.dirname(os.path.abspath(__file__)) diff --git a/tests/ut/python/adv_robustness/detectors/black/test_similarity_detector.py b/tests/ut/python/adv_robustness/detectors/black/test_similarity_detector.py index 0773d68..0a06bca 100644 --- a/tests/ut/python/adv_robustness/detectors/black/test_similarity_detector.py +++ b/tests/ut/python/adv_robustness/detectors/black/test_similarity_detector.py @@ -24,8 +24,6 @@ from mindspore.ops.operations import Add from mindarmour.adv_robustness.detectors import SimilarityDetector -context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") - class EncoderNet(Cell): """ @@ -66,6 +64,7 @@ def test_similarity_detector(): """ Similarity detector unit test """ + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") # Prepare dataset np.random.seed(5) x_train = np.random.rand(1000, 32, 32, 3).astype(np.float32) diff --git a/tests/ut/python/adv_robustness/detectors/test_ensemble_detector.py b/tests/ut/python/adv_robustness/detectors/test_ensemble_detector.py index 8c0dc22..271aebd 100644 --- a/tests/ut/python/adv_robustness/detectors/test_ensemble_detector.py +++ b/tests/ut/python/adv_robustness/detectors/test_ensemble_detector.py @@ -26,8 +26,6 @@ from mindarmour.adv_robustness.detectors import ErrorBasedDetector from mindarmour.adv_robustness.detectors import RegionBasedDetector from mindarmour.adv_robustness.detectors import EnsembleDetector -context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") - class Net(Cell): """ @@ -74,6 +72,7 @@ def test_ensemble_detector(): """ Compute mindspore result. """ + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") np.random.seed(6) adv = np.random.rand(4, 4).astype(np.float32) model = Model(Net()) @@ -97,6 +96,7 @@ def test_ensemble_detector(): @pytest.mark.env_card @pytest.mark.component_mindarmour def test_error(): + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") np.random.seed(6) adv = np.random.rand(4, 4).astype(np.float32) model = Model(Net()) diff --git a/tests/ut/python/adv_robustness/detectors/test_mag_net.py b/tests/ut/python/adv_robustness/detectors/test_mag_net.py index bcaa341..6a403c6 100644 --- a/tests/ut/python/adv_robustness/detectors/test_mag_net.py +++ b/tests/ut/python/adv_robustness/detectors/test_mag_net.py @@ -26,8 +26,6 @@ from mindspore import context from mindarmour.adv_robustness.detectors import ErrorBasedDetector from mindarmour.adv_robustness.detectors import DivergenceBasedDetector -context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") - class Net(Cell): """ @@ -79,6 +77,7 @@ def test_mag_net(): """ Compute mindspore result. """ + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") np.random.seed(5) ori = np.random.rand(4, 4, 4).astype(np.float32) np.random.seed(6) @@ -100,6 +99,7 @@ def test_mag_net_transform(): """ Compute mindspore result. """ + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") np.random.seed(6) adv = np.random.rand(4, 4, 4).astype(np.float32) model = Model(Net()) @@ -117,6 +117,7 @@ def test_mag_net_divergence(): """ Compute mindspore result. """ + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") np.random.seed(5) ori = np.random.rand(4, 4, 4).astype(np.float32) np.random.seed(6) @@ -140,6 +141,7 @@ def test_mag_net_divergence_transform(): """ Compute mindspore result. """ + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") np.random.seed(6) adv = np.random.rand(4, 4, 4).astype(np.float32) encoder = Model(Net()) @@ -155,6 +157,7 @@ def test_mag_net_divergence_transform(): @pytest.mark.env_card @pytest.mark.component_mindarmour def test_value_error(): + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") np.random.seed(6) adv = np.random.rand(4, 4, 4).astype(np.float32) encoder = Model(Net()) diff --git a/tests/ut/python/adv_robustness/detectors/test_region_based_detector.py b/tests/ut/python/adv_robustness/detectors/test_region_based_detector.py index b515d12..882fbdb 100644 --- a/tests/ut/python/adv_robustness/detectors/test_region_based_detector.py +++ b/tests/ut/python/adv_robustness/detectors/test_region_based_detector.py @@ -25,9 +25,6 @@ from mindspore.ops.operations import Add from mindarmour.adv_robustness.detectors import RegionBasedDetector -context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") - - class Net(Cell): """ Construct the network of target model. @@ -55,6 +52,7 @@ def test_region_based_classification(): """ Compute mindspore result. """ + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") np.random.seed(5) ori = np.random.rand(4, 4).astype(np.float32) labels = np.array([[1, 0, 0, 0], [0, 0, 1, 0], [0, 0, 1, 0], @@ -76,6 +74,7 @@ def test_region_based_classification(): @pytest.mark.env_card @pytest.mark.component_mindarmour def test_value_error(): + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") np.random.seed(5) ori = np.random.rand(4, 4).astype(np.float32) labels = np.array([[1, 0, 0, 0], [0, 0, 1, 0], [0, 0, 1, 0], diff --git a/tests/ut/python/adv_robustness/detectors/test_spatial_smoothing.py b/tests/ut/python/adv_robustness/detectors/test_spatial_smoothing.py index 5dc6f4e..de27502 100644 --- a/tests/ut/python/adv_robustness/detectors/test_spatial_smoothing.py +++ b/tests/ut/python/adv_robustness/detectors/test_spatial_smoothing.py @@ -24,8 +24,6 @@ from mindspore import context from mindarmour.adv_robustness.detectors import SpatialSmoothing -context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") - # for use class Net(Cell): @@ -55,6 +53,7 @@ def test_spatial_smoothing(): """ Compute mindspore result. """ + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") input_shape = (50, 3) np.random.seed(1) @@ -84,6 +83,7 @@ def test_spatial_smoothing_diff(): """ Compute mindspore result. """ + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") input_shape = (50, 3) np.random.seed(1) input_np = np.random.randn(*input_shape).astype(np.float32)