Browse Source

Extend Deepfool to object detection models

tags/v1.2.1
liuluobin 4 years ago
parent
commit
8fded5f4ff
7 changed files with 346 additions and 175 deletions
  1. +2
    -13
      mindarmour/adv_robustness/attacks/attack.py
  2. +217
    -66
      mindarmour/adv_robustness/attacks/deep_fool.py
  3. +7
    -39
      mindarmour/adv_robustness/attacks/gradient_method.py
  4. +7
    -51
      mindarmour/adv_robustness/attacks/iterative_gradient_method.py
  5. +19
    -0
      mindarmour/utils/_check_param.py
  6. +56
    -6
      mindarmour/utils/util.py
  7. +38
    -0
      tests/ut/python/adv_robustness/attacks/test_deep_fool.py

+ 2
- 13
mindarmour/adv_robustness/attacks/attack.py View File

@@ -18,7 +18,7 @@ from abc import abstractmethod

import numpy as np

from mindarmour.utils._check_param import check_pair_numpy_param, \
from mindarmour.utils._check_param import check_inputs_labels, \
check_int_positive, check_equal_shape, check_numpy_param, check_model
from mindarmour.utils.util import calculate_iou
from mindarmour.utils.logger import LogUtil
@@ -55,18 +55,7 @@ class Attack:
>>> labels = np.array([3, 0])
>>> advs = attack.batch_generate(inputs, labels, batch_size=2)
"""
inputs_image = inputs[0] if isinstance(inputs, tuple) else inputs
if isinstance(inputs, tuple):
for i, inputs_item in enumerate(inputs):
_ = check_pair_numpy_param('inputs_image', inputs_image, \
'inputs[{}]'.format(i), inputs_item)
if isinstance(labels, tuple):
for i, labels_item in enumerate(labels):
_ = check_pair_numpy_param('inputs_image', inputs_image, \
'labels[{}]'.format(i), labels_item)
else:
_ = check_pair_numpy_param('inputs', inputs_image, \
'labels', labels)
inputs_image, inputs, labels = check_inputs_labels(inputs, labels)
arr_x = inputs
arr_y = labels
len_x = inputs_image.shape[0]


+ 217
- 66
mindarmour/adv_robustness/attacks/deep_fool.py View File

@@ -20,16 +20,76 @@ from mindspore import Tensor
from mindspore.nn import Cell

from mindarmour.utils.logger import LogUtil
from mindarmour.utils.util import GradWrap, jacobian_matrix
from mindarmour.utils.util import GradWrap, jacobian_matrix, \
jacobian_matrix_for_detection, calculate_iou, to_tensor_tuple
from mindarmour.utils._check_param import check_pair_numpy_param, check_model, \
check_value_positive, check_int_positive, check_norm_level, \
check_param_multi_types, check_param_type
check_param_multi_types, check_param_type, check_value_non_negative
from .attack import Attack

LOGGER = LogUtil.get_instance()
TAG = 'DeepFool'


class _GetLogits(Cell):
def __init__(self, network):
super(_GetLogits, self).__init__()
self._network = network

def construct(self, *inputs):
_, pre_logits = self._network(*inputs)
return pre_logits


def _deepfool_detection_scores(inputs, gt_boxes, gt_labels, network):
"""
Evaluate the detection result of inputs, specially for object detection models.

Args:
inputs (numpy.ndarray): Input samples.
gt_boxes (numpy.ndarray): Ground-truth boxes of inputs.
gt_labels (numpy.ndarray): Ground-truth labels of inputs.
model (BlackModel): Target model.

Returns:
- numpy.ndarray, detection scores of inputs.

