Browse Source

Fix several bugs for PSOAttack and GeneticAttack.

tags/v1.2.1
jin-xiulang 4 years ago
parent
commit
7aa6b285d3
11 changed files with 83 additions and 56 deletions
  1. +1
    -1
      examples/model_security/model_attacks/black_box/mnist_attack_genetic.py
  2. +1
    -3
      examples/privacy/diff_privacy/lenet5_dp.py
  3. +1
    -3
      examples/privacy/diff_privacy/lenet5_dp_ada_gaussian.py
  4. +1
    -3
      examples/privacy/diff_privacy/lenet5_dp_ada_sgd_graph.py
  5. +1
    -2
      examples/privacy/diff_privacy/lenet5_dp_optimizer.py
  6. +19
    -11
      mindarmour/adv_robustness/attacks/attack.py
  7. +28
    -13
      mindarmour/adv_robustness/attacks/black/genetic_attack.py
  8. +28
    -17
      mindarmour/adv_robustness/attacks/black/pso_attack.py
  9. +1
    -1
      tests/ut/python/adv_robustness/attacks/black/test_genetic_attack.py
  10. +1
    -1
      tests/ut/python/adv_robustness/attacks/black/test_pso_attack.py
  11. +1
    -1
      tests/ut/python/adv_robustness/attacks/test_jsma.py

+ 1
- 1
examples/model_security/model_attacks/black_box/mnist_attack_genetic.py View File

@@ -87,7 +87,7 @@ def test_genetic_attack_on_mnist():

# attacking
attack = GeneticAttack(model=model, pop_size=6, mutation_rate=0.05,
per_bounds=0.1, step_size=0.25, temp=0.1,
per_bounds=0.4, step_size=0.25, temp=0.1,
sparse=True)
targeted_labels = np.random.randint(0, 10, size=len(true_labels))
for i, true_l in enumerate(true_labels):


+ 1
- 3
examples/privacy/diff_privacy/lenet5_dp.py View File

