From 7f8e9de6dce638abc0cca53baeec18a756364cd4 Mon Sep 17 00:00:00 2001 From: liuluobin Date: Mon, 26 Oct 2020 16:14:24 +0800 Subject: [PATCH] Fixed support for Faster RCNN --- mindarmour/adv_robustness/attacks/attack.py | 36 ++++-- .../adv_robustness/attacks/test_gradient_method.py | 122 ++++++++++++++++++++- 2 files changed, 145 insertions(+), 13 deletions(-) diff --git a/mindarmour/adv_robustness/attacks/attack.py b/mindarmour/adv_robustness/attacks/attack.py index 5610427..ca17956 100644 --- a/mindarmour/adv_robustness/attacks/attack.py +++ b/mindarmour/adv_robustness/attacks/attack.py @@ -41,8 +41,8 @@ class Attack: Args: inputs (numpy.ndarray): Samples based on which adversarial examples are generated. - labels (numpy.ndarray): Labels of samples, whose values determined - by specific attacks. + labels (Union[numpy.ndarray, tuple]): Original/target labels. \ + For each input if it has more than one label, it is wrapped in a tuple. batch_size (int): The number of samples in one batch. Returns: @@ -53,22 +53,36 @@ class Attack: >>> labels = np.array([3, 0]) >>> advs = attack.batch_generate(inputs, labels, batch_size=2) """ - arr_x, arr_y = check_pair_numpy_param('inputs', inputs, 'labels', labels) + if isinstance(labels, tuple): + for i, labels_item in enumerate(labels): + arr_x, _ = check_pair_numpy_param('inputs', inputs, \ + 'labels[{}]'.format(i), labels_item) + else: + arr_x, _ = check_pair_numpy_param('inputs', inputs, \ + 'labels', labels) + arr_y = labels len_x = arr_x.shape[0] batch_size = check_int_positive('batch_size', batch_size) - batchs = int(len_x / batch_size) - rest = len_x - batchs*batch_size + batches = int(len_x / batch_size) + rest = len_x - batches*batch_size res = [] - for i in range(batchs): + for i in range(batches): x_batch = arr_x[i*batch_size: (i + 1)*batch_size] - y_batch = arr_y[i*batch_size: (i + 1)*batch_size] + if isinstance(arr_y, tuple): + y_batch = tuple([sub_labels[i*batch_size: (i + 1)*batch_size] for sub_labels in arr_y]) + else: + y_batch = arr_y[i*batch_size: (i + 1)*batch_size] adv_x = self.generate(x_batch, y_batch) # Black-attack methods will return 3 values, just get the second. res.append(adv_x[1] if isinstance(adv_x, tuple) else adv_x) if rest != 0: - x_batch = arr_x[batchs*batch_size:] - y_batch = arr_y[batchs*batch_size:] + x_batch = arr_x[batches*batch_size:] + if isinstance(arr_y, tuple): + y_batch = tuple([sub_labels[batches*batch_size:] for sub_labels in arr_y]) + else: + y_batch = arr_y[batches*batch_size:] + y_batch = arr_y[batches*batch_size:] adv_x = self.generate(x_batch, y_batch) # Black-attack methods will return 3 values, just get the second. res.append(adv_x[1] if isinstance(adv_x, tuple) else adv_x) @@ -85,8 +99,8 @@ class Attack: Args: inputs (numpy.ndarray): Samples based on which adversarial examples are generated. - labels (numpy.ndarray): Labels of samples, whose values determined - by specific attacks. + labels (Union[numpy.ndarray, tuple]): Original/target labels. \ + For each input if it has more than one label, it is wrapped in a tuple. Raises: NotImplementedError: It is an abstract method. diff --git a/tests/ut/python/adv_robustness/attacks/test_gradient_method.py b/tests/ut/python/adv_robustness/attacks/test_gradient_method.py index bf7a638..cdbe3ba 100644 --- a/tests/ut/python/adv_robustness/attacks/test_gradient_method.py +++ b/tests/ut/python/adv_robustness/attacks/test_gradient_method.py @@ -18,9 +18,9 @@ import numpy as np import pytest import mindspore.nn as nn -from mindspore.nn import Cell +from mindspore.nn import Cell, SoftmaxCrossEntropyWithLogits import mindspore.context as context -from mindspore.nn import SoftmaxCrossEntropyWithLogits +from mindspore.ops.composite import GradOperation from mindarmour.adv_robustness.attacks import FastGradientMethod from mindarmour.adv_robustness.attacks import FastGradientSignMethod @@ -57,6 +57,52 @@ class Net(Cell): return out +class Net2(Cell): + """ + Construct the network of target model. A network with multiple input data. + + Examples: + >>> net = Net2() + """ + + def __init__(self): + super(Net2, self).__init__() + self._relu = nn.ReLU() + + def construct(self, inputs1, inputs2): + out1 = self._relu(inputs1) + out2 = self._relu(inputs2) + return out1 + out2 + + +class WithLossCell(Cell): + """Wrap the network with loss function""" + def __init__(self, backbone, loss_fn): + super(WithLossCell, self).__init__(auto_prefix=False) + self._backbone = backbone + self._loss_fn = loss_fn + + def construct(self, inputs1, inputs2, labels): + out = self._backbone(inputs1, inputs2) + return self._loss_fn(out, labels) + + +class GradWrapWithLoss(Cell): + """ + Construct a network to compute the gradient of loss function in \ + input space and weighted by 'weight'. + """ + + def __init__(self, network): + super(GradWrapWithLoss, self).__init__() + self._grad_all = GradOperation(get_all=True, sens_param=False) + self._network = network + + def construct(self, inputs1, inputs2, labels): + gout = self._grad_all(self._network)(inputs1, inputs2, labels) + return gout[0] + + @pytest.mark.level0 @pytest.mark.platform_arm_ascend_training @pytest.mark.platform_x86_ascend_training @@ -234,6 +280,78 @@ def test_random_least_likely_class_method(): @pytest.mark.platform_x86_ascend_training @pytest.mark.env_card @pytest.mark.component_mindarmour +def test_fast_gradient_method_multi_inputs(): + """ + Fast gradient method unit test. + """ + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + input_np = np.asarray([[0.1, 0.2, 0.7]]).astype(np.float32) + anno_np = np.asarray([[0.4, 0.8, 0.5]]).astype(np.float32) + label = np.asarray([2], np.int32) + label = np.eye(3)[label].astype(np.float32) + + loss_fn = SoftmaxCrossEntropyWithLogits(sparse=False) + with_loss_cell = WithLossCell(Net2(), loss_fn) + grad_with_loss_net = GradWrapWithLoss(with_loss_cell) + attack = FastGradientMethod(grad_with_loss_net) + ms_adv_x = attack.generate(input_np, (anno_np, label)) + + assert np.any(ms_adv_x != input_np), 'Fast gradient method: generate value' \ + ' must not be equal to original value.' + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_card +@pytest.mark.component_mindarmour +def test_batch_generate(): + """ + Fast gradient method unit test. + """ + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + input_np = np.random.random([10, 3]).astype(np.float32) + label = np.random.randint(0, 3, [10]) + label = np.eye(3)[label].astype(np.float32) + + loss_fn = SoftmaxCrossEntropyWithLogits(sparse=False) + attack = FastGradientMethod(Net(), loss_fn=loss_fn) + ms_adv_x = attack.batch_generate(input_np, label, 4) + + assert np.any(ms_adv_x != input_np), 'Fast gradient method: generate value' \ + ' must not be equal to original value.' + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_card +@pytest.mark.component_mindarmour +def test_batch_generate_multi_inputs(): + """ + Fast gradient method unit test. + """ + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + input_np = np.random.random([10, 3]).astype(np.float32) + anno_np = np.random.random([10, 3]).astype(np.float32) + label = np.random.randint(0, 3, [10]) + label = np.eye(3)[label].astype(np.float32) + + loss_fn = SoftmaxCrossEntropyWithLogits(sparse=False) + with_loss_cell = WithLossCell(Net2(), loss_fn) + grad_with_loss_net = GradWrapWithLoss(with_loss_cell) + attack = FastGradientMethod(grad_with_loss_net) + ms_adv_x = attack.generate(input_np, (anno_np, label)) + + assert np.any(ms_adv_x != input_np), 'Fast gradient method: generate value' \ + ' must not be equal to original value.' + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_card +@pytest.mark.component_mindarmour def test_assert_error(): """ Random least likely class method unit test.