- numpy.ndarray, the number of objects that are correctly detected.
"""
network = check_param_type('network', network, Cell)
inputs_tensor = to_tensor_tuple(inputs)
box_and_confi, pred_logits = network(*inputs_tensor)
box_and_confi, pred_logits = box_and_confi.asnumpy(), pred_logits.asnumpy()
pred_labels = np.argmax(pred_logits, axis=2)
det_scores = []
correct_labels_num = []
gt_boxes_num = gt_boxes.shape[0]
iou_thres = 0.5
for idx, (boxes, labels) in enumerate(zip(box_and_confi, pred_labels)):
score = 0
box_num = boxes.shape[0]
correct_label_flag = np.zeros(gt_labels.shape)
gt_boxes_idx = gt_boxes[idx]
gt_labels_idx = gt_labels[idx]
for i in range(box_num):
pred_box = boxes[i]
max_iou_confi = 0
for j in range(gt_boxes_num):
iou = calculate_iou(pred_box[:4], gt_boxes_idx[j][:4])
if labels[i] == gt_labels_idx[j] and iou > iou_thres:
max_iou_confi = max(max_iou_confi, pred_box[-1] + iou)
correct_label_flag[j] = 1
score += max_iou_confi
det_scores.append(score)
correct_labels_num.append(np.sum(correct_label_flag))
return np.array(det_scores), np.array(correct_labels_num)


def _is_success(inputs, gt_boxes, gt_labels, network, gt_object_nums, reserve_ratio):
_, correct_nums_adv = _deepfool_detection_scores(inputs, gt_boxes, gt_labels, network)
return np.all(correct_nums_adv <= (gt_object_nums*reserve_ratio).astype(np.int32))


class DeepFool(Attack):
"""
DeepFool is an untargeted & iterative attack achieved by moving the benign
@@ -56,8 +116,8 @@ class DeepFool(Attack):
>>> attack = DeepFool(network)
"""

def __init__(self, network, num_classes, max_iters=50, overshoot=0.02,
norm_level=2, bounds=None, sparse=True):
def __init__(self, network, num_classes, model_type='classification',
reserve_ratio=0.3, max_iters=50, overshoot=0.02, norm_level=2, bounds=None, sparse=True):
super(DeepFool, self).__init__()
self._network = check_model('network', network, Cell)
self._network.set_grad(True)
@@ -66,18 +126,32 @@ class DeepFool(Attack):
self._norm_level = check_norm_level(norm_level)
self._num_classes = check_int_positive('num_classes', num_classes)
self._net_grad = GradWrap(self._network)
self._bounds = check_param_multi_types('bounds', bounds, [list, tuple])
self._bounds = bounds
if self._bounds is not None:
self._bounds = check_param_multi_types('bounds', bounds, [list, tuple])
for b in self._bounds:
_ = check_param_multi_types('bound', b, [int, float])
self._sparse = check_param_type('sparse', sparse, bool)
for b in self._bounds:
_ = check_param_multi_types('bound', b, [int, float])
self._model_type = check_param_type('model_type', model_type, str)
if self._model_type not in ('classification', 'detection'):
msg = "Only 'classification' or 'detection' is supported now, but got {}.".format(self._model_type)
LOGGER.error(TAG, msg)
raise ValueError(msg)
self._reserve_ratio = check_value_non_negative('reserve_ratio', reserve_ratio)
if self._reserve_ratio > 1:
msg = 'reserve_ratio should be less than 1.0, but got {}.'.format(self._reserve_ratio)
LOGGER.error(TAG, msg)
raise ValueError(TAG, msg)

def generate(self, inputs, labels):
"""
Generate adversarial examples based on input samples and original labels.

Args:
inputs (numpy.ndarray): Input samples.
labels (numpy.ndarray): Original labels.
inputs (Union[numpy.ndarray, tuple]): Input samples. The format of inputs can be (inputs1, input2, ...) \
or only one array if model_type='detection'
labels (Union[numpy.ndarray, tuple]): Original labels. The format of labels should be \
(gt_boxes, gt_labels) if model_type='detection'.

