Compare commits

...

8 Commits
master ... r0.5

Author SHA1 Message Date
  mindspore-ci-bot 456a8a03a9 !47 update release 5 years ago
  zhenghuanhuan 6029a36719 update release. 5 years ago
  mindspore-ci-bot 4ce64a6473 !45 fix issue 5 years ago
  mindspore-ci-bot 92165efc37 !42 Add example of MechanismsFactory and add example of mech=None in dpmodel 5 years ago
  ZhidanLiu aaa9f89f7d add example of MechanismsFactory and add example of mech=None in dpmodel 5 years ago
  mindspore-ci-bot 4f2b3cf4c2 !40 Solve issue:[CT][MA][DP]TGaussian default parameters in graph mode is unqualified. https://gitee.com/mindspore/dashboard/issues?id=I1LMJD 5 years ago
  ZhidanLiu 9649ab3d8d update default value 5 years ago
  mindspore-ci-bot 2905c9e014 !41 fix issue 5 years ago
5 changed files with 90 additions and 40 deletions
Unified View
  1. +23
    -0
      RELEASE.md
  2. +1
    -1
      example/mnist_demo/lenet5_config.py
  3. +1
    -1
      example/mnist_demo/lenet5_dp.py
  4. +46
    -10
      mindarmour/diff_privacy/mechanisms/mechanisms.py
  5. +19
    -28
      mindarmour/diff_privacy/train/model.py

+ 23
- 0
RELEASE.md View File

@@ -1,3 +1,26 @@
# Release 0.5.0-beta

## Major Features and Improvements

### Differential privacy model training

* Optimizers with differential privacy

* Differential privacy model training now supports both Pynative mode and graph mode.

* Graph mode is recommended for its performance.

## Bugfixes

## Contributors

Thanks goes to these wonderful people:

Liu Liu, Huanhuan Zheng, Xiulang Jin, Zhidan Liu.

Contributions of any kind are welcome!


# Release 0.3.0-alpha # Release 0.3.0-alpha


## Major Features and Improvements ## Major Features and Improvements


+ 1
- 1
example/mnist_demo/lenet5_config.py View File

