|
|
@@ -48,8 +48,11 @@ from mindspore.nn.wrap.loss_scale import _grad_overflow |
|
|
|
from mindspore.nn import Cell |
|
|
|
from mindspore import ParameterTuple |
|
|
|
|
|
|
|
from mindarmour.diff_privacy.mechanisms import mechanisms |
|
|
|
from mindarmour.utils._check_param import check_param_type |
|
|
|
from mindarmour.utils._check_param import check_value_positive |
|
|
|
from mindarmour.utils._check_param import check_int_positive |
|
|
|
|
|
|
|
|
|
|
|
GRADIENT_CLIP_TYPE = 1 |
|
|
|
grad_scale = C.MultitypeFuncGraph("grad_scale") |
|
|
@@ -67,7 +70,7 @@ class DPModel(Model): |
|
|
|
This class is overload mindspore.train.model.Model. |
|
|
|
|
|
|
|
Args: |
|
|
|
micro_batches (int): The number of small batches split from an origianl batch. Default: None. |
|
|
|
micro_batches (int): The number of small batches split from an origianl batch. Default: 2. |
|
|
|
norm_clip (float): Use to clip the bound, if set 1, will retun the original data. Default: 1.0. |
|
|
|
dp_mech (Mechanisms): The object can generate the different type of noise. Default: None. |
|
|
|
|
|
|
@@ -106,14 +109,17 @@ class DPModel(Model): |
|
|
|
>>> dataset = get_dataset() |
|
|
|
>>> model.train(2, dataset) |
|
|
|
""" |
|
|
|
def __init__(self, micro_batches=None, norm_clip=1.0, dp_mech=None, **kwargs): |
|
|
|
def __init__(self, micro_batches=2, norm_clip=1.0, dp_mech=None, **kwargs): |
|
|
|
if micro_batches: |
|
|
|
self._micro_batches = int(micro_batches) |
|
|
|
self._micro_batches = check_int_positive('micro_batches', micro_batches) |
|
|
|
else: |
|
|
|
self._micro_batches = None |
|
|
|
float_norm_clip = check_param_type('l2_norm_clip', norm_clip, float) |
|
|
|
self._norm_clip = check_value_positive('l2_norm_clip', float_norm_clip) |
|
|
|
self._dp_mech = dp_mech |
|
|
|
if isinstance(dp_mech, mechanisms.Mechanisms): |
|
|
|
self._dp_mech = dp_mech |
|
|
|
else: |
|
|
|
raise TypeError('dp mechanisms should be instance of class Mechansms, but got {}'.format(type(dp_mech))) |
|
|
|
super(DPModel, self).__init__(**kwargs) |
|
|
|
|
|
|
|
def _amp_build_train_network(self, network, optimizer, loss_fn=None, level='O0', **kwargs): |
|
|
|