Returns:
numpy.ndarray, adversarial examples.
@@ -88,67 +162,144 @@ class DeepFool(Attack):
Examples:
>>> advs = generate([[0.2, 0.3, 0.4], [0.3, 0.4, 0.5]], [1, 2])
"""
inputs, labels = check_pair_numpy_param('inputs', inputs,
'labels', labels)
if not self._sparse:
labels = np.argmax(labels, axis=1)
inputs_dtype = inputs.dtype
iteration = 0
origin_labels = labels
cur_labels = origin_labels.copy()
weight = np.squeeze(np.zeros(inputs.shape[1:]))
r_tot = np.zeros(inputs.shape)
x_origin = inputs
while np.any(cur_labels == origin_labels) and iteration < self._max_iters:
preds = self._network(Tensor(inputs)).asnumpy()
grads = jacobian_matrix(self._net_grad, inputs, self._num_classes)
for idx in range(inputs.shape[0]):
diff_w = np.inf
label = origin_labels[idx]
if cur_labels[idx] != label:
continue
for k in range(self._num_classes):
if k == label:
if self._model_type == 'detection':
images, auxiliary_inputs = inputs[0], inputs[1:]
gt_boxes, gt_labels = labels
_, gt_object_nums = _deepfool_detection_scores(inputs, gt_boxes, gt_labels, self._network)
if not self._sparse:
gt_labels = np.argmax(gt_labels, axis=2)
origin_labels = np.zeros(gt_labels.shape[0])
for i in range(gt_labels.shape[0]):
origin_labels[i] = np.argmax(np.bincount(gt_labels[i]))
images_dtype = images.dtype
iteration = 0
num_boxes = gt_labels.shape[1]
merge_net = _GetLogits(self._network)
detection_net_grad = GradWrap(merge_net)
weight = np.squeeze(np.zeros(images.shape[1:]))
r_tot = np.zeros(images.shape)
x_origin = images
while not _is_success((images,) + auxiliary_inputs, gt_boxes, gt_labels, self._network, gt_object_nums, \
self._reserve_ratio) and iteration < self._max_iters:
preds_logits = merge_net(*to_tensor_tuple(images), *to_tensor_tuple(auxiliary_inputs)).asnumpy()
grads = jacobian_matrix_for_detection(detection_net_grad, (images,) + auxiliary_inputs,
num_boxes, self._num_classes)
for idx in range(images.shape[0]):
diff_w = np.inf
label = int(origin_labels[idx])
auxiliary_input_i = tuple()
for item in auxiliary_inputs:
auxiliary_input_i += (np.expand_dims(item[idx], axis=0),)
gt_boxes_i = np.expand_dims(gt_boxes[idx], axis=0)
gt_labels_i = np.expand_dims(gt_labels[idx], axis=0)
inputs_i = (np.expand_dims(images[idx], axis=0),) + auxiliary_input_i
if _is_success(inputs_i, gt_boxes_i, gt_labels_i,
self._network, gt_object_nums[idx], self._reserve_ratio):
continue
for k in range(self._num_classes):
if k == label:
continue
w_k = grads[k, idx, ...] - grads[label, idx, ...]
f_k = np.mean(np.abs(preds_logits[idx, :, k, ...] - preds_logits[idx, :, label, ...]))
if self._norm_level == 2 or self._norm_level == '2':
diff_w_k = abs(f_k) / (np.linalg.norm(w_k) + 1e-8)
elif self._norm_level == np.inf \
or self._norm_level == 'inf':
diff_w_k = abs(f_k) / (np.linalg.norm(w_k, ord=1) + 1e-8)
else:
msg = 'ord {} is not available.' \
.format(str(self._norm_level))
LOGGER.error(TAG, msg)
raise NotImplementedError(msg)
if diff_w_k < diff_w:
diff_w = diff_w_k
weight = w_k
if self._norm_level == 2 or self._norm_level == '2':
r_i = diff_w*weight / (np.linalg.norm(weight) + 1e-8)
elif self._norm_level == np.inf or self._norm_level == 'inf':
r_i = diff_w*np.sign(weight) \
/ (np.linalg.norm(weight, ord=1) + 1e-8)
else:
msg = 'ord {} is not available in normalization,' \
.format(str(self._norm_level))
LOGGER.error(TAG, msg)
raise NotImplementedError(msg)
r_tot[idx, ...] = r_tot[idx, ...] + r_i

if self._bounds is not None:
clip_min, clip_max = self._bounds
images = x_origin + (1 + self._overshoot)*r_tot*(clip_max-clip_min)
images = np.clip(images, clip_min, clip_max)
else:
images = x_origin + (1 + self._overshoot)*r_tot
iteration += 1
images = images.astype(images_dtype)
del preds_logits, grads
return images

if self._model_type == 'classification':
inputs, labels = check_pair_numpy_param('inputs', inputs,
'labels', labels)
if not self._sparse:
labels = np.argmax(labels, axis=1)
inputs_dtype = inputs.dtype
iteration = 0
origin_labels = labels
cur_labels = origin_labels.copy()
weight = np.squeeze(np.zeros(inputs.shape[1:]))
r_tot = np.zeros(inputs.shape)
x_origin = inputs
while np.any(cur_labels == origin_labels) and iteration < self._max_iters:
preds = self._network(Tensor(inputs)).asnumpy()
grads = jacobian_matrix(self._net_grad, inputs, self._num_classes)
for idx in range(inputs.shape[0]):
diff_w = np.inf
label = origin_labels[idx]
if cur_labels[idx] != label:
continue
w_k = grads[k, idx, ...] - grads[label, idx, ...]
f_k = preds[idx, k] - preds[idx, label]
for k in range(self._num_classes):
if k == label:
continue
w_k = grads[k, idx, ...] - grads[label, idx, ...]
f_k = preds[idx, k] - preds[idx, label]
if self._norm_level == 2 or self._norm_level == '2':
diff_w_k = abs(f_k) / (np.linalg.norm(w_k) + 1e-8)
elif self._norm_level == np.inf \
or self._norm_level == 'inf':
diff_w_k = abs(f_k) / (np.linalg.norm(w_k, ord=1) + 1e-8)
else:
msg = 'ord {} is not available.' \
.format(str(self._norm_level))
LOGGER.error(TAG, msg)
raise NotImplementedError(msg)
if diff_w_k < diff_w:
diff_w = diff_w_k
weight = w_k

if self._norm_level == 2 or self._norm_level == '2':
diff_w_k = abs(f_k) / (np.linalg.norm(w_k) + 1e-8)
elif self._norm_level == np.inf \
or self._norm_level == 'inf':
diff_w_k = abs(f_k) / (np.linalg.norm(w_k, ord=1) + 1e-8)
r_i = diff_w*weight / (np.linalg.norm(weight) + 1e-8)
elif self._norm_level == np.inf or self._norm_level == 'inf':
r_i = diff_w*np.sign(weight) \
/ (np.linalg.norm(weight, ord=1) + 1e-8)
else:
msg = 'ord {} is not available.' \
msg = 'ord {} is not available in normalization.' \
.format(str(self._norm_level))
LOGGER.error(TAG, msg)
raise NotImplementedError(msg)
if diff_w_k < diff_w:
diff_w = diff_w_k
weight = w_k

if self._norm_level == 2 or self._norm_level == '2':
r_i = diff_w*weight / (np.linalg.norm(weight) + 1e-8)
elif self._norm_level == np.inf or self._norm_level == 'inf':
r_i = diff_w*np.sign(weight) \
/ (np.linalg.norm(weight, ord=1) + 1e-8)
r_tot[idx, ...] = r_tot[idx, ...] + r_i

if self._bounds is not None:
clip_min, clip_max = self._bounds
inputs = x_origin + (1 + self._overshoot)*r_tot*(clip_max
- clip_min)
inputs = np.clip(inputs, clip_min, clip_max)
else:
msg = 'ord {} is not available in normalization.' \
.format(str(self._norm_level))
LOGGER.error(TAG, msg)
raise NotImplementedError(msg)
r_tot[idx, ...] = r_tot[idx, ...] + r_i

if self._bounds is not None:
clip_min, clip_max = self._bounds
inputs = x_origin + (1 + self._overshoot)*r_tot*(clip_max
- clip_min)
inputs = np.clip(inputs, clip_min, clip_max)
else:
inputs = x_origin + (1 + self._overshoot)*r_tot
cur_labels = np.argmax(
self._network(Tensor(inputs.astype(inputs_dtype))).asnumpy(),
axis=1)
iteration += 1
inputs = inputs.astype(inputs_dtype)
del preds, grads
return inputs
inputs = x_origin + (1 + self._overshoot)*r_tot
cur_labels = np.argmax(
self._network(Tensor(inputs.astype(inputs_dtype))).asnumpy(),
axis=1)
iteration += 1
inputs = inputs.astype(inputs_dtype)
del preds, grads
return inputs
return None

+ 7
- 39
mindarmour/adv_robustness/attacks/gradient_method.py View File

@@ -18,12 +18,11 @@ from abc import abstractmethod

import numpy as np

from mindspore import Tensor
from mindspore.nn import Cell

from mindarmour.utils.util import WithLossCell, GradWrapWithLoss
from mindarmour.utils.util import WithLossCell, GradWrapWithLoss, to_tensor_tuple
from mindarmour.utils.logger import LogUtil
from mindarmour.utils._check_param import check_pair_numpy_param, check_model, \
from mindarmour.utils._check_param import check_model, check_inputs_labels, \
normalize_value, check_value_positive, check_param_multi_types, \
check_norm_level, check_param_type
from .attack import Attack
@@ -91,18 +90,7 @@ class GradientMethod(Attack):
Returns:
numpy.ndarray, generated adversarial examples.
"""
inputs_image = inputs[0] if isinstance(inputs, tuple) else inputs
if isinstance(inputs, tuple):
for i, inputs_item in enumerate(inputs):
_ = check_pair_numpy_param('inputs_image', inputs_image, \
'inputs[{}]'.format(i), inputs_item)
if isinstance(labels, tuple):
for i, labels_item in enumerate(labels):
_ = check_pair_numpy_param('inputs_image', inputs_image, \
'labels[{}]'.format(i), labels_item)
else:
_ = check_pair_numpy_param('inputs', inputs_image, \
'labels', labels)
inputs_image, inputs, labels = check_inputs_labels(inputs, labels)
self._dtype = inputs_image.dtype
gradient = self._gradient(inputs, labels)
# use random method or not
@@ -196,18 +184,8 @@ class FastGradientMethod(GradientMethod):
Returns:
numpy.ndarray, gradient of inputs.
"""
if isinstance(inputs, tuple):
inputs_tensor = tuple()
for item in inputs:
inputs_tensor += (Tensor(item),)
else:
inputs_tensor = (Tensor(inputs),)
if isinstance(labels, tuple):
labels_tensor = tuple()
for item in labels:
labels_tensor += (Tensor(item),)
else:
labels_tensor = (Tensor(labels),)
inputs_tensor = to_tensor_tuple(inputs)
labels_tensor = to_tensor_tuple(labels)
out_grad = self._grad_all(*inputs_tensor, *labels_tensor)
if isinstance(out_grad, tuple):
out_grad = out_grad[0]
@@ -315,18 +293,8 @@ class FastGradientSignMethod(GradientMethod):
Returns:
numpy.ndarray, gradient of inputs.
"""
if isinstance(inputs, tuple):
inputs_tensor = tuple()
for item in inputs:
inputs_tensor += (Tensor(item),)
else:
inputs_tensor = (Tensor(inputs),)
if isinstance(labels, tuple):
labels_tensor = tuple()
for item in labels:
labels_tensor += (Tensor(item),)
else:
labels_tensor = (Tensor(labels),)
inputs_tensor = to_tensor_tuple(inputs)
labels_tensor = to_tensor_tuple(labels)
out_grad = self._grad_all(*inputs_tensor, *labels_tensor)
if isinstance(out_grad, tuple):
out_grad = out_grad[0]


