@@ -87,7 +87,7 @@ def test_genetic_attack_on_mnist(): | |||
# attacking | |||
attack = GeneticAttack(model=model, pop_size=6, mutation_rate=0.05, | |||
per_bounds=0.1, step_size=0.25, temp=0.1, | |||
per_bounds=0.4, step_size=0.25, temp=0.1, | |||
sparse=True) | |||
targeted_labels = np.random.randint(0, 10, size=len(true_labels)) | |||
for i, true_l in enumerate(true_labels): | |||
@@ -107,9 +107,7 @@ if __name__ == "__main__": | |||
raise ValueError( | |||
"Number of micro_batches should divide evenly batch_size") | |||
# Create a factory class of DP noise mechanisms, this method is adding noise | |||
# in gradients while training. Initial_noise_multiplier is suggested to be | |||
# greater than 1.0, otherwise the privacy budget would be huge, which means | |||
# that the privacy protection effect is weak. Mechanisms can be 'Gaussian' | |||
# in gradients while training. Mechanisms can be 'Gaussian' | |||
# or 'AdaGaussian', in which noise would be decayed with 'AdaGaussian' | |||
# mechanism while be constant with 'Gaussian' mechanism. | |||
noise_mech = NoiseMechanismsFactory().create(cfg.noise_mechanisms, | |||
@@ -106,9 +106,7 @@ if __name__ == "__main__": | |||
raise ValueError( | |||
"Number of micro_batches should divide evenly batch_size") | |||
# Create a factory class of DP noise mechanisms, this method is adding noise | |||
# in gradients while training. Initial_noise_multiplier is suggested to be | |||
# greater than 1.0, otherwise the privacy budget would be huge, which means | |||
# that the privacy protection effect is weak. Mechanisms can be 'Gaussian' | |||
# in gradients while training. Mechanisms can be 'Gaussian' | |||
# or 'AdaGaussian', in which noise would be decayed with 'AdaGaussian' | |||
# mechanism while be constant with 'Gaussian' mechanism. | |||
noise_mech = NoiseMechanismsFactory().create(cfg.noise_mechanisms, | |||
@@ -106,9 +106,7 @@ if __name__ == "__main__": | |||
raise ValueError( | |||
"Number of micro_batches should divide evenly batch_size") | |||
# Create a factory class of DP noise mechanisms, this method is adding noise | |||
# in gradients while training. Initial_noise_multiplier is suggested to be | |||
# greater than 1.0, otherwise the privacy budget would be huge, which means | |||
# that the privacy protection effect is weak. Mechanisms can be 'Gaussian' | |||
# in gradients while training. Mechanisms can be 'Gaussian' | |||
# or 'AdaGaussian', in which noise would be decayed with 'AdaGaussian' | |||
# mechanism while be constant with 'Gaussian' mechanism. | |||
noise_mech = NoiseMechanismsFactory().create(cfg.noise_mechanisms, | |||
@@ -103,8 +103,7 @@ if __name__ == "__main__": | |||
if cfg.micro_batches and cfg.batch_size % cfg.micro_batches != 0: | |||
raise ValueError("Number of micro_batches should divide evenly batch_size") | |||
# Create a factory class of DP mechanisms, this method is adding noise in gradients while training. | |||
# Initial_noise_multiplier is suggested to be greater than 1.0, otherwise the privacy budget would be huge, which | |||
# means that the privacy protection effect is weak. Mechanisms can be 'Gaussian' or 'AdaGaussian', in which noise | |||
# Mechanisms can be 'Gaussian' or 'AdaGaussian', in which noise | |||
# would be decayed with 'AdaGaussian' mechanism while be constant with 'Gaussian' mechanism. | |||
dp_opt = DPOptimizerClassFactory(micro_batches=cfg.micro_batches) | |||
dp_opt.set_mechanisms(cfg.noise_mechanisms, | |||
@@ -179,19 +179,27 @@ class Attack: | |||
LOGGER.info(TAG, 'Reduction begins...') | |||
model = check_model('model', model, BlackModel) | |||
x_ori = check_numpy_param('x_ori', x_ori) | |||
_, gt_num = self._detection_scores((x_ori,) + auxiliary_inputs, gt_boxes, gt_labels, model) | |||
best_position = check_numpy_param('best_position', best_position) | |||
x_ori, best_position = check_equal_shape('x_ori', x_ori, 'best_position', best_position) | |||
x_shape = best_position.shape | |||
reduction_iters = 1000 # recover 0.1% each step | |||
_, original_num = self._detection_scores((best_position,) + auxiliary_inputs, gt_boxes, gt_labels, model) | |||
for _ in range(reduction_iters): | |||
diff = x_ori - best_position | |||
res = 0.5*diff*(np.random.random(x_shape) < 0.001) | |||
best_position += res | |||
_, correct_num = self._detection_scores((best_position,) + auxiliary_inputs, gt_boxes, gt_labels, model) | |||
q_times += 1 | |||
if correct_num > original_num: | |||
best_position -= res | |||
# pylint: disable=invalid-name | |||
REDUCTION_ITERS = 6 # recover 10% difference each time and recover 60% totally. | |||
for _ in range(REDUCTION_ITERS): | |||
BLOCK_NUM = 30 # divide the image into 30 segments | |||
block_width = best_position.shape[0] // BLOCK_NUM | |||
if block_width > 0: | |||
for i in range(BLOCK_NUM): | |||
diff = x_ori[i*block_width: (i+1)*block_width, :, :]\ | |||
- best_position[i*block_width:(i+1)*block_width, :, :] | |||
if np.max(np.abs(diff)) >= 0.1*(self._bounds[1] - self._bounds[0]): | |||
res = diff*0.1 | |||
best_position[i*block_width: (i+1)*block_width, :, :] += res | |||
_, correct_num = self._detection_scores((best_position,) + auxiliary_inputs, gt_boxes, | |||
gt_labels, model) | |||
q_times += 1 | |||
if correct_num[0] > max(original_num[0], gt_num[0]*self._reserve_ratio): | |||
best_position[i*block_width:(i+1)*block_width, :, :] -= res | |||
return best_position, q_times | |||
@staticmethod | |||
@@ -229,7 +237,7 @@ class Attack: | |||
max_iou_confi = 0 | |||
for j in range(gt_box_num): | |||
iou = calculate_iou(pred_box[:4], gt_box[j][:4]) | |||
if labels[i] == gt_label[j] and iou > iou_thres: | |||
if labels[i] == gt_label[j] and iou > iou_thres and correct_label_flag[j] == 0: | |||
max_iou_confi = max(max_iou_confi, pred_box[-1] + iou) | |||
correct_label_flag[j] = 1 | |||
score += max_iou_confi | |||
@@ -162,7 +162,10 @@ class GeneticAttack(Attack): | |||
inputs, labels = check_pair_numpy_param('inputs', inputs, | |||
'labels', labels) | |||
if self._sparse: | |||
label_squ = np.squeeze(labels) | |||
if labels.size > 1: | |||
label_squ = np.squeeze(labels) | |||
else: | |||
label_squ = labels | |||
if len(label_squ.shape) >= 2 or label_squ.shape[0] != inputs.shape[0]: | |||
msg = "The parameter 'sparse' of GeneticAttack is True, but the input labels is not sparse style " \ | |||
"and got its shape as {}.".format(labels.shape) | |||
@@ -198,16 +201,17 @@ class GeneticAttack(Attack): | |||
# generate particles | |||
ori_copies = np.repeat(x_ori[np.newaxis, :], self._pop_size, axis=0) | |||
# initial perturbations | |||
cur_pert = np.clip(np.random.random(ori_copies.shape)*self._step_size*pixel_deep, | |||
(0 - self._per_bounds)*pixel_deep, | |||
self._per_bounds*pixel_deep) | |||
cur_pert = np.random.uniform(self._bounds[0], self._bounds[1], ori_copies.shape) | |||
cur_pop = ori_copies + cur_pert | |||
query_times = 0 | |||
iters = 0 | |||
while iters < self._max_steps: | |||
iters += 1 | |||
cur_pop = np.clip( | |||
ori_copies + cur_pert, self._bounds[0], self._bounds[1]) | |||
cur_pop = np.clip(np.clip(cur_pop, | |||
ori_copies - pixel_deep*self._per_bounds, | |||
ori_copies + pixel_deep*self._per_bounds), | |||
self._bounds[0], self._bounds[1]) | |||
if self._model_type == 'classification': | |||
pop_preds = self._model.predict(cur_pop) | |||
@@ -235,9 +239,19 @@ class GeneticAttack(Attack): | |||
fit_vals = abs( | |||
confi_ori - confi_adv) - self._c / self._pop_size * np.linalg.norm( | |||
(cur_pop - x_ori).reshape(cur_pop.shape[0], -1), axis=1) | |||
if np.max(fit_vals) < 0: | |||
self._c /= 2 | |||
if np.max(fit_vals) < -2: | |||
LOGGER.debug(TAG, | |||
'best fitness value is %s, which is too small. We recommend that you decrease ' | |||
'the value of the initialization parameter c.', np.max(fit_vals)) | |||
if iters < 3 and np.max(fit_vals) > 100: | |||
LOGGER.debug(TAG, | |||
'best fitness value is %s, which is too large. We recommend that you increase ' | |||
'the value of the initialization parameter c.', np.max(fit_vals)) | |||
if np.min(correct_nums_adv) <= int(gt_object_num*self._reserve_ratio): | |||
is_success = True | |||
best_idx = np.argmin(correct_nums_adv) | |||
@@ -252,6 +266,7 @@ class GeneticAttack(Attack): | |||
break | |||
best_fit = max(fit_vals) | |||
if best_fit > self._best_fit: | |||
self._best_fit = best_fit | |||
self._plateau_times = 0 | |||
@@ -263,19 +278,19 @@ class GeneticAttack(Attack): | |||
self._plateau_times = 0 | |||
if self._adaptive: | |||
step_noise = max(self._step_size, 0.4*(0.9**self._adap_times)) | |||
step_p = max(self._step_size, 0.5*(0.9**self._adap_times)) | |||
step_p = max(self._mutation_rate, 0.5*(0.9**self._adap_times)) | |||
else: | |||
step_noise = self._step_size | |||
step_p = self._mutation_rate | |||
step_temp = self._temp | |||
elite = cur_pert[np.argmax(fit_vals)] | |||
elite = cur_pop[np.argmax(fit_vals)] | |||
select_probs = softmax(fit_vals/step_temp) | |||
select_args = np.arange(self._pop_size) | |||
parents_arg = np.random.choice( | |||
a=select_args, size=2*(self._pop_size - 1), | |||
replace=True, p=select_probs) | |||
parent1 = cur_pert[parents_arg[:self._pop_size - 1]] | |||
parent2 = cur_pert[parents_arg[self._pop_size - 1:]] | |||
parent1 = cur_pop[parents_arg[:self._pop_size - 1]] | |||
parent2 = cur_pop[parents_arg[self._pop_size - 1:]] | |||
parent1_probs = select_probs[parents_arg[:self._pop_size - 1]] | |||
parent2_probs = select_probs[parents_arg[self._pop_size - 1:]] | |||
parent2_probs = parent2_probs / (parent1_probs + parent2_probs) | |||
@@ -290,11 +305,11 @@ class GeneticAttack(Attack): | |||
mutated_childs = self._mutation( | |||
childs, step_noise=self._per_bounds*step_noise, | |||
prob=step_p) | |||
cur_pert = np.concatenate((mutated_childs, elite[np.newaxis, :])) | |||
cur_pop = np.concatenate((mutated_childs, elite[np.newaxis, :])) | |||
if not is_success: | |||
LOGGER.debug(TAG, 'fail to find adversarial sample.') | |||
final_adv = elite + x_ori | |||
final_adv = elite | |||
if self._model_type == 'detection': | |||
final_adv, query_times = self._fast_reduction( | |||
x_ori, final_adv, query_times, auxiliary_input_i, gt_boxes_i, gt_labels_i, model=self._model) | |||
@@ -161,15 +161,12 @@ class PSOAttack(Attack): | |||
Returns: | |||
numpy.ndarray, mutational inputs. | |||
""" | |||
# LOGGER.info(TAG, 'Mutation happens...') | |||
LOGGER.info(TAG, 'Mutation happens...') | |||
pixel_deep = self._bounds[1] - self._bounds[0] | |||
cur_pop = check_numpy_param('cur_pop', cur_pop) | |||
perturb_noise = (np.random.random(cur_pop.shape) - 0.5)*pixel_deep | |||
mutated_pop = perturb_noise*(np.random.random(cur_pop.shape) < self._pm) + cur_pop | |||
if self._model_type == 'classification': | |||
mutated_pop = np.clip(np.clip(mutated_pop, cur_pop - self._per_bounds*np.abs(cur_pop), | |||
cur_pop + self._per_bounds*np.abs(cur_pop)), | |||
self._bounds[0], self._bounds[1]) | |||
mutated_pop = np.clip(perturb_noise*(np.random.random(cur_pop.shape) < self._pm) + cur_pop, self._bounds[0], | |||
self._bounds[1]) | |||
return mutated_pop | |||
def generate(self, inputs, labels): | |||
@@ -199,7 +196,10 @@ class PSOAttack(Attack): | |||
inputs, labels = check_pair_numpy_param('inputs', inputs, | |||
'labels', labels) | |||
if self._sparse: | |||
label_squ = np.squeeze(labels) | |||
if labels.size > 1: | |||
label_squ = np.squeeze(labels) | |||
else: | |||
label_squ = labels | |||
if len(label_squ.shape) >= 2 or label_squ.shape[0] != inputs.shape[0]: | |||
msg = "The parameter 'sparse' of PSOAttack is True, but the input labels is not sparse style and " \ | |||
"got its shape as {}.".format(labels.shape) | |||
@@ -242,12 +242,12 @@ class PSOAttack(Attack): | |||
# initial global optimum position | |||
best_position = x_ori | |||
x_copies = np.repeat(x_ori[np.newaxis, :], self._pop_size, axis=0) | |||
cur_noise = np.clip((np.random.random(x_copies.shape) - 0.5)*pixel_deep*self._step_size, | |||
cur_noise = np.clip(np.random.random(x_copies.shape)*pixel_deep, | |||
(0 - self._per_bounds)*(np.abs(x_copies) + 0.1), | |||
self._per_bounds*(np.abs(x_copies) + 0.1)) | |||
par = np.clip(x_copies + cur_noise, self._bounds[0], self._bounds[1]) | |||
# initial advs | |||
par_ori = np.copy(par) | |||
par = np.clip(x_copies + cur_noise, self._bounds[0], self._bounds[1]) | |||
# initial optimum positions for particles | |||
par_best_poi = np.copy(par) | |||
# initial optimum fitness values | |||
@@ -264,12 +264,17 @@ class PSOAttack(Attack): | |||
v_particles = self._step_size*( | |||
v_particles + self._c1*ran_1*(best_position - par)) \ | |||
+ self._c2*ran_2*(par_best_poi - par) | |||
par = np.clip(np.clip(par + v_particles, | |||
par_ori - (np.abs(par_ori) + 0.1*pixel_deep)*self._per_bounds, | |||
par_ori + (np.abs(par_ori) + 0.1*pixel_deep)*self._per_bounds), | |||
self._bounds[0], self._bounds[1]) | |||
if iters > 20 and is_mutation: | |||
par += v_particles | |||
if iters > 6 and is_mutation: | |||
par = self._mutation_op(par) | |||
par = np.clip(np.clip(par, | |||
x_copies - (np.abs(x_copies) + 0.1*pixel_deep)*self._per_bounds, | |||
x_copies + (np.abs(x_copies) + 0.1*pixel_deep)*self._per_bounds), | |||
self._bounds[0], self._bounds[1]) | |||
if self._model_type == 'classification': | |||
confi_adv = self._confidence_cla(par, label_i) | |||
elif self._model_type == 'detection': | |||
@@ -283,8 +288,14 @@ class PSOAttack(Attack): | |||
par_best_poi[k] = par[k] | |||
if fit_value[k] > best_fitness: | |||
best_fitness = fit_value[k] | |||
best_position = par[k] | |||
best_position = par[k].copy() | |||
iters += 1 | |||
if best_fitness < -2: | |||
LOGGER.debug(TAG, 'best fitness value is %s, which is too small. We recommend that you decrease ' | |||
'the value of the initialization parameter c.', best_fitness) | |||
if iters < 3 and best_fitness > 100: | |||
LOGGER.debug(TAG, 'best fitness value is %s, which is too large. We recommend that you increase ' | |||
'the value of the initialization parameter c.', best_fitness) | |||
is_mutation = False | |||
if (best_fitness - last_best_fit) < last_best_fit*0.05: | |||
is_mutation = True | |||
@@ -324,7 +335,7 @@ class PSOAttack(Attack): | |||
adv_list.append(best_position) | |||
success_list.append(is_success) | |||
query_times_list.append(q_times) | |||
del x_copies, cur_noise, par, par_ori, par_best_poi | |||
del x_copies, cur_noise, par, par_best_poi | |||
return np.asarray(success_list), \ | |||
np.asarray(adv_list), \ | |||
np.asarray(query_times_list) |
@@ -175,7 +175,7 @@ def test_genetic_attack_detection_cpu(): | |||
""" | |||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | |||
batch_size = 2 | |||
inputs = np.random.random((batch_size, 3, 28, 28)) | |||
inputs = np.random.random((batch_size, 100, 100, 3)) | |||
model = DetectionModel() | |||
attack = GeneticAttack(model, model_type='detection', pop_size=6, mutation_rate=0.05, | |||
per_bounds=0.1, step_size=0.25, temp=0.1, | |||
@@ -201,7 +201,7 @@ def test_pso_attack_detection_cpu(): | |||
""" | |||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | |||
batch_size = 2 | |||
inputs = np.random.random((batch_size, 3, 28, 28)) | |||
inputs = np.random.random((batch_size, 100, 100, 3)) | |||
model = DetectionModel() | |||
attack = PSOAttack(model, t_max=30, pm=0.5, model_type='detection', reserve_ratio=0.5) | |||
@@ -115,7 +115,7 @@ def test_jsma_attack_gpu(): | |||
""" | |||
JSMA-Attack test | |||
""" | |||
context.set_context(device_target="GPU") | |||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
net = Net() | |||
input_shape = (1, 5) | |||
batch_size, classes = input_shape | |||