From a98b84717c01d9d186d03434dea0036847223f0c Mon Sep 17 00:00:00 2001 From: zhenghuanhuan Date: Fri, 29 May 2020 17:59:32 +0800 Subject: [PATCH] [MA][diff_privacy][Func] micro_batches and dp_mech not checked https://gitee.com/mindspore/dashboard/issues?id=I1IS9G --- mindarmour/diff_privacy/optimizer/optimizer.py | 16 +++------------- mindarmour/diff_privacy/train/model.py | 14 ++++++++++---- 2 files changed, 13 insertions(+), 17 deletions(-) diff --git a/mindarmour/diff_privacy/optimizer/optimizer.py b/mindarmour/diff_privacy/optimizer/optimizer.py index a844e79..0162a20 100644 --- a/mindarmour/diff_privacy/optimizer/optimizer.py +++ b/mindarmour/diff_privacy/optimizer/optimizer.py @@ -27,7 +27,7 @@ class DPOptimizerClassFactory: Factory class of Optimizer. Args: - micro_batches (int): The number of small batches split from an origianl batch. Default: None. + micro_batches (int): The number of small batches split from an origianl batch. Default: 2. Returns: Optimizer, Optimizer class @@ -39,7 +39,7 @@ class DPOptimizerClassFactory: >>> learning_rate=cfg.lr, >>> momentum=cfg.momentum) """ - def __init__(self, micro_batches=None): + def __init__(self, micro_batches=2): self._mech_factory = MechanismsFactory() self.mech = None self._micro_batches = check_int_positive('micro_batches', micro_batches) @@ -72,17 +72,7 @@ class DPOptimizerClassFactory: if policy == 'Adam': cls = self._get_dp_optimizer_class(nn.Adam, self.mech, self._micro_batches, *args, **kwargs) return cls - if policy == 'AdamWeightDecay': - cls = self._get_dp_optimizer_class(nn.AdamWeightDecay, self.mech, self._micro_batches, *args, **kwargs) - return cls - if policy == 'AdamWeightDecayDynamicLR': - cls = self._get_dp_optimizer_class(nn.AdamWeightDecayDynamicLR, - self.mech, - self._micro_batches, - *args, **kwargs) - return cls - raise NameError("The {} is not implement, please choose ['SGD', 'Momentum', 'AdamWeightDecay', " - "'Adam', 'AdamWeightDecayDynamicLR']".format(policy)) + raise NameError("The {} is not implement, please choose ['SGD', 'Momentum', 'Adam']".format(policy)) def _get_dp_optimizer_class(self, cls, mech, micro_batches): """ diff --git a/mindarmour/diff_privacy/train/model.py b/mindarmour/diff_privacy/train/model.py index 434bbe0..b423327 100644 --- a/mindarmour/diff_privacy/train/model.py +++ b/mindarmour/diff_privacy/train/model.py @@ -48,8 +48,11 @@ from mindspore.nn.wrap.loss_scale import _grad_overflow from mindspore.nn import Cell from mindspore import ParameterTuple +from mindarmour.diff_privacy.mechanisms import mechanisms from mindarmour.utils._check_param import check_param_type from mindarmour.utils._check_param import check_value_positive +from mindarmour.utils._check_param import check_int_positive + GRADIENT_CLIP_TYPE = 1 grad_scale = C.MultitypeFuncGraph("grad_scale") @@ -67,7 +70,7 @@ class DPModel(Model): This class is overload mindspore.train.model.Model. Args: - micro_batches (int): The number of small batches split from an origianl batch. Default: None. + micro_batches (int): The number of small batches split from an origianl batch. Default: 2. norm_clip (float): Use to clip the bound, if set 1, will retun the original data. Default: 1.0. dp_mech (Mechanisms): The object can generate the different type of noise. Default: None. @@ -106,14 +109,17 @@ class DPModel(Model): >>> dataset = get_dataset() >>> model.train(2, dataset) """ - def __init__(self, micro_batches=None, norm_clip=1.0, dp_mech=None, **kwargs): + def __init__(self, micro_batches=2, norm_clip=1.0, dp_mech=None, **kwargs): if micro_batches: - self._micro_batches = int(micro_batches) + self._micro_batches = check_int_positive('micro_batches', micro_batches) else: self._micro_batches = None float_norm_clip = check_param_type('l2_norm_clip', norm_clip, float) self._norm_clip = check_value_positive('l2_norm_clip', float_norm_clip) - self._dp_mech = dp_mech + if isinstance(dp_mech, mechanisms.Mechanisms): + self._dp_mech = dp_mech + else: + raise TypeError('dp mechanisms should be instance of class Mechansms, but got {}'.format(type(dp_mech))) super(DPModel, self).__init__(**kwargs) def _amp_build_train_network(self, network, optimizer, loss_fn=None, level='O0', **kwargs):