+ 7
- 51
mindarmour/adv_robustness/attacks/iterative_gradient_method.py View File

@@ -18,11 +18,10 @@ import numpy as np
from PIL import Image, ImageOps

from mindspore.nn import Cell
from mindspore import Tensor

from mindarmour.utils.logger import LogUtil
from mindarmour.utils.util import WithLossCell, GradWrapWithLoss
from mindarmour.utils._check_param import check_pair_numpy_param, \
from mindarmour.utils.util import WithLossCell, GradWrapWithLoss, to_tensor_tuple
from mindarmour.utils._check_param import check_inputs_labels, \
normalize_value, check_model, check_value_positive, check_int_positive, \
check_param_type, check_norm_level, check_param_multi_types
from .attack import Attack
@@ -223,18 +222,7 @@ class BasicIterativeMethod(IterativeGradientMethod):
>>> [[0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
>>> [0, 0, 0, 0, 0, 0, 1, 0, 0, 0]])
"""
inputs_image = inputs[0] if isinstance(inputs, tuple) else inputs
if isinstance(inputs, tuple):
for i, inputs_item in enumerate(inputs):
_ = check_pair_numpy_param('inputs_image', inputs_image, \
'inputs[{}]'.format(i), inputs_item)
if isinstance(labels, tuple):
for i, labels_item in enumerate(labels):
_ = check_pair_numpy_param('inputs_image', inputs_image, \
'labels[{}]'.format(i), labels_item)
else:
_ = check_pair_numpy_param('inputs', inputs_image, \
'labels', labels)
inputs_image, inputs, labels = check_inputs_labels(inputs, labels)
arr_x = inputs_image
if self._bounds is not None:
clip_min, clip_max = self._bounds
@@ -322,18 +310,7 @@ class MomentumIterativeMethod(IterativeGradientMethod):
>>> [[0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
>>> [0, 0, 0, 0, 0, 1, 0, 0, 0, 0]])
"""
inputs_image = inputs[0] if isinstance(inputs, tuple) else inputs
if isinstance(inputs, tuple):
for i, inputs_item in enumerate(inputs):
_ = check_pair_numpy_param('inputs_image', inputs_image, \
'inputs[{}]'.format(i), inputs_item)
if isinstance(labels, tuple):
for i, labels_item in enumerate(labels):
_ = check_pair_numpy_param('inputs_image', inputs_image, \
'labels[{}]'.format(i), labels_item)
else:
_ = check_pair_numpy_param('inputs', inputs_image, \
'labels', labels)
inputs_image, inputs, labels = check_inputs_labels(inputs, labels)
arr_x = inputs_image
momentum = 0
if self._bounds is not None:
@@ -392,18 +369,8 @@ class MomentumIterativeMethod(IterativeGradientMethod):
>>> [[0, 0, 0, 1, 0, 0, 0, 0, 0, 0])
"""
# get grad of loss over x
if isinstance(inputs, tuple):
inputs_tensor = tuple()
for item in inputs:
inputs_tensor += (Tensor(item),)
else:
inputs_tensor = (Tensor(inputs),)
if isinstance(labels, tuple):
labels_tensor = tuple()
for item in labels:
labels_tensor += (Tensor(item),)
else:
labels_tensor = (Tensor(labels),)
inputs_tensor = to_tensor_tuple(inputs)
labels_tensor = to_tensor_tuple(labels)
out_grad = self._loss_grad(*inputs_tensor, *labels_tensor)
if isinstance(out_grad, tuple):
out_grad = out_grad[0]
@@ -473,18 +440,7 @@ class ProjectedGradientDescent(BasicIterativeMethod):
>>> [[0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
>>> [1, 0, 0, 0, 0, 0, 0, 0, 0, 0]])
"""
inputs_image = inputs[0] if isinstance(inputs, tuple) else inputs
if isinstance(inputs, tuple):
for i, inputs_item in enumerate(inputs):
_ = check_pair_numpy_param('inputs_image', inputs_image, \
'inputs[{}]'.format(i), inputs_item)
if isinstance(labels, tuple):
for i, labels_item in enumerate(labels):
_ = check_pair_numpy_param('inputs_image', inputs_image, \
'labels[{}]'.format(i), labels_item)
else:
_ = check_pair_numpy_param('inputs', inputs_image, \
'labels', labels)
inputs_image, inputs, labels = check_inputs_labels(inputs, labels)
arr_x = inputs_image
if self._bounds is not None:
clip_min, clip_max = self._bounds


