@@ -18,7 +18,7 @@ from abc import abstractmethod | |||
import numpy as np | |||
from mindarmour.utils._check_param import check_pair_numpy_param, \ | |||
from mindarmour.utils._check_param import check_inputs_labels, \ | |||
check_int_positive, check_equal_shape, check_numpy_param, check_model | |||
from mindarmour.utils.util import calculate_iou | |||
from mindarmour.utils.logger import LogUtil | |||
@@ -55,18 +55,7 @@ class Attack: | |||
>>> labels = np.array([3, 0]) | |||
>>> advs = attack.batch_generate(inputs, labels, batch_size=2) | |||
""" | |||
inputs_image = inputs[0] if isinstance(inputs, tuple) else inputs | |||
if isinstance(inputs, tuple): | |||
for i, inputs_item in enumerate(inputs): | |||
_ = check_pair_numpy_param('inputs_image', inputs_image, \ | |||
'inputs[{}]'.format(i), inputs_item) | |||
if isinstance(labels, tuple): | |||
for i, labels_item in enumerate(labels): | |||
_ = check_pair_numpy_param('inputs_image', inputs_image, \ | |||
'labels[{}]'.format(i), labels_item) | |||
else: | |||
_ = check_pair_numpy_param('inputs', inputs_image, \ | |||
'labels', labels) | |||
inputs_image, inputs, labels = check_inputs_labels(inputs, labels) | |||
arr_x = inputs | |||
arr_y = labels | |||
len_x = inputs_image.shape[0] | |||
@@ -20,16 +20,76 @@ from mindspore import Tensor | |||
from mindspore.nn import Cell | |||
from mindarmour.utils.logger import LogUtil | |||
from mindarmour.utils.util import GradWrap, jacobian_matrix | |||
from mindarmour.utils.util import GradWrap, jacobian_matrix, \ | |||
jacobian_matrix_for_detection, calculate_iou, to_tensor_tuple | |||
from mindarmour.utils._check_param import check_pair_numpy_param, check_model, \ | |||
check_value_positive, check_int_positive, check_norm_level, \ | |||
check_param_multi_types, check_param_type | |||
check_param_multi_types, check_param_type, check_value_non_negative | |||
from .attack import Attack | |||
LOGGER = LogUtil.get_instance() | |||
TAG = 'DeepFool' | |||
class _GetLogits(Cell): | |||
def __init__(self, network): | |||
super(_GetLogits, self).__init__() | |||
self._network = network | |||
def construct(self, *inputs): | |||
_, pre_logits = self._network(*inputs) | |||
return pre_logits | |||
def _deepfool_detection_scores(inputs, gt_boxes, gt_labels, network): | |||
""" | |||
Evaluate the detection result of inputs, specially for object detection models. | |||
Args: | |||
inputs (numpy.ndarray): Input samples. | |||
gt_boxes (numpy.ndarray): Ground-truth boxes of inputs. | |||
gt_labels (numpy.ndarray): Ground-truth labels of inputs. | |||
model (BlackModel): Target model. | |||
Returns: | |||
- numpy.ndarray, detection scores of inputs. | |||
- numpy.ndarray, the number of objects that are correctly detected. | |||
""" | |||
network = check_param_type('network', network, Cell) | |||
inputs_tensor = to_tensor_tuple(inputs) | |||
box_and_confi, pred_logits = network(*inputs_tensor) | |||
box_and_confi, pred_logits = box_and_confi.asnumpy(), pred_logits.asnumpy() | |||
pred_labels = np.argmax(pred_logits, axis=2) | |||
det_scores = [] | |||
correct_labels_num = [] | |||
gt_boxes_num = gt_boxes.shape[0] | |||
iou_thres = 0.5 | |||
for idx, (boxes, labels) in enumerate(zip(box_and_confi, pred_labels)): | |||
score = 0 | |||
box_num = boxes.shape[0] | |||
correct_label_flag = np.zeros(gt_labels.shape) | |||
gt_boxes_idx = gt_boxes[idx] | |||
gt_labels_idx = gt_labels[idx] | |||
for i in range(box_num): | |||
pred_box = boxes[i] | |||
max_iou_confi = 0 | |||
for j in range(gt_boxes_num): | |||
iou = calculate_iou(pred_box[:4], gt_boxes_idx[j][:4]) | |||
if labels[i] == gt_labels_idx[j] and iou > iou_thres: | |||
max_iou_confi = max(max_iou_confi, pred_box[-1] + iou) | |||
correct_label_flag[j] = 1 | |||
score += max_iou_confi | |||
det_scores.append(score) | |||
correct_labels_num.append(np.sum(correct_label_flag)) | |||
return np.array(det_scores), np.array(correct_labels_num) | |||
def _is_success(inputs, gt_boxes, gt_labels, network, gt_object_nums, reserve_ratio): | |||
_, correct_nums_adv = _deepfool_detection_scores(inputs, gt_boxes, gt_labels, network) | |||
return np.all(correct_nums_adv <= (gt_object_nums*reserve_ratio).astype(np.int32)) | |||
class DeepFool(Attack): | |||
""" | |||
DeepFool is an untargeted & iterative attack achieved by moving the benign | |||
@@ -56,8 +116,8 @@ class DeepFool(Attack): | |||
>>> attack = DeepFool(network) | |||
""" | |||
def __init__(self, network, num_classes, max_iters=50, overshoot=0.02, | |||
norm_level=2, bounds=None, sparse=True): | |||
def __init__(self, network, num_classes, model_type='classification', | |||
reserve_ratio=0.3, max_iters=50, overshoot=0.02, norm_level=2, bounds=None, sparse=True): | |||
super(DeepFool, self).__init__() | |||
self._network = check_model('network', network, Cell) | |||
self._network.set_grad(True) | |||
@@ -66,18 +126,32 @@ class DeepFool(Attack): | |||
self._norm_level = check_norm_level(norm_level) | |||
self._num_classes = check_int_positive('num_classes', num_classes) | |||
self._net_grad = GradWrap(self._network) | |||
self._bounds = check_param_multi_types('bounds', bounds, [list, tuple]) | |||
self._bounds = bounds | |||
if self._bounds is not None: | |||
self._bounds = check_param_multi_types('bounds', bounds, [list, tuple]) | |||
for b in self._bounds: | |||
_ = check_param_multi_types('bound', b, [int, float]) | |||
self._sparse = check_param_type('sparse', sparse, bool) | |||
for b in self._bounds: | |||
_ = check_param_multi_types('bound', b, [int, float]) | |||
self._model_type = check_param_type('model_type', model_type, str) | |||
if self._model_type not in ('classification', 'detection'): | |||
msg = "Only 'classification' or 'detection' is supported now, but got {}.".format(self._model_type) | |||
LOGGER.error(TAG, msg) | |||
raise ValueError(msg) | |||
self._reserve_ratio = check_value_non_negative('reserve_ratio', reserve_ratio) | |||
if self._reserve_ratio > 1: | |||
msg = 'reserve_ratio should be less than 1.0, but got {}.'.format(self._reserve_ratio) | |||
LOGGER.error(TAG, msg) | |||
raise ValueError(TAG, msg) | |||
def generate(self, inputs, labels): | |||
""" | |||
Generate adversarial examples based on input samples and original labels. | |||
Args: | |||
inputs (numpy.ndarray): Input samples. | |||
labels (numpy.ndarray): Original labels. | |||
inputs (Union[numpy.ndarray, tuple]): Input samples. The format of inputs can be (inputs1, input2, ...) \ | |||
or only one array if model_type='detection' | |||
labels (Union[numpy.ndarray, tuple]): Original labels. The format of labels should be \ | |||
(gt_boxes, gt_labels) if model_type='detection'. | |||
Returns: | |||
numpy.ndarray, adversarial examples. | |||
@@ -88,67 +162,144 @@ class DeepFool(Attack): | |||
Examples: | |||
>>> advs = generate([[0.2, 0.3, 0.4], [0.3, 0.4, 0.5]], [1, 2]) | |||
""" | |||
inputs, labels = check_pair_numpy_param('inputs', inputs, | |||
'labels', labels) | |||
if not self._sparse: | |||
labels = np.argmax(labels, axis=1) | |||
inputs_dtype = inputs.dtype | |||
iteration = 0 | |||
origin_labels = labels | |||
cur_labels = origin_labels.copy() | |||
weight = np.squeeze(np.zeros(inputs.shape[1:])) | |||
r_tot = np.zeros(inputs.shape) | |||
x_origin = inputs | |||
while np.any(cur_labels == origin_labels) and iteration < self._max_iters: | |||
preds = self._network(Tensor(inputs)).asnumpy() | |||
grads = jacobian_matrix(self._net_grad, inputs, self._num_classes) | |||
for idx in range(inputs.shape[0]): | |||
diff_w = np.inf | |||
label = origin_labels[idx] | |||
if cur_labels[idx] != label: | |||
continue | |||
for k in range(self._num_classes): | |||
if k == label: | |||
if self._model_type == 'detection': | |||
images, auxiliary_inputs = inputs[0], inputs[1:] | |||
gt_boxes, gt_labels = labels | |||
_, gt_object_nums = _deepfool_detection_scores(inputs, gt_boxes, gt_labels, self._network) | |||
if not self._sparse: | |||
gt_labels = np.argmax(gt_labels, axis=2) | |||
origin_labels = np.zeros(gt_labels.shape[0]) | |||
for i in range(gt_labels.shape[0]): | |||
origin_labels[i] = np.argmax(np.bincount(gt_labels[i])) | |||
images_dtype = images.dtype | |||
iteration = 0 | |||
num_boxes = gt_labels.shape[1] | |||
merge_net = _GetLogits(self._network) | |||
detection_net_grad = GradWrap(merge_net) | |||
weight = np.squeeze(np.zeros(images.shape[1:])) | |||
r_tot = np.zeros(images.shape) | |||
x_origin = images | |||
while not _is_success((images,) + auxiliary_inputs, gt_boxes, gt_labels, self._network, gt_object_nums, \ | |||
self._reserve_ratio) and iteration < self._max_iters: | |||
preds_logits = merge_net(*to_tensor_tuple(images), *to_tensor_tuple(auxiliary_inputs)).asnumpy() | |||
grads = jacobian_matrix_for_detection(detection_net_grad, (images,) + auxiliary_inputs, | |||
num_boxes, self._num_classes) | |||
for idx in range(images.shape[0]): | |||
diff_w = np.inf | |||
label = int(origin_labels[idx]) | |||
auxiliary_input_i = tuple() | |||
for item in auxiliary_inputs: | |||
auxiliary_input_i += (np.expand_dims(item[idx], axis=0),) | |||
gt_boxes_i = np.expand_dims(gt_boxes[idx], axis=0) | |||
gt_labels_i = np.expand_dims(gt_labels[idx], axis=0) | |||
inputs_i = (np.expand_dims(images[idx], axis=0),) + auxiliary_input_i | |||
if _is_success(inputs_i, gt_boxes_i, gt_labels_i, | |||
self._network, gt_object_nums[idx], self._reserve_ratio): | |||
continue | |||
for k in range(self._num_classes): | |||
if k == label: | |||
continue | |||
w_k = grads[k, idx, ...] - grads[label, idx, ...] | |||
f_k = np.mean(np.abs(preds_logits[idx, :, k, ...] - preds_logits[idx, :, label, ...])) | |||
if self._norm_level == 2 or self._norm_level == '2': | |||
diff_w_k = abs(f_k) / (np.linalg.norm(w_k) + 1e-8) | |||
elif self._norm_level == np.inf \ | |||
or self._norm_level == 'inf': | |||
diff_w_k = abs(f_k) / (np.linalg.norm(w_k, ord=1) + 1e-8) | |||
else: | |||
msg = 'ord {} is not available.' \ | |||
.format(str(self._norm_level)) | |||
LOGGER.error(TAG, msg) | |||
raise NotImplementedError(msg) | |||
if diff_w_k < diff_w: | |||
diff_w = diff_w_k | |||
weight = w_k | |||
if self._norm_level == 2 or self._norm_level == '2': | |||
r_i = diff_w*weight / (np.linalg.norm(weight) + 1e-8) | |||
elif self._norm_level == np.inf or self._norm_level == 'inf': | |||
r_i = diff_w*np.sign(weight) \ | |||
/ (np.linalg.norm(weight, ord=1) + 1e-8) | |||
else: | |||
msg = 'ord {} is not available in normalization,' \ | |||
.format(str(self._norm_level)) | |||
LOGGER.error(TAG, msg) | |||
raise NotImplementedError(msg) | |||
r_tot[idx, ...] = r_tot[idx, ...] + r_i | |||
if self._bounds is not None: | |||
clip_min, clip_max = self._bounds | |||
images = x_origin + (1 + self._overshoot)*r_tot*(clip_max-clip_min) | |||
images = np.clip(images, clip_min, clip_max) | |||
else: | |||
images = x_origin + (1 + self._overshoot)*r_tot | |||
iteration += 1 | |||
images = images.astype(images_dtype) | |||
del preds_logits, grads | |||
return images | |||
if self._model_type == 'classification': | |||
inputs, labels = check_pair_numpy_param('inputs', inputs, | |||
'labels', labels) | |||
if not self._sparse: | |||
labels = np.argmax(labels, axis=1) | |||
inputs_dtype = inputs.dtype | |||
iteration = 0 | |||
origin_labels = labels | |||
cur_labels = origin_labels.copy() | |||
weight = np.squeeze(np.zeros(inputs.shape[1:])) | |||
r_tot = np.zeros(inputs.shape) | |||
x_origin = inputs | |||
while np.any(cur_labels == origin_labels) and iteration < self._max_iters: | |||
preds = self._network(Tensor(inputs)).asnumpy() | |||
grads = jacobian_matrix(self._net_grad, inputs, self._num_classes) | |||
for idx in range(inputs.shape[0]): | |||
diff_w = np.inf | |||
label = origin_labels[idx] | |||
if cur_labels[idx] != label: | |||
continue | |||
w_k = grads[k, idx, ...] - grads[label, idx, ...] | |||
f_k = preds[idx, k] - preds[idx, label] | |||
for k in range(self._num_classes): | |||
if k == label: | |||
continue | |||
w_k = grads[k, idx, ...] - grads[label, idx, ...] | |||
f_k = preds[idx, k] - preds[idx, label] | |||
if self._norm_level == 2 or self._norm_level == '2': | |||
diff_w_k = abs(f_k) / (np.linalg.norm(w_k) + 1e-8) | |||
elif self._norm_level == np.inf \ | |||
or self._norm_level == 'inf': | |||
diff_w_k = abs(f_k) / (np.linalg.norm(w_k, ord=1) + 1e-8) | |||
else: | |||
msg = 'ord {} is not available.' \ | |||
.format(str(self._norm_level)) | |||
LOGGER.error(TAG, msg) | |||
raise NotImplementedError(msg) | |||
if diff_w_k < diff_w: | |||
diff_w = diff_w_k | |||
weight = w_k | |||
if self._norm_level == 2 or self._norm_level == '2': | |||
diff_w_k = abs(f_k) / (np.linalg.norm(w_k) + 1e-8) | |||
elif self._norm_level == np.inf \ | |||
or self._norm_level == 'inf': | |||
diff_w_k = abs(f_k) / (np.linalg.norm(w_k, ord=1) + 1e-8) | |||
r_i = diff_w*weight / (np.linalg.norm(weight) + 1e-8) | |||
elif self._norm_level == np.inf or self._norm_level == 'inf': | |||
r_i = diff_w*np.sign(weight) \ | |||
/ (np.linalg.norm(weight, ord=1) + 1e-8) | |||
else: | |||
msg = 'ord {} is not available.' \ | |||
msg = 'ord {} is not available in normalization.' \ | |||
.format(str(self._norm_level)) | |||
LOGGER.error(TAG, msg) | |||
raise NotImplementedError(msg) | |||
if diff_w_k < diff_w: | |||
diff_w = diff_w_k | |||
weight = w_k | |||
if self._norm_level == 2 or self._norm_level == '2': | |||
r_i = diff_w*weight / (np.linalg.norm(weight) + 1e-8) | |||
elif self._norm_level == np.inf or self._norm_level == 'inf': | |||
r_i = diff_w*np.sign(weight) \ | |||
/ (np.linalg.norm(weight, ord=1) + 1e-8) | |||
r_tot[idx, ...] = r_tot[idx, ...] + r_i | |||
if self._bounds is not None: | |||
clip_min, clip_max = self._bounds | |||
inputs = x_origin + (1 + self._overshoot)*r_tot*(clip_max | |||
- clip_min) | |||
inputs = np.clip(inputs, clip_min, clip_max) | |||
else: | |||
msg = 'ord {} is not available in normalization.' \ | |||
.format(str(self._norm_level)) | |||
LOGGER.error(TAG, msg) | |||
raise NotImplementedError(msg) | |||
r_tot[idx, ...] = r_tot[idx, ...] + r_i | |||
if self._bounds is not None: | |||
clip_min, clip_max = self._bounds | |||
inputs = x_origin + (1 + self._overshoot)*r_tot*(clip_max | |||
- clip_min) | |||
inputs = np.clip(inputs, clip_min, clip_max) | |||
else: | |||
inputs = x_origin + (1 + self._overshoot)*r_tot | |||
cur_labels = np.argmax( | |||
self._network(Tensor(inputs.astype(inputs_dtype))).asnumpy(), | |||
axis=1) | |||
iteration += 1 | |||
inputs = inputs.astype(inputs_dtype) | |||
del preds, grads | |||
return inputs | |||
inputs = x_origin + (1 + self._overshoot)*r_tot | |||
cur_labels = np.argmax( | |||
self._network(Tensor(inputs.astype(inputs_dtype))).asnumpy(), | |||
axis=1) | |||
iteration += 1 | |||
inputs = inputs.astype(inputs_dtype) | |||
del preds, grads | |||
return inputs | |||
return None |
@@ -18,12 +18,11 @@ from abc import abstractmethod | |||
import numpy as np | |||
from mindspore import Tensor | |||
from mindspore.nn import Cell | |||
from mindarmour.utils.util import WithLossCell, GradWrapWithLoss | |||
from mindarmour.utils.util import WithLossCell, GradWrapWithLoss, to_tensor_tuple | |||
from mindarmour.utils.logger import LogUtil | |||
from mindarmour.utils._check_param import check_pair_numpy_param, check_model, \ | |||
from mindarmour.utils._check_param import check_model, check_inputs_labels, \ | |||
normalize_value, check_value_positive, check_param_multi_types, \ | |||
check_norm_level, check_param_type | |||
from .attack import Attack | |||
@@ -91,18 +90,7 @@ class GradientMethod(Attack): | |||
Returns: | |||
numpy.ndarray, generated adversarial examples. | |||
""" | |||
inputs_image = inputs[0] if isinstance(inputs, tuple) else inputs | |||
if isinstance(inputs, tuple): | |||
for i, inputs_item in enumerate(inputs): | |||
_ = check_pair_numpy_param('inputs_image', inputs_image, \ | |||
'inputs[{}]'.format(i), inputs_item) | |||
if isinstance(labels, tuple): | |||
for i, labels_item in enumerate(labels): | |||
_ = check_pair_numpy_param('inputs_image', inputs_image, \ | |||
'labels[{}]'.format(i), labels_item) | |||
else: | |||
_ = check_pair_numpy_param('inputs', inputs_image, \ | |||
'labels', labels) | |||
inputs_image, inputs, labels = check_inputs_labels(inputs, labels) | |||
self._dtype = inputs_image.dtype | |||
gradient = self._gradient(inputs, labels) | |||
# use random method or not | |||
@@ -196,18 +184,8 @@ class FastGradientMethod(GradientMethod): | |||
Returns: | |||
numpy.ndarray, gradient of inputs. | |||
""" | |||
if isinstance(inputs, tuple): | |||
inputs_tensor = tuple() | |||
for item in inputs: | |||
inputs_tensor += (Tensor(item),) | |||
else: | |||
inputs_tensor = (Tensor(inputs),) | |||
if isinstance(labels, tuple): | |||
labels_tensor = tuple() | |||
for item in labels: | |||
labels_tensor += (Tensor(item),) | |||
else: | |||
labels_tensor = (Tensor(labels),) | |||
inputs_tensor = to_tensor_tuple(inputs) | |||
labels_tensor = to_tensor_tuple(labels) | |||
out_grad = self._grad_all(*inputs_tensor, *labels_tensor) | |||
if isinstance(out_grad, tuple): | |||
out_grad = out_grad[0] | |||
@@ -315,18 +293,8 @@ class FastGradientSignMethod(GradientMethod): | |||
Returns: | |||
numpy.ndarray, gradient of inputs. | |||
""" | |||
if isinstance(inputs, tuple): | |||
inputs_tensor = tuple() | |||
for item in inputs: | |||
inputs_tensor += (Tensor(item),) | |||
else: | |||
inputs_tensor = (Tensor(inputs),) | |||
if isinstance(labels, tuple): | |||
labels_tensor = tuple() | |||
for item in labels: | |||
labels_tensor += (Tensor(item),) | |||
else: | |||
labels_tensor = (Tensor(labels),) | |||
inputs_tensor = to_tensor_tuple(inputs) | |||
labels_tensor = to_tensor_tuple(labels) | |||
out_grad = self._grad_all(*inputs_tensor, *labels_tensor) | |||
if isinstance(out_grad, tuple): | |||
out_grad = out_grad[0] | |||
@@ -18,11 +18,10 @@ import numpy as np | |||
from PIL import Image, ImageOps | |||
from mindspore.nn import Cell | |||
from mindspore import Tensor | |||
from mindarmour.utils.logger import LogUtil | |||
from mindarmour.utils.util import WithLossCell, GradWrapWithLoss | |||
from mindarmour.utils._check_param import check_pair_numpy_param, \ | |||
from mindarmour.utils.util import WithLossCell, GradWrapWithLoss, to_tensor_tuple | |||
from mindarmour.utils._check_param import check_inputs_labels, \ | |||
normalize_value, check_model, check_value_positive, check_int_positive, \ | |||
check_param_type, check_norm_level, check_param_multi_types | |||
from .attack import Attack | |||
@@ -223,18 +222,7 @@ class BasicIterativeMethod(IterativeGradientMethod): | |||
>>> [[0, 0, 1, 0, 0, 0, 0, 0, 0, 0], | |||
>>> [0, 0, 0, 0, 0, 0, 1, 0, 0, 0]]) | |||
""" | |||
inputs_image = inputs[0] if isinstance(inputs, tuple) else inputs | |||
if isinstance(inputs, tuple): | |||
for i, inputs_item in enumerate(inputs): | |||
_ = check_pair_numpy_param('inputs_image', inputs_image, \ | |||
'inputs[{}]'.format(i), inputs_item) | |||
if isinstance(labels, tuple): | |||
for i, labels_item in enumerate(labels): | |||
_ = check_pair_numpy_param('inputs_image', inputs_image, \ | |||
'labels[{}]'.format(i), labels_item) | |||
else: | |||
_ = check_pair_numpy_param('inputs', inputs_image, \ | |||
'labels', labels) | |||
inputs_image, inputs, labels = check_inputs_labels(inputs, labels) | |||
arr_x = inputs_image | |||
if self._bounds is not None: | |||
clip_min, clip_max = self._bounds | |||
@@ -322,18 +310,7 @@ class MomentumIterativeMethod(IterativeGradientMethod): | |||
>>> [[0, 0, 0, 0, 0, 0, 0, 0, 1, 0], | |||
>>> [0, 0, 0, 0, 0, 1, 0, 0, 0, 0]]) | |||
""" | |||
inputs_image = inputs[0] if isinstance(inputs, tuple) else inputs | |||
if isinstance(inputs, tuple): | |||
for i, inputs_item in enumerate(inputs): | |||
_ = check_pair_numpy_param('inputs_image', inputs_image, \ | |||
'inputs[{}]'.format(i), inputs_item) | |||
if isinstance(labels, tuple): | |||
for i, labels_item in enumerate(labels): | |||
_ = check_pair_numpy_param('inputs_image', inputs_image, \ | |||
'labels[{}]'.format(i), labels_item) | |||
else: | |||
_ = check_pair_numpy_param('inputs', inputs_image, \ | |||
'labels', labels) | |||
inputs_image, inputs, labels = check_inputs_labels(inputs, labels) | |||
arr_x = inputs_image | |||
momentum = 0 | |||
if self._bounds is not None: | |||
@@ -392,18 +369,8 @@ class MomentumIterativeMethod(IterativeGradientMethod): | |||
>>> [[0, 0, 0, 1, 0, 0, 0, 0, 0, 0]) | |||
""" | |||
# get grad of loss over x | |||
if isinstance(inputs, tuple): | |||
inputs_tensor = tuple() | |||
for item in inputs: | |||
inputs_tensor += (Tensor(item),) | |||
else: | |||
inputs_tensor = (Tensor(inputs),) | |||
if isinstance(labels, tuple): | |||
labels_tensor = tuple() | |||
for item in labels: | |||
labels_tensor += (Tensor(item),) | |||
else: | |||
labels_tensor = (Tensor(labels),) | |||
inputs_tensor = to_tensor_tuple(inputs) | |||
labels_tensor = to_tensor_tuple(labels) | |||
out_grad = self._loss_grad(*inputs_tensor, *labels_tensor) | |||
if isinstance(out_grad, tuple): | |||
out_grad = out_grad[0] | |||
@@ -473,18 +440,7 @@ class ProjectedGradientDescent(BasicIterativeMethod): | |||
>>> [[0, 0, 0, 0, 0, 0, 0, 0, 0, 1], | |||
>>> [1, 0, 0, 0, 0, 0, 0, 0, 0, 0]]) | |||
""" | |||
inputs_image = inputs[0] if isinstance(inputs, tuple) else inputs | |||
if isinstance(inputs, tuple): | |||
for i, inputs_item in enumerate(inputs): | |||
_ = check_pair_numpy_param('inputs_image', inputs_image, \ | |||
'inputs[{}]'.format(i), inputs_item) | |||
if isinstance(labels, tuple): | |||
for i, labels_item in enumerate(labels): | |||
_ = check_pair_numpy_param('inputs_image', inputs_image, \ | |||
'labels[{}]'.format(i), labels_item) | |||
else: | |||
_ = check_pair_numpy_param('inputs', inputs_image, \ | |||
'labels', labels) | |||
inputs_image, inputs, labels = check_inputs_labels(inputs, labels) | |||
arr_x = inputs_image | |||
if self._bounds is not None: | |||
clip_min, clip_max = self._bounds | |||
@@ -327,3 +327,22 @@ def check_detection_inputs(inputs, labels): | |||
LOGGER.error(TAG, msg) | |||
raise ValueError(msg) | |||
return images, auxiliary_inputs, gt_boxes, gt_labels | |||
def check_inputs_labels(inputs, labels): | |||
"""check inputs and labels is valid for white box method.""" | |||
_ = check_param_multi_types('inputs', inputs, (tuple, np.ndarray)) | |||
_ = check_param_multi_types('labels', labels, (tuple, np.ndarray)) | |||
inputs_image = inputs[0] if isinstance(inputs, tuple) else inputs | |||
if isinstance(inputs, tuple): | |||
for i, inputs_item in enumerate(inputs): | |||
_ = check_pair_numpy_param('inputs_image', inputs_image, \ | |||
'inputs[{}]'.format(i), inputs_item) | |||
if isinstance(labels, tuple): | |||
for i, labels_item in enumerate(labels): | |||
_ = check_pair_numpy_param('inputs', inputs_image, \ | |||
'labels[{}]'.format(i), labels_item) | |||
else: | |||
_ = check_pair_numpy_param('inputs', inputs_image, \ | |||
'labels', labels) | |||
return inputs_image, inputs, labels |
@@ -17,7 +17,7 @@ from mindspore import Tensor | |||
from mindspore.nn import Cell | |||
from mindspore.ops.composite import GradOperation | |||
from mindarmour.utils._check_param import check_numpy_param | |||
from mindarmour.utils._check_param import check_numpy_param, check_param_multi_types | |||
from .logger import LogUtil | |||
@@ -54,6 +54,44 @@ def jacobian_matrix(grad_wrap_net, inputs, num_classes): | |||
return np.asarray(grads_matrix) | |||
def jacobian_matrix_for_detection(grad_wrap_net, inputs, num_boxes, num_classes): | |||
""" | |||
Calculate the Jacobian matrix for inputs, specifically for object detection model. | |||
Args: | |||
grad_wrap_net (Cell): A network wrapped by GradWrap. | |||
inputs (numpy.ndarray): Input samples. | |||
num_boxes (int): Number of boxes infered by each image. | |||
num_classes (int): Number of labels of model output. | |||
Returns: | |||
numpy.ndarray, the Jacobian matrix of inputs. (labels, batch_size, ...) | |||
Raises: | |||
ValueError: If grad_wrap_net is not a instance of class `GradWrap`. | |||
""" | |||
if not isinstance(grad_wrap_net, GradWrap): | |||
msg = 'grad_wrap_net be and instance of class `GradWrap`.' | |||
LOGGER.error(TAG, msg) | |||
raise ValueError(msg) | |||
grad_wrap_net.set_train() | |||
grads_matrix = [] | |||
inputs_tensor = tuple() | |||
if isinstance(inputs, tuple): | |||
for item in inputs: | |||
inputs_tensor += (Tensor(item),) | |||
else: | |||
inputs_tensor += (Tensor(inputs),) | |||
for idx in range(num_classes): | |||
batch_size = inputs[0].shape[0] if isinstance(inputs, tuple) else inputs.shape[0] | |||
sens = np.zeros((batch_size, num_boxes, num_classes)).astype(np.float32) | |||
sens[:, :, idx] = 1.0 | |||
grads = grad_wrap_net(*(inputs_tensor), Tensor(sens)) | |||
grads_matrix.append(grads.asnumpy()) | |||
return np.asarray(grads_matrix) | |||
class WithLossCell(Cell): | |||
""" | |||
Wrap the network with loss function. | |||
@@ -152,19 +190,19 @@ class GradWrap(Cell): | |||
self.grad = GradOperation(get_all=False, sens_param=True) | |||
self.network = network | |||
def construct(self, inputs, weight): | |||
def construct(self, *data): | |||
""" | |||
Compute jacobian matrix. | |||
Args: | |||
inputs (Tensor): Inputs of network. | |||
weight (Tensor): Weight of each gradient, `weight` has the same | |||
shape with labels. | |||
data (Tensor): Data consists of inputs and weight. \ | |||
- inputs: Inputs of network. \ | |||
- weight: Weight of each gradient, 'weight' has the same shape with labels. | |||
Returns: | |||
Tensor, Jacobian matrix. | |||
""" | |||
gout = self.grad(self.network)(inputs, weight) | |||
gout = self.grad(self.network)(*data) | |||
return gout | |||
@@ -199,3 +237,15 @@ def calculate_iou(box_i, box_j): | |||
return 0 | |||
inner_area = (inner_right_line - inner_left_line)*(inner_top_line - inner_bottom_line) | |||
return inner_area / (s_i + s_j - inner_area) | |||
def to_tensor_tuple(inputs_ori): | |||
"""Transfer inputs data into tensor type.""" | |||
inputs_ori = check_param_multi_types('inputs_ori', inputs_ori, [np.ndarray, tuple]) | |||
if isinstance(inputs_ori, tuple): | |||
inputs_tensor = tuple() | |||
for item in inputs_ori: | |||
inputs_tensor += (Tensor(item),) | |||
else: | |||
inputs_tensor = (Tensor(inputs_ori),) | |||
return inputs_tensor |
@@ -54,6 +54,23 @@ class Net(Cell): | |||
return out | |||
class Net2(Cell): | |||
""" | |||
Construct the network of target model, specifically for detection model test case. | |||
Examples: | |||
>>> net = Net2() | |||
""" | |||
def __init__(self): | |||
super(Net2, self).__init__() | |||
self._softmax = P.Softmax() | |||
def construct(self, inputs1, inputs2): | |||
out1 = self._softmax(inputs2) | |||
out2 = self._softmax(inputs1) | |||
return out1, out2 | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.platform_x86_ascend_training | |||
@@ -84,6 +101,27 @@ def test_deepfool_attack(): | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_card | |||
@pytest.mark.component_mindarmour | |||
def test_deepfool_attack_detection(): | |||
""" | |||
Deepfool-Attack test | |||
""" | |||
net = Net2() | |||
inputs1_np = np.random.random((2, 10, 10)).astype(np.float32) | |||
inputs2_np = np.random.random((2, 10, 5)).astype(np.float32) | |||
gt_boxes = inputs1_np[:, :, 0: 5] | |||
gt_labels = np.argmax(inputs1_np, axis=2) | |||
num_classes = 10 | |||
attack = DeepFool(net, num_classes, model_type='detection', reserve_ratio=0.3, | |||
bounds=(0.0, 1.0)) | |||
_ = attack.generate((inputs1_np, inputs2_np), (gt_boxes, gt_labels)) | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_card | |||
@pytest.mark.component_mindarmour | |||
def test_deepfool_attack_inf(): | |||
""" | |||
Deepfool-Attack test | |||