diff --git a/examples/model_security/model_attacks/black_box/mnist_attack_genetic.py b/examples/model_security/model_attacks/black_box/mnist_attack_genetic.py index 4b947bc..1ba1e92 100644 --- a/examples/model_security/model_attacks/black_box/mnist_attack_genetic.py +++ b/examples/model_security/model_attacks/black_box/mnist_attack_genetic.py @@ -87,7 +87,7 @@ def test_genetic_attack_on_mnist(): # attacking attack = GeneticAttack(model=model, pop_size=6, mutation_rate=0.05, - per_bounds=0.1, step_size=0.25, temp=0.1, + per_bounds=0.4, step_size=0.25, temp=0.1, sparse=True) targeted_labels = np.random.randint(0, 10, size=len(true_labels)) for i, true_l in enumerate(true_labels): diff --git a/examples/privacy/diff_privacy/lenet5_dp.py b/examples/privacy/diff_privacy/lenet5_dp.py index ad0e243..d6f74e0 100644 --- a/examples/privacy/diff_privacy/lenet5_dp.py +++ b/examples/privacy/diff_privacy/lenet5_dp.py @@ -107,9 +107,7 @@ if __name__ == "__main__": raise ValueError( "Number of micro_batches should divide evenly batch_size") # Create a factory class of DP noise mechanisms, this method is adding noise - # in gradients while training. Initial_noise_multiplier is suggested to be - # greater than 1.0, otherwise the privacy budget would be huge, which means - # that the privacy protection effect is weak. Mechanisms can be 'Gaussian' + # in gradients while training. Mechanisms can be 'Gaussian' # or 'AdaGaussian', in which noise would be decayed with 'AdaGaussian' # mechanism while be constant with 'Gaussian' mechanism. noise_mech = NoiseMechanismsFactory().create(cfg.noise_mechanisms, diff --git a/examples/privacy/diff_privacy/lenet5_dp_ada_gaussian.py b/examples/privacy/diff_privacy/lenet5_dp_ada_gaussian.py index e23b6ed..918d824 100644 --- a/examples/privacy/diff_privacy/lenet5_dp_ada_gaussian.py +++ b/examples/privacy/diff_privacy/lenet5_dp_ada_gaussian.py @@ -106,9 +106,7 @@ if __name__ == "__main__": raise ValueError( "Number of micro_batches should divide evenly batch_size") # Create a factory class of DP noise mechanisms, this method is adding noise - # in gradients while training. Initial_noise_multiplier is suggested to be - # greater than 1.0, otherwise the privacy budget would be huge, which means - # that the privacy protection effect is weak. Mechanisms can be 'Gaussian' + # in gradients while training. Mechanisms can be 'Gaussian' # or 'AdaGaussian', in which noise would be decayed with 'AdaGaussian' # mechanism while be constant with 'Gaussian' mechanism. noise_mech = NoiseMechanismsFactory().create(cfg.noise_mechanisms, diff --git a/examples/privacy/diff_privacy/lenet5_dp_ada_sgd_graph.py b/examples/privacy/diff_privacy/lenet5_dp_ada_sgd_graph.py index c223eb8..0c36419 100644 --- a/examples/privacy/diff_privacy/lenet5_dp_ada_sgd_graph.py +++ b/examples/privacy/diff_privacy/lenet5_dp_ada_sgd_graph.py @@ -106,9 +106,7 @@ if __name__ == "__main__": raise ValueError( "Number of micro_batches should divide evenly batch_size") # Create a factory class of DP noise mechanisms, this method is adding noise - # in gradients while training. Initial_noise_multiplier is suggested to be - # greater than 1.0, otherwise the privacy budget would be huge, which means - # that the privacy protection effect is weak. Mechanisms can be 'Gaussian' + # in gradients while training. Mechanisms can be 'Gaussian' # or 'AdaGaussian', in which noise would be decayed with 'AdaGaussian' # mechanism while be constant with 'Gaussian' mechanism. noise_mech = NoiseMechanismsFactory().create(cfg.noise_mechanisms, diff --git a/examples/privacy/diff_privacy/lenet5_dp_optimizer.py b/examples/privacy/diff_privacy/lenet5_dp_optimizer.py index 5bee8a5..05f7205 100644 --- a/examples/privacy/diff_privacy/lenet5_dp_optimizer.py +++ b/examples/privacy/diff_privacy/lenet5_dp_optimizer.py @@ -103,8 +103,7 @@ if __name__ == "__main__": if cfg.micro_batches and cfg.batch_size % cfg.micro_batches != 0: raise ValueError("Number of micro_batches should divide evenly batch_size") # Create a factory class of DP mechanisms, this method is adding noise in gradients while training. - # Initial_noise_multiplier is suggested to be greater than 1.0, otherwise the privacy budget would be huge, which - # means that the privacy protection effect is weak. Mechanisms can be 'Gaussian' or 'AdaGaussian', in which noise + # Mechanisms can be 'Gaussian' or 'AdaGaussian', in which noise # would be decayed with 'AdaGaussian' mechanism while be constant with 'Gaussian' mechanism. dp_opt = DPOptimizerClassFactory(micro_batches=cfg.micro_batches) dp_opt.set_mechanisms(cfg.noise_mechanisms, diff --git a/mindarmour/adv_robustness/attacks/attack.py b/mindarmour/adv_robustness/attacks/attack.py index 1c97a2a..bf543bd 100644 --- a/mindarmour/adv_robustness/attacks/attack.py +++ b/mindarmour/adv_robustness/attacks/attack.py @@ -179,19 +179,27 @@ class Attack: LOGGER.info(TAG, 'Reduction begins...') model = check_model('model', model, BlackModel) x_ori = check_numpy_param('x_ori', x_ori) + _, gt_num = self._detection_scores((x_ori,) + auxiliary_inputs, gt_boxes, gt_labels, model) best_position = check_numpy_param('best_position', best_position) x_ori, best_position = check_equal_shape('x_ori', x_ori, 'best_position', best_position) - x_shape = best_position.shape - reduction_iters = 1000 # recover 0.1% each step _, original_num = self._detection_scores((best_position,) + auxiliary_inputs, gt_boxes, gt_labels, model) - for _ in range(reduction_iters): - diff = x_ori - best_position - res = 0.5*diff*(np.random.random(x_shape) < 0.001) - best_position += res - _, correct_num = self._detection_scores((best_position,) + auxiliary_inputs, gt_boxes, gt_labels, model) - q_times += 1 - if correct_num > original_num: - best_position -= res + # pylint: disable=invalid-name + REDUCTION_ITERS = 6 # recover 10% difference each time and recover 60% totally. + for _ in range(REDUCTION_ITERS): + BLOCK_NUM = 30 # divide the image into 30 segments + block_width = best_position.shape[0] // BLOCK_NUM + if block_width > 0: + for i in range(BLOCK_NUM): + diff = x_ori[i*block_width: (i+1)*block_width, :, :]\ + - best_position[i*block_width:(i+1)*block_width, :, :] + if np.max(np.abs(diff)) >= 0.1*(self._bounds[1] - self._bounds[0]): + res = diff*0.1 + best_position[i*block_width: (i+1)*block_width, :, :] += res + _, correct_num = self._detection_scores((best_position,) + auxiliary_inputs, gt_boxes, + gt_labels, model) + q_times += 1 + if correct_num[0] > max(original_num[0], gt_num[0]*self._reserve_ratio): + best_position[i*block_width:(i+1)*block_width, :, :] -= res return best_position, q_times @staticmethod @@ -229,7 +237,7 @@ class Attack: max_iou_confi = 0 for j in range(gt_box_num): iou = calculate_iou(pred_box[:4], gt_box[j][:4]) - if labels[i] == gt_label[j] and iou > iou_thres: + if labels[i] == gt_label[j] and iou > iou_thres and correct_label_flag[j] == 0: max_iou_confi = max(max_iou_confi, pred_box[-1] + iou) correct_label_flag[j] = 1 score += max_iou_confi diff --git a/mindarmour/adv_robustness/attacks/black/genetic_attack.py b/mindarmour/adv_robustness/attacks/black/genetic_attack.py index dbacc40..7a52379 100644 --- a/mindarmour/adv_robustness/attacks/black/genetic_attack.py +++ b/mindarmour/adv_robustness/attacks/black/genetic_attack.py @@ -162,7 +162,10 @@ class GeneticAttack(Attack): inputs, labels = check_pair_numpy_param('inputs', inputs, 'labels', labels) if self._sparse: - label_squ = np.squeeze(labels) + if labels.size > 1: + label_squ = np.squeeze(labels) + else: + label_squ = labels if len(label_squ.shape) >= 2 or label_squ.shape[0] != inputs.shape[0]: msg = "The parameter 'sparse' of GeneticAttack is True, but the input labels is not sparse style " \ "and got its shape as {}.".format(labels.shape) @@ -198,16 +201,17 @@ class GeneticAttack(Attack): # generate particles ori_copies = np.repeat(x_ori[np.newaxis, :], self._pop_size, axis=0) # initial perturbations - cur_pert = np.clip(np.random.random(ori_copies.shape)*self._step_size*pixel_deep, - (0 - self._per_bounds)*pixel_deep, - self._per_bounds*pixel_deep) - + cur_pert = np.random.uniform(self._bounds[0], self._bounds[1], ori_copies.shape) + cur_pop = ori_copies + cur_pert query_times = 0 iters = 0 + while iters < self._max_steps: iters += 1 - cur_pop = np.clip( - ori_copies + cur_pert, self._bounds[0], self._bounds[1]) + cur_pop = np.clip(np.clip(cur_pop, + ori_copies - pixel_deep*self._per_bounds, + ori_copies + pixel_deep*self._per_bounds), + self._bounds[0], self._bounds[1]) if self._model_type == 'classification': pop_preds = self._model.predict(cur_pop) @@ -235,9 +239,19 @@ class GeneticAttack(Attack): fit_vals = abs( confi_ori - confi_adv) - self._c / self._pop_size * np.linalg.norm( (cur_pop - x_ori).reshape(cur_pop.shape[0], -1), axis=1) + if np.max(fit_vals) < 0: self._c /= 2 + if np.max(fit_vals) < -2: + LOGGER.debug(TAG, + 'best fitness value is %s, which is too small. We recommend that you decrease ' + 'the value of the initialization parameter c.', np.max(fit_vals)) + if iters < 3 and np.max(fit_vals) > 100: + LOGGER.debug(TAG, + 'best fitness value is %s, which is too large. We recommend that you increase ' + 'the value of the initialization parameter c.', np.max(fit_vals)) + if np.min(correct_nums_adv) <= int(gt_object_num*self._reserve_ratio): is_success = True best_idx = np.argmin(correct_nums_adv) @@ -252,6 +266,7 @@ class GeneticAttack(Attack): break best_fit = max(fit_vals) + if best_fit > self._best_fit: self._best_fit = best_fit self._plateau_times = 0 @@ -263,19 +278,19 @@ class GeneticAttack(Attack): self._plateau_times = 0 if self._adaptive: step_noise = max(self._step_size, 0.4*(0.9**self._adap_times)) - step_p = max(self._step_size, 0.5*(0.9**self._adap_times)) + step_p = max(self._mutation_rate, 0.5*(0.9**self._adap_times)) else: step_noise = self._step_size step_p = self._mutation_rate step_temp = self._temp - elite = cur_pert[np.argmax(fit_vals)] + elite = cur_pop[np.argmax(fit_vals)] select_probs = softmax(fit_vals/step_temp) select_args = np.arange(self._pop_size) parents_arg = np.random.choice( a=select_args, size=2*(self._pop_size - 1), replace=True, p=select_probs) - parent1 = cur_pert[parents_arg[:self._pop_size - 1]] - parent2 = cur_pert[parents_arg[self._pop_size - 1:]] + parent1 = cur_pop[parents_arg[:self._pop_size - 1]] + parent2 = cur_pop[parents_arg[self._pop_size - 1:]] parent1_probs = select_probs[parents_arg[:self._pop_size - 1]] parent2_probs = select_probs[parents_arg[self._pop_size - 1:]] parent2_probs = parent2_probs / (parent1_probs + parent2_probs) @@ -290,11 +305,11 @@ class GeneticAttack(Attack): mutated_childs = self._mutation( childs, step_noise=self._per_bounds*step_noise, prob=step_p) - cur_pert = np.concatenate((mutated_childs, elite[np.newaxis, :])) + cur_pop = np.concatenate((mutated_childs, elite[np.newaxis, :])) if not is_success: LOGGER.debug(TAG, 'fail to find adversarial sample.') - final_adv = elite + x_ori + final_adv = elite if self._model_type == 'detection': final_adv, query_times = self._fast_reduction( x_ori, final_adv, query_times, auxiliary_input_i, gt_boxes_i, gt_labels_i, model=self._model) diff --git a/mindarmour/adv_robustness/attacks/black/pso_attack.py b/mindarmour/adv_robustness/attacks/black/pso_attack.py index 2811dc3..25bcf54 100644 --- a/mindarmour/adv_robustness/attacks/black/pso_attack.py +++ b/mindarmour/adv_robustness/attacks/black/pso_attack.py @@ -161,15 +161,12 @@ class PSOAttack(Attack): Returns: numpy.ndarray, mutational inputs. """ - # LOGGER.info(TAG, 'Mutation happens...') + LOGGER.info(TAG, 'Mutation happens...') pixel_deep = self._bounds[1] - self._bounds[0] cur_pop = check_numpy_param('cur_pop', cur_pop) perturb_noise = (np.random.random(cur_pop.shape) - 0.5)*pixel_deep - mutated_pop = perturb_noise*(np.random.random(cur_pop.shape) < self._pm) + cur_pop - if self._model_type == 'classification': - mutated_pop = np.clip(np.clip(mutated_pop, cur_pop - self._per_bounds*np.abs(cur_pop), - cur_pop + self._per_bounds*np.abs(cur_pop)), - self._bounds[0], self._bounds[1]) + mutated_pop = np.clip(perturb_noise*(np.random.random(cur_pop.shape) < self._pm) + cur_pop, self._bounds[0], + self._bounds[1]) return mutated_pop def generate(self, inputs, labels): @@ -199,7 +196,10 @@ class PSOAttack(Attack): inputs, labels = check_pair_numpy_param('inputs', inputs, 'labels', labels) if self._sparse: - label_squ = np.squeeze(labels) + if labels.size > 1: + label_squ = np.squeeze(labels) + else: + label_squ = labels if len(label_squ.shape) >= 2 or label_squ.shape[0] != inputs.shape[0]: msg = "The parameter 'sparse' of PSOAttack is True, but the input labels is not sparse style and " \ "got its shape as {}.".format(labels.shape) @@ -242,12 +242,12 @@ class PSOAttack(Attack): # initial global optimum position best_position = x_ori x_copies = np.repeat(x_ori[np.newaxis, :], self._pop_size, axis=0) - cur_noise = np.clip((np.random.random(x_copies.shape) - 0.5)*pixel_deep*self._step_size, + cur_noise = np.clip(np.random.random(x_copies.shape)*pixel_deep, (0 - self._per_bounds)*(np.abs(x_copies) + 0.1), self._per_bounds*(np.abs(x_copies) + 0.1)) - par = np.clip(x_copies + cur_noise, self._bounds[0], self._bounds[1]) + # initial advs - par_ori = np.copy(par) + par = np.clip(x_copies + cur_noise, self._bounds[0], self._bounds[1]) # initial optimum positions for particles par_best_poi = np.copy(par) # initial optimum fitness values @@ -264,12 +264,17 @@ class PSOAttack(Attack): v_particles = self._step_size*( v_particles + self._c1*ran_1*(best_position - par)) \ + self._c2*ran_2*(par_best_poi - par) - par = np.clip(np.clip(par + v_particles, - par_ori - (np.abs(par_ori) + 0.1*pixel_deep)*self._per_bounds, - par_ori + (np.abs(par_ori) + 0.1*pixel_deep)*self._per_bounds), - self._bounds[0], self._bounds[1]) - if iters > 20 and is_mutation: + + par += v_particles + + if iters > 6 and is_mutation: par = self._mutation_op(par) + + par = np.clip(np.clip(par, + x_copies - (np.abs(x_copies) + 0.1*pixel_deep)*self._per_bounds, + x_copies + (np.abs(x_copies) + 0.1*pixel_deep)*self._per_bounds), + self._bounds[0], self._bounds[1]) + if self._model_type == 'classification': confi_adv = self._confidence_cla(par, label_i) elif self._model_type == 'detection': @@ -283,8 +288,14 @@ class PSOAttack(Attack): par_best_poi[k] = par[k] if fit_value[k] > best_fitness: best_fitness = fit_value[k] - best_position = par[k] + best_position = par[k].copy() iters += 1 + if best_fitness < -2: + LOGGER.debug(TAG, 'best fitness value is %s, which is too small. We recommend that you decrease ' + 'the value of the initialization parameter c.', best_fitness) + if iters < 3 and best_fitness > 100: + LOGGER.debug(TAG, 'best fitness value is %s, which is too large. We recommend that you increase ' + 'the value of the initialization parameter c.', best_fitness) is_mutation = False if (best_fitness - last_best_fit) < last_best_fit*0.05: is_mutation = True @@ -324,7 +335,7 @@ class PSOAttack(Attack): adv_list.append(best_position) success_list.append(is_success) query_times_list.append(q_times) - del x_copies, cur_noise, par, par_ori, par_best_poi + del x_copies, cur_noise, par, par_best_poi return np.asarray(success_list), \ np.asarray(adv_list), \ np.asarray(query_times_list) diff --git a/tests/ut/python/adv_robustness/attacks/black/test_genetic_attack.py b/tests/ut/python/adv_robustness/attacks/black/test_genetic_attack.py index 76f860d..bcc3552 100644 --- a/tests/ut/python/adv_robustness/attacks/black/test_genetic_attack.py +++ b/tests/ut/python/adv_robustness/attacks/black/test_genetic_attack.py @@ -175,7 +175,7 @@ def test_genetic_attack_detection_cpu(): """ context.set_context(mode=context.GRAPH_MODE, device_target="CPU") batch_size = 2 - inputs = np.random.random((batch_size, 3, 28, 28)) + inputs = np.random.random((batch_size, 100, 100, 3)) model = DetectionModel() attack = GeneticAttack(model, model_type='detection', pop_size=6, mutation_rate=0.05, per_bounds=0.1, step_size=0.25, temp=0.1, diff --git a/tests/ut/python/adv_robustness/attacks/black/test_pso_attack.py b/tests/ut/python/adv_robustness/attacks/black/test_pso_attack.py index d3da309..7a5b5e0 100644 --- a/tests/ut/python/adv_robustness/attacks/black/test_pso_attack.py +++ b/tests/ut/python/adv_robustness/attacks/black/test_pso_attack.py @@ -201,7 +201,7 @@ def test_pso_attack_detection_cpu(): """ context.set_context(mode=context.GRAPH_MODE, device_target="CPU") batch_size = 2 - inputs = np.random.random((batch_size, 3, 28, 28)) + inputs = np.random.random((batch_size, 100, 100, 3)) model = DetectionModel() attack = PSOAttack(model, t_max=30, pm=0.5, model_type='detection', reserve_ratio=0.5) diff --git a/tests/ut/python/adv_robustness/attacks/test_jsma.py b/tests/ut/python/adv_robustness/attacks/test_jsma.py index 5457a98..98678dd 100644 --- a/tests/ut/python/adv_robustness/attacks/test_jsma.py +++ b/tests/ut/python/adv_robustness/attacks/test_jsma.py @@ -115,7 +115,7 @@ def test_jsma_attack_gpu(): """ JSMA-Attack test """ - context.set_context(device_target="GPU") + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") net = Net() input_shape = (1, 5) batch_size, classes = input_shape