+ 19
- 0
mindarmour/utils/_check_param.py View File

@@ -327,3 +327,22 @@ def check_detection_inputs(inputs, labels):
LOGGER.error(TAG, msg)
raise ValueError(msg)
return images, auxiliary_inputs, gt_boxes, gt_labels


def check_inputs_labels(inputs, labels):
"""check inputs and labels is valid for white box method."""
_ = check_param_multi_types('inputs', inputs, (tuple, np.ndarray))
_ = check_param_multi_types('labels', labels, (tuple, np.ndarray))
inputs_image = inputs[0] if isinstance(inputs, tuple) else inputs
if isinstance(inputs, tuple):
for i, inputs_item in enumerate(inputs):
_ = check_pair_numpy_param('inputs_image', inputs_image, \
'inputs[{}]'.format(i), inputs_item)
if isinstance(labels, tuple):
for i, labels_item in enumerate(labels):
_ = check_pair_numpy_param('inputs', inputs_image, \
'labels[{}]'.format(i), labels_item)
else:
_ = check_pair_numpy_param('inputs', inputs_image, \
'labels', labels)
return inputs_image, inputs, labels

+ 56
- 6
mindarmour/utils/util.py View File

@@ -17,7 +17,7 @@ from mindspore import Tensor
from mindspore.nn import Cell
from mindspore.ops.composite import GradOperation