@@ -33,7 +33,7 @@ mnist_cfg = edict({
'dataset_sink_mode': False, # whether deliver all training data to device one time 'dataset_sink_mode': False, # whether deliver all training data to device one time
'micro_batches': 16, # the number of small batches split from an original batch 'micro_batches': 16, # the number of small batches split from an original batch
'norm_clip': 1.0, # the clip bound of the gradients of model's training parameters 'norm_clip': 1.0, # the clip bound of the gradients of model's training parameters
'initial_noise_multiplier': 0.2, # the initial multiplication coefficient of the noise added to training
'initial_noise_multiplier': 1.5, # the initial multiplication coefficient of the noise added to training
# parameters' gradients # parameters' gradients
'mechanisms': 'AdaGaussian', # the method of adding noise in gradients while training 'mechanisms': 'AdaGaussian', # the method of adding noise in gradients while training
'optimizer': 'Momentum' # the base optimizer used for Differential privacy training 'optimizer': 'Momentum' # the base optimizer used for Differential privacy training


+ 1
- 1
example/mnist_demo/lenet5_dp.py View File

@@ -87,7 +87,7 @@ def generate_mnist_dataset(data_path, batch_size=32, repeat_size=1,


if __name__ == "__main__": if __name__ == "__main__":
# This configure can run both in pynative mode and graph mode # This configure can run both in pynative mode and graph mode
context.set_context(mode=context.PYNATIVE_MODE, device_target=cfg.device_target)
context.set_context(mode=context.GRAPH_MODE, device_target=cfg.device_target)
network = LeNet5() network = LeNet5()
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps, config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps,


+ 46
- 10
mindarmour/diff_privacy/mechanisms/mechanisms.py View File

@@ -38,8 +38,8 @@ class MechanismsFactory:
""" """
Args: Args:
policy(str): Noise generated strategy, could be 'Gaussian' or policy(str): Noise generated strategy, could be 'Gaussian' or
'AdaGaussian'. Noise would be decayed with 'AdaGaussian' mechanism while
be constant with 'Gaussian' mechanism. Default: 'AdaGaussian'.
'AdaGaussian'. Noise would be decayed with 'AdaGaussian' mechanism
while be constant with 'Gaussian' mechanism.
args(Union[float, str]): Parameters used for creating noise args(Union[float, str]): Parameters used for creating noise
mechanisms. mechanisms.
kwargs(Union[float, str]): Parameters used for creating noise kwargs(Union[float, str]): Parameters used for creating noise
@@ -47,8 +47,44 @@ class MechanismsFactory:


Raises: Raises:
NameError: `policy` must be in ['Gaussian', 'AdaGaussian']. NameError: `policy` must be in ['Gaussian', 'AdaGaussian'].

Returns: Returns:
Mechanisms, class of noise generated Mechanism. Mechanisms, class of noise generated Mechanism.

Examples:
>>> class Net(nn.Cell):
>>> def __init__(self):
>>> super(Net, self).__init__()
>>> self.conv = nn.Conv2d(3, 64, 3, has_bias=False, weight_init='normal')
>>> self.bn = nn.BatchNorm2d(64)
>>> self.relu = nn.ReLU()
>>> self.flatten = nn.Flatten()
>>> self.fc = nn.Dense(64*224*224, 12) # padding=0
>>>
>>> def construct(self, x):
>>> x = self.conv(x)
>>> x = self.bn(x)
>>> x = self.relu(x)
>>> x = self.flatten(x)
>>> out = self.fc(x)
>>> return out
>>> norm_clip = 1.0
>>> initial_noise_multiplier = 1.5
>>> net = Net()
>>> loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
>>> net_opt = Momentum(params=net.trainable_params(), learning_rate=0.01, momentum=0.9)
>>> mech = MechanismsFactory().create('Gaussian',
>>> norm_bound=norm_clip,
>>> initial_noise_multiplier=initial_noise_multiplier)
>>> model = DPModel(micro_batches=2,
>>> norm_clip=1.0,
>>> mech=mech,
>>> network=net,
>>> loss_fn=loss,
>>> optimizer=net_opt,
>>> metrics=None)
>>> dataset = get_dataset()
>>> model.train(2, dataset)
""" """
if policy == 'Gaussian': if policy == 'Gaussian':
return GaussianRandom(*args, **kwargs) return GaussianRandom(*args, **kwargs)
@@ -75,7 +111,7 @@ class GaussianRandom(Mechanisms):


Args: Args:
norm_bound(float): Clipping bound for the l2 norm of the gradients. norm_bound(float): Clipping bound for the l2 norm of the gradients.
Default: 1.0.
Default: 0.5.
initial_noise_multiplier(float): Ratio of the standard deviation of initial_noise_multiplier(float): Ratio of the standard deviation of
Gaussian noise divided by the norm_bound, which will be used to Gaussian noise divided by the norm_bound, which will be used to
calculate privacy spent. Default: 1.5. calculate privacy spent. Default: 1.5.
@@ -87,14 +123,14 @@ class GaussianRandom(Mechanisms):


Examples: Examples:
>>> gradients = Tensor([0.2, 0.9], mstype.float32) >>> gradients = Tensor([0.2, 0.9], mstype.float32)
>>> norm_bound = 1.0
>>> initial_noise_multiplier = 0.1
>>> norm_bound = 0.5
>>> initial_noise_multiplier = 1.5
>>> net = GaussianRandom(norm_bound, initial_noise_multiplier) >>> net = GaussianRandom(norm_bound, initial_noise_multiplier)
>>> res = net(gradients) >>> res = net(gradients)
>>> print(res) >>> print(res)
""" """


def __init__(self, norm_bound=1.0, initial_noise_multiplier=1.5, mean=0.0, seed=0):
def __init__(self, norm_bound=0.5, initial_noise_multiplier=1.5, mean=0.0, seed=0):
super(GaussianRandom, self).__init__() super(GaussianRandom, self).__init__()
self._norm_bound = check_value_positive('norm_bound', norm_bound) self._norm_bound = check_value_positive('norm_bound', norm_bound)
self._norm_bound = Tensor(norm_bound, mstype.float32) self._norm_bound = Tensor(norm_bound, mstype.float32)
@@ -129,10 +165,10 @@ class AdaGaussianRandom(Mechanisms):


Args: Args:
norm_bound(float): Clipping bound for the l2 norm of the gradients. norm_bound(float): Clipping bound for the l2 norm of the gradients.
Default: 1.5.
Default: 1.0.
initial_noise_multiplier(float): Ratio of the standard deviation of initial_noise_multiplier(float): Ratio of the standard deviation of
Gaussian noise divided by the norm_bound, which will be used to Gaussian noise divided by the norm_bound, which will be used to
calculate privacy spent. Default: 5.0.
calculate privacy spent. Default: 1.5.
mean(float): Average value of random noise. Default: 0.0 mean(float): Average value of random noise. Default: 0.0
noise_decay_rate(float): Hyper parameter for controlling the noise decay. noise_decay_rate(float): Hyper parameter for controlling the noise decay.
Default: 6e-4. Default: 6e-4.
@@ -146,7 +182,7 @@ class AdaGaussianRandom(Mechanisms):
Examples: Examples:
>>> gradients = Tensor([0.2, 0.9], mstype.float32) >>> gradients = Tensor([0.2, 0.9], mstype.float32)
>>> norm_bound = 1.0 >>> norm_bound = 1.0
>>> initial_noise_multiplier = 5.0
>>> initial_noise_multiplier = 1.5
>>> mean = 0.0 >>> mean = 0.0
>>> noise_decay_rate = 6e-4 >>> noise_decay_rate = 6e-4
>>> decay_policy = "Time" >>> decay_policy = "Time"
@@ -156,7 +192,7 @@ class AdaGaussianRandom(Mechanisms):
>>> print(res) >>> print(res)
""" """


def __init__(self, norm_bound=1.5, initial_noise_multiplier=5.0, mean=0.0,
def __init__(self, norm_bound=1.0, initial_noise_multiplier=1.5, mean=0.0,
noise_decay_rate=6e-4, decay_policy='Time', seed=0): noise_decay_rate=6e-4, decay_policy='Time', seed=0):
super(AdaGaussianRandom, self).__init__() super(AdaGaussianRandom, self).__init__()
norm_bound = check_value_positive('norm_bound', norm_bound) norm_bound = check_value_positive('norm_bound', norm_bound)


+ 19
- 28
mindarmour/diff_privacy/train/model.py View File

@@ -72,38 +72,29 @@ class DPModel(Model):
mech (Mechanisms): The object can generate the different type of noise. Default: None. mech (Mechanisms): The object can generate the different type of noise. Default: None.


Examples: Examples:
>>> class Net(nn.Cell):
>>> def __init__(self):
>>> super(Net, self).__init__()
>>> self.conv = nn.Conv2d(3, 64, 3, has_bias=False, weight_init='normal')
>>> self.bn = nn.BatchNorm2d(64)
>>> self.relu = nn.ReLU()
>>> self.flatten = nn.Flatten()
>>> self.fc = nn.Dense(64*224*224, 12) # padding=0
>>>
>>> def construct(self, x):
>>> x = self.conv(x)
>>> x = self.bn(x)
>>> x = self.relu(x)
>>> x = self.flatten(x)
>>> out = self.fc(x)
>>> return out
>>>
>>> net = Net()
>>> norm_clip = 1.0
>>> initial_noise_multiplier = 0.01
>>> network = LeNet5()
>>> batch_size = 32
>>> batches = 128
>>> epochs = 1
>>> micro_batches = 2
>>> loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) >>> loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
>>> net_opt = Momentum(params=net.trainable_params(), learning_rate=0.01, momentum=0.9)
>>> mech = MechanismsFactory().create('Gaussian',
>>> norm_bound=args.norm_clip,
>>> initial_noise_multiplier=args.initial_noise_multiplier)
>>> model = DPModel(micro_batches=2,
>>> norm_clip=1.0,
>>> mech=mech,
>>> network=net,
>>> factory_opt = DPOptimizerClassFactory(micro_batches=micro_batches)
>>> factory_opt.set_mechanisms('Gaussian',
>>> norm_bound=norm_clip,
>>> initial_noise_multiplier=initial_noise_multiplier)
>>> net_opt = factory_opt.create('Momentum')(network.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> model = DPModel(micro_batches=micro_batches,
>>> norm_clip=norm_clip,
>>> mech=None,
>>> network=network,
>>> loss_fn=loss, >>> loss_fn=loss,
>>> optimizer=net_opt, >>> optimizer=net_opt,
>>> metrics=None) >>> metrics=None)
>>> dataset = get_dataset()
>>> model.train(2, dataset)
>>> ms_ds = ds.GeneratorDataset(dataset_generator(batch_size, batches), ['data', 'label'])
>>> ms_ds.set_dataset_size(batch_size * batches)
>>> model.train(epochs, ms_ds, dataset_sink_mode=False)
""" """


def __init__(self, micro_batches=2, norm_clip=1.0, mech=None, **kwargs): def __init__(self, micro_batches=2, norm_clip=1.0, mech=None, **kwargs):


Loading…
Cancel
Save