Browse Source

Fixed support for Faster RCNN

tags/v1.2.1
liuluobin 4 years ago
parent
commit
7f8e9de6dc
2 changed files with 145 additions and 13 deletions
  1. +25
    -11
      mindarmour/adv_robustness/attacks/attack.py
  2. +120
    -2
      tests/ut/python/adv_robustness/attacks/test_gradient_method.py

+ 25
- 11
mindarmour/adv_robustness/attacks/attack.py View File

@@ -41,8 +41,8 @@ class Attack:
Args:
inputs (numpy.ndarray): Samples based on which adversarial
examples are generated.
labels (numpy.ndarray): Labels of samples, whose values determined
by specific attacks.
labels (Union[numpy.ndarray, tuple]): Original/target labels. \
For each input if it has more than one label, it is wrapped in a tuple.
batch_size (int): The number of samples in one batch.

Returns:
@@ -53,22 +53,36 @@ class Attack:
>>> labels = np.array([3, 0])
>>> advs = attack.batch_generate(inputs, labels, batch_size=2)
"""
arr_x, arr_y = check_pair_numpy_param('inputs', inputs, 'labels', labels)
if isinstance(labels, tuple):
for i, labels_item in enumerate(labels):
arr_x, _ = check_pair_numpy_param('inputs', inputs, \
'labels[{}]'.format(i), labels_item)
else:
arr_x, _ = check_pair_numpy_param('inputs', inputs, \
'labels', labels)
arr_y = labels
len_x = arr_x.shape[0]
batch_size = check_int_positive('batch_size', batch_size)
batchs = int(len_x / batch_size)
rest = len_x - batchs*batch_size
batches = int(len_x / batch_size)
rest = len_x - batches*batch_size
res = []
for i in range(batchs):
for i in range(batches):
x_batch = arr_x[i*batch_size: (i + 1)*batch_size]
y_batch = arr_y[i*batch_size: (i + 1)*batch_size]
if isinstance(arr_y, tuple):
y_batch = tuple([sub_labels[i*batch_size: (i + 1)*batch_size] for sub_labels in arr_y])
else:
y_batch = arr_y[i*batch_size: (i + 1)*batch_size]
adv_x = self.generate(x_batch, y_batch)
# Black-attack methods will return 3 values, just get the second.
res.append(adv_x[1] if isinstance(adv_x, tuple) else adv_x)

if rest != 0:
x_batch = arr_x[batchs*batch_size:]
y_batch = arr_y[batchs*batch_size:]
x_batch = arr_x[batches*batch_size:]
if isinstance(arr_y, tuple):
y_batch = tuple([sub_labels[batches*batch_size:] for sub_labels in arr_y])
else:
y_batch = arr_y[batches*batch_size:]
y_batch = arr_y[batches*batch_size:]
adv_x = self.generate(x_batch, y_batch)
# Black-attack methods will return 3 values, just get the second.
res.append(adv_x[1] if isinstance(adv_x, tuple) else adv_x)
@@ -85,8 +99,8 @@ class Attack:
Args:
inputs (numpy.ndarray): Samples based on which adversarial
examples are generated.
labels (numpy.ndarray): Labels of samples, whose values determined
by specific attacks.
labels (Union[numpy.ndarray, tuple]): Original/target labels. \
For each input if it has more than one label, it is wrapped in a tuple.

Raises:
NotImplementedError: It is an abstract method.


+ 120
- 2
tests/ut/python/adv_robustness/attacks/test_gradient_method.py View File

@@ -18,9 +18,9 @@ import numpy as np
import pytest

import mindspore.nn as nn
from mindspore.nn import Cell
from mindspore.nn import Cell, SoftmaxCrossEntropyWithLogits
import mindspore.context as context
from mindspore.nn import SoftmaxCrossEntropyWithLogits
from mindspore.ops.composite import GradOperation

from mindarmour.adv_robustness.attacks import FastGradientMethod
from mindarmour.adv_robustness.attacks import FastGradientSignMethod
@@ -57,6 +57,52 @@ class Net(Cell):
return out


class Net2(Cell):
"""
Construct the network of target model. A network with multiple input data.

