diff --git a/docs/api/api_python/mindarmour.reliability.rst b/docs/api/api_python/mindarmour.reliability.rst index f15aa56..4572f6f 100644 --- a/docs/api/api_python/mindarmour.reliability.rst +++ b/docs/api/api_python/mindarmour.reliability.rst @@ -108,6 +108,8 @@ MindArmour的可靠性方法。 .. py:method:: get_optimal_threshold(label, ds_eval) + 获取最佳阈值。尝试找到一个最佳阈值来检测OOD样本。最佳阈值由标记的数据集 `ds_eval` 计算。 + 参数: - **label** (numpy.ndarray) - 区分图像是否为分布内或分布外的标签。 - **ds_eval** (numpy.ndarray) - 帮助查找阈值的测试数据集。 diff --git a/examples/privacy/diff_privacy/lenet5_config.py b/examples/privacy/diff_privacy/lenet5_config.py index e0e2f97..dfd3983 100644 --- a/examples/privacy/diff_privacy/lenet5_config.py +++ b/examples/privacy/diff_privacy/lenet5_config.py @@ -33,7 +33,7 @@ mnist_cfg = edict({ 'dataset_sink_mode': False, # whether deliver all training data to device one time 'micro_batches': 32, # the number of small batches split from an original batch 'norm_bound': 1.0, # the clip bound of the gradients of model's training parameters - 'initial_noise_multiplier': 0.05, # the initial multiplication coefficient of the noise added to training + 'initial_noise_multiplier': 1.0, # the initial multiplication coefficient of the noise added to training # parameters' gradients 'noise_mechanisms': 'Gaussian', # the method of adding noise in gradients while training 'clip_mechanisms': 'Gaussian', # the method of adaptive clipping gradients while training @@ -41,5 +41,6 @@ mnist_cfg = edict({ 'clip_learning_rate': 0.001, # Learning rate of update norm clip. 'target_unclipped_quantile': 0.9, # Target quantile of norm clip. 'fraction_stddev': 0.01, # The stddev of Gaussian normal which used in empirical_fraction. - 'optimizer': 'Momentum' # the base optimizer used for Differential privacy training + 'optimizer': 'Momentum', # the base optimizer used for Differential privacy training + 'target_delta': 1e-5 # the target delta budget for DP training }) diff --git a/examples/privacy/diff_privacy/lenet5_dp.py b/examples/privacy/diff_privacy/lenet5_dp.py index c112699..49b686e 100644 --- a/examples/privacy/diff_privacy/lenet5_dp.py +++ b/examples/privacy/diff_privacy/lenet5_dp.py @@ -134,6 +134,7 @@ if __name__ == "__main__": batch_size=cfg.batch_size, initial_noise_multiplier=cfg.initial_noise_multiplier, per_print_times=234, + target_delta=cfg.target_delta, noise_decay_mode=None) # Create the DP model for training. model = DPModel(micro_batches=cfg.micro_batches, diff --git a/mindarmour/privacy/diff_privacy/train/model.py b/mindarmour/privacy/diff_privacy/train/model.py index 7af3480..abf5500 100644 --- a/mindarmour/privacy/diff_privacy/train/model.py +++ b/mindarmour/privacy/diff_privacy/train/model.py @@ -54,7 +54,6 @@ from ..mechanisms.mechanisms import _MechanismsParamsUpdater LOGGER = LogUtil.get_instance() TAG = 'DP model' -GRADIENT_CLIP_TYPE = 1 _grad_scale = C.MultitypeFuncGraph("grad_scale") _reciprocal = P.Reciprocal() @@ -76,8 +75,8 @@ class DPModel(Model): Args: micro_batches (int): The number of small batches split from an original batch. Default: 2. - norm_bound (float): Use to clip the bound, if set 1, will return the - original data. Default: 1.0. + norm_bound (float): The norm bound that is used to clip the gradient of + each sample. Default: 1.0. noise_mech (Mechanisms): The object can generate the different type of noise. Default: None. clip_mech (Mechanisms): The object is used to update the adaptive clip. @@ -275,9 +274,10 @@ class _ClipGradients(nn.Cell): Clip gradients. Inputs: - grads (tuple[Tensor]): Gradients. - clip_type (int): The way to clip, 0 for 'value', 1 for 'norm'. - clip_value (float): Specifies how much to clip. + grads (tuple[Tensor]): Gradients to clip. + clip_norm (float): The l2-norm bound used to clip the gradients. + cur_norm (float): The l2-norm of grads. If None, the norm will be + calculated in this function. Default: None. Outputs: tuple[Tensor], clipped gradients. @@ -285,24 +285,29 @@ class _ClipGradients(nn.Cell): def __init__(self): super(_ClipGradients, self).__init__() - self.clip_by_norm = nn.ClipByNorm() - self.dtype = P.DType() + self._add = P.Add() + self._reduce_sum = P.ReduceSum() + self._square_all = P.Square() + self._sqrt = P.Sqrt() - def construct(self, grads, clip_type, clip_value): + def construct(self, grads, clip_norm, cur_norm=None): """ construct a compute flow. """ - if clip_type not in (0, 1): + if cur_norm is None: + # calculate current l2-norm of grads + square_sum = Tensor(0, mstype.float32) + for grad in grads: + square_sum = self._add(square_sum, self._reduce_sum(self._square_all(grad))) + cur_norm = self._sqrt(square_sum) + + if cur_norm <= clip_norm: return grads new_grads = () for grad in grads: - if clip_type == 0: - norm = C.clip_by_value(grad, -clip_value, clip_value) - else: - norm = self.clip_by_norm(grad, clip_value) - new_grads = new_grads + (norm,) - + clipped_grad = grad * (clip_norm / cur_norm) + new_grads = new_grads + (clipped_grad,) return new_grads @@ -339,8 +344,8 @@ class _TrainOneStepWithLossScaleCell(Cell): Default: None. micro_batches (int): The number of small batches split from an original batch. Default: None. - norm_bound (Tensor): Use to clip the bound, if set 1, will return the - original data. Default: 1.0. + norm_bound (Tensor): The norm bound that is used to clip the gradient of + each sample. Default: 1.0. noise_mech (Mechanisms): The object can generate the different type of noise. Default: None. @@ -466,8 +471,8 @@ class _TrainOneStepWithLossScaleCell(Cell): beta = self._add(beta, self._cast(self._less(norm_grad, self._norm_bound), mstype.float32)) - record_grad = self._clip_by_global_norm(record_grad, GRADIENT_CLIP_TYPE, - self._norm_bound) + record_grad = self._clip_by_global_norm(record_grad, + self._norm_bound, norm_grad) grads = record_grad total_loss = loss for i in range(1, self._micro_batches): @@ -488,8 +493,7 @@ class _TrainOneStepWithLossScaleCell(Cell): mstype.float32)) record_grad = self._clip_by_global_norm(record_grad, - GRADIENT_CLIP_TYPE, - self._norm_bound) + self._norm_bound, norm_grad) grads = self._tuple_add(grads, record_grad) total_loss = P.Add()(total_loss, loss) loss = P.Div()(total_loss, self._micro_float) @@ -560,8 +564,8 @@ class _TrainOneStepCell(Cell): propagation. Default value is 1.0. micro_batches (int): The number of small batches split from an original batch. Default: None. - norm_bound (Tensor): Use to clip the bound, if set 1, will return the - original data. Default: 1.0. + norm_bound (Tensor): The norm bound that is used to clip the gradient of + each sample. Default: 1.0. noise_mech (Mechanisms): The object can generate the different type of noise. Default: None. clip_mech (Mechanisms): The object is used to update the adaptive clip. @@ -644,20 +648,22 @@ class _TrainOneStepCell(Cell): sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) record_grad = self.grad(self.network, weights)(record_datas[0], record_labels[0], sens) - beta = self._zero + + # calcu norm_grad + square_sum = self._zero + for grad in record_grad: + square_sum = self._add(square_sum, self._reduce_sum(self._square_all(grad))) + norm_grad = self._sqrt(square_sum) + # calcu beta + beta = self._zero if self._clip_mech is not None: - square_sum = self._zero - for grad in record_grad: - square_sum = self._add(square_sum, - self._reduce_sum(self._square_all(grad))) - norm_grad = self._sqrt(square_sum) beta = self._add(beta, self._cast(self._less(norm_grad, self._norm_bound), mstype.float32)) - record_grad = self._clip_by_global_norm(record_grad, GRADIENT_CLIP_TYPE, - self._norm_bound) + record_grad = self._clip_by_global_norm(record_grad, + self._norm_bound, norm_grad) grads = record_grad total_loss = loss for i in range(1, self._micro_batches): @@ -666,20 +672,22 @@ class _TrainOneStepCell(Cell): record_grad = self.grad(self.network, weights)(record_datas[i], record_labels[i], sens) + + # calcu norm_grad + square_sum = self._zero + for grad in record_grad: + square_sum = self._add(square_sum, + self._reduce_sum(self._square_all(grad))) + norm_grad = self._sqrt(square_sum) + # calcu beta if self._clip_mech is not None: - square_sum = self._zero - for grad in record_grad: - square_sum = self._add(square_sum, - self._reduce_sum(self._square_all(grad))) - norm_grad = self._sqrt(square_sum) beta = self._add(beta, self._cast(self._less(norm_grad, self._norm_bound), mstype.float32)) record_grad = self._clip_by_global_norm(record_grad, - GRADIENT_CLIP_TYPE, - self._norm_bound) + self._norm_bound, norm_grad) grads = self._tuple_add(grads, record_grad) total_loss = P.Add()(total_loss, loss) loss = self._div(total_loss, self._micro_float)