from mindarmour.utils._check_param import check_numpy_param
from mindarmour.utils._check_param import check_numpy_param, check_param_multi_types

from .logger import LogUtil

@@ -54,6 +54,44 @@ def jacobian_matrix(grad_wrap_net, inputs, num_classes):
return np.asarray(grads_matrix)


def jacobian_matrix_for_detection(grad_wrap_net, inputs, num_boxes, num_classes):
"""
Calculate the Jacobian matrix for inputs, specifically for object detection model.

Args:
grad_wrap_net (Cell): A network wrapped by GradWrap.
inputs (numpy.ndarray): Input samples.
num_boxes (int): Number of boxes infered by each image.
num_classes (int): Number of labels of model output.

Returns:
numpy.ndarray, the Jacobian matrix of inputs. (labels, batch_size, ...)

Raises:
ValueError: If grad_wrap_net is not a instance of class `GradWrap`.
"""
if not isinstance(grad_wrap_net, GradWrap):
msg = 'grad_wrap_net be and instance of class `GradWrap`.'
LOGGER.error(TAG, msg)
raise ValueError(msg)
grad_wrap_net.set_train()
grads_matrix = []
inputs_tensor = tuple()
if isinstance(inputs, tuple):
for item in inputs:
inputs_tensor += (Tensor(item),)
else:
inputs_tensor += (Tensor(inputs),)
for idx in range(num_classes):
batch_size = inputs[0].shape[0] if isinstance(inputs, tuple) else inputs.shape[0]
sens = np.zeros((batch_size, num_boxes, num_classes)).astype(np.float32)
sens[:, :, idx] = 1.0

