@@ -30,9 +30,9 @@ mnist_cfg = edict({ | |||
'keep_checkpoint_max': 10, # the maximum number of checkpoint files would be saved | |||
'device_target': 'Ascend', # device used | |||
'data_path': './MNIST_unzip', # the path of training and testing data set | |||
'dataset_sink_mode': False, # whether deliver all training data to device one time | |||
'dataset_sink_mode': False, # whether deliver all training data to device one time | |||
'micro_batches': 16, # the number of small batches split from an original batch | |||
'l2_norm_bound': 1.0, # the clip bound of the gradients of model's training parameters | |||
'norm_clip': 1.0, # the clip bound of the gradients of model's training parameters | |||
'initial_noise_multiplier': 0.2, # the initial multiplication coefficient of the noise added to training | |||
# parameters' gradients | |||
'mechanisms': 'AdaGaussian', # the method of adding noise in gradients while training | |||
@@ -108,7 +108,7 @@ if __name__ == "__main__": | |||
# means that the privacy protection effect is weak. Mechanisms can be 'Gaussian' or 'AdaGaussian', in which noise | |||
# would be decayed with 'AdaGaussian' mechanism while be constant with 'Gaussian' mechanism. | |||
mech = MechanismsFactory().create(cfg.mechanisms, | |||
norm_bound=cfg.l2_norm_bound, | |||
norm_bound=cfg.norm_clip, | |||
initial_noise_multiplier=cfg.initial_noise_multiplier) | |||
net_opt = nn.Momentum(params=network.trainable_params(), learning_rate=cfg.lr, momentum=cfg.momentum) | |||
# Create a monitor for DP training. The function of the monitor is to compute and print the privacy budget(eps | |||
@@ -116,11 +116,11 @@ if __name__ == "__main__": | |||
rdp_monitor = PrivacyMonitorFactory.create('rdp', | |||
num_samples=60000, | |||
batch_size=cfg.batch_size, | |||
initial_noise_multiplier=cfg.initial_noise_multiplier*cfg.l2_norm_bound, | |||
initial_noise_multiplier=cfg.initial_noise_multiplier*cfg.norm_clip, | |||
per_print_times=10) | |||
# Create the DP model for training. | |||
model = DPModel(micro_batches=cfg.micro_batches, | |||
norm_clip=cfg.l2_norm_bound, | |||
norm_clip=cfg.norm_clip, | |||
mech=mech, | |||
network=network, | |||
loss_fn=net_loss, | |||
@@ -109,7 +109,7 @@ if __name__ == "__main__": | |||
# would be decayed with 'AdaGaussian' mechanism while be constant with 'Gaussian' mechanism. | |||
dp_opt = DPOptimizerClassFactory(micro_batches=cfg.micro_batches) | |||
dp_opt.set_mechanisms(cfg.mechanisms, | |||
norm_bound=cfg.l2_norm_bound, | |||
norm_bound=cfg.norm_clip, | |||
initial_noise_multiplier=cfg.initial_noise_multiplier) | |||
net_opt = dp_opt.create('Momentum')(params=network.trainable_params(), learning_rate=cfg.lr, momentum=cfg.momentum) | |||
# Create a monitor for DP training. The function of the monitor is to compute and print the privacy budget(eps | |||
@@ -117,11 +117,11 @@ if __name__ == "__main__": | |||
rdp_monitor = PrivacyMonitorFactory.create('rdp', | |||
num_samples=60000, | |||
batch_size=cfg.batch_size, | |||
initial_noise_multiplier=cfg.initial_noise_multiplier*cfg.l2_norm_bound, | |||
initial_noise_multiplier=cfg.initial_noise_multiplier*cfg.norm_clip, | |||
per_print_times=10) | |||
# Create the DP model for training. | |||
model = DPModel(micro_batches=cfg.micro_batches, | |||
norm_clip=cfg.l2_norm_bound, | |||
norm_clip=cfg.norm_clip, | |||
mech=None, | |||
network=network, | |||
loss_fn=net_loss, | |||
@@ -93,7 +93,7 @@ class DPModel(Model): | |||
>>> loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) | |||
>>> net_opt = Momentum(params=net.trainable_params(), learning_rate=0.01, momentum=0.9) | |||
>>> mech = MechanismsFactory().create('Gaussian', | |||
>>> norm_bound=args.l2_norm_bound, | |||
>>> norm_bound=args.norm_clip, | |||
>>> initial_noise_multiplier=args.initial_noise_multiplier) | |||
>>> model = DPModel(micro_batches=2, | |||
>>> norm_clip=1.0, | |||
@@ -111,8 +111,8 @@ class DPModel(Model): | |||
self._micro_batches = check_int_positive('micro_batches', micro_batches) | |||
else: | |||
self._micro_batches = None | |||
float_norm_clip = check_param_type('l2_norm_clip', norm_clip, float) | |||
self._norm_clip = check_value_positive('l2_norm_clip', float_norm_clip) | |||
norm_clip = check_param_type('norm_clip', norm_clip, float) | |||
self._norm_clip = check_value_positive('norm_clip', norm_clip) | |||
if mech is not None and "DPOptimizer" in kwargs['optimizer'].__class__.__name__: | |||
raise ValueError('DPOptimizer is not supported while mech is not None') | |||
if mech is None: | |||
@@ -180,14 +180,14 @@ class DPModel(Model): | |||
optimizer, | |||
scale_update_cell=update_cell, | |||
micro_batches=self._micro_batches, | |||
l2_norm_clip=self._norm_clip, | |||
norm_clip=self._norm_clip, | |||
mech=self._mech).set_train() | |||
return network | |||
network = _TrainOneStepCell(network, | |||
optimizer, | |||
loss_scale, | |||
micro_batches=self._micro_batches, | |||
l2_norm_clip=self._norm_clip, | |||
norm_clip=self._norm_clip, | |||
mech=self._mech).set_train() | |||
return network | |||
@@ -300,7 +300,7 @@ class _TrainOneStepWithLossScaleCell(Cell): | |||
optimizer (Cell): Optimizer for updating the weights. | |||
scale_update_cell(Cell): The loss scaling update logic cell. Default: None. | |||
micro_batches (int): The number of small batches split from an original batch. Default: None. | |||
l2_norm_clip (float): Use to clip the bound, if set 1, will return the original data. Default: 1.0. | |||
norm_clip (float): Use to clip the bound, if set 1, will return the original data. Default: 1.0. | |||
mech (Mechanisms): The object can generate the different type of noise. Default: None. | |||
Inputs: | |||
@@ -316,7 +316,7 @@ class _TrainOneStepWithLossScaleCell(Cell): | |||
- **loss_scale** (Tensor) - Tensor with shape :math:`()`. | |||
""" | |||
def __init__(self, network, optimizer, scale_update_cell=None, micro_batches=None, l2_norm_clip=1.0, mech=None): | |||
def __init__(self, network, optimizer, scale_update_cell=None, micro_batches=None, norm_clip=1.0, mech=None): | |||
super(_TrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False) | |||
self.network = network | |||
self.network.set_grad() | |||
@@ -358,8 +358,8 @@ class _TrainOneStepWithLossScaleCell(Cell): | |||
# dp params | |||
self._micro_batches = micro_batches | |||
float_norm_clip = check_param_type('l2_norm_clip', l2_norm_clip, float) | |||
self._l2_norm = check_value_positive('l2_norm_clip', float_norm_clip) | |||
norm_clip = check_param_type('norm_clip', norm_clip, float) | |||
self._l2_norm = check_value_positive('norm_clip', norm_clip) | |||
self._split = P.Split(0, self._micro_batches) | |||
self._clip_by_global_norm = _ClipGradients() | |||
self._mech = mech | |||
@@ -452,7 +452,7 @@ class _TrainOneStepCell(Cell): | |||
optimizer (Cell): Optimizer for updating the weights. | |||
sens (Number): The scaling number to be filled as the input of back propagation. Default value is 1.0. | |||
micro_batches (int): The number of small batches split from an original batch. Default: None. | |||
l2_norm_clip (float): Use to clip the bound, if set 1, will return the original data. Default: 1.0. | |||
norm_clip (float): Use to clip the bound, if set 1, will return the original data. Default: 1.0. | |||
mech (Mechanisms): The object can generate the different type of noise. Default: None. | |||
Inputs: | |||
@@ -463,7 +463,7 @@ class _TrainOneStepCell(Cell): | |||
Tensor, a scalar Tensor with shape :math:`()`. | |||
""" | |||
def __init__(self, network, optimizer, sens=1.0, micro_batches=None, l2_norm_clip=1.0, mech=None): | |||
def __init__(self, network, optimizer, sens=1.0, micro_batches=None, norm_clip=1.0, mech=None): | |||
super(_TrainOneStepCell, self).__init__(auto_prefix=False) | |||
self.network = network | |||
self.network.set_grad() | |||
@@ -484,8 +484,8 @@ class _TrainOneStepCell(Cell): | |||
# dp params | |||
self._micro_batches = micro_batches | |||
float_norm_clip = check_param_type('l2_norm_clip', l2_norm_clip, float) | |||
self._l2_norm = check_value_positive('l2_norm_clip', float_norm_clip) | |||
norm_clip = check_param_type('norm_clip', norm_clip, float) | |||
self._l2_norm = check_value_positive('norm_clip', norm_clip) | |||
self._split = P.Split(0, self._micro_batches) | |||
self._clip_by_global_norm = _ClipGradients() | |||
self._mech = mech | |||
@@ -43,7 +43,7 @@ def dataset_generator(batch_size, batches): | |||
@pytest.mark.component_mindarmour | |||
def test_dp_model_pynative_mode(): | |||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") | |||
l2_norm_bound = 1.0 | |||
norm_clip = 1.0 | |||
initial_noise_multiplier = 0.01 | |||
network = LeNet5() | |||
batch_size = 32 | |||
@@ -53,11 +53,11 @@ def test_dp_model_pynative_mode(): | |||
loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) | |||
factory_opt = DPOptimizerClassFactory(micro_batches=micro_batches) | |||
factory_opt.set_mechanisms('Gaussian', | |||
norm_bound=l2_norm_bound, | |||
norm_bound=norm_clip, | |||
initial_noise_multiplier=initial_noise_multiplier) | |||
net_opt = factory_opt.create('Momentum')(network.trainable_params(), learning_rate=0.1, momentum=0.9) | |||
model = DPModel(micro_batches=micro_batches, | |||
norm_clip=l2_norm_bound, | |||
norm_clip=norm_clip, | |||
mech=None, | |||
network=network, | |||
loss_fn=loss, | |||
@@ -75,7 +75,7 @@ def test_dp_model_pynative_mode(): | |||
@pytest.mark.component_mindarmour | |||
def test_dp_model_with_graph_mode(): | |||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
l2_norm_bound = 1.0 | |||
norm_clip = 1.0 | |||
initial_noise_multiplier = 0.01 | |||
network = LeNet5() | |||
batch_size = 32 | |||
@@ -83,11 +83,11 @@ def test_dp_model_with_graph_mode(): | |||
epochs = 1 | |||
loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) | |||
mech = MechanismsFactory().create('Gaussian', | |||
norm_bound=l2_norm_bound, | |||
norm_bound=norm_clip, | |||
initial_noise_multiplier=initial_noise_multiplier) | |||
net_opt = nn.Momentum(network.trainable_params(), learning_rate=0.1, momentum=0.9) | |||
model = DPModel(micro_batches=2, | |||
norm_clip=l2_norm_bound, | |||
norm_clip=norm_clip, | |||
mech=mech, | |||
network=network, | |||
loss_fn=loss, | |||