From fe6cd1fdf6bcdf54f7f89730da097ea29af88dae Mon Sep 17 00:00:00 2001 From: liuluobin Date: Tue, 29 Jun 2021 16:44:39 +0800 Subject: [PATCH] fix GradOperation failed when compute jacobian matrix for detection --- mindarmour/adv_robustness/attacks/deep_fool.py | 259 +++++++++++---------- mindarmour/utils/util.py | 3 +- .../evaluation/test_membership_inference.py | 2 + 3 files changed, 137 insertions(+), 127 deletions(-) diff --git a/mindarmour/adv_robustness/attacks/deep_fool.py b/mindarmour/adv_robustness/attacks/deep_fool.py index db8cca3..a303b17 100644 --- a/mindarmour/adv_robustness/attacks/deep_fool.py +++ b/mindarmour/adv_robustness/attacks/deep_fool.py @@ -124,7 +124,6 @@ class DeepFool(Attack): reserve_ratio=0.3, max_iters=50, overshoot=0.02, norm_level=2, bounds=None, sparse=True): super(DeepFool, self).__init__() self._network = check_model('network', network, Cell) - self._network.set_grad(True) self._max_iters = check_int_positive('max_iters', max_iters) self._overshoot = check_value_positive('overshoot', overshoot) self._norm_level = check_norm_level(norm_level) @@ -169,143 +168,153 @@ class DeepFool(Attack): >>> advs = generate([[0.2, 0.3, 0.4], [0.3, 0.4, 0.5]], [1, 2]) """ if self._model_type == 'detection': - images, auxiliary_inputs = inputs[0], inputs[1:] - gt_boxes, gt_labels = labels - _, gt_object_nums = _deepfool_detection_scores(inputs, gt_boxes, gt_labels, self._network) - if not self._sparse: - gt_labels = np.argmax(gt_labels, axis=2) - origin_labels = np.zeros(gt_labels.shape[0]) - for i in range(gt_labels.shape[0]): - origin_labels[i] = np.argmax(np.bincount(gt_labels[i])) - images_dtype = images.dtype - iteration = 0 - num_boxes = gt_labels.shape[1] - merge_net = _GetLogits(self._network) - detection_net_grad = GradWrap(merge_net) - weight = np.squeeze(np.zeros(images.shape[1:])) - r_tot = np.zeros(images.shape) - x_origin = images - while not _is_success((images,) + auxiliary_inputs, gt_boxes, gt_labels, self._network, gt_object_nums, \ - self._reserve_ratio) and iteration < self._max_iters: - preds_logits = merge_net(*to_tensor_tuple(images), *to_tensor_tuple(auxiliary_inputs)).asnumpy() - grads = jacobian_matrix_for_detection(detection_net_grad, (images,) + auxiliary_inputs, - num_boxes, self._num_classes) - for idx in range(images.shape[0]): - diff_w = np.inf - label = int(origin_labels[idx]) - auxiliary_input_i = tuple() - for item in auxiliary_inputs: - auxiliary_input_i += (np.expand_dims(item[idx], axis=0),) - gt_boxes_i = np.expand_dims(gt_boxes[idx], axis=0) - gt_labels_i = np.expand_dims(gt_labels[idx], axis=0) - inputs_i = (np.expand_dims(images[idx], axis=0),) + auxiliary_input_i - if _is_success(inputs_i, gt_boxes_i, gt_labels_i, - self._network, gt_object_nums[idx], self._reserve_ratio): + return self._generate_detection(inputs, labels) + if self._model_type == 'classification': + return self._generate_classification(inputs, labels) + return None + + + def _generate_detection(self, inputs, labels): + """Generate adversarial examples in detection scenario""" + images, auxiliary_inputs = inputs[0], inputs[1:] + gt_boxes, gt_labels = labels + _, gt_object_nums = _deepfool_detection_scores(inputs, gt_boxes, gt_labels, self._network) + if not self._sparse: + gt_labels = np.argmax(gt_labels, axis=2) + origin_labels = np.zeros(gt_labels.shape[0]) + for i in range(gt_labels.shape[0]): + origin_labels[i] = np.argmax(np.bincount(gt_labels[i])) + images_dtype = images.dtype + iteration = 0 + num_boxes = gt_labels.shape[1] + merge_net = _GetLogits(self._network) + detection_net_grad = GradWrap(merge_net) + weight = np.squeeze(np.zeros(images.shape[1:])) + r_tot = np.zeros(images.shape) + x_origin = images + while not _is_success((images,) + auxiliary_inputs, gt_boxes, gt_labels, self._network, gt_object_nums, \ + self._reserve_ratio) and iteration < self._max_iters: + preds_logits = merge_net(*to_tensor_tuple(images), *to_tensor_tuple(auxiliary_inputs)).asnumpy() + grads = jacobian_matrix_for_detection(detection_net_grad, (images,) + auxiliary_inputs, + num_boxes, self._num_classes) + for idx in range(images.shape[0]): + diff_w = np.inf + label = int(origin_labels[idx]) + auxiliary_input_i = tuple() + for item in auxiliary_inputs: + auxiliary_input_i += (np.expand_dims(item[idx], axis=0),) + gt_boxes_i = np.expand_dims(gt_boxes[idx], axis=0) + gt_labels_i = np.expand_dims(gt_labels[idx], axis=0) + inputs_i = (np.expand_dims(images[idx], axis=0),) + auxiliary_input_i + if _is_success(inputs_i, gt_boxes_i, gt_labels_i, + self._network, gt_object_nums[idx], self._reserve_ratio): + continue + for k in range(self._num_classes): + if k == label: continue - for k in range(self._num_classes): - if k == label: - continue - w_k = grads[k, idx, ...] - grads[label, idx, ...] - f_k = np.mean(np.abs(preds_logits[idx, :, k, ...] - preds_logits[idx, :, label, ...])) - if self._norm_level == 2 or self._norm_level == '2': - diff_w_k = abs(f_k) / (np.linalg.norm(w_k) + 1e-8) - elif self._norm_level == np.inf \ - or self._norm_level == 'inf': - diff_w_k = abs(f_k) / (np.linalg.norm(w_k, ord=1) + 1e-8) - else: - msg = 'ord {} is not available.' \ - .format(str(self._norm_level)) - LOGGER.error(TAG, msg) - raise NotImplementedError(msg) - if diff_w_k < diff_w: - diff_w = diff_w_k - weight = w_k + w_k = grads[k, idx, ...] - grads[label, idx, ...] + f_k = np.mean(np.abs(preds_logits[idx, :, k, ...] - preds_logits[idx, :, label, ...])) if self._norm_level == 2 or self._norm_level == '2': - r_i = diff_w*weight / (np.linalg.norm(weight) + 1e-8) - elif self._norm_level == np.inf or self._norm_level == 'inf': - r_i = diff_w*np.sign(weight) \ - / (np.linalg.norm(weight, ord=1) + 1e-8) + diff_w_k = abs(f_k) / (np.linalg.norm(w_k) + 1e-8) + elif self._norm_level == np.inf \ + or self._norm_level == 'inf': + diff_w_k = abs(f_k) / (np.linalg.norm(w_k, ord=1) + 1e-8) else: - msg = 'ord {} is not available in normalization,' \ + msg = 'ord {} is not available.' \ .format(str(self._norm_level)) LOGGER.error(TAG, msg) raise NotImplementedError(msg) - r_tot[idx, ...] = r_tot[idx, ...] + r_i - - if self._bounds is not None: - clip_min, clip_max = self._bounds - images = x_origin + (1 + self._overshoot)*r_tot*(clip_max-clip_min) - images = np.clip(images, clip_min, clip_max) + if diff_w_k < diff_w: + diff_w = diff_w_k + weight = w_k + if self._norm_level == 2 or self._norm_level == '2': + r_i = diff_w*weight / (np.linalg.norm(weight) + 1e-8) + elif self._norm_level == np.inf or self._norm_level == 'inf': + r_i = diff_w*np.sign(weight) \ + / (np.linalg.norm(weight, ord=1) + 1e-8) else: - images = x_origin + (1 + self._overshoot)*r_tot - iteration += 1 - images = images.astype(images_dtype) - del preds_logits, grads - return images + msg = 'ord {} is not available in normalization,' \ + .format(str(self._norm_level)) + LOGGER.error(TAG, msg) + raise NotImplementedError(msg) + r_tot[idx, ...] = r_tot[idx, ...] + r_i + + if self._bounds is not None: + clip_min, clip_max = self._bounds + images = x_origin + (1 + self._overshoot)*r_tot*(clip_max-clip_min) + images = np.clip(images, clip_min, clip_max) + else: + images = x_origin + (1 + self._overshoot)*r_tot + iteration += 1 + images = images.astype(images_dtype) + del preds_logits, grads + return images + - if self._model_type == 'classification': - inputs, labels = check_pair_numpy_param('inputs', inputs, - 'labels', labels) - if not self._sparse: - labels = np.argmax(labels, axis=1) - inputs_dtype = inputs.dtype - iteration = 0 - origin_labels = labels - cur_labels = origin_labels.copy() - weight = np.squeeze(np.zeros(inputs.shape[1:])) - r_tot = np.zeros(inputs.shape) - x_origin = inputs - while np.any(cur_labels == origin_labels) and iteration < self._max_iters: - preds = self._network(Tensor(inputs)).asnumpy() - grads = jacobian_matrix(self._net_grad, inputs, self._num_classes) - for idx in range(inputs.shape[0]): - diff_w = np.inf - label = origin_labels[idx] - if cur_labels[idx] != label: - continue - for k in range(self._num_classes): - if k == label: - continue - w_k = grads[k, idx, ...] - grads[label, idx, ...] - f_k = preds[idx, k] - preds[idx, label] - if self._norm_level == 2 or self._norm_level == '2': - diff_w_k = abs(f_k) / (np.linalg.norm(w_k) + 1e-8) - elif self._norm_level == np.inf \ - or self._norm_level == 'inf': - diff_w_k = abs(f_k) / (np.linalg.norm(w_k, ord=1) + 1e-8) - else: - msg = 'ord {} is not available.' \ - .format(str(self._norm_level)) - LOGGER.error(TAG, msg) - raise NotImplementedError(msg) - if diff_w_k < diff_w: - diff_w = diff_w_k - weight = w_k + def _generate_classification(self, inputs, labels): + """Generate adversarial examples in classification scenario""" + inputs, labels = check_pair_numpy_param('inputs', inputs, + 'labels', labels) + if not self._sparse: + labels = np.argmax(labels, axis=1) + inputs_dtype = inputs.dtype + iteration = 0 + origin_labels = labels + cur_labels = origin_labels.copy() + weight = np.squeeze(np.zeros(inputs.shape[1:])) + r_tot = np.zeros(inputs.shape) + x_origin = inputs + while np.any(cur_labels == origin_labels) and iteration < self._max_iters: + preds = self._network(Tensor(inputs)).asnumpy() + grads = jacobian_matrix(self._net_grad, inputs, self._num_classes) + for idx in range(inputs.shape[0]): + diff_w = np.inf + label = origin_labels[idx] + if cur_labels[idx] != label: + continue + for k in range(self._num_classes): + if k == label: + continue + w_k = grads[k, idx, ...] - grads[label, idx, ...] + f_k = preds[idx, k] - preds[idx, label] if self._norm_level == 2 or self._norm_level == '2': - r_i = diff_w*weight / (np.linalg.norm(weight) + 1e-8) - elif self._norm_level == np.inf or self._norm_level == 'inf': - r_i = diff_w*np.sign(weight) \ - / (np.linalg.norm(weight, ord=1) + 1e-8) + diff_w_k = abs(f_k) / (np.linalg.norm(w_k) + 1e-8) + elif self._norm_level == np.inf \ + or self._norm_level == 'inf': + diff_w_k = abs(f_k) / (np.linalg.norm(w_k, ord=1) + 1e-8) else: - msg = 'ord {} is not available in normalization.' \ + msg = 'ord {} is not available.' \ .format(str(self._norm_level)) LOGGER.error(TAG, msg) raise NotImplementedError(msg) - r_tot[idx, ...] = r_tot[idx, ...] + r_i + if diff_w_k < diff_w: + diff_w = diff_w_k + weight = w_k - if self._bounds is not None: - clip_min, clip_max = self._bounds - inputs = x_origin + (1 + self._overshoot)*r_tot*(clip_max - - clip_min) - inputs = np.clip(inputs, clip_min, clip_max) + if self._norm_level == 2 or self._norm_level == '2': + r_i = diff_w*weight / (np.linalg.norm(weight) + 1e-8) + elif self._norm_level == np.inf or self._norm_level == 'inf': + r_i = diff_w*np.sign(weight) \ + / (np.linalg.norm(weight, ord=1) + 1e-8) else: - inputs = x_origin + (1 + self._overshoot)*r_tot - cur_labels = np.argmax( - self._network(Tensor(inputs.astype(inputs_dtype))).asnumpy(), - axis=1) - iteration += 1 - inputs = inputs.astype(inputs_dtype) - del preds, grads - return inputs - return None + msg = 'ord {} is not available in normalization.' \ + .format(str(self._norm_level)) + LOGGER.error(TAG, msg) + raise NotImplementedError(msg) + r_tot[idx, ...] = r_tot[idx, ...] + r_i + + if self._bounds is not None: + clip_min, clip_max = self._bounds + inputs = x_origin + (1 + self._overshoot)*r_tot*(clip_max + - clip_min) + inputs = np.clip(inputs, clip_min, clip_max) + else: + inputs = x_origin + (1 + self._overshoot)*r_tot + cur_labels = np.argmax( + self._network(Tensor(inputs.astype(inputs_dtype))).asnumpy(), + axis=1) + iteration += 1 + inputs = inputs.astype(inputs_dtype) + del preds, grads + return inputs diff --git a/mindarmour/utils/util.py b/mindarmour/utils/util.py index bb7e3c5..d787d39 100644 --- a/mindarmour/utils/util.py +++ b/mindarmour/utils/util.py @@ -86,9 +86,8 @@ def jacobian_matrix_for_detection(grad_wrap_net, inputs, num_boxes, num_classes) inputs_tensor += (Tensor(inputs),) for idx in range(num_classes): batch_size = inputs[0].shape[0] if isinstance(inputs, tuple) else inputs.shape[0] - sens = np.zeros((batch_size, num_boxes, num_classes)).astype(np.float32) + sens = np.zeros((batch_size, num_boxes, num_classes), np.float32) sens[:, :, idx] = 1.0 - grads = grad_wrap_net(*(inputs_tensor), Tensor(sens)) grads_matrix.append(grads.asnumpy()) return np.asarray(grads_matrix) diff --git a/tests/ut/python/privacy/evaluation/test_membership_inference.py b/tests/ut/python/privacy/evaluation/test_membership_inference.py index 6353911..8ff0a09 100644 --- a/tests/ut/python/privacy/evaluation/test_membership_inference.py +++ b/tests/ut/python/privacy/evaluation/test_membership_inference.py @@ -60,6 +60,8 @@ def test_get_membership_inference_object(): @pytest.mark.level0 @pytest.mark.platform_x86_ascend_training @pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu @pytest.mark.env_onecard @pytest.mark.component_mindarmour def test_membership_inference_object_train():