grads = grad_wrap_net(*(inputs_tensor), Tensor(sens))
grads_matrix.append(grads.asnumpy())
return np.asarray(grads_matrix)


class WithLossCell(Cell):
"""
Wrap the network with loss function.
@@ -152,19 +190,19 @@ class GradWrap(Cell):
self.grad = GradOperation(get_all=False, sens_param=True)
self.network = network

def construct(self, inputs, weight):
def construct(self, *data):
"""
Compute jacobian matrix.

Args:
inputs (Tensor): Inputs of network.
weight (Tensor): Weight of each gradient, `weight` has the same
shape with labels.
data (Tensor): Data consists of inputs and weight. \
- inputs: Inputs of network. \
- weight: Weight of each gradient, 'weight' has the same shape with labels.

Returns:
Tensor, Jacobian matrix.
"""
gout = self.grad(self.network)(inputs, weight)
gout = self.grad(self.network)(*data)
return gout


@@ -199,3 +237,15 @@ def calculate_iou(box_i, box_j):
return 0
inner_area = (inner_right_line - inner_left_line)*(inner_top_line - inner_bottom_line)
return inner_area / (s_i + s_j - inner_area)


def to_tensor_tuple(inputs_ori):
"""Transfer inputs data into tensor type."""
inputs_ori = check_param_multi_types('inputs_ori', inputs_ori, [np.ndarray, tuple])
if isinstance(inputs_ori, tuple):
inputs_tensor = tuple()
for item in inputs_ori:
inputs_tensor += (Tensor(item),)
else:
inputs_tensor = (Tensor(inputs_ori),)
return inputs_tensor

+ 38
- 0
tests/ut/python/adv_robustness/attacks/test_deep_fool.py View File

@@ -54,6 +54,23 @@ class Net(Cell):
return out


class Net2(Cell):
"""
Construct the network of target model, specifically for detection model test case.

Examples:
>>> net = Net2()
"""
def __init__(self):
super(Net2, self).__init__()
self._softmax = P.Softmax()

def construct(self, inputs1, inputs2):
out1 = self._softmax(inputs2)
out2 = self._softmax(inputs1)
return out1, out2


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@@ -84,6 +101,27 @@ def test_deepfool_attack():
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_card
@pytest.mark.component_mindarmour
def test_deepfool_attack_detection():
"""
Deepfool-Attack test
"""
net = Net2()
inputs1_np = np.random.random((2, 10, 10)).astype(np.float32)
inputs2_np = np.random.random((2, 10, 5)).astype(np.float32)
gt_boxes = inputs1_np[:, :, 0: 5]
gt_labels = np.argmax(inputs1_np, axis=2)
num_classes = 10

attack = DeepFool(net, num_classes, model_type='detection', reserve_ratio=0.3,
bounds=(0.0, 1.0))
_ = attack.generate((inputs1_np, inputs2_np), (gt_boxes, gt_labels))


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_card
@pytest.mark.component_mindarmour
def test_deepfool_attack_inf():
"""
Deepfool-Attack test


Loading…
Cancel
Save