Examples:
>>> net = Net2()
"""

def __init__(self):
super(Net2, self).__init__()
self._relu = nn.ReLU()

def construct(self, inputs1, inputs2):
out1 = self._relu(inputs1)
out2 = self._relu(inputs2)
return out1 + out2


class WithLossCell(Cell):
"""Wrap the network with loss function"""
def __init__(self, backbone, loss_fn):
super(WithLossCell, self).__init__(auto_prefix=False)
self._backbone = backbone
self._loss_fn = loss_fn

def construct(self, inputs1, inputs2, labels):
out = self._backbone(inputs1, inputs2)
return self._loss_fn(out, labels)


class GradWrapWithLoss(Cell):
"""
Construct a network to compute the gradient of loss function in \
input space and weighted by 'weight'.
"""

def __init__(self, network):
super(GradWrapWithLoss, self).__init__()
self._grad_all = GradOperation(get_all=True, sens_param=False)
self._network = network

def construct(self, inputs1, inputs2, labels):
gout = self._grad_all(self._network)(inputs1, inputs2, labels)
return gout[0]


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@@ -234,6 +280,78 @@ def test_random_least_likely_class_method():
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_card
@pytest.mark.component_mindarmour
def test_fast_gradient_method_multi_inputs():
"""
Fast gradient method unit test.
"""
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
input_np = np.asarray([[0.1, 0.2, 0.7]]).astype(np.float32)
anno_np = np.asarray([[0.4, 0.8, 0.5]]).astype(np.float32)
label = np.asarray([2], np.int32)
label = np.eye(3)[label].astype(np.float32)

loss_fn = SoftmaxCrossEntropyWithLogits(sparse=False)
with_loss_cell = WithLossCell(Net2(), loss_fn)
grad_with_loss_net = GradWrapWithLoss(with_loss_cell)
attack = FastGradientMethod(grad_with_loss_net)
ms_adv_x = attack.generate(input_np, (anno_np, label))

assert np.any(ms_adv_x != input_np), 'Fast gradient method: generate value' \
' must not be equal to original value.'


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_card
@pytest.mark.component_mindarmour
def test_batch_generate():
"""
Fast gradient method unit test.
"""
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
input_np = np.random.random([10, 3]).astype(np.float32)
label = np.random.randint(0, 3, [10])
label = np.eye(3)[label].astype(np.float32)

loss_fn = SoftmaxCrossEntropyWithLogits(sparse=False)
attack = FastGradientMethod(Net(), loss_fn=loss_fn)
ms_adv_x = attack.batch_generate(input_np, label, 4)

assert np.any(ms_adv_x != input_np), 'Fast gradient method: generate value' \
' must not be equal to original value.'


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_card
@pytest.mark.component_mindarmour
def test_batch_generate_multi_inputs():
"""
Fast gradient method unit test.
"""
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
input_np = np.random.random([10, 3]).astype(np.float32)
anno_np = np.random.random([10, 3]).astype(np.float32)
label = np.random.randint(0, 3, [10])
label = np.eye(3)[label].astype(np.float32)

loss_fn = SoftmaxCrossEntropyWithLogits(sparse=False)
with_loss_cell = WithLossCell(Net2(), loss_fn)
grad_with_loss_net = GradWrapWithLoss(with_loss_cell)
attack = FastGradientMethod(grad_with_loss_net)
ms_adv_x = attack.generate(input_np, (anno_np, label))

assert np.any(ms_adv_x != input_np), 'Fast gradient method: generate value' \
' must not be equal to original value.'


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_card
@pytest.mark.component_mindarmour
def test_assert_error():
"""
Random least likely class method unit test.


Loading…
Cancel
Save