rename mech._decay_policy to mech._noise_update. after rebase, rename norm_clip to norm_bound. change mechanismsfactory to noisemechanismsfactory in test. 11 12 13 1414141414141414141414141414 15tags/v1.2.1
@@ -32,7 +32,7 @@ mnist_cfg = edict({ | |||
'data_path': './MNIST_unzip', # the path of training and testing data set | |||
'dataset_sink_mode': False, # whether deliver all training data to device one time | |||
'micro_batches': 16, # the number of small batches split from an original batch | |||
'norm_clip': 1.0, # the clip bound of the gradients of model's training parameters | |||
'norm_bound': 1.0, # the clip bound of the gradients of model's training parameters | |||
'initial_noise_multiplier': 0.5, # the initial multiplication coefficient of the noise added to training | |||
# parameters' gradients | |||
'noise_mechanisms': 'AdaGaussian', # the method of adding noise in gradients while training | |||
@@ -115,8 +115,9 @@ if __name__ == "__main__": | |||
# or 'AdaGaussian', in which noise would be decayed with 'AdaGaussian' | |||
# mechanism while be constant with 'Gaussian' mechanism. | |||
noise_mech = NoiseMechanismsFactory().create(cfg.noise_mechanisms, | |||
norm_bound=cfg.norm_clip, | |||
initial_noise_multiplier=cfg.initial_noise_multiplier) | |||
norm_bound=cfg.norm_bound, | |||
initial_noise_multiplier=cfg.initial_noise_multiplier, | |||
noise_update='Exp') | |||
# Create a factory class of clip mechanisms, this method is to adaptive clip | |||
# gradients while training, decay_policy support 'Linear' and 'Geometric', | |||
# learning_rate is the learning rate to update clip_norm, | |||
@@ -136,11 +137,11 @@ if __name__ == "__main__": | |||
rdp_monitor = PrivacyMonitorFactory.create('rdp', | |||
num_samples=60000, | |||
batch_size=cfg.batch_size, | |||
initial_noise_multiplier=cfg.initial_noise_multiplier*cfg.norm_clip, | |||
initial_noise_multiplier=cfg.initial_noise_multiplier*cfg.norm_bound, | |||
per_print_times=10) | |||
# Create the DP model for training. | |||
model = DPModel(micro_batches=cfg.micro_batches, | |||
norm_clip=cfg.norm_clip, | |||
norm_bound=cfg.norm_bound, | |||
noise_mech=noise_mech, | |||
clip_mech=clip_mech, | |||
network=network, | |||
@@ -12,7 +12,7 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
""" | |||
python lenet5_dp_pynative_mode.py --data_path /YourDataPath --micro_batches=2 | |||
python lenet5_dp_pynative_model.py --data_path /YourDataPath --micro_batches=2 | |||
""" | |||
import os | |||
@@ -32,6 +32,7 @@ import mindspore.common.dtype as mstype | |||
from mindarmour.diff_privacy import DPModel | |||
from mindarmour.diff_privacy import PrivacyMonitorFactory | |||
from mindarmour.diff_privacy import DPOptimizerClassFactory | |||
from mindarmour.diff_privacy import ClipMechanismsFactory | |||
from mindarmour.utils.logger import LogUtil | |||
from lenet5_net import LeNet5 | |||
from lenet5_config import mnist_cfg as cfg | |||
@@ -108,21 +109,35 @@ if __name__ == "__main__": | |||
# means that the privacy protection effect is weak. Mechanisms can be 'Gaussian' or 'AdaGaussian', in which noise | |||
# would be decayed with 'AdaGaussian' mechanism while be constant with 'Gaussian' mechanism. | |||
dp_opt = DPOptimizerClassFactory(micro_batches=cfg.micro_batches) | |||
dp_opt.set_mechanisms(cfg.mechanisms, | |||
norm_bound=cfg.norm_clip, | |||
initial_noise_multiplier=cfg.initial_noise_multiplier) | |||
dp_opt.set_mechanisms(cfg.noise_mechanisms, | |||
norm_bound=cfg.norm_bound, | |||
initial_noise_multiplier=cfg.initial_noise_multiplier, | |||
noise_update='Exp') | |||
# Create a factory class of clip mechanisms, this method is to adaptive clip | |||
# gradients while training, decay_policy support 'Linear' and 'Geometric', | |||
# learning_rate is the learning rate to update clip_norm, | |||
# target_unclipped_quantile is the target quantile of norm clip, | |||
# fraction_stddev is the stddev of Gaussian normal which used in | |||
# empirical_fraction, the formula is | |||
# $empirical_fraction + N(0, fraction_stddev)$. | |||
clip_mech = ClipMechanismsFactory().create(cfg.clip_mechanisms, | |||
decay_policy=cfg.clip_decay_policy, | |||
learning_rate=cfg.clip_learning_rate, | |||
target_unclipped_quantile=cfg.target_unclipped_quantile, | |||
fraction_stddev=cfg.fraction_stddev) | |||
net_opt = dp_opt.create('Momentum')(params=network.trainable_params(), learning_rate=cfg.lr, momentum=cfg.momentum) | |||
# Create a monitor for DP training. The function of the monitor is to compute and print the privacy budget(eps | |||
# and delta) while training. | |||
rdp_monitor = PrivacyMonitorFactory.create('rdp', | |||
num_samples=60000, | |||
batch_size=cfg.batch_size, | |||
initial_noise_multiplier=cfg.initial_noise_multiplier*cfg.norm_clip, | |||
initial_noise_multiplier=cfg.initial_noise_multiplier*cfg.norm_bound, | |||
per_print_times=10) | |||
# Create the DP model for training. | |||
model = DPModel(micro_batches=cfg.micro_batches, | |||
norm_clip=cfg.norm_clip, | |||
mech=None, | |||
norm_bound=cfg.norm_bound, | |||
noise_mech=None, | |||
clip_mech=clip_mech, | |||
network=network, | |||
loss_fn=net_loss, | |||
optimizer=net_opt, |
@@ -2,7 +2,7 @@ | |||
This module provide Differential Privacy feature to protect user privacy. | |||
""" | |||
from .mechanisms.mechanisms import NoiseGaussianRandom | |||
from .mechanisms.mechanisms import AdaGaussianRandom | |||
from .mechanisms.mechanisms import NoiseAdaGaussianRandom | |||
from .mechanisms.mechanisms import AdaClippingWithGaussianRandom | |||
from .mechanisms.mechanisms import NoiseMechanismsFactory | |||
from .mechanisms.mechanisms import ClipMechanismsFactory | |||
@@ -11,7 +11,7 @@ from .optimizer.optimizer import DPOptimizerClassFactory | |||
from .train.model import DPModel | |||
__all__ = ['NoiseGaussianRandom', | |||
'AdaGaussianRandom', | |||
'NoiseAdaGaussianRandom', | |||
'AdaClippingWithGaussianRandom', | |||
'NoiseMechanismsFactory', | |||
'ClipMechanismsFactory', | |||
@@ -19,6 +19,7 @@ from abc import abstractmethod | |||
from mindspore import Tensor | |||
from mindspore.nn import Cell | |||
from mindspore.ops import operations as P | |||
from mindspore.ops.composite import normal | |||
from mindspore.common.parameter import Parameter | |||
from mindspore.common import dtype as mstype | |||
@@ -55,7 +56,7 @@ class ClipMechanismsFactory: | |||
Examples: | |||
>>> decay_policy = 'Linear' | |||
>>> beta = Tensor(0.5, mstype.float32) | |||
>>> norm_clip = Tensor(1.0, mstype.float32) | |||
>>> norm_bound = Tensor(1.0, mstype.float32) | |||
>>> beta_stddev = 0.1 | |||
>>> learning_rate = 0.1 | |||
>>> target_unclipped_quantile = 0.3 | |||
@@ -65,7 +66,7 @@ class ClipMechanismsFactory: | |||
>>> learning_rate=learning_rate, | |||
>>> target_unclipped_quantile=target_unclipped_quantile, | |||
>>> fraction_stddev=beta_stddev) | |||
>>> next_norm_clip = ada_clip(beta, norm_clip) | |||
>>> next_norm_bound = ada_clip(beta, norm_bound) | |||
""" | |||
if mech_name == 'Gaussian': | |||
@@ -81,25 +82,32 @@ class NoiseMechanismsFactory: | |||
pass | |||
@staticmethod | |||
def create(policy, *args, **kwargs): | |||
def create(mech_name='Gaussian', norm_bound=0.5, initial_noise_multiplier=1.5, seed=0, noise_decay_rate=6e-6, | |||
noise_update=None): | |||
""" | |||
Args: | |||
policy(str): Noise generated strategy, could be 'Gaussian' or | |||
mech_name(str): Noise generated strategy, could be 'Gaussian' or | |||
'AdaGaussian'. Noise would be decayed with 'AdaGaussian' mechanism | |||
while be constant with 'Gaussian' mechanism. | |||
args(Union[float, str]): Parameters used for creating noise | |||
mechanisms. | |||
kwargs(Union[float, str]): Parameters used for creating noise | |||
mechanisms. | |||
norm_bound(float): Clipping bound for the l2 norm of the gradients. | |||
initial_noise_multiplier(float): Ratio of the standard deviation of | |||
Gaussian noise divided by the norm_bound, which will be used to | |||
calculate privacy spent. | |||
seed(int): Original random seed, if seed=0 random normal will use secure | |||
random number. IF seed!=0 random normal will generate values using | |||
given seed. | |||
noise_decay_rate(float): Hyper parameter for controlling the noise decay. | |||
noise_update(str): Mechanisms parameters update policy. Default: None, no | |||
parameters need update. | |||
Raises: | |||
NameError: `policy` must be in ['Gaussian', 'AdaGaussian']. | |||
NameError: `mech_name` must be in ['Gaussian', 'AdaGaussian']. | |||
Returns: | |||
Mechanisms, class of noise generated Mechanism. | |||
Examples: | |||
>>> norm_clip = 1.0 | |||
>>> norm_bound = 1.0 | |||
>>> initial_noise_multiplier = 0.01 | |||
>>> network = LeNet5() | |||
>>> batch_size = 32 | |||
@@ -107,7 +115,7 @@ class NoiseMechanismsFactory: | |||
>>> epochs = 1 | |||
>>> loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) | |||
>>> noise_mech = NoiseMechanismsFactory().create('Gaussian', | |||
>>> norm_bound=norm_clip, | |||
>>> norm_bound=norm_bound, | |||
>>> initial_noise_multiplier=initial_noise_multiplier) | |||
>>> clip_mech = ClipMechanismsFactory().create('Gaussian', | |||
>>> decay_policy='Linear', | |||
@@ -118,7 +126,7 @@ class NoiseMechanismsFactory: | |||
>>> momentum=0.9) | |||
>>> model = DPModel(micro_batches=2, | |||
>>> clip_mech=clip_mech, | |||
>>> norm_clip=norm_clip, | |||
>>> norm_bound=norm_bound, | |||
>>> noise_mech=noise_mech, | |||
>>> network=network, | |||
>>> loss_fn=loss, | |||
@@ -129,15 +137,22 @@ class NoiseMechanismsFactory: | |||
>>> ms_ds.set_dataset_size(batch_size * batches) | |||
>>> model.train(epochs, ms_ds, dataset_sink_mode=False) | |||
""" | |||
if policy == 'Gaussian': | |||
return NoiseGaussianRandom(*args, **kwargs) | |||
if policy == 'AdaGaussian': | |||
return AdaGaussianRandom(*args, **kwargs) | |||
if mech_name == 'Gaussian': | |||
return NoiseGaussianRandom(norm_bound=norm_bound, | |||
initial_noise_multiplier=initial_noise_multiplier, | |||
seed=seed, | |||
noise_update=noise_update) | |||
if mech_name == 'AdaGaussian': | |||
return NoiseAdaGaussianRandom(norm_bound=norm_bound, | |||
initial_noise_multiplier=initial_noise_multiplier, | |||
seed=seed, | |||
noise_decay_rate=noise_decay_rate, | |||
noise_update=noise_update) | |||
raise NameError("The {} is not implement, please choose " | |||
"['Gaussian', 'AdaGaussian']".format(policy)) | |||
"['Gaussian', 'AdaGaussian']".format(mech_name)) | |||
class Mechanisms(Cell): | |||
class _Mechanisms(Cell): | |||
""" | |||
Basic class of noise generated mechanism. | |||
""" | |||
@@ -149,21 +164,19 @@ class Mechanisms(Cell): | |||
""" | |||
class NoiseGaussianRandom(Mechanisms): | |||
class NoiseGaussianRandom(_Mechanisms): | |||
""" | |||
Gaussian noise generated mechanism. | |||
Args: | |||
norm_bound(float): Clipping bound for the l2 norm of the gradients. | |||
Default: 0.5. | |||
initial_noise_multiplier(float): Ratio of the standard deviation of | |||
Gaussian noise divided by the norm_bound, which will be used to | |||
calculate privacy spent. Default: 1.5. | |||
calculate privacy spent. | |||
seed(int): Original random seed, if seed=0 random normal will use secure | |||
random number. IF seed!=0 random normal will generate values using | |||
given seed. Default: 0. | |||
policy(str): Mechanisms parameters update policy. Default: None, no | |||
parameters need update. | |||
given seed. | |||
noise_update(str): Mechanisms parameters update policy. Default: None. | |||
Returns: | |||
Tensor, generated noise with shape like given gradients. | |||
@@ -172,24 +185,25 @@ class NoiseGaussianRandom(Mechanisms): | |||
>>> gradients = Tensor([0.2, 0.9], mstype.float32) | |||
>>> norm_bound = 0.5 | |||
>>> initial_noise_multiplier = 1.5 | |||
>>> net = NoiseGaussianRandom(norm_bound, initial_noise_multiplier) | |||
>>> seed = 0 | |||
>>> noise_update = None | |||
>>> net = NoiseGaussianRandom(norm_bound, initial_noise_multiplier, seed, noise_update) | |||
>>> res = net(gradients) | |||
>>> print(res) | |||
""" | |||
def __init__(self, norm_bound=0.5, initial_noise_multiplier=1.5, seed=0, | |||
policy=None): | |||
def __init__(self, norm_bound, initial_noise_multiplier, seed, noise_update=None): | |||
super(NoiseGaussianRandom, self).__init__() | |||
self._norm_bound = check_value_positive('norm_bound', norm_bound) | |||
self._norm_bound = Tensor(norm_bound, mstype.float32) | |||
self._initial_noise_multiplier = check_value_positive( | |||
'initial_noise_multiplier', | |||
initial_noise_multiplier) | |||
self._initial_noise_multiplier = Tensor(initial_noise_multiplier, | |||
mstype.float32) | |||
self._initial_noise_multiplier = check_value_positive('initial_noise_multiplier', | |||
initial_noise_multiplier) | |||
self._initial_noise_multiplier = Tensor(initial_noise_multiplier, mstype.float32) | |||
self._mean = Tensor(0, mstype.float32) | |||
self._normal = P.Normal(seed=seed) | |||
self._decay_policy = policy | |||
if noise_update is not None: | |||
raise ValueError('noise_update must be None in GaussianRandom class, but got {}.'.format(noise_update)) | |||
self._noise_update = noise_update | |||
self._seed = seed | |||
def construct(self, gradients): | |||
""" | |||
@@ -203,26 +217,25 @@ class NoiseGaussianRandom(Mechanisms): | |||
""" | |||
shape = P.Shape()(gradients) | |||
stddev = P.Mul()(self._norm_bound, self._initial_noise_multiplier) | |||
noise = self._normal(shape, self._mean, stddev) | |||
noise = normal(shape, self._mean, stddev, self._seed) | |||
return noise | |||
class AdaGaussianRandom(Mechanisms): | |||
class NoiseAdaGaussianRandom(NoiseGaussianRandom): | |||
""" | |||
Adaptive Gaussian noise generated mechanism. Noise would be decayed with | |||
training. Decay mode could be 'Time' mode or 'Step' mode. | |||
training. Decay mode could be 'Time' mode, 'Step' mode, 'Exp' mode. | |||
Args: | |||
norm_bound(float): Clipping bound for the l2 norm of the gradients. | |||
Default: 1.0. | |||
initial_noise_multiplier(float): Ratio of the standard deviation of | |||
Gaussian noise divided by the norm_bound, which will be used to | |||
calculate privacy spent. Default: 1.5. | |||
calculate privacy spent. | |||
seed(int): Original random seed, if seed=0 random normal will use secure | |||
random number. IF seed!=0 random normal will generate values using | |||
given seed. | |||
noise_decay_rate(float): Hyper parameter for controlling the noise decay. | |||
Default: 6e-4. | |||
decay_policy(str): Noise decay strategy include 'Step' and 'Time'. | |||
Default: 'Time'. | |||
seed(int): Original random seed. Default: 0. | |||
noise_update(str): Noise decay strategy include 'Step', 'Time', 'Exp'. | |||
Returns: | |||
Tensor, generated noise with shape like given gradients. | |||
@@ -231,56 +244,27 @@ class AdaGaussianRandom(Mechanisms): | |||
>>> gradients = Tensor([0.2, 0.9], mstype.float32) | |||
>>> norm_bound = 1.0 | |||
>>> initial_noise_multiplier = 1.5 | |||
>>> seed = 0 | |||
>>> noise_decay_rate = 6e-4 | |||
>>> decay_policy = "Time" | |||
>>> net = AdaGaussianRandom(norm_bound, initial_noise_multiplier, | |||
>>> noise_decay_rate, decay_policy) | |||
>>> noise_update = "Time" | |||
>>> net = NoiseAdaGaussianRandom(norm_bound, initial_noise_multiplier, seed, noise_decay_rate, noise_update) | |||
>>> res = net(gradients) | |||
>>> print(res) | |||
""" | |||
def __init__(self, norm_bound=1.0, initial_noise_multiplier=1.5, | |||
noise_decay_rate=6e-4, decay_policy='Time', seed=0): | |||
super(AdaGaussianRandom, self).__init__() | |||
norm_bound = check_value_positive('norm_bound', norm_bound) | |||
initial_noise_multiplier = check_value_positive( | |||
'initial_noise_multiplier', | |||
initial_noise_multiplier) | |||
self._norm_bound = Tensor(norm_bound, mstype.float32) | |||
initial_noise_multiplier = Tensor(initial_noise_multiplier, | |||
mstype.float32) | |||
self._initial_noise_multiplier = Parameter(initial_noise_multiplier, | |||
name='initial_noise_multiplier') | |||
self._noise_multiplier = Parameter(initial_noise_multiplier, | |||
def __init__(self, norm_bound, initial_noise_multiplier, seed, noise_decay_rate, noise_update): | |||
super(NoiseAdaGaussianRandom, self).__init__(norm_bound=norm_bound, | |||
initial_noise_multiplier=initial_noise_multiplier, | |||
seed=seed) | |||
self._noise_multiplier = Parameter(self._initial_noise_multiplier, | |||
name='noise_multiplier') | |||
self._mean = Tensor(0, mstype.float32) | |||
noise_decay_rate = check_param_type('noise_decay_rate', | |||
noise_decay_rate, float) | |||
noise_decay_rate = check_param_type('noise_decay_rate', noise_decay_rate, float) | |||
check_param_in_range('noise_decay_rate', noise_decay_rate, 0.0, 1.0) | |||
self._noise_decay_rate = Tensor(noise_decay_rate, mstype.float32) | |||
if decay_policy not in ['Time', 'Step', 'Exp']: | |||
raise NameError("The decay_policy must be in ['Time', 'Step', 'Exp'], but " | |||
"get {}".format(decay_policy)) | |||
self._decay_policy = decay_policy | |||
self._mul = P.Mul() | |||
self._normal = P.Normal(seed=seed) | |||
def construct(self, gradients): | |||
""" | |||
Generate adaptive Gaussian noise. | |||
Args: | |||
gradients(Tensor): The gradients. | |||
Returns: | |||
Tensor, generated noise with shape like given gradients. | |||
""" | |||
shape = P.Shape()(gradients) | |||
noise = self._normal(shape, self._mean, | |||
self._mul(self._noise_multiplier, | |||
self._norm_bound)) | |||
return noise | |||
if noise_update not in ['Time', 'Step', 'Exp']: | |||
raise NameError("The noise_update must be in ['Time', 'Step', 'Exp'], but " | |||
"get {}".format(noise_update)) | |||
self._noise_update = noise_update | |||
class _MechanismsParamsUpdater(Cell): | |||
@@ -288,7 +272,7 @@ class _MechanismsParamsUpdater(Cell): | |||
Update mechanisms parameters, the parameters will refresh in train period. | |||
Args: | |||
policy(str): Pass in by the mechanisms class, mechanisms parameters | |||
noise_update(str): Pass in by the mechanisms class, mechanisms parameters | |||
update policy. | |||
decay_rate(Tensor): Pass in by the mechanisms class, hyper parameter for | |||
controlling the decay size. | |||
@@ -300,9 +284,9 @@ class _MechanismsParamsUpdater(Cell): | |||
Returns: | |||
Tuple, next params value. | |||
""" | |||
def __init__(self, policy, decay_rate, cur_noise_multiplier, init_noise_multiplier): | |||
def __init__(self, noise_update, decay_rate, cur_noise_multiplier, init_noise_multiplier): | |||
super(_MechanismsParamsUpdater, self).__init__() | |||
self._policy = policy | |||
self._noise_update = noise_update | |||
self._decay_rate = decay_rate | |||
self._cur_noise_multiplier = cur_noise_multiplier | |||
self._init_noise_multiplier = init_noise_multiplier | |||
@@ -322,27 +306,27 @@ class _MechanismsParamsUpdater(Cell): | |||
Returns: | |||
Tuple, next step parameters value. | |||
""" | |||
if self._policy == 'Time': | |||
if self._noise_update == 'Time': | |||
temp = self._div(self._init_noise_multiplier, self._cur_noise_multiplier) | |||
temp = self._add(temp, self._decay_rate) | |||
next_noise_multiplier = self._assign(self._cur_noise_multiplier, | |||
self._div(self._init_noise_multiplier, temp)) | |||
elif self._policy == 'Step': | |||
elif self._noise_update == 'Step': | |||
temp = self._sub(self._one, self._decay_rate) | |||
next_noise_multiplier = self._assign(self._cur_noise_multiplier, | |||
self._mul(temp, self._cur_noise_multiplier)) | |||
else: | |||
next_noise_multiplier = self._assign(self._cur_noise_multiplier, | |||
self._div(self._one, self._exp(self._one))) | |||
self._div(self._cur_noise_multiplier, self._exp(self._decay_rate))) | |||
return next_noise_multiplier | |||
class AdaClippingWithGaussianRandom(Cell): | |||
""" | |||
Adaptive clipping. If `decay_policy` is 'Linear', the update formula is | |||
$ norm_clip = norm_clip - learning_rate*(beta-target_unclipped_quantile)$. | |||
$ norm_bound = norm_bound - learning_rate*(beta-target_unclipped_quantile)$. | |||
`decay_policy` is 'Geometric', the update formula is | |||
$ norm_clip = norm_clip*exp(-learning_rate*(empirical_fraction-target_unclipped_quantile))$. | |||
$ norm_bound = norm_bound*exp(-learning_rate*(empirical_fraction-target_unclipped_quantile))$. | |||
where beta is the empirical fraction of samples with the value at most | |||
`target_unclipped_quantile`. | |||
@@ -363,7 +347,7 @@ class AdaClippingWithGaussianRandom(Cell): | |||
Examples: | |||
>>> decay_policy = 'Linear' | |||
>>> beta = Tensor(0.5, mstype.float32) | |||
>>> norm_clip = Tensor(1.0, mstype.float32) | |||
>>> norm_bound = Tensor(1.0, mstype.float32) | |||
>>> beta_stddev = 0.01 | |||
>>> learning_rate = 0.001 | |||
>>> target_unclipped_quantile = 0.9 | |||
@@ -371,7 +355,7 @@ class AdaClippingWithGaussianRandom(Cell): | |||
>>> learning_rate=learning_rate, | |||
>>> target_unclipped_quantile=target_unclipped_quantile, | |||
>>> fraction_stddev=beta_stddev) | |||
>>> next_norm_clip = ada_clip(beta, norm_clip) | |||
>>> next_norm_bound = ada_clip(beta, norm_bound) | |||
""" | |||
@@ -400,32 +384,32 @@ class AdaClippingWithGaussianRandom(Cell): | |||
self._sub = P.Sub() | |||
self._mul = P.Mul() | |||
self._exp = P.Exp() | |||
self._normal = P.Normal(seed=seed) | |||
self._seed = seed | |||
def construct(self, empirical_fraction, norm_clip): | |||
def construct(self, empirical_fraction, norm_bound): | |||
""" | |||
Update value of norm_clip. | |||
Update value of norm_bound. | |||
Args: | |||
empirical_fraction(Tensor): empirical fraction of samples with the | |||
value at most `target_unclipped_quantile`. | |||
norm_clip(Tensor): Clipping bound for the l2 norm of the gradients. | |||
norm_bound(Tensor): Clipping bound for the l2 norm of the gradients. | |||
Returns: | |||
Tensor, generated noise with shape like given gradients. | |||
""" | |||
fraction_noise = self._normal((1,), self._zero, self._fraction_stddev) | |||
fraction_noise = normal((1,), self._zero, self._fraction_stddev, self._seed) | |||
empirical_fraction = self._add(empirical_fraction, fraction_noise) | |||
if self._decay_policy == 'Linear': | |||
grad_clip = self._sub(empirical_fraction, | |||
self._target_unclipped_quantile) | |||
next_norm_clip = self._sub(norm_clip, | |||
self._mul(self._learning_rate, grad_clip)) | |||
next_norm_bound = self._sub(norm_bound, | |||
self._mul(self._learning_rate, grad_clip)) | |||
# decay_policy == 'Geometric' | |||
else: | |||
grad_clip = self._sub(empirical_fraction, | |||
self._target_unclipped_quantile) | |||
grad_clip = self._exp(self._mul(-self._learning_rate, grad_clip)) | |||
next_norm_clip = self._mul(norm_clip, grad_clip) | |||
return next_norm_clip | |||
next_norm_bound = self._mul(norm_bound, grad_clip) | |||
return next_norm_bound |
@@ -127,8 +127,8 @@ class DPOptimizerClassFactory: | |||
self._micro_float = Tensor(micro_batches, mstype.float32) | |||
self._mech_param_updater = None | |||
if self._mech is not None and self._mech._decay_policy is not None: | |||
self._mech_param_updater = _MechanismsParamsUpdater(policy=self._mech._decay_policy, | |||
if self._mech is not None and self._mech._noise_update is not None: | |||
self._mech_param_updater = _MechanismsParamsUpdater(noise_update=self._mech._noise_update, | |||
decay_rate=self._mech._noise_decay_rate, | |||
cur_noise_multiplier= | |||
self._mech._noise_multiplier, | |||
@@ -75,7 +75,7 @@ class DPModel(Model): | |||
Args: | |||
micro_batches (int): The number of small batches split from an original | |||
batch. Default: 2. | |||
norm_clip (float): Use to clip the bound, if set 1, will retun the | |||
norm_bound (float): Use to clip the bound, if set 1, will return the | |||
original data. Default: 1.0. | |||
noise_mech (Mechanisms): The object can generate the different type of | |||
noise. Default: None. | |||
@@ -83,7 +83,7 @@ class DPModel(Model): | |||
Default: None. | |||
Examples: | |||
>>> norm_clip = 1.0 | |||
>>> norm_bound = 1.0 | |||
>>> initial_noise_multiplier = 0.01 | |||
>>> network = LeNet5() | |||
>>> batch_size = 32 | |||
@@ -93,7 +93,7 @@ class DPModel(Model): | |||
>>> loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) | |||
>>> factory_opt = DPOptimizerClassFactory(micro_batches=micro_batches) | |||
>>> factory_opt.set_mechanisms('Gaussian', | |||
>>> norm_bound=norm_clip, | |||
>>> norm_bound=norm_bound, | |||
>>> initial_noise_multiplier=initial_noise_multiplier) | |||
>>> net_opt = factory_opt.create('Momentum')(network.trainable_params(), | |||
>>> learning_rate=0.1, momentum=0.9) | |||
@@ -103,7 +103,7 @@ class DPModel(Model): | |||
>>> target_unclipped_quantile=0.9, | |||
>>> fraction_stddev=0.01) | |||
>>> model = DPModel(micro_batches=micro_batches, | |||
>>> norm_clip=norm_clip, | |||
>>> norm_bound=norm_bound, | |||
>>> clip_mech=clip_mech, | |||
>>> noise_mech=None, | |||
>>> network=network, | |||
@@ -116,17 +116,18 @@ class DPModel(Model): | |||
>>> model.train(epochs, ms_ds, dataset_sink_mode=False) | |||
""" | |||
def __init__(self, micro_batches=2, norm_clip=1.0, noise_mech=None, | |||
def __init__(self, micro_batches=2, norm_bound=1.0, noise_mech=None, | |||
clip_mech=None, **kwargs): | |||
if micro_batches: | |||
self._micro_batches = check_int_positive('micro_batches', | |||
micro_batches) | |||
else: | |||
self._micro_batches = None | |||
norm_clip = check_param_type('norm_clip', norm_clip, float) | |||
norm_clip = check_value_positive('norm_clip', norm_clip) | |||
norm_clip = Tensor(norm_clip, mstype.float32) | |||
self._norm_clip = Parameter(norm_clip, 'norm_clip') | |||
norm_bound = check_param_type('norm_bound', norm_bound, float) | |||
norm_bound = check_value_positive('norm_bound', norm_bound) | |||
norm_bound = Tensor(norm_bound, mstype.float32) | |||
self._norm_bound = Parameter(norm_bound, 'norm_bound') | |||
if noise_mech is not None and "DPOptimizer" in kwargs['optimizer'].__class__.__name__: | |||
msg = 'DPOptimizer is not supported while noise_mech is not None' | |||
LOGGER.error(TAG, msg) | |||
@@ -219,14 +220,14 @@ class DPModel(Model): | |||
optimizer, | |||
scale_update_cell=update_cell, | |||
micro_batches=self._micro_batches, | |||
norm_clip=self._norm_clip, | |||
norm_bound=self._norm_bound, | |||
clip_mech=self._clip_mech, | |||
noise_mech=self._noise_mech).set_train() | |||
return network | |||
network = _TrainOneStepCell(network, | |||
optimizer, | |||
self._norm_clip, | |||
self._norm_bound, | |||
loss_scale, | |||
micro_batches=self._micro_batches, | |||
clip_mech=self._clip_mech, | |||
@@ -347,7 +348,7 @@ class _TrainOneStepWithLossScaleCell(Cell): | |||
Default: None. | |||
micro_batches (int): The number of small batches split from an original | |||
batch. Default: None. | |||
norm_clip (Tensor): Use to clip the bound, if set 1, will return the | |||
norm_bound (Tensor): Use to clip the bound, if set 1, will return the | |||
original data. Default: 1.0. | |||
noise_mech (Mechanisms): The object can generate the different type of | |||
noise. Default: None. | |||
@@ -366,7 +367,7 @@ class _TrainOneStepWithLossScaleCell(Cell): | |||
""" | |||
def __init__(self, network, optimizer, scale_update_cell=None, | |||
micro_batches=None, norm_clip=1.0, noise_mech=None, | |||
micro_batches=None, norm_bound=1.0, noise_mech=None, | |||
clip_mech=None): | |||
super(_TrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False) | |||
self.network = network | |||
@@ -405,15 +406,13 @@ class _TrainOneStepWithLossScaleCell(Cell): | |||
self.loss_scale = None | |||
self.loss_scaling_manager = scale_update_cell | |||
if scale_update_cell: | |||
self.loss_scale = Parameter( | |||
Tensor(scale_update_cell.get_loss_scale(), | |||
dtype=mstype.float32), | |||
name="loss_scale") | |||
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32), | |||
name="loss_scale") | |||
self.add_flags(has_effect=True) | |||
# dp params | |||
self._micro_batches = micro_batches | |||
self._norm_clip = norm_clip | |||
self._norm_bound = norm_bound | |||
self._split = P.Split(0, self._micro_batches) | |||
self._clip_by_global_norm = _ClipGradients() | |||
self._noise_mech = noise_mech | |||
@@ -433,9 +432,9 @@ class _TrainOneStepWithLossScaleCell(Cell): | |||
self._cast = P.Cast() | |||
self._noise_mech_param_updater = None | |||
if self._noise_mech is not None and self._noise_mech._decay_policy is not None: | |||
if self._noise_mech is not None and self._noise_mech._noise_update is not None: | |||
self._noise_mech_param_updater = _MechanismsParamsUpdater( | |||
policy=self._noise_mech._decay_policy, | |||
noise_update=self._noise_mech._noise_update, | |||
decay_rate=self._noise_mech._noise_decay_rate, | |||
cur_noise_multiplier= | |||
self._noise_mech._noise_multiplier, | |||
@@ -477,10 +476,10 @@ class _TrainOneStepWithLossScaleCell(Cell): | |||
self._reduce_sum(self._square_all(grad))) | |||
norm_grad = self._sqrt(square_sum) | |||
beta = self._add(beta, | |||
self._cast(self._less(norm_grad, self._norm_clip), | |||
self._cast(self._less(norm_grad, self._norm_bound), | |||
mstype.float32)) | |||
record_grad = self._clip_by_global_norm(record_grad, GRADIENT_CLIP_TYPE, | |||
self._norm_clip) | |||
self._norm_bound) | |||
grads = record_grad | |||
total_loss = loss | |||
for i in range(1, self._micro_batches): | |||
@@ -497,12 +496,12 @@ class _TrainOneStepWithLossScaleCell(Cell): | |||
self._reduce_sum(self._square_all(grad))) | |||
norm_grad = self._sqrt(square_sum) | |||
beta = self._add(beta, | |||
self._cast(self._less(norm_grad, self._norm_clip), | |||
self._cast(self._less(norm_grad, self._norm_bound), | |||
mstype.float32)) | |||
record_grad = self._clip_by_global_norm(record_grad, | |||
GRADIENT_CLIP_TYPE, | |||
self._norm_clip) | |||
self._norm_bound) | |||
grads = self._tuple_add(grads, record_grad) | |||
total_loss = P.TensorAdd()(total_loss, loss) | |||
loss = P.Div()(total_loss, self._micro_float) | |||
@@ -552,8 +551,8 @@ class _TrainOneStepWithLossScaleCell(Cell): | |||
ret = (loss, cond, scaling_sens) | |||
if self._clip_mech is not None: | |||
next_norm_clip = self._clip_mech(beta, self._norm_clip) | |||
P.assign(self._norm_clip, next_norm_clip) | |||
next_norm_bound = self._clip_mech(beta, self._norm_bound) | |||
P.assign(self._norm_bound, next_norm_bound) | |||
return F.depend(ret, opt) | |||
@@ -573,7 +572,7 @@ class _TrainOneStepCell(Cell): | |||
propagation. Default value is 1.0. | |||
micro_batches (int): The number of small batches split from an original | |||
batch. Default: None. | |||
norm_clip (Tensor): Use to clip the bound, if set 1, will return the | |||
norm_bound (Tensor): Use to clip the bound, if set 1, will return the | |||
original data. Default: 1.0. | |||
noise_mech (Mechanisms): The object can generate the different type | |||
of noise. Default: None. | |||
@@ -586,7 +585,7 @@ class _TrainOneStepCell(Cell): | |||
Tensor, a scalar Tensor with shape :math:`()`. | |||
""" | |||
def __init__(self, network, optimizer, norm_clip=1.0, sens=1.0, | |||
def __init__(self, network, optimizer, norm_bound=1.0, sens=1.0, | |||
micro_batches=None, | |||
noise_mech=None, clip_mech=None): | |||
super(_TrainOneStepCell, self).__init__(auto_prefix=False) | |||
@@ -616,7 +615,7 @@ class _TrainOneStepCell(Cell): | |||
LOGGER.error(TAG, msg) | |||
raise ValueError(msg) | |||
self._micro_batches = micro_batches | |||
self._norm_clip = norm_clip | |||
self._norm_bound = norm_bound | |||
self._split = P.Split(0, self._micro_batches) | |||
self._clip_by_global_norm = _ClipGradients() | |||
self._noise_mech = noise_mech | |||
@@ -637,9 +636,9 @@ class _TrainOneStepCell(Cell): | |||
self._micro_float = Tensor(micro_batches, mstype.float32) | |||
self._noise_mech_param_updater = None | |||
if self._noise_mech is not None and self._noise_mech._decay_policy is not None: | |||
if self._noise_mech is not None and self._noise_mech._noise_update is not None: | |||
self._noise_mech_param_updater = _MechanismsParamsUpdater( | |||
policy=self._noise_mech._decay_policy, | |||
noise_update=self._noise_mech._noise_update, | |||
decay_rate=self._noise_mech._noise_decay_rate, | |||
cur_noise_multiplier= | |||
self._noise_mech._noise_multiplier, | |||
@@ -664,11 +663,11 @@ class _TrainOneStepCell(Cell): | |||
self._reduce_sum(self._square_all(grad))) | |||
norm_grad = self._sqrt(square_sum) | |||
beta = self._add(beta, | |||
self._cast(self._less(norm_grad, self._norm_clip), | |||
self._cast(self._less(norm_grad, self._norm_bound), | |||
mstype.float32)) | |||
record_grad = self._clip_by_global_norm(record_grad, GRADIENT_CLIP_TYPE, | |||
self._norm_clip) | |||
self._norm_bound) | |||
grads = record_grad | |||
total_loss = loss | |||
for i in range(1, self._micro_batches): | |||
@@ -683,12 +682,12 @@ class _TrainOneStepCell(Cell): | |||
self._reduce_sum(self._square_all(grad))) | |||
norm_grad = self._sqrt(square_sum) | |||
beta = self._add(beta, | |||
self._cast(self._less(norm_grad, self._norm_clip), | |||
self._cast(self._less(norm_grad, self._norm_bound), | |||
mstype.float32)) | |||
record_grad = self._clip_by_global_norm(record_grad, | |||
GRADIENT_CLIP_TYPE, | |||
self._norm_clip) | |||
self._norm_bound) | |||
grads = self._tuple_add(grads, record_grad) | |||
total_loss = P.TensorAdd()(total_loss, loss) | |||
loss = self._div(total_loss, self._micro_float) | |||
@@ -712,8 +711,8 @@ class _TrainOneStepCell(Cell): | |||
grads = self.grad_reducer(grads) | |||
if self._clip_mech is not None: | |||
next_norm_clip = self._clip_mech(beta, self._norm_clip) | |||
self._norm_clip = self._assign(self._norm_clip, next_norm_clip) | |||
loss = F.depend(loss, next_norm_clip) | |||
next_norm_bound = self._clip_mech(beta, self._norm_bound) | |||
self._norm_bound = self._assign(self._norm_bound, next_norm_bound) | |||
loss = F.depend(loss, next_norm_bound) | |||
return F.depend(loss, self.optimizer(grads)) |
@@ -63,14 +63,14 @@ class ModelCoverageMetrics: | |||
self._model = check_model('model', model, Model) | |||
self._segmented_num = check_int_positive('segmented_num', segmented_num) | |||
self._neuron_num = check_int_positive('neuron_num', neuron_num) | |||
if self._neuron_num > 1e+10: | |||
if self._neuron_num >= 1e+10: | |||
msg = 'neuron_num should be less than 1e+10, otherwise a MemoryError' \ | |||
'would occur' | |||
LOGGER.error(TAG, msg) | |||
train_dataset = check_numpy_param('train_dataset', train_dataset) | |||
self._lower_bounds = [np.inf]*neuron_num | |||
self._upper_bounds = [-np.inf]*neuron_num | |||
self._var = [0]*neuron_num | |||
self._lower_bounds = [np.inf]*self._neuron_num | |||
self._upper_bounds = [-np.inf]*self._neuron_num | |||
self._var = [0]*self._neuron_num | |||
self._main_section_hits = [[0 for _ in range(self._segmented_num)] for _ in | |||
range(self._neuron_num)] | |||
self._lower_corner_hits = [0]*self._neuron_num | |||
@@ -1,6 +1,6 @@ | |||
numpy >= 1.17.0 | |||
scipy >= 1.3.3 | |||
matplotlib >= 3.1.3 | |||
matplotlib >= 3.2.1 | |||
Pillow >= 2.0.0 | |||
pytest >= 4.3.1 | |||
wheel >= 0.32.0 | |||
@@ -104,7 +104,7 @@ setup( | |||
install_requires=[ | |||
'scipy >= 1.3.3', | |||
'numpy >= 1.17.0', | |||
'matplotlib >= 3.1.3', | |||
'matplotlib >= 3.2.1', | |||
'Pillow >= 2.0.0' | |||
], | |||
classifiers=[ | |||
@@ -19,8 +19,7 @@ import pytest | |||
from mindspore import context | |||
from mindspore import Tensor | |||
from mindspore.common import dtype as mstype | |||
from mindarmour.diff_privacy import NoiseGaussianRandom | |||
from mindarmour.diff_privacy import AdaGaussianRandom | |||
from mindarmour.diff_privacy import NoiseAdaGaussianRandom | |||
from mindarmour.diff_privacy import AdaClippingWithGaussianRandom | |||
from mindarmour.diff_privacy import NoiseMechanismsFactory | |||
from mindarmour.diff_privacy import ClipMechanismsFactory | |||
@@ -30,72 +29,98 @@ from mindarmour.diff_privacy import ClipMechanismsFactory | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
@pytest.mark.component_mindarmour | |||
def test_graph_gaussian(): | |||
def test_graph_factory(): | |||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
grad = Tensor([0.3, 0.2, 0.4], mstype.float32) | |||
norm_bound = 1.0 | |||
initial_noise_multiplier = 0.1 | |||
net = NoiseGaussianRandom(norm_bound, initial_noise_multiplier) | |||
res = net(grad) | |||
print(res) | |||
alpha = 0.5 | |||
noise_update = 'Step' | |||
factory = NoiseMechanismsFactory() | |||
noise_mech = factory.create('Gaussian', | |||
norm_bound, | |||
initial_noise_multiplier) | |||
noise = noise_mech(grad) | |||
print('Gaussian noise: ', noise) | |||
ada_noise_mech = factory.create('AdaGaussian', | |||
norm_bound, | |||
initial_noise_multiplier, | |||
noise_decay_rate=alpha, | |||
noise_update=noise_update) | |||
ada_noise = ada_noise_mech(grad) | |||
print('ada noise: ', ada_noise) | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
@pytest.mark.component_mindarmour | |||
def test_pynative_gaussian(): | |||
def test_pynative_factory(): | |||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") | |||
grad = Tensor([0.3, 0.2, 0.4], mstype.float32) | |||
norm_bound = 1.0 | |||
initial_noise_multiplier = 0.1 | |||
net = NoiseGaussianRandom(norm_bound, initial_noise_multiplier) | |||
res = net(grad) | |||
print(res) | |||
alpha = 0.5 | |||
noise_update = 'Step' | |||
factory = NoiseMechanismsFactory() | |||
noise_mech = factory.create('Gaussian', | |||
norm_bound, | |||
initial_noise_multiplier) | |||
noise = noise_mech(grad) | |||
print('Gaussian noise: ', noise) | |||
ada_noise_mech = factory.create('AdaGaussian', | |||
norm_bound, | |||
initial_noise_multiplier, | |||
noise_decay_rate=alpha, | |||
noise_update=noise_update) | |||
ada_noise = ada_noise_mech(grad) | |||
print('ada noise: ', ada_noise) | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
@pytest.mark.component_mindarmour | |||
def test_graph_ada_gaussian(): | |||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
def test_pynative_gaussian(): | |||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") | |||
grad = Tensor([0.3, 0.2, 0.4], mstype.float32) | |||
norm_bound = 1.0 | |||
initial_noise_multiplier = 0.1 | |||
alpha = 0.5 | |||
decay_policy = 'Step' | |||
net = AdaGaussianRandom(norm_bound, initial_noise_multiplier, | |||
noise_decay_rate=alpha, decay_policy=decay_policy) | |||
res = net(grad) | |||
print(res) | |||
noise_update = 'Step' | |||
factory = NoiseMechanismsFactory() | |||
noise_mech = factory.create('Gaussian', | |||
norm_bound, | |||
initial_noise_multiplier) | |||
noise = noise_mech(grad) | |||
print('Gaussian noise: ', noise) | |||
ada_noise_mech = factory.create('AdaGaussian', | |||
norm_bound, | |||
initial_noise_multiplier, | |||
noise_decay_rate=alpha, | |||
noise_update=noise_update) | |||
ada_noise = ada_noise_mech(grad) | |||
print('ada noise: ', ada_noise) | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
@pytest.mark.component_mindarmour | |||
def test_graph_factory(): | |||
def test_graph_ada_gaussian(): | |||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
grad = Tensor([0.3, 0.2, 0.4], mstype.float32) | |||
norm_bound = 1.0 | |||
initial_noise_multiplier = 0.1 | |||
alpha = 0.5 | |||
decay_policy = 'Step' | |||
noise_mechanism = NoiseMechanismsFactory() | |||
noise_construct = noise_mechanism.create('Gaussian', | |||
norm_bound, | |||
initial_noise_multiplier) | |||
noise = noise_construct(grad) | |||
print('Gaussian noise: ', noise) | |||
ada_mechanism = NoiseMechanismsFactory() | |||
ada_noise_construct = ada_mechanism.create('AdaGaussian', | |||
norm_bound, | |||
initial_noise_multiplier, | |||
noise_decay_rate=alpha, | |||
decay_policy=decay_policy) | |||
ada_noise = ada_noise_construct(grad) | |||
print('ada noise: ', ada_noise) | |||
noise_decay_rate = 0.5 | |||
noise_update = 'Step' | |||
ada_noise_mech = NoiseAdaGaussianRandom(norm_bound, | |||
initial_noise_multiplier, | |||
seed=0, | |||
noise_decay_rate=noise_decay_rate, | |||
noise_update=noise_update) | |||
res = ada_noise_mech(grad) | |||
print(res) | |||
@pytest.mark.level0 | |||
@@ -107,11 +132,14 @@ def test_pynative_ada_gaussian(): | |||
grad = Tensor([0.3, 0.2, 0.4], mstype.float32) | |||
norm_bound = 1.0 | |||
initial_noise_multiplier = 0.1 | |||
alpha = 0.5 | |||
decay_policy = 'Step' | |||
net = AdaGaussianRandom(norm_bound, initial_noise_multiplier, | |||
noise_decay_rate=alpha, decay_policy=decay_policy) | |||
res = net(grad) | |||
noise_decay_rate = 0.5 | |||
noise_update = 'Step' | |||
ada_noise_mech = NoiseAdaGaussianRandom(norm_bound, | |||
initial_noise_multiplier, | |||
seed=0, | |||
noise_decay_rate=noise_decay_rate, | |||
noise_update=noise_update) | |||
res = ada_noise_mech(grad) | |||
print(res) | |||
@@ -119,26 +147,20 @@ def test_pynative_ada_gaussian(): | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
@pytest.mark.component_mindarmour | |||
def test_pynative_factory(): | |||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") | |||
def test_graph_exponential(): | |||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
grad = Tensor([0.3, 0.2, 0.4], mstype.float32) | |||
norm_bound = 1.0 | |||
initial_noise_multiplier = 0.1 | |||
alpha = 0.5 | |||
decay_policy = 'Step' | |||
noise_mechanism = NoiseMechanismsFactory() | |||
noise_construct = noise_mechanism.create('Gaussian', | |||
norm_bound, | |||
initial_noise_multiplier) | |||
noise = noise_construct(grad) | |||
print('Gaussian noise: ', noise) | |||
ada_mechanism = NoiseMechanismsFactory() | |||
ada_noise_construct = ada_mechanism.create('AdaGaussian', | |||
norm_bound, | |||
initial_noise_multiplier, | |||
noise_decay_rate=alpha, | |||
decay_policy=decay_policy) | |||
ada_noise = ada_noise_construct(grad) | |||
noise_update = 'Exp' | |||
factory = NoiseMechanismsFactory() | |||
ada_noise = factory.create('AdaGaussian', | |||
norm_bound, | |||
initial_noise_multiplier, | |||
noise_decay_rate=alpha, | |||
noise_update=noise_update) | |||
ada_noise = ada_noise(grad) | |||
print('ada noise: ', ada_noise) | |||
@@ -152,35 +174,14 @@ def test_pynative_exponential(): | |||
norm_bound = 1.0 | |||
initial_noise_multiplier = 0.1 | |||
alpha = 0.5 | |||
decay_policy = 'Exp' | |||
ada_mechanism = NoiseMechanismsFactory() | |||
ada_noise_construct = ada_mechanism.create('AdaGaussian', | |||
norm_bound, | |||
initial_noise_multiplier, | |||
noise_decay_rate=alpha, | |||
decay_policy=decay_policy) | |||
ada_noise = ada_noise_construct(grad) | |||
print('ada noise: ', ada_noise) | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
@pytest.mark.component_mindarmour | |||
def test_graph_exponential(): | |||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
grad = Tensor([0.3, 0.2, 0.4], mstype.float32) | |||
norm_bound = 1.0 | |||
initial_noise_multiplier = 0.1 | |||
alpha = 0.5 | |||
decay_policy = 'Exp' | |||
ada_mechanism = NoiseMechanismsFactory() | |||
ada_noise_construct = ada_mechanism.create('AdaGaussian', | |||
norm_bound, | |||
initial_noise_multiplier, | |||
noise_decay_rate=alpha, | |||
decay_policy=decay_policy) | |||
ada_noise = ada_noise_construct(grad) | |||
noise_update = 'Exp' | |||
factory = NoiseMechanismsFactory() | |||
ada_noise = factory.create('AdaGaussian', | |||
norm_bound, | |||
initial_noise_multiplier, | |||
noise_decay_rate=alpha, | |||
noise_update=noise_update) | |||
ada_noise = ada_noise(grad) | |||
print('ada noise: ', ada_noise) | |||
@@ -192,7 +193,7 @@ def test_ada_clip_gaussian_random_pynative(): | |||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") | |||
decay_policy = 'Linear' | |||
beta = Tensor(0.5, mstype.float32) | |||
norm_clip = Tensor(1.0, mstype.float32) | |||
norm_bound = Tensor(1.0, mstype.float32) | |||
beta_stddev = 0.1 | |||
learning_rate = 0.1 | |||
target_unclipped_quantile = 0.3 | |||
@@ -201,8 +202,8 @@ def test_ada_clip_gaussian_random_pynative(): | |||
target_unclipped_quantile=target_unclipped_quantile, | |||
fraction_stddev=beta_stddev, | |||
seed=1) | |||
next_norm_clip = ada_clip(beta, norm_clip) | |||
print('Liner next norm clip:', next_norm_clip) | |||
next_norm_bound = ada_clip(beta, norm_bound) | |||
print('Liner next norm clip:', next_norm_bound) | |||
decay_policy = 'Geometric' | |||
ada_clip = AdaClippingWithGaussianRandom(decay_policy=decay_policy, | |||
@@ -210,8 +211,8 @@ def test_ada_clip_gaussian_random_pynative(): | |||
target_unclipped_quantile=target_unclipped_quantile, | |||
fraction_stddev=beta_stddev, | |||
seed=1) | |||
next_norm_clip = ada_clip(beta, norm_clip) | |||
print('Geometric next norm clip:', next_norm_clip) | |||
next_norm_bound = ada_clip(beta, norm_bound) | |||
print('Geometric next norm clip:', next_norm_bound) | |||
@pytest.mark.level0 | |||
@@ -222,7 +223,7 @@ def test_ada_clip_gaussian_random_graph(): | |||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
decay_policy = 'Linear' | |||
beta = Tensor(0.5, mstype.float32) | |||
norm_clip = Tensor(1.0, mstype.float32) | |||
norm_bound = Tensor(1.0, mstype.float32) | |||
beta_stddev = 0.1 | |||
learning_rate = 0.1 | |||
target_unclipped_quantile = 0.3 | |||
@@ -231,8 +232,8 @@ def test_ada_clip_gaussian_random_graph(): | |||
target_unclipped_quantile=target_unclipped_quantile, | |||
fraction_stddev=beta_stddev, | |||
seed=1) | |||
next_norm_clip = ada_clip(beta, norm_clip) | |||
print('Liner next norm clip:', next_norm_clip) | |||
next_norm_bound = ada_clip(beta, norm_bound) | |||
print('Liner next norm clip:', next_norm_bound) | |||
decay_policy = 'Geometric' | |||
ada_clip = AdaClippingWithGaussianRandom(decay_policy=decay_policy, | |||
@@ -240,8 +241,8 @@ def test_ada_clip_gaussian_random_graph(): | |||
target_unclipped_quantile=target_unclipped_quantile, | |||
fraction_stddev=beta_stddev, | |||
seed=1) | |||
next_norm_clip = ada_clip(beta, norm_clip) | |||
print('Geometric next norm clip:', next_norm_clip) | |||
next_norm_bound = ada_clip(beta, norm_bound) | |||
print('Geometric next norm clip:', next_norm_bound) | |||
@pytest.mark.level0 | |||
@@ -252,18 +253,18 @@ def test_pynative_clip_mech_factory(): | |||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") | |||
decay_policy = 'Linear' | |||
beta = Tensor(0.5, mstype.float32) | |||
norm_clip = Tensor(1.0, mstype.float32) | |||
norm_bound = Tensor(1.0, mstype.float32) | |||
beta_stddev = 0.1 | |||
learning_rate = 0.1 | |||
target_unclipped_quantile = 0.3 | |||
clip_mechanism = ClipMechanismsFactory() | |||
ada_clip = clip_mechanism.create('Gaussian', | |||
decay_policy=decay_policy, | |||
learning_rate=learning_rate, | |||
target_unclipped_quantile=target_unclipped_quantile, | |||
fraction_stddev=beta_stddev) | |||
next_norm_clip = ada_clip(beta, norm_clip) | |||
print('next_norm_clip: ', next_norm_clip) | |||
factory = ClipMechanismsFactory() | |||
ada_clip = factory.create('Gaussian', | |||
decay_policy=decay_policy, | |||
learning_rate=learning_rate, | |||
target_unclipped_quantile=target_unclipped_quantile, | |||
fraction_stddev=beta_stddev) | |||
next_norm_bound = ada_clip(beta, norm_bound) | |||
print('next_norm_bound: ', next_norm_bound) | |||
@pytest.mark.level0 | |||
@@ -274,15 +275,15 @@ def test_graph_clip_mech_factory(): | |||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
decay_policy = 'Linear' | |||
beta = Tensor(0.5, mstype.float32) | |||
norm_clip = Tensor(1.0, mstype.float32) | |||
norm_bound = Tensor(1.0, mstype.float32) | |||
beta_stddev = 0.1 | |||
learning_rate = 0.1 | |||
target_unclipped_quantile = 0.3 | |||
clip_mechanism = ClipMechanismsFactory() | |||
ada_clip = clip_mechanism.create('Gaussian', | |||
decay_policy=decay_policy, | |||
learning_rate=learning_rate, | |||
target_unclipped_quantile=target_unclipped_quantile, | |||
fraction_stddev=beta_stddev) | |||
next_norm_clip = ada_clip(beta, norm_clip) | |||
print('next_norm_clip: ', next_norm_clip) | |||
factory = ClipMechanismsFactory() | |||
ada_clip = factory.create('Gaussian', | |||
decay_policy=decay_policy, | |||
learning_rate=learning_rate, | |||
target_unclipped_quantile=target_unclipped_quantile, | |||
fraction_stddev=beta_stddev) | |||
next_norm_bound = ada_clip(beta, norm_bound) | |||
print('next_norm_bound: ', next_norm_bound) |
@@ -46,7 +46,7 @@ def dataset_generator(batch_size, batches): | |||
@pytest.mark.component_mindarmour | |||
def test_dp_model_with_pynative_mode(): | |||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") | |||
norm_clip = 1.0 | |||
norm_bound = 1.0 | |||
initial_noise_multiplier = 0.01 | |||
network = LeNet5() | |||
batch_size = 32 | |||
@@ -56,7 +56,7 @@ def test_dp_model_with_pynative_mode(): | |||
loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) | |||
factory_opt = DPOptimizerClassFactory(micro_batches=micro_batches) | |||
factory_opt.set_mechanisms('Gaussian', | |||
norm_bound=norm_clip, | |||
norm_bound=norm_bound, | |||
initial_noise_multiplier=initial_noise_multiplier) | |||
net_opt = factory_opt.create('Momentum')(network.trainable_params(), | |||
learning_rate=0.1, momentum=0.9) | |||
@@ -66,7 +66,7 @@ def test_dp_model_with_pynative_mode(): | |||
target_unclipped_quantile=0.9, | |||
fraction_stddev=0.01) | |||
model = DPModel(micro_batches=micro_batches, | |||
norm_clip=norm_clip, | |||
norm_bound=norm_bound, | |||
clip_mech=clip_mech, | |||
noise_mech=None, | |||
network=network, | |||
@@ -86,7 +86,7 @@ def test_dp_model_with_pynative_mode(): | |||
@pytest.mark.component_mindarmour | |||
def test_dp_model_with_graph_mode(): | |||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
norm_clip = 1.0 | |||
norm_bound = 1.0 | |||
initial_noise_multiplier = 0.01 | |||
network = LeNet5() | |||
batch_size = 32 | |||
@@ -94,7 +94,7 @@ def test_dp_model_with_graph_mode(): | |||
epochs = 1 | |||
loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) | |||
noise_mech = NoiseMechanismsFactory().create('Gaussian', | |||
norm_bound=norm_clip, | |||
norm_bound=norm_bound, | |||
initial_noise_multiplier=initial_noise_multiplier) | |||
clip_mech = ClipMechanismsFactory().create('Gaussian', | |||
decay_policy='Linear', | |||
@@ -105,7 +105,7 @@ def test_dp_model_with_graph_mode(): | |||
momentum=0.9) | |||
model = DPModel(micro_batches=2, | |||
clip_mech=clip_mech, | |||
norm_clip=norm_clip, | |||
norm_bound=norm_bound, | |||
noise_mech=noise_mech, | |||
network=network, | |||
loss_fn=loss, | |||
@@ -124,22 +124,25 @@ def test_dp_model_with_graph_mode(): | |||
@pytest.mark.component_mindarmour | |||
def test_dp_model_with_graph_mode_ada_gaussian(): | |||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
norm_clip = 1.0 | |||
norm_bound = 1.0 | |||
initial_noise_multiplier = 0.01 | |||
network = LeNet5() | |||
batch_size = 32 | |||
batches = 128 | |||
epochs = 1 | |||
alpha = 0.8 | |||
loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) | |||
noise_mech = NoiseMechanismsFactory().create('AdaGaussian', | |||
norm_bound=norm_clip, | |||
initial_noise_multiplier=initial_noise_multiplier) | |||
norm_bound=norm_bound, | |||
initial_noise_multiplier=initial_noise_multiplier, | |||
noise_decay_rate=alpha, | |||
noise_update='Exp') | |||
clip_mech = None | |||
net_opt = nn.Momentum(network.trainable_params(), learning_rate=0.1, | |||
momentum=0.9) | |||
model = DPModel(micro_batches=2, | |||
clip_mech=clip_mech, | |||
norm_clip=norm_clip, | |||
norm_bound=norm_bound, | |||
noise_mech=noise_mech, | |||
network=network, | |||
loss_fn=loss, | |||
@@ -34,10 +34,10 @@ def test_optimizer(): | |||
momentum = 0.9 | |||
micro_batches = 2 | |||
loss = nn.SoftmaxCrossEntropyWithLogits() | |||
gaussian_mech = DPOptimizerClassFactory(micro_batches) | |||
gaussian_mech.set_mechanisms('Gaussian', norm_bound=1.5, initial_noise_multiplier=5.0) | |||
net_opt = gaussian_mech.create('SGD')(params=network.trainable_params(), learning_rate=lr, | |||
momentum=momentum) | |||
factory = DPOptimizerClassFactory(micro_batches) | |||
factory.set_mechanisms('Gaussian', norm_bound=1.5, initial_noise_multiplier=5.0) | |||
net_opt = factory.create('SGD')(params=network.trainable_params(), learning_rate=lr, | |||
momentum=momentum) | |||
_ = Model(network, loss_fn=loss, optimizer=net_opt, metrics=None) | |||
@@ -52,10 +52,10 @@ def test_optimizer_gpu(): | |||
momentum = 0.9 | |||
micro_batches = 2 | |||
loss = nn.SoftmaxCrossEntropyWithLogits() | |||
gaussian_mech = DPOptimizerClassFactory(micro_batches) | |||
gaussian_mech.set_mechanisms('Gaussian', norm_bound=1.5, initial_noise_multiplier=5.0) | |||
net_opt = gaussian_mech.create('SGD')(params=network.trainable_params(), learning_rate=lr, | |||
momentum=momentum) | |||
factory = DPOptimizerClassFactory(micro_batches) | |||
factory.set_mechanisms('Gaussian', norm_bound=1.5, initial_noise_multiplier=5.0) | |||
net_opt = factory.create('SGD')(params=network.trainable_params(), learning_rate=lr, | |||
momentum=momentum) | |||
_ = Model(network, loss_fn=loss, optimizer=net_opt, metrics=None) | |||
@@ -70,8 +70,8 @@ def test_optimizer_cpu(): | |||
momentum = 0.9 | |||
micro_batches = 2 | |||
loss = nn.SoftmaxCrossEntropyWithLogits() | |||
gaussian_mech = DPOptimizerClassFactory(micro_batches) | |||
gaussian_mech.set_mechanisms('Gaussian', norm_bound=1.5, initial_noise_multiplier=5.0) | |||
net_opt = gaussian_mech.create('SGD')(params=network.trainable_params(), learning_rate=lr, | |||
momentum=momentum) | |||
factory = DPOptimizerClassFactory(micro_batches) | |||
factory.set_mechanisms('Gaussian', norm_bound=1.5, initial_noise_multiplier=5.0) | |||
net_opt = factory.create('SGD')(params=network.trainable_params(), learning_rate=lr, | |||
momentum=momentum) | |||
_ = Model(network, loss_fn=loss, optimizer=net_opt, metrics=None) |