Browse Source

fix GradOperation failed when compute jacobian matrix for detection

tags/v1.6.0
liuluobin 4 years ago
parent
commit
fe6cd1fdf6
3 changed files with 137 additions and 127 deletions
  1. +134
    -125
      mindarmour/adv_robustness/attacks/deep_fool.py
  2. +1
    -2
      mindarmour/utils/util.py
  3. +2
    -0
      tests/ut/python/privacy/evaluation/test_membership_inference.py

+ 134
- 125
mindarmour/adv_robustness/attacks/deep_fool.py View File

@@ -124,7 +124,6 @@ class DeepFool(Attack):
reserve_ratio=0.3, max_iters=50, overshoot=0.02, norm_level=2, bounds=None, sparse=True):
super(DeepFool, self).__init__()
self._network = check_model('network', network, Cell)
self._network.set_grad(True)
self._max_iters = check_int_positive('max_iters', max_iters)
self._overshoot = check_value_positive('overshoot', overshoot)
self._norm_level = check_norm_level(norm_level)
@@ -169,143 +168,153 @@ class DeepFool(Attack):
>>> advs = generate([[0.2, 0.3, 0.4], [0.3, 0.4, 0.5]], [1, 2])
"""
if self._model_type == 'detection':
images, auxiliary_inputs = inputs[0], inputs[1:]
gt_boxes, gt_labels = labels
_, gt_object_nums = _deepfool_detection_scores(inputs, gt_boxes, gt_labels, self._network)
if not self._sparse:
gt_labels = np.argmax(gt_labels, axis=2)
origin_labels = np.zeros(gt_labels.shape[0])
for i in range(gt_labels.shape[0]):
origin_labels[i] = np.argmax(np.bincount(gt_labels[i]))
images_dtype = images.dtype
iteration = 0
num_boxes = gt_labels.shape[1]
merge_net = _GetLogits(self._network)
detection_net_grad = GradWrap(merge_net)
weight = np.squeeze(np.zeros(images.shape[1:]))
r_tot = np.zeros(images.shape)
x_origin = images
while not _is_success((images,) + auxiliary_inputs, gt_boxes, gt_labels, self._network, gt_object_nums, \
self._reserve_ratio) and iteration < self._max_iters:
preds_logits = merge_net(*to_tensor_tuple(images), *to_tensor_tuple(auxiliary_inputs)).asnumpy()
grads = jacobian_matrix_for_detection(detection_net_grad, (images,) + auxiliary_inputs,
num_boxes, self._num_classes)
for idx in range(images.shape[0]):
diff_w = np.inf
label = int(origin_labels[idx])
auxiliary_input_i = tuple()
for item in auxiliary_inputs:
auxiliary_input_i += (np.expand_dims(item[idx], axis=0),)
gt_boxes_i = np.expand_dims(gt_boxes[idx], axis=0)
gt_labels_i = np.expand_dims(gt_labels[idx], axis=0)
inputs_i = (np.expand_dims(images[idx], axis=0),) + auxiliary_input_i
if _is_success(inputs_i, gt_boxes_i, gt_labels_i,
self._network, gt_object_nums[idx], self._reserve_ratio):
return self._generate_detection(inputs, labels)
if self._model_type == 'classification':
return self._generate_classification(inputs, labels)
return None


def _generate_detection(self, inputs, labels):
"""Generate adversarial examples in detection scenario"""
images, auxiliary_inputs = inputs[0], inputs[1:]
gt_boxes, gt_labels = labels
_, gt_object_nums = _deepfool_detection_scores(inputs, gt_boxes, gt_labels, self._network)
if not self._sparse:
gt_labels = np.argmax(gt_labels, axis=2)
origin_labels = np.zeros(gt_labels.shape[0])
for i in range(gt_labels.shape[0]):
origin_labels[i] = np.argmax(np.bincount(gt_labels[i]))
images_dtype = images.dtype
iteration = 0
num_boxes = gt_labels.shape[1]
merge_net = _GetLogits(self._network)
detection_net_grad = GradWrap(merge_net)
weight = np.squeeze(np.zeros(images.shape[1:]))
r_tot = np.zeros(images.shape)
x_origin = images
while not _is_success((images,) + auxiliary_inputs, gt_boxes, gt_labels, self._network, gt_object_nums, \
self._reserve_ratio) and iteration < self._max_iters:
preds_logits = merge_net(*to_tensor_tuple(images), *to_tensor_tuple(auxiliary_inputs)).asnumpy()
grads = jacobian_matrix_for_detection(detection_net_grad, (images,) + auxiliary_inputs,
num_boxes, self._num_classes)
for idx in range(images.shape[0]):
diff_w = np.inf
label = int(origin_labels[idx])
auxiliary_input_i = tuple()
for item in auxiliary_inputs:
auxiliary_input_i += (np.expand_dims(item[idx], axis=0),)
gt_boxes_i = np.expand_dims(gt_boxes[idx], axis=0)
gt_labels_i = np.expand_dims(gt_labels[idx], axis=0)
inputs_i = (np.expand_dims(images[idx], axis=0),) + auxiliary_input_i
if _is_success(inputs_i, gt_boxes_i, gt_labels_i,
self._network, gt_object_nums[idx], self._reserve_ratio):
continue
for k in range(self._num_classes):
if k == label:
continue
for k in range(self._num_classes):
if k == label:
continue
w_k = grads[k, idx, ...] - grads[label, idx, ...]
f_k = np.mean(np.abs(preds_logits[idx, :, k, ...] - preds_logits[idx, :, label, ...]))
if self._norm_level == 2 or self._norm_level == '2':
diff_w_k = abs(f_k) / (np.linalg.norm(w_k) + 1e-8)
elif self._norm_level == np.inf \
or self._norm_level == 'inf':
diff_w_k = abs(f_k) / (np.linalg.norm(w_k, ord=1) + 1e-8)
else:
msg = 'ord {} is not available.' \
.format(str(self._norm_level))
LOGGER.error(TAG, msg)
raise NotImplementedError(msg)
if diff_w_k < diff_w:
diff_w = diff_w_k
weight = w_k
w_k = grads[k, idx, ...] - grads[label, idx, ...]
f_k = np.mean(np.abs(preds_logits[idx, :, k, ...] - preds_logits[idx, :, label, ...]))
if self._norm_level == 2 or self._norm_level == '2':
r_i = diff_w*weight / (np.linalg.norm(weight) + 1e-8)
elif self._norm_level == np.inf or self._norm_level == 'inf':
r_i = diff_w*np.sign(weight) \
/ (np.linalg.norm(weight, ord=1) + 1e-8)
diff_w_k = abs(f_k) / (np.linalg.norm(w_k) + 1e-8)
elif self._norm_level == np.inf \
or self._norm_level == 'inf':
diff_w_k = abs(f_k) / (np.linalg.norm(w_k, ord=1) + 1e-8)
else:
msg = 'ord {} is not available in normalization,' \
msg = 'ord {} is not available.' \
.format(str(self._norm_level))
LOGGER.error(TAG, msg)
raise NotImplementedError(msg)
r_tot[idx, ...] = r_tot[idx, ...] + r_i

if self._bounds is not None:
clip_min, clip_max = self._bounds
images = x_origin + (1 + self._overshoot)*r_tot*(clip_max-clip_min)
images = np.clip(images, clip_min, clip_max)
if diff_w_k < diff_w:
diff_w = diff_w_k
weight = w_k
if self._norm_level == 2 or self._norm_level == '2':
r_i = diff_w*weight / (np.linalg.norm(weight) + 1e-8)
elif self._norm_level == np.inf or self._norm_level == 'inf':
r_i = diff_w*np.sign(weight) \
/ (np.linalg.norm(weight, ord=1) + 1e-8)
else:
images = x_origin + (1 + self._overshoot)*r_tot
iteration += 1
images = images.astype(images_dtype)
del preds_logits, grads
return images
msg = 'ord {} is not available in normalization,' \
.format(str(self._norm_level))
LOGGER.error(TAG, msg)
raise NotImplementedError(msg)
r_tot[idx, ...] = r_tot[idx, ...] + r_i

if self._bounds is not None:
clip_min, clip_max = self._bounds
images = x_origin + (1 + self._overshoot)*r_tot*(clip_max-clip_min)
images = np.clip(images, clip_min, clip_max)
else:
images = x_origin + (1 + self._overshoot)*r_tot
iteration += 1
images = images.astype(images_dtype)
del preds_logits, grads
return images


if self._model_type == 'classification':
inputs, labels = check_pair_numpy_param('inputs', inputs,
'labels', labels)
if not self._sparse:
labels = np.argmax(labels, axis=1)
inputs_dtype = inputs.dtype
iteration = 0
origin_labels = labels
cur_labels = origin_labels.copy()
weight = np.squeeze(np.zeros(inputs.shape[1:]))
r_tot = np.zeros(inputs.shape)
x_origin = inputs
while np.any(cur_labels == origin_labels) and iteration < self._max_iters:
preds = self._network(Tensor(inputs)).asnumpy()
grads = jacobian_matrix(self._net_grad, inputs, self._num_classes)
for idx in range(inputs.shape[0]):
diff_w = np.inf
label = origin_labels[idx]
if cur_labels[idx] != label:
continue
for k in range(self._num_classes):
if k == label:
continue
w_k = grads[k, idx, ...] - grads[label, idx, ...]
f_k = preds[idx, k] - preds[idx, label]
if self._norm_level == 2 or self._norm_level == '2':
diff_w_k = abs(f_k) / (np.linalg.norm(w_k) + 1e-8)
elif self._norm_level == np.inf \
or self._norm_level == 'inf':
diff_w_k = abs(f_k) / (np.linalg.norm(w_k, ord=1) + 1e-8)
else:
msg = 'ord {} is not available.' \
.format(str(self._norm_level))
LOGGER.error(TAG, msg)
raise NotImplementedError(msg)
if diff_w_k < diff_w:
diff_w = diff_w_k
weight = w_k

def _generate_classification(self, inputs, labels):
"""Generate adversarial examples in classification scenario"""
inputs, labels = check_pair_numpy_param('inputs', inputs,
'labels', labels)
if not self._sparse:
labels = np.argmax(labels, axis=1)
inputs_dtype = inputs.dtype
iteration = 0
origin_labels = labels
cur_labels = origin_labels.copy()
weight = np.squeeze(np.zeros(inputs.shape[1:]))
r_tot = np.zeros(inputs.shape)
x_origin = inputs
while np.any(cur_labels == origin_labels) and iteration < self._max_iters:
preds = self._network(Tensor(inputs)).asnumpy()
grads = jacobian_matrix(self._net_grad, inputs, self._num_classes)
for idx in range(inputs.shape[0]):
diff_w = np.inf
label = origin_labels[idx]
if cur_labels[idx] != label:
continue
for k in range(self._num_classes):
if k == label:
continue
w_k = grads[k, idx, ...] - grads[label, idx, ...]
f_k = preds[idx, k] - preds[idx, label]
if self._norm_level == 2 or self._norm_level == '2':
r_i = diff_w*weight / (np.linalg.norm(weight) + 1e-8)
elif self._norm_level == np.inf or self._norm_level == 'inf':
r_i = diff_w*np.sign(weight) \
/ (np.linalg.norm(weight, ord=1) + 1e-8)
diff_w_k = abs(f_k) / (np.linalg.norm(w_k) + 1e-8)
elif self._norm_level == np.inf \
or self._norm_level == 'inf':
diff_w_k = abs(f_k) / (np.linalg.norm(w_k, ord=1) + 1e-8)
else:
msg = 'ord {} is not available in normalization.' \
msg = 'ord {} is not available.' \
.format(str(self._norm_level))
LOGGER.error(TAG, msg)
raise NotImplementedError(msg)
r_tot[idx, ...] = r_tot[idx, ...] + r_i
if diff_w_k < diff_w:
diff_w = diff_w_k
weight = w_k

if self._bounds is not None:
clip_min, clip_max = self._bounds
inputs = x_origin + (1 + self._overshoot)*r_tot*(clip_max
- clip_min)
inputs = np.clip(inputs, clip_min, clip_max)
if self._norm_level == 2 or self._norm_level == '2':
r_i = diff_w*weight / (np.linalg.norm(weight) + 1e-8)
elif self._norm_level == np.inf or self._norm_level == 'inf':
r_i = diff_w*np.sign(weight) \
/ (np.linalg.norm(weight, ord=1) + 1e-8)
else:
inputs = x_origin + (1 + self._overshoot)*r_tot
cur_labels = np.argmax(
self._network(Tensor(inputs.astype(inputs_dtype))).asnumpy(),
axis=1)
iteration += 1
inputs = inputs.astype(inputs_dtype)
del preds, grads
return inputs
return None
msg = 'ord {} is not available in normalization.' \
.format(str(self._norm_level))
LOGGER.error(TAG, msg)
raise NotImplementedError(msg)
r_tot[idx, ...] = r_tot[idx, ...] + r_i

if self._bounds is not None:
clip_min, clip_max = self._bounds
inputs = x_origin + (1 + self._overshoot)*r_tot*(clip_max
- clip_min)
inputs = np.clip(inputs, clip_min, clip_max)
else:
inputs = x_origin + (1 + self._overshoot)*r_tot
cur_labels = np.argmax(
self._network(Tensor(inputs.astype(inputs_dtype))).asnumpy(),
axis=1)
iteration += 1
inputs = inputs.astype(inputs_dtype)
del preds, grads
return inputs

+ 1
- 2
mindarmour/utils/util.py View File

@@ -86,9 +86,8 @@ def jacobian_matrix_for_detection(grad_wrap_net, inputs, num_boxes, num_classes)
inputs_tensor += (Tensor(inputs),)
for idx in range(num_classes):
batch_size = inputs[0].shape[0] if isinstance(inputs, tuple) else inputs.shape[0]
sens = np.zeros((batch_size, num_boxes, num_classes)).astype(np.float32)
sens = np.zeros((batch_size, num_boxes, num_classes), np.float32)
sens[:, :, idx] = 1.0

grads = grad_wrap_net(*(inputs_tensor), Tensor(sens))
grads_matrix.append(grads.asnumpy())
return np.asarray(grads_matrix)


+ 2
- 0
tests/ut/python/privacy/evaluation/test_membership_inference.py View File

@@ -60,6 +60,8 @@ def test_get_membership_inference_object():
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.component_mindarmour
def test_membership_inference_object_train():


Loading…
Cancel
Save