@@ -107,9 +107,7 @@ if __name__ == "__main__":
raise ValueError(
"Number of micro_batches should divide evenly batch_size")
# Create a factory class of DP noise mechanisms, this method is adding noise
# in gradients while training. Initial_noise_multiplier is suggested to be
# greater than 1.0, otherwise the privacy budget would be huge, which means
# that the privacy protection effect is weak. Mechanisms can be 'Gaussian'
# in gradients while training. Mechanisms can be 'Gaussian'
# or 'AdaGaussian', in which noise would be decayed with 'AdaGaussian'
# mechanism while be constant with 'Gaussian' mechanism.
noise_mech = NoiseMechanismsFactory().create(cfg.noise_mechanisms,


+ 1
- 3
examples/privacy/diff_privacy/lenet5_dp_ada_gaussian.py View File

@@ -106,9 +106,7 @@ if __name__ == "__main__":
raise ValueError(
"Number of micro_batches should divide evenly batch_size")
# Create a factory class of DP noise mechanisms, this method is adding noise
# in gradients while training. Initial_noise_multiplier is suggested to be
# greater than 1.0, otherwise the privacy budget would be huge, which means
# that the privacy protection effect is weak. Mechanisms can be 'Gaussian'
# in gradients while training. Mechanisms can be 'Gaussian'
# or 'AdaGaussian', in which noise would be decayed with 'AdaGaussian'
# mechanism while be constant with 'Gaussian' mechanism.
noise_mech = NoiseMechanismsFactory().create(cfg.noise_mechanisms,


+ 1
- 3
examples/privacy/diff_privacy/lenet5_dp_ada_sgd_graph.py View File

@@ -106,9 +106,7 @@ if __name__ == "__main__":
raise ValueError(
"Number of micro_batches should divide evenly batch_size")
# Create a factory class of DP noise mechanisms, this method is adding noise
# in gradients while training. Initial_noise_multiplier is suggested to be
# greater than 1.0, otherwise the privacy budget would be huge, which means
# that the privacy protection effect is weak. Mechanisms can be 'Gaussian'
# in gradients while training. Mechanisms can be 'Gaussian'
# or 'AdaGaussian', in which noise would be decayed with 'AdaGaussian'
# mechanism while be constant with 'Gaussian' mechanism.
noise_mech = NoiseMechanismsFactory().create(cfg.noise_mechanisms,


+ 1
- 2
examples/privacy/diff_privacy/lenet5_dp_optimizer.py View File

@@ -103,8 +103,7 @@ if __name__ == "__main__":
if cfg.micro_batches and cfg.batch_size % cfg.micro_batches != 0:
raise ValueError("Number of micro_batches should divide evenly batch_size")
# Create a factory class of DP mechanisms, this method is adding noise in gradients while training.
# Initial_noise_multiplier is suggested to be greater than 1.0, otherwise the privacy budget would be huge, which
# means that the privacy protection effect is weak. Mechanisms can be 'Gaussian' or 'AdaGaussian', in which noise
# Mechanisms can be 'Gaussian' or 'AdaGaussian', in which noise
# would be decayed with 'AdaGaussian' mechanism while be constant with 'Gaussian' mechanism.
dp_opt = DPOptimizerClassFactory(micro_batches=cfg.micro_batches)
dp_opt.set_mechanisms(cfg.noise_mechanisms,


+ 19
- 11
mindarmour/adv_robustness/attacks/attack.py View File

@@ -179,19 +179,27 @@ class Attack:
LOGGER.info(TAG, 'Reduction begins...')
model = check_model('model', model, BlackModel)
x_ori = check_numpy_param('x_ori', x_ori)
_, gt_num = self._detection_scores((x_ori,) + auxiliary_inputs, gt_boxes, gt_labels, model)
best_position = check_numpy_param('best_position', best_position)
x_ori, best_position = check_equal_shape('x_ori', x_ori, 'best_position', best_position)
x_shape = best_position.shape
reduction_iters = 1000 # recover 0.1% each step
_, original_num = self._detection_scores((best_position,) + auxiliary_inputs, gt_boxes, gt_labels, model)
for _ in range(reduction_iters):
diff = x_ori - best_position
res = 0.5*diff*(np.random.random(x_shape) < 0.001)
best_position += res
_, correct_num = self._detection_scores((best_position,) + auxiliary_inputs, gt_boxes, gt_labels, model)
q_times += 1
if correct_num > original_num:
best_position -= res
# pylint: disable=invalid-name
REDUCTION_ITERS = 6 # recover 10% difference each time and recover 60% totally.
for _ in range(REDUCTION_ITERS):
BLOCK_NUM = 30 # divide the image into 30 segments
block_width = best_position.shape[0] // BLOCK_NUM
if block_width > 0:
for i in range(BLOCK_NUM):
diff = x_ori[i*block_width: (i+1)*block_width, :, :]\
- best_position[i*block_width:(i+1)*block_width, :, :]
if np.max(np.abs(diff)) >= 0.1*(self._bounds[1] - self._bounds[0]):
res = diff*0.1
best_position[i*block_width: (i+1)*block_width, :, :] += res
_, correct_num = self._detection_scores((best_position,) + auxiliary_inputs, gt_boxes,
gt_labels, model)
q_times += 1
if correct_num[0] > max(original_num[0], gt_num[0]*self._reserve_ratio):
best_position[i*block_width:(i+1)*block_width, :, :] -= res
return best_position, q_times

@staticmethod
@@ -229,7 +237,7 @@ class Attack:
max_iou_confi = 0
for j in range(gt_box_num):
iou = calculate_iou(pred_box[:4], gt_box[j][:4])
if labels[i] == gt_label[j] and iou > iou_thres:
if labels[i] == gt_label[j] and iou > iou_thres and correct_label_flag[j] == 0:
max_iou_confi = max(max_iou_confi, pred_box[-1] + iou)
correct_label_flag[j] = 1
score += max_iou_confi


+ 28
- 13
mindarmour/adv_robustness/attacks/black/genetic_attack.py View File

@@ -162,7 +162,10 @@ class GeneticAttack(Attack):
inputs, labels = check_pair_numpy_param('inputs', inputs,
'labels', labels)
if self._sparse:
label_squ = np.squeeze(labels)
if labels.size > 1:
label_squ = np.squeeze(labels)
else:
label_squ = labels
if len(label_squ.shape) >= 2 or label_squ.shape[0] != inputs.shape[0]:
msg = "The parameter 'sparse' of GeneticAttack is True, but the input labels is not sparse style " \
"and got its shape as {}.".format(labels.shape)
@@ -198,16 +201,17 @@ class GeneticAttack(Attack):
# generate particles
ori_copies = np.repeat(x_ori[np.newaxis, :], self._pop_size, axis=0)
# initial perturbations
cur_pert = np.clip(np.random.random(ori_copies.shape)*self._step_size*pixel_deep,
(0 - self._per_bounds)*pixel_deep,
self._per_bounds*pixel_deep)

cur_pert = np.random.uniform(self._bounds[0], self._bounds[1], ori_copies.shape)
cur_pop = ori_copies + cur_pert
query_times = 0
iters = 0

while iters < self._max_steps:
iters += 1
cur_pop = np.clip(
ori_copies + cur_pert, self._bounds[0], self._bounds[1])
cur_pop = np.clip(np.clip(cur_pop,
ori_copies - pixel_deep*self._per_bounds,
ori_copies + pixel_deep*self._per_bounds),
self._bounds[0], self._bounds[1])

if self._model_type == 'classification':
pop_preds = self._model.predict(cur_pop)
@@ -235,9 +239,19 @@ class GeneticAttack(Attack):
fit_vals = abs(
confi_ori - confi_adv) - self._c / self._pop_size * np.linalg.norm(
(cur_pop - x_ori).reshape(cur_pop.shape[0], -1), axis=1)

if np.max(fit_vals) < 0:
self._c /= 2

if np.max(fit_vals) < -2:
LOGGER.debug(TAG,
'best fitness value is %s, which is too small. We recommend that you decrease '
'the value of the initialization parameter c.', np.max(fit_vals))
if iters < 3 and np.max(fit_vals) > 100:
LOGGER.debug(TAG,
'best fitness value is %s, which is too large. We recommend that you increase '
'the value of the initialization parameter c.', np.max(fit_vals))

if np.min(correct_nums_adv) <= int(gt_object_num*self._reserve_ratio):
is_success = True
best_idx = np.argmin(correct_nums_adv)
@@ -252,6 +266,7 @@ class GeneticAttack(Attack):
break

best_fit = max(fit_vals)

if best_fit > self._best_fit:
self._best_fit = best_fit
self._plateau_times = 0
@@ -263,19 +278,19 @@ class GeneticAttack(Attack):
self._plateau_times = 0
if self._adaptive:
step_noise = max(self._step_size, 0.4*(0.9**self._adap_times))
step_p = max(self._step_size, 0.5*(0.9**self._adap_times))
step_p = max(self._mutation_rate, 0.5*(0.9**self._adap_times))
else:
step_noise = self._step_size
step_p = self._mutation_rate
step_temp = self._temp
elite = cur_pert[np.argmax(fit_vals)]
elite = cur_pop[np.argmax(fit_vals)]
select_probs = softmax(fit_vals/step_temp)
select_args = np.arange(self._pop_size)
parents_arg = np.random.choice(
a=select_args, size=2*(self._pop_size - 1),
replace=True, p=select_probs)
parent1 = cur_pert[parents_arg[:self._pop_size - 1]]
parent2 = cur_pert[parents_arg[self._pop_size - 1:]]
parent1 = cur_pop[parents_arg[:self._pop_size - 1]]
parent2 = cur_pop[parents_arg[self._pop_size - 1:]]
parent1_probs = select_probs[parents_arg[:self._pop_size - 1]]
parent2_probs = select_probs[parents_arg[self._pop_size - 1:]]
parent2_probs = parent2_probs / (parent1_probs + parent2_probs)
@@ -290,11 +305,11 @@ class GeneticAttack(Attack):
mutated_childs = self._mutation(
childs, step_noise=self._per_bounds*step_noise,
prob=step_p)
cur_pert = np.concatenate((mutated_childs, elite[np.newaxis, :]))
cur_pop = np.concatenate((mutated_childs, elite[np.newaxis, :]))

if not is_success:
LOGGER.debug(TAG, 'fail to find adversarial sample.')
final_adv = elite + x_ori
final_adv = elite
if self._model_type == 'detection':
final_adv, query_times = self._fast_reduction(
x_ori, final_adv, query_times, auxiliary_input_i, gt_boxes_i, gt_labels_i, model=self._model)


+ 28
- 17
mindarmour/adv_robustness/attacks/black/pso_attack.py View File

@@ -161,15 +161,12 @@ class PSOAttack(Attack):
Returns:
numpy.ndarray, mutational inputs.
"""
# LOGGER.info(TAG, 'Mutation happens...')
LOGGER.info(TAG, 'Mutation happens...')
pixel_deep = self._bounds[1] - self._bounds[0]
cur_pop = check_numpy_param('cur_pop', cur_pop)
perturb_noise = (np.random.random(cur_pop.shape) - 0.5)*pixel_deep
mutated_pop = perturb_noise*(np.random.random(cur_pop.shape) < self._pm) + cur_pop
if self._model_type == 'classification':
mutated_pop = np.clip(np.clip(mutated_pop, cur_pop - self._per_bounds*np.abs(cur_pop),
cur_pop + self._per_bounds*np.abs(cur_pop)),
self._bounds[0], self._bounds[1])
mutated_pop = np.clip(perturb_noise*(np.random.random(cur_pop.shape) < self._pm) + cur_pop, self._bounds[0],
self._bounds[1])
return mutated_pop

def generate(self, inputs, labels):
@@ -199,7 +196,10 @@ class PSOAttack(Attack):
inputs, labels = check_pair_numpy_param('inputs', inputs,
'labels', labels)
if self._sparse:
label_squ = np.squeeze(labels)
if labels.size > 1:
label_squ = np.squeeze(labels)
else:
label_squ = labels
if len(label_squ.shape) >= 2 or label_squ.shape[0] != inputs.shape[0]:
msg = "The parameter 'sparse' of PSOAttack is True, but the input labels is not sparse style and " \
"got its shape as {}.".format(labels.shape)
@@ -242,12 +242,12 @@ class PSOAttack(Attack):
# initial global optimum position
best_position = x_ori
x_copies = np.repeat(x_ori[np.newaxis, :], self._pop_size, axis=0)
cur_noise = np.clip((np.random.random(x_copies.shape) - 0.5)*pixel_deep*self._step_size,
cur_noise = np.clip(np.random.random(x_copies.shape)*pixel_deep,
(0 - self._per_bounds)*(np.abs(x_copies) + 0.1),
self._per_bounds*(np.abs(x_copies) + 0.1))
par = np.clip(x_copies + cur_noise, self._bounds[0], self._bounds[1])
# initial advs
par_ori = np.copy(par)
par = np.clip(x_copies + cur_noise, self._bounds[0], self._bounds[1])
# initial optimum positions for particles
par_best_poi = np.copy(par)
# initial optimum fitness values
@@ -264,12 +264,17 @@ class PSOAttack(Attack):
v_particles = self._step_size*(
v_particles + self._c1*ran_1*(best_position - par)) \
+ self._c2*ran_2*(par_best_poi - par)
par = np.clip(np.clip(par + v_particles,
par_ori - (np.abs(par_ori) + 0.1*pixel_deep)*self._per_bounds,
par_ori + (np.abs(par_ori) + 0.1*pixel_deep)*self._per_bounds),
self._bounds[0], self._bounds[1])
if iters > 20 and is_mutation:

par += v_particles

if iters > 6 and is_mutation:
par = self._mutation_op(par)

par = np.clip(np.clip(par,
x_copies - (np.abs(x_copies) + 0.1*pixel_deep)*self._per_bounds,
x_copies + (np.abs(x_copies) + 0.1*pixel_deep)*self._per_bounds),
self._bounds[0], self._bounds[1])

if self._model_type == 'classification':
confi_adv = self._confidence_cla(par, label_i)
elif self._model_type == 'detection':
@@ -283,8 +288,14 @@ class PSOAttack(Attack):
par_best_poi[k] = par[k]
if fit_value[k] > best_fitness:
best_fitness = fit_value[k]
best_position = par[k]
best_position = par[k].copy()
iters += 1
if best_fitness < -2:
LOGGER.debug(TAG, 'best fitness value is %s, which is too small. We recommend that you decrease '
'the value of the initialization parameter c.', best_fitness)
if iters < 3 and best_fitness > 100:
LOGGER.debug(TAG, 'best fitness value is %s, which is too large. We recommend that you increase '
'the value of the initialization parameter c.', best_fitness)
is_mutation = False
if (best_fitness - last_best_fit) < last_best_fit*0.05:
is_mutation = True
@@ -324,7 +335,7 @@ class PSOAttack(Attack):
adv_list.append(best_position)
success_list.append(is_success)
query_times_list.append(q_times)
del x_copies, cur_noise, par, par_ori, par_best_poi
del x_copies, cur_noise, par, par_best_poi
return np.asarray(success_list), \
np.asarray(adv_list), \
np.asarray(query_times_list)

+ 1
- 1
tests/ut/python/adv_robustness/attacks/black/test_genetic_attack.py View File

@@ -175,7 +175,7 @@ def test_genetic_attack_detection_cpu():
"""
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
batch_size = 2
inputs = np.random.random((batch_size, 3, 28, 28))
inputs = np.random.random((batch_size, 100, 100, 3))
model = DetectionModel()
attack = GeneticAttack(model, model_type='detection', pop_size=6, mutation_rate=0.05,
per_bounds=0.1, step_size=0.25, temp=0.1,


+ 1
- 1
tests/ut/python/adv_robustness/attacks/black/test_pso_attack.py View File

@@ -201,7 +201,7 @@ def test_pso_attack_detection_cpu():
"""
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
batch_size = 2
inputs = np.random.random((batch_size, 3, 28, 28))
inputs = np.random.random((batch_size, 100, 100, 3))
model = DetectionModel()
attack = PSOAttack(model, t_max=30, pm=0.5, model_type='detection', reserve_ratio=0.5)



+ 1
- 1
tests/ut/python/adv_robustness/attacks/test_jsma.py View File

@@ -115,7 +115,7 @@ def test_jsma_attack_gpu():
"""
JSMA-Attack test
"""
context.set_context(device_target="GPU")
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
net = Net()
input_shape = (1, 5)
batch_size, classes = input_shape


Loading…
Cancel
Save