From 8fded5f4ff053c93e44934116473cb3e5269be79 Mon Sep 17 00:00:00 2001 From: liuluobin Date: Wed, 18 Nov 2020 10:03:09 +0800 Subject: [PATCH] Extend Deepfool to object detection models --- mindarmour/adv_robustness/attacks/attack.py | 15 +- mindarmour/adv_robustness/attacks/deep_fool.py | 283 ++++++++++++++++----- .../adv_robustness/attacks/gradient_method.py | 46 +--- .../attacks/iterative_gradient_method.py | 58 +---- mindarmour/utils/_check_param.py | 19 ++ mindarmour/utils/util.py | 62 ++++- .../adv_robustness/attacks/test_deep_fool.py | 38 +++ 7 files changed, 346 insertions(+), 175 deletions(-) diff --git a/mindarmour/adv_robustness/attacks/attack.py b/mindarmour/adv_robustness/attacks/attack.py index 149941f..1c97a2a 100644 --- a/mindarmour/adv_robustness/attacks/attack.py +++ b/mindarmour/adv_robustness/attacks/attack.py @@ -18,7 +18,7 @@ from abc import abstractmethod import numpy as np -from mindarmour.utils._check_param import check_pair_numpy_param, \ +from mindarmour.utils._check_param import check_inputs_labels, \ check_int_positive, check_equal_shape, check_numpy_param, check_model from mindarmour.utils.util import calculate_iou from mindarmour.utils.logger import LogUtil @@ -55,18 +55,7 @@ class Attack: >>> labels = np.array([3, 0]) >>> advs = attack.batch_generate(inputs, labels, batch_size=2) """ - inputs_image = inputs[0] if isinstance(inputs, tuple) else inputs - if isinstance(inputs, tuple): - for i, inputs_item in enumerate(inputs): - _ = check_pair_numpy_param('inputs_image', inputs_image, \ - 'inputs[{}]'.format(i), inputs_item) - if isinstance(labels, tuple): - for i, labels_item in enumerate(labels): - _ = check_pair_numpy_param('inputs_image', inputs_image, \ - 'labels[{}]'.format(i), labels_item) - else: - _ = check_pair_numpy_param('inputs', inputs_image, \ - 'labels', labels) + inputs_image, inputs, labels = check_inputs_labels(inputs, labels) arr_x = inputs arr_y = labels len_x = inputs_image.shape[0] diff --git a/mindarmour/adv_robustness/attacks/deep_fool.py b/mindarmour/adv_robustness/attacks/deep_fool.py index dc3a556..67eaba5 100644 --- a/mindarmour/adv_robustness/attacks/deep_fool.py +++ b/mindarmour/adv_robustness/attacks/deep_fool.py @@ -20,16 +20,76 @@ from mindspore import Tensor from mindspore.nn import Cell from mindarmour.utils.logger import LogUtil -from mindarmour.utils.util import GradWrap, jacobian_matrix +from mindarmour.utils.util import GradWrap, jacobian_matrix, \ + jacobian_matrix_for_detection, calculate_iou, to_tensor_tuple from mindarmour.utils._check_param import check_pair_numpy_param, check_model, \ check_value_positive, check_int_positive, check_norm_level, \ - check_param_multi_types, check_param_type + check_param_multi_types, check_param_type, check_value_non_negative from .attack import Attack LOGGER = LogUtil.get_instance() TAG = 'DeepFool' +class _GetLogits(Cell): + def __init__(self, network): + super(_GetLogits, self).__init__() + self._network = network + + def construct(self, *inputs): + _, pre_logits = self._network(*inputs) + return pre_logits + + +def _deepfool_detection_scores(inputs, gt_boxes, gt_labels, network): + """ + Evaluate the detection result of inputs, specially for object detection models. + + Args: + inputs (numpy.ndarray): Input samples. + gt_boxes (numpy.ndarray): Ground-truth boxes of inputs. + gt_labels (numpy.ndarray): Ground-truth labels of inputs. + model (BlackModel): Target model. + + Returns: + - numpy.ndarray, detection scores of inputs. + + - numpy.ndarray, the number of objects that are correctly detected. + """ + network = check_param_type('network', network, Cell) + inputs_tensor = to_tensor_tuple(inputs) + box_and_confi, pred_logits = network(*inputs_tensor) + box_and_confi, pred_logits = box_and_confi.asnumpy(), pred_logits.asnumpy() + pred_labels = np.argmax(pred_logits, axis=2) + det_scores = [] + correct_labels_num = [] + gt_boxes_num = gt_boxes.shape[0] + iou_thres = 0.5 + for idx, (boxes, labels) in enumerate(zip(box_and_confi, pred_labels)): + score = 0 + box_num = boxes.shape[0] + correct_label_flag = np.zeros(gt_labels.shape) + gt_boxes_idx = gt_boxes[idx] + gt_labels_idx = gt_labels[idx] + for i in range(box_num): + pred_box = boxes[i] + max_iou_confi = 0 + for j in range(gt_boxes_num): + iou = calculate_iou(pred_box[:4], gt_boxes_idx[j][:4]) + if labels[i] == gt_labels_idx[j] and iou > iou_thres: + max_iou_confi = max(max_iou_confi, pred_box[-1] + iou) + correct_label_flag[j] = 1 + score += max_iou_confi + det_scores.append(score) + correct_labels_num.append(np.sum(correct_label_flag)) + return np.array(det_scores), np.array(correct_labels_num) + + +def _is_success(inputs, gt_boxes, gt_labels, network, gt_object_nums, reserve_ratio): + _, correct_nums_adv = _deepfool_detection_scores(inputs, gt_boxes, gt_labels, network) + return np.all(correct_nums_adv <= (gt_object_nums*reserve_ratio).astype(np.int32)) + + class DeepFool(Attack): """ DeepFool is an untargeted & iterative attack achieved by moving the benign @@ -56,8 +116,8 @@ class DeepFool(Attack): >>> attack = DeepFool(network) """ - def __init__(self, network, num_classes, max_iters=50, overshoot=0.02, - norm_level=2, bounds=None, sparse=True): + def __init__(self, network, num_classes, model_type='classification', + reserve_ratio=0.3, max_iters=50, overshoot=0.02, norm_level=2, bounds=None, sparse=True): super(DeepFool, self).__init__() self._network = check_model('network', network, Cell) self._network.set_grad(True) @@ -66,18 +126,32 @@ class DeepFool(Attack): self._norm_level = check_norm_level(norm_level) self._num_classes = check_int_positive('num_classes', num_classes) self._net_grad = GradWrap(self._network) - self._bounds = check_param_multi_types('bounds', bounds, [list, tuple]) + self._bounds = bounds + if self._bounds is not None: + self._bounds = check_param_multi_types('bounds', bounds, [list, tuple]) + for b in self._bounds: + _ = check_param_multi_types('bound', b, [int, float]) self._sparse = check_param_type('sparse', sparse, bool) - for b in self._bounds: - _ = check_param_multi_types('bound', b, [int, float]) + self._model_type = check_param_type('model_type', model_type, str) + if self._model_type not in ('classification', 'detection'): + msg = "Only 'classification' or 'detection' is supported now, but got {}.".format(self._model_type) + LOGGER.error(TAG, msg) + raise ValueError(msg) + self._reserve_ratio = check_value_non_negative('reserve_ratio', reserve_ratio) + if self._reserve_ratio > 1: + msg = 'reserve_ratio should be less than 1.0, but got {}.'.format(self._reserve_ratio) + LOGGER.error(TAG, msg) + raise ValueError(TAG, msg) def generate(self, inputs, labels): """ Generate adversarial examples based on input samples and original labels. Args: - inputs (numpy.ndarray): Input samples. - labels (numpy.ndarray): Original labels. + inputs (Union[numpy.ndarray, tuple]): Input samples. The format of inputs can be (inputs1, input2, ...) \ + or only one array if model_type='detection' + labels (Union[numpy.ndarray, tuple]): Original labels. The format of labels should be \ + (gt_boxes, gt_labels) if model_type='detection'. Returns: numpy.ndarray, adversarial examples. @@ -88,67 +162,144 @@ class DeepFool(Attack): Examples: >>> advs = generate([[0.2, 0.3, 0.4], [0.3, 0.4, 0.5]], [1, 2]) """ - inputs, labels = check_pair_numpy_param('inputs', inputs, - 'labels', labels) - if not self._sparse: - labels = np.argmax(labels, axis=1) - inputs_dtype = inputs.dtype - iteration = 0 - origin_labels = labels - cur_labels = origin_labels.copy() - weight = np.squeeze(np.zeros(inputs.shape[1:])) - r_tot = np.zeros(inputs.shape) - x_origin = inputs - while np.any(cur_labels == origin_labels) and iteration < self._max_iters: - preds = self._network(Tensor(inputs)).asnumpy() - grads = jacobian_matrix(self._net_grad, inputs, self._num_classes) - for idx in range(inputs.shape[0]): - diff_w = np.inf - label = origin_labels[idx] - if cur_labels[idx] != label: - continue - for k in range(self._num_classes): - if k == label: + if self._model_type == 'detection': + images, auxiliary_inputs = inputs[0], inputs[1:] + gt_boxes, gt_labels = labels + _, gt_object_nums = _deepfool_detection_scores(inputs, gt_boxes, gt_labels, self._network) + if not self._sparse: + gt_labels = np.argmax(gt_labels, axis=2) + origin_labels = np.zeros(gt_labels.shape[0]) + for i in range(gt_labels.shape[0]): + origin_labels[i] = np.argmax(np.bincount(gt_labels[i])) + images_dtype = images.dtype + iteration = 0 + num_boxes = gt_labels.shape[1] + merge_net = _GetLogits(self._network) + detection_net_grad = GradWrap(merge_net) + weight = np.squeeze(np.zeros(images.shape[1:])) + r_tot = np.zeros(images.shape) + x_origin = images + while not _is_success((images,) + auxiliary_inputs, gt_boxes, gt_labels, self._network, gt_object_nums, \ + self._reserve_ratio) and iteration < self._max_iters: + preds_logits = merge_net(*to_tensor_tuple(images), *to_tensor_tuple(auxiliary_inputs)).asnumpy() + grads = jacobian_matrix_for_detection(detection_net_grad, (images,) + auxiliary_inputs, + num_boxes, self._num_classes) + for idx in range(images.shape[0]): + diff_w = np.inf + label = int(origin_labels[idx]) + auxiliary_input_i = tuple() + for item in auxiliary_inputs: + auxiliary_input_i += (np.expand_dims(item[idx], axis=0),) + gt_boxes_i = np.expand_dims(gt_boxes[idx], axis=0) + gt_labels_i = np.expand_dims(gt_labels[idx], axis=0) + inputs_i = (np.expand_dims(images[idx], axis=0),) + auxiliary_input_i + if _is_success(inputs_i, gt_boxes_i, gt_labels_i, + self._network, gt_object_nums[idx], self._reserve_ratio): + continue + for k in range(self._num_classes): + if k == label: + continue + w_k = grads[k, idx, ...] - grads[label, idx, ...] + f_k = np.mean(np.abs(preds_logits[idx, :, k, ...] - preds_logits[idx, :, label, ...])) + if self._norm_level == 2 or self._norm_level == '2': + diff_w_k = abs(f_k) / (np.linalg.norm(w_k) + 1e-8) + elif self._norm_level == np.inf \ + or self._norm_level == 'inf': + diff_w_k = abs(f_k) / (np.linalg.norm(w_k, ord=1) + 1e-8) + else: + msg = 'ord {} is not available.' \ + .format(str(self._norm_level)) + LOGGER.error(TAG, msg) + raise NotImplementedError(msg) + if diff_w_k < diff_w: + diff_w = diff_w_k + weight = w_k + if self._norm_level == 2 or self._norm_level == '2': + r_i = diff_w*weight / (np.linalg.norm(weight) + 1e-8) + elif self._norm_level == np.inf or self._norm_level == 'inf': + r_i = diff_w*np.sign(weight) \ + / (np.linalg.norm(weight, ord=1) + 1e-8) + else: + msg = 'ord {} is not available in normalization,' \ + .format(str(self._norm_level)) + LOGGER.error(TAG, msg) + raise NotImplementedError(msg) + r_tot[idx, ...] = r_tot[idx, ...] + r_i + + if self._bounds is not None: + clip_min, clip_max = self._bounds + images = x_origin + (1 + self._overshoot)*r_tot*(clip_max-clip_min) + images = np.clip(images, clip_min, clip_max) + else: + images = x_origin + (1 + self._overshoot)*r_tot + iteration += 1 + images = images.astype(images_dtype) + del preds_logits, grads + return images + + if self._model_type == 'classification': + inputs, labels = check_pair_numpy_param('inputs', inputs, + 'labels', labels) + if not self._sparse: + labels = np.argmax(labels, axis=1) + inputs_dtype = inputs.dtype + iteration = 0 + origin_labels = labels + cur_labels = origin_labels.copy() + weight = np.squeeze(np.zeros(inputs.shape[1:])) + r_tot = np.zeros(inputs.shape) + x_origin = inputs + while np.any(cur_labels == origin_labels) and iteration < self._max_iters: + preds = self._network(Tensor(inputs)).asnumpy() + grads = jacobian_matrix(self._net_grad, inputs, self._num_classes) + for idx in range(inputs.shape[0]): + diff_w = np.inf + label = origin_labels[idx] + if cur_labels[idx] != label: continue - w_k = grads[k, idx, ...] - grads[label, idx, ...] - f_k = preds[idx, k] - preds[idx, label] + for k in range(self._num_classes): + if k == label: + continue + w_k = grads[k, idx, ...] - grads[label, idx, ...] + f_k = preds[idx, k] - preds[idx, label] + if self._norm_level == 2 or self._norm_level == '2': + diff_w_k = abs(f_k) / (np.linalg.norm(w_k) + 1e-8) + elif self._norm_level == np.inf \ + or self._norm_level == 'inf': + diff_w_k = abs(f_k) / (np.linalg.norm(w_k, ord=1) + 1e-8) + else: + msg = 'ord {} is not available.' \ + .format(str(self._norm_level)) + LOGGER.error(TAG, msg) + raise NotImplementedError(msg) + if diff_w_k < diff_w: + diff_w = diff_w_k + weight = w_k + if self._norm_level == 2 or self._norm_level == '2': - diff_w_k = abs(f_k) / (np.linalg.norm(w_k) + 1e-8) - elif self._norm_level == np.inf \ - or self._norm_level == 'inf': - diff_w_k = abs(f_k) / (np.linalg.norm(w_k, ord=1) + 1e-8) + r_i = diff_w*weight / (np.linalg.norm(weight) + 1e-8) + elif self._norm_level == np.inf or self._norm_level == 'inf': + r_i = diff_w*np.sign(weight) \ + / (np.linalg.norm(weight, ord=1) + 1e-8) else: - msg = 'ord {} is not available.' \ + msg = 'ord {} is not available in normalization.' \ .format(str(self._norm_level)) LOGGER.error(TAG, msg) raise NotImplementedError(msg) - if diff_w_k < diff_w: - diff_w = diff_w_k - weight = w_k - - if self._norm_level == 2 or self._norm_level == '2': - r_i = diff_w*weight / (np.linalg.norm(weight) + 1e-8) - elif self._norm_level == np.inf or self._norm_level == 'inf': - r_i = diff_w*np.sign(weight) \ - / (np.linalg.norm(weight, ord=1) + 1e-8) + r_tot[idx, ...] = r_tot[idx, ...] + r_i + + if self._bounds is not None: + clip_min, clip_max = self._bounds + inputs = x_origin + (1 + self._overshoot)*r_tot*(clip_max + - clip_min) + inputs = np.clip(inputs, clip_min, clip_max) else: - msg = 'ord {} is not available in normalization.' \ - .format(str(self._norm_level)) - LOGGER.error(TAG, msg) - raise NotImplementedError(msg) - r_tot[idx, ...] = r_tot[idx, ...] + r_i - - if self._bounds is not None: - clip_min, clip_max = self._bounds - inputs = x_origin + (1 + self._overshoot)*r_tot*(clip_max - - clip_min) - inputs = np.clip(inputs, clip_min, clip_max) - else: - inputs = x_origin + (1 + self._overshoot)*r_tot - cur_labels = np.argmax( - self._network(Tensor(inputs.astype(inputs_dtype))).asnumpy(), - axis=1) - iteration += 1 - inputs = inputs.astype(inputs_dtype) - del preds, grads - return inputs + inputs = x_origin + (1 + self._overshoot)*r_tot + cur_labels = np.argmax( + self._network(Tensor(inputs.astype(inputs_dtype))).asnumpy(), + axis=1) + iteration += 1 + inputs = inputs.astype(inputs_dtype) + del preds, grads + return inputs + return None diff --git a/mindarmour/adv_robustness/attacks/gradient_method.py b/mindarmour/adv_robustness/attacks/gradient_method.py index 9011558..cd7e2e1 100644 --- a/mindarmour/adv_robustness/attacks/gradient_method.py +++ b/mindarmour/adv_robustness/attacks/gradient_method.py @@ -18,12 +18,11 @@ from abc import abstractmethod import numpy as np -from mindspore import Tensor from mindspore.nn import Cell -from mindarmour.utils.util import WithLossCell, GradWrapWithLoss +from mindarmour.utils.util import WithLossCell, GradWrapWithLoss, to_tensor_tuple from mindarmour.utils.logger import LogUtil -from mindarmour.utils._check_param import check_pair_numpy_param, check_model, \ +from mindarmour.utils._check_param import check_model, check_inputs_labels, \ normalize_value, check_value_positive, check_param_multi_types, \ check_norm_level, check_param_type from .attack import Attack @@ -91,18 +90,7 @@ class GradientMethod(Attack): Returns: numpy.ndarray, generated adversarial examples. """ - inputs_image = inputs[0] if isinstance(inputs, tuple) else inputs - if isinstance(inputs, tuple): - for i, inputs_item in enumerate(inputs): - _ = check_pair_numpy_param('inputs_image', inputs_image, \ - 'inputs[{}]'.format(i), inputs_item) - if isinstance(labels, tuple): - for i, labels_item in enumerate(labels): - _ = check_pair_numpy_param('inputs_image', inputs_image, \ - 'labels[{}]'.format(i), labels_item) - else: - _ = check_pair_numpy_param('inputs', inputs_image, \ - 'labels', labels) + inputs_image, inputs, labels = check_inputs_labels(inputs, labels) self._dtype = inputs_image.dtype gradient = self._gradient(inputs, labels) # use random method or not @@ -196,18 +184,8 @@ class FastGradientMethod(GradientMethod): Returns: numpy.ndarray, gradient of inputs. """ - if isinstance(inputs, tuple): - inputs_tensor = tuple() - for item in inputs: - inputs_tensor += (Tensor(item),) - else: - inputs_tensor = (Tensor(inputs),) - if isinstance(labels, tuple): - labels_tensor = tuple() - for item in labels: - labels_tensor += (Tensor(item),) - else: - labels_tensor = (Tensor(labels),) + inputs_tensor = to_tensor_tuple(inputs) + labels_tensor = to_tensor_tuple(labels) out_grad = self._grad_all(*inputs_tensor, *labels_tensor) if isinstance(out_grad, tuple): out_grad = out_grad[0] @@ -315,18 +293,8 @@ class FastGradientSignMethod(GradientMethod): Returns: numpy.ndarray, gradient of inputs. """ - if isinstance(inputs, tuple): - inputs_tensor = tuple() - for item in inputs: - inputs_tensor += (Tensor(item),) - else: - inputs_tensor = (Tensor(inputs),) - if isinstance(labels, tuple): - labels_tensor = tuple() - for item in labels: - labels_tensor += (Tensor(item),) - else: - labels_tensor = (Tensor(labels),) + inputs_tensor = to_tensor_tuple(inputs) + labels_tensor = to_tensor_tuple(labels) out_grad = self._grad_all(*inputs_tensor, *labels_tensor) if isinstance(out_grad, tuple): out_grad = out_grad[0] diff --git a/mindarmour/adv_robustness/attacks/iterative_gradient_method.py b/mindarmour/adv_robustness/attacks/iterative_gradient_method.py index 3537585..94acf25 100644 --- a/mindarmour/adv_robustness/attacks/iterative_gradient_method.py +++ b/mindarmour/adv_robustness/attacks/iterative_gradient_method.py @@ -18,11 +18,10 @@ import numpy as np from PIL import Image, ImageOps from mindspore.nn import Cell -from mindspore import Tensor from mindarmour.utils.logger import LogUtil -from mindarmour.utils.util import WithLossCell, GradWrapWithLoss -from mindarmour.utils._check_param import check_pair_numpy_param, \ +from mindarmour.utils.util import WithLossCell, GradWrapWithLoss, to_tensor_tuple +from mindarmour.utils._check_param import check_inputs_labels, \ normalize_value, check_model, check_value_positive, check_int_positive, \ check_param_type, check_norm_level, check_param_multi_types from .attack import Attack @@ -223,18 +222,7 @@ class BasicIterativeMethod(IterativeGradientMethod): >>> [[0, 0, 1, 0, 0, 0, 0, 0, 0, 0], >>> [0, 0, 0, 0, 0, 0, 1, 0, 0, 0]]) """ - inputs_image = inputs[0] if isinstance(inputs, tuple) else inputs - if isinstance(inputs, tuple): - for i, inputs_item in enumerate(inputs): - _ = check_pair_numpy_param('inputs_image', inputs_image, \ - 'inputs[{}]'.format(i), inputs_item) - if isinstance(labels, tuple): - for i, labels_item in enumerate(labels): - _ = check_pair_numpy_param('inputs_image', inputs_image, \ - 'labels[{}]'.format(i), labels_item) - else: - _ = check_pair_numpy_param('inputs', inputs_image, \ - 'labels', labels) + inputs_image, inputs, labels = check_inputs_labels(inputs, labels) arr_x = inputs_image if self._bounds is not None: clip_min, clip_max = self._bounds @@ -322,18 +310,7 @@ class MomentumIterativeMethod(IterativeGradientMethod): >>> [[0, 0, 0, 0, 0, 0, 0, 0, 1, 0], >>> [0, 0, 0, 0, 0, 1, 0, 0, 0, 0]]) """ - inputs_image = inputs[0] if isinstance(inputs, tuple) else inputs - if isinstance(inputs, tuple): - for i, inputs_item in enumerate(inputs): - _ = check_pair_numpy_param('inputs_image', inputs_image, \ - 'inputs[{}]'.format(i), inputs_item) - if isinstance(labels, tuple): - for i, labels_item in enumerate(labels): - _ = check_pair_numpy_param('inputs_image', inputs_image, \ - 'labels[{}]'.format(i), labels_item) - else: - _ = check_pair_numpy_param('inputs', inputs_image, \ - 'labels', labels) + inputs_image, inputs, labels = check_inputs_labels(inputs, labels) arr_x = inputs_image momentum = 0 if self._bounds is not None: @@ -392,18 +369,8 @@ class MomentumIterativeMethod(IterativeGradientMethod): >>> [[0, 0, 0, 1, 0, 0, 0, 0, 0, 0]) """ # get grad of loss over x - if isinstance(inputs, tuple): - inputs_tensor = tuple() - for item in inputs: - inputs_tensor += (Tensor(item),) - else: - inputs_tensor = (Tensor(inputs),) - if isinstance(labels, tuple): - labels_tensor = tuple() - for item in labels: - labels_tensor += (Tensor(item),) - else: - labels_tensor = (Tensor(labels),) + inputs_tensor = to_tensor_tuple(inputs) + labels_tensor = to_tensor_tuple(labels) out_grad = self._loss_grad(*inputs_tensor, *labels_tensor) if isinstance(out_grad, tuple): out_grad = out_grad[0] @@ -473,18 +440,7 @@ class ProjectedGradientDescent(BasicIterativeMethod): >>> [[0, 0, 0, 0, 0, 0, 0, 0, 0, 1], >>> [1, 0, 0, 0, 0, 0, 0, 0, 0, 0]]) """ - inputs_image = inputs[0] if isinstance(inputs, tuple) else inputs - if isinstance(inputs, tuple): - for i, inputs_item in enumerate(inputs): - _ = check_pair_numpy_param('inputs_image', inputs_image, \ - 'inputs[{}]'.format(i), inputs_item) - if isinstance(labels, tuple): - for i, labels_item in enumerate(labels): - _ = check_pair_numpy_param('inputs_image', inputs_image, \ - 'labels[{}]'.format(i), labels_item) - else: - _ = check_pair_numpy_param('inputs', inputs_image, \ - 'labels', labels) + inputs_image, inputs, labels = check_inputs_labels(inputs, labels) arr_x = inputs_image if self._bounds is not None: clip_min, clip_max = self._bounds diff --git a/mindarmour/utils/_check_param.py b/mindarmour/utils/_check_param.py index 1adf18d..9c1318c 100644 --- a/mindarmour/utils/_check_param.py +++ b/mindarmour/utils/_check_param.py @@ -327,3 +327,22 @@ def check_detection_inputs(inputs, labels): LOGGER.error(TAG, msg) raise ValueError(msg) return images, auxiliary_inputs, gt_boxes, gt_labels + + +def check_inputs_labels(inputs, labels): + """check inputs and labels is valid for white box method.""" + _ = check_param_multi_types('inputs', inputs, (tuple, np.ndarray)) + _ = check_param_multi_types('labels', labels, (tuple, np.ndarray)) + inputs_image = inputs[0] if isinstance(inputs, tuple) else inputs + if isinstance(inputs, tuple): + for i, inputs_item in enumerate(inputs): + _ = check_pair_numpy_param('inputs_image', inputs_image, \ + 'inputs[{}]'.format(i), inputs_item) + if isinstance(labels, tuple): + for i, labels_item in enumerate(labels): + _ = check_pair_numpy_param('inputs', inputs_image, \ + 'labels[{}]'.format(i), labels_item) + else: + _ = check_pair_numpy_param('inputs', inputs_image, \ + 'labels', labels) + return inputs_image, inputs, labels diff --git a/mindarmour/utils/util.py b/mindarmour/utils/util.py index 64c7f18..16a170a 100644 --- a/mindarmour/utils/util.py +++ b/mindarmour/utils/util.py @@ -17,7 +17,7 @@ from mindspore import Tensor from mindspore.nn import Cell from mindspore.ops.composite import GradOperation -from mindarmour.utils._check_param import check_numpy_param +from mindarmour.utils._check_param import check_numpy_param, check_param_multi_types from .logger import LogUtil @@ -54,6 +54,44 @@ def jacobian_matrix(grad_wrap_net, inputs, num_classes): return np.asarray(grads_matrix) +def jacobian_matrix_for_detection(grad_wrap_net, inputs, num_boxes, num_classes): + """ + Calculate the Jacobian matrix for inputs, specifically for object detection model. + + Args: + grad_wrap_net (Cell): A network wrapped by GradWrap. + inputs (numpy.ndarray): Input samples. + num_boxes (int): Number of boxes infered by each image. + num_classes (int): Number of labels of model output. + + Returns: + numpy.ndarray, the Jacobian matrix of inputs. (labels, batch_size, ...) + + Raises: + ValueError: If grad_wrap_net is not a instance of class `GradWrap`. + """ + if not isinstance(grad_wrap_net, GradWrap): + msg = 'grad_wrap_net be and instance of class `GradWrap`.' + LOGGER.error(TAG, msg) + raise ValueError(msg) + grad_wrap_net.set_train() + grads_matrix = [] + inputs_tensor = tuple() + if isinstance(inputs, tuple): + for item in inputs: + inputs_tensor += (Tensor(item),) + else: + inputs_tensor += (Tensor(inputs),) + for idx in range(num_classes): + batch_size = inputs[0].shape[0] if isinstance(inputs, tuple) else inputs.shape[0] + sens = np.zeros((batch_size, num_boxes, num_classes)).astype(np.float32) + sens[:, :, idx] = 1.0 + + grads = grad_wrap_net(*(inputs_tensor), Tensor(sens)) + grads_matrix.append(grads.asnumpy()) + return np.asarray(grads_matrix) + + class WithLossCell(Cell): """ Wrap the network with loss function. @@ -152,19 +190,19 @@ class GradWrap(Cell): self.grad = GradOperation(get_all=False, sens_param=True) self.network = network - def construct(self, inputs, weight): + def construct(self, *data): """ Compute jacobian matrix. Args: - inputs (Tensor): Inputs of network. - weight (Tensor): Weight of each gradient, `weight` has the same - shape with labels. + data (Tensor): Data consists of inputs and weight. \ + - inputs: Inputs of network. \ + - weight: Weight of each gradient, 'weight' has the same shape with labels. Returns: Tensor, Jacobian matrix. """ - gout = self.grad(self.network)(inputs, weight) + gout = self.grad(self.network)(*data) return gout @@ -199,3 +237,15 @@ def calculate_iou(box_i, box_j): return 0 inner_area = (inner_right_line - inner_left_line)*(inner_top_line - inner_bottom_line) return inner_area / (s_i + s_j - inner_area) + + +def to_tensor_tuple(inputs_ori): + """Transfer inputs data into tensor type.""" + inputs_ori = check_param_multi_types('inputs_ori', inputs_ori, [np.ndarray, tuple]) + if isinstance(inputs_ori, tuple): + inputs_tensor = tuple() + for item in inputs_ori: + inputs_tensor += (Tensor(item),) + else: + inputs_tensor = (Tensor(inputs_ori),) + return inputs_tensor diff --git a/tests/ut/python/adv_robustness/attacks/test_deep_fool.py b/tests/ut/python/adv_robustness/attacks/test_deep_fool.py index 7f82a0b..d12a7d1 100644 --- a/tests/ut/python/adv_robustness/attacks/test_deep_fool.py +++ b/tests/ut/python/adv_robustness/attacks/test_deep_fool.py @@ -54,6 +54,23 @@ class Net(Cell): return out +class Net2(Cell): + """ + Construct the network of target model, specifically for detection model test case. + + Examples: + >>> net = Net2() + """ + def __init__(self): + super(Net2, self).__init__() + self._softmax = P.Softmax() + + def construct(self, inputs1, inputs2): + out1 = self._softmax(inputs2) + out2 = self._softmax(inputs1) + return out1, out2 + + @pytest.mark.level0 @pytest.mark.platform_arm_ascend_training @pytest.mark.platform_x86_ascend_training @@ -84,6 +101,27 @@ def test_deepfool_attack(): @pytest.mark.platform_x86_ascend_training @pytest.mark.env_card @pytest.mark.component_mindarmour +def test_deepfool_attack_detection(): + """ + Deepfool-Attack test + """ + net = Net2() + inputs1_np = np.random.random((2, 10, 10)).astype(np.float32) + inputs2_np = np.random.random((2, 10, 5)).astype(np.float32) + gt_boxes = inputs1_np[:, :, 0: 5] + gt_labels = np.argmax(inputs1_np, axis=2) + num_classes = 10 + + attack = DeepFool(net, num_classes, model_type='detection', reserve_ratio=0.3, + bounds=(0.0, 1.0)) + _ = attack.generate((inputs1_np, inputs2_np), (gt_boxes, gt_labels)) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_card +@pytest.mark.component_mindarmour def test_deepfool_attack_inf(): """ Deepfool-Attack test