|
- # Copyright 2020 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """
- Differential privacy model.
- """
- from easydict import EasyDict as edict
-
- import mindspore as ms
- from mindspore.train.model import Model
- from mindspore._checkparam import Validator as validator
- from mindspore._checkparam import Rel
- from mindspore.train import amp
- from mindspore.train.amp import _config_level
- from mindspore.common import dtype as mstype
- from mindspore.nn.wrap.cell_wrapper import _VirtualDatasetCell
- from mindspore.parallel._utils import _get_parallel_mode
- from mindspore.train.model import ParallelMode
- from mindspore.train.amp import _do_keep_batchnorm_fp32
- from mindspore.train.amp import _add_loss_network
- from mindspore import context
- from mindspore import nn
- from mindspore import Tensor
- from mindspore.ops import composite as C
- from mindspore.ops import operations as P
- from mindspore.ops import functional as F
- from mindspore.ops.operations import NPUGetFloatStatus
- from mindspore.ops.operations import NPUAllocFloatStatus
- from mindspore.ops.operations import NPUClearFloatStatus
- from mindspore.ops.operations import ReduceSum
- from mindspore.ops.operations import LessEqual
- from mindspore.ops.operations import ControlDepend
- from mindspore.parallel._utils import _get_mirror_mean
- from mindspore.parallel._utils import _get_device_num
- from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
- from mindspore.common.parameter import Parameter
- from mindspore.nn.wrap.loss_scale import _grad_overflow
- from mindspore.nn import Cell
- from mindspore import ParameterTuple
-
- from mindarmour.diff_privacy.mechanisms import mechanisms
- from mindarmour.utils._check_param import check_param_type
- from mindarmour.utils._check_param import check_value_positive
- from mindarmour.utils._check_param import check_int_positive
-
-
- GRADIENT_CLIP_TYPE = 1
- grad_scale = C.MultitypeFuncGraph("grad_scale")
- reciprocal = P.Reciprocal()
-
-
- @grad_scale.register("Tensor", "Tensor")
- def tensor_grad_scale(scale, grad):
- """ grad scaling """
- return grad*reciprocal(scale)
-
-
- class DPModel(Model):
- """
- This class is overload mindspore.train.model.Model.
-
- Args:
- micro_batches (int): The number of small batches split from an origianl batch. Default: 2.
- norm_clip (float): Use to clip the bound, if set 1, will retun the original data. Default: 1.0.
- dp_mech (Mechanisms): The object can generate the different type of noise. Default: None.
-
- Examples:
- >>> class Net(nn.Cell):
- >>> def __init__(self):
- >>> super(Net, self).__init__()
- >>> self.conv = nn.Conv2d(3, 64, 3, has_bias=False, weight_init='normal')
- >>> self.bn = nn.BatchNorm2d(64)
- >>> self.relu = nn.ReLU()
- >>> self.flatten = nn.Flatten()
- >>> self.fc = nn.Dense(64*224*224, 12) # padding=0
- >>>
- >>> def construct(self, x):
- >>> x = self.conv(x)
- >>> x = self.bn(x)
- >>> x = self.relu(x)
- >>> x = self.flatten(x)
- >>> out = self.fc(x)
- >>> return out
- >>>
- >>> net = Net()
- >>> loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
- >>> optim = Momentum(params=net.trainable_params(), learning_rate=0.01, momentum=0.9)
- >>> gaussian_mech = DPOptimizerClassFactory()
- >>> gaussian_mech.set_mechanisms('Gaussian',
- >>> norm_bound=args.l2_norm_bound,
- >>> initial_noise_multiplier=args.initial_noise_multiplier)
- >>> model = DPModel(micro_batches=2,
- >>> norm_clip=1.0,
- >>> dp_mech=gaussian_mech.mech,
- >>> network=net,
- >>> loss_fn=loss,
- >>> optimizer=optim,
- >>> metrics=None)
- >>> dataset = get_dataset()
- >>> model.train(2, dataset)
- """
- def __init__(self, micro_batches=2, norm_clip=1.0, dp_mech=None, **kwargs):
- if micro_batches:
- self._micro_batches = check_int_positive('micro_batches', micro_batches)
- else:
- self._micro_batches = None
- float_norm_clip = check_param_type('l2_norm_clip', norm_clip, float)
- self._norm_clip = check_value_positive('l2_norm_clip', float_norm_clip)
- if isinstance(dp_mech, mechanisms.Mechanisms):
- self._dp_mech = dp_mech
- else:
- raise TypeError('dp mechanisms should be instance of class Mechansms, but got {}'.format(type(dp_mech)))
- super(DPModel, self).__init__(**kwargs)
-
- def _amp_build_train_network(self, network, optimizer, loss_fn=None, level='O0', **kwargs):
- """
- Build the mixed precision training cell automatically.
-
- Args:
- network (Cell): Definition of the network.
- loss_fn (Union[None, Cell]): Definition of the loss_fn. If None, the `network` should have the loss inside.
- Default: None.
- optimizer (Optimizer): Optimizer to update the Parameter.
- level (str): Supports [O0, O2]. Default: "O0".
-
- - O0: Do not change.
- - O2: Cast network to float16, keep batchnorm and `loss_fn` (if set) run in float32,
- using dynamic loss scale.
-
- cast_model_type (:class:`mindspore.dtype`): Supports `mstype.float16` or `mstype.float32`.
- If set to `mstype.float16`, use `float16` mode to train. If set, overwrite the level setting.
- keep_batchnorm_fp32 (bool): Keep Batchnorm run in `float32`. If set, overwrite the level setting.
- loss_scale_manager (Union[None, LossScaleManager]): If None, not scale the loss, or else
- scale the loss by LossScaleManager. If set, overwrite the level setting.
- """
- validator.check_value_type('network', network, nn.Cell, None)
- validator.check_value_type('optimizer', optimizer, nn.Optimizer, None)
- validator.check('level', level, "", ['O0', 'O2'], Rel.IN, None)
- self._check_kwargs(kwargs)
- config = dict(_config_level[level], **kwargs)
- config = edict(config)
-
- if config.cast_model_type == mstype.float16:
- network.to_float(mstype.float16)
-
- if config.keep_batchnorm_fp32:
- _do_keep_batchnorm_fp32(network)
-
- if loss_fn:
- network = _add_loss_network(network, loss_fn, config.cast_model_type)
-
- if _get_parallel_mode() in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
- network = _VirtualDatasetCell(network)
-
- loss_scale = 1.0
- if config.loss_scale_manager is not None:
- loss_scale_manager = config.loss_scale_manager
- loss_scale = loss_scale_manager.get_loss_scale()
- update_cell = loss_scale_manager.get_update_cell()
- if update_cell is not None:
- # only cpu not support `TrainOneStepWithLossScaleCell` for control flow.
- if not context.get_context("enable_ge") and context.get_context("device_target") == "CPU":
- raise ValueError("Only `loss_scale_manager=None` and "
- "`loss_scale_manager=FixedLossScaleManager(drop_overflow_update=False)`"
- "are supported in current version. If you use `O2` option, please"
- "use `loss_scale_manager=None` or `FixedLossScaleManager`")
- network = _TrainOneStepWithLossScaleCell(network,
- optimizer,
- scale_update_cell=update_cell,
- micro_batches=self._micro_batches,
- l2_norm_clip=self._norm_clip,
- mech=self._dp_mech).set_train()
- return network
- network = _TrainOneStepCell(network,
- optimizer,
- loss_scale,
- micro_batches=self._micro_batches,
- l2_norm_clip=self._norm_clip,
- mech=self._dp_mech).set_train()
- return network
-
- def _build_train_network(self):
- """Build train network"""
- network = self._network
- if self._micro_batches:
- if self._optimizer:
- if self._loss_scale_manager_set:
- network = self._amp_build_train_network(network,
- self._optimizer,
- self._loss_fn,
- level=self._amp_level,
- loss_scale_manager=self._loss_scale_manager,
- keep_batchnorm_fp32=self._keep_bn_fp32)
- else:
- network = self._amp_build_train_network(network,
- self._optimizer,
- self._loss_fn,
- level=self._amp_level,
- keep_batchnorm_fp32=self._keep_bn_fp32)
- elif self._loss_fn:
- network = nn.WithLossCell(network, self._loss_fn)
- else:
- if self._optimizer:
- if self._loss_scale_manager_set:
- network = amp.build_train_network(network,
- self._optimizer,
- self._loss_fn,
- level=self._amp_level,
- loss_scale_manager=self._loss_scale_manager,
- keep_batchnorm_fp32=self._keep_bn_fp32)
- else:
- network = amp.build_train_network(network,
- self._optimizer,
- self._loss_fn,
- level=self._amp_level,
- keep_batchnorm_fp32=self._keep_bn_fp32)
- elif self._loss_fn:
- network = nn.WithLossCell(network, self._loss_fn)
-
- if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
- network.set_auto_parallel()
- return network
-
-
- class _ClipGradients(nn.Cell):
- """
- Clip gradients.
-
- Inputs:
- grads (tuple[Tensor]): Gradients.
- clip_type (int): The way to clip, 0 for 'value', 1 for 'norm'.
- clip_value (float): Specifies how much to clip.
-
- Outputs:
- tuple[Tensor], clipped gradients.
- """
- def __init__(self):
- super(_ClipGradients, self).__init__()
- self.clip_by_norm = nn.ClipByNorm()
- self.dtype = P.DType()
-
- def construct(self, grads, clip_type, clip_value):
- """
- construct a compute flow.
- """
- if clip_type not in (0, 1):
- return grads
-
- new_grads = ()
- for grad in grads:
- if clip_type == 0:
- t = C.clip_by_value(grad, F.tuple_to_array((-clip_value,)),
- F.tuple_to_array((clip_value,)))
- else:
- t = self.clip_by_norm(grad, F.tuple_to_array((clip_value,)))
- new_grads = new_grads + (t,)
-
- return new_grads
-
-
- class _TrainOneStepWithLossScaleCell(Cell):
- r"""
- Network training with loss scaling.
-
- This is a training step with loss scaling. It takes a network, an optimizer and possibly a scale update
- Cell as args. The loss scale value can be updated in both host side or device side. The
- TrainOneStepWithLossScaleCell will be compiled to be graph which takes `data`, `label`, `sens` as input
- data. The `sens` is acting as loss scaling value. If you want to update it on host side, the value should
- be provided. If `sens` is not given, the loss scale update logic should be provied by `scale_update_cell`.
- If `scale_update_cell` is not None and `sens` is provided, the `scale_update_cell` will be ignored.
-
- Args:
- network (Cell): The training network.
- optimizer (Cell): Optimizer for updating the weights.
- scale_update_cell(Cell): The loss scaling update logic cell. Default: None.
- micro_batches (int): The number of small batches split from an original batch. Default: None.
- l2_norm_clip (float): Use to clip the bound, if set 1, will return the original data. Default: 1.0.
- mech (Mechanisms): The object can generate the different type of noise. Default: None.
-
- Inputs:
- - **inputs** (Tensor) - Tensor of shape :math:`(N, \ldots)`.
- - **label** (Tensor) - Tensor of shape :math:`(N, \ldots)`.
- - **scaling_sens** (Tensor) - Tensor of shape :math:`()`.
-
- Outputs:
- Tuple of 3 Tensor, the loss, overflow flag and current loss scaling value.
-
- - **loss** (Tensor) - Tensor with shape :math:`()`.
- - **overflow** (Tensor) - Tensor with shape :math:`()`, type is bool.
- - **loss_scale** (Tensor) - Tensor with shape :math:`()`.
- """
-
- def __init__(self, network, optimizer, scale_update_cell=None, micro_batches=None, l2_norm_clip=1.0, mech=None):
- super(_TrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False)
- self.network = network
- self.network.add_flags(defer_inline=True)
- self.weights = ParameterTuple(network.trainable_params())
- self.optimizer = optimizer
- self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True)
- self.hyper_map = C.HyperMap()
- if context.get_context("device_target") == "GPU":
- self.gpu_target = True
- self.float_status = P.FloatStatus()
- self.addn = P.AddN()
- self.reshape = P.Reshape()
- else:
- self.gpu_target = False
- self.alloc_status = NPUAllocFloatStatus()
- self.get_status = NPUGetFloatStatus()
- self.clear_status = NPUClearFloatStatus()
- self.reduce_sum = ReduceSum(keep_dims=False)
- self.base = Tensor(1, mstype.float32)
- self.less_equal = LessEqual()
- self.depend_parameter_use = ControlDepend(depend_mode=1)
- self.allreduce = P.AllReduce()
- self.parallel_mode = _get_parallel_mode()
- self.grad_reducer = F.identity
- self.reducer_flag = self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]
- if self.reducer_flag:
- mean = _get_mirror_mean()
- degree = _get_device_num()
- self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
- self.is_distributed = self.parallel_mode != ParallelMode.STAND_ALONE
-
- self.loss_scale = None
- self.loss_scaling_manager = scale_update_cell
- if scale_update_cell:
- self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32),
- name="loss_scale")
- self.add_flags(has_effect=True)
-
- # dp params
- self._micro_batches = micro_batches
- float_norm_clip = check_param_type('l2_norm_clip', l2_norm_clip, float)
- self._l2_norm = check_value_positive('l2_norm_clip', float_norm_clip)
- self._split = P.Split(0, self._micro_batches)
- self._clip_by_global_norm = _ClipGradients()
- self._mech = mech
-
- def construct(self, data, label, sens=None):
- """
- construct a compute flow.
- """
- init = False
- if not self.gpu_target:
- # init overflow buffer
- init = self.alloc_status()
- # clear overflow buffer
- self.clear_status(init)
-
- if sens is None:
- scaling_sens = self.loss_scale
- else:
- scaling_sens = sens
-
- # DP clip
- weights = self.weights
- record_datas = self._split(data)
- record_labels = self._split(label)
- grads = ()
- # first index
- loss = self.network(record_datas[0], record_labels[0])
- scaling_sens_filled = C.ones_like(loss)*F.cast(scaling_sens, F.dtype(loss))
- record_grad = self.grad(self.network, weights)(record_datas[0], record_labels[0], scaling_sens_filled)
- record_grad = self._clip_by_global_norm(record_grad, GRADIENT_CLIP_TYPE, self._l2_norm)
-
- grad_sum = list(record_grad)
- grad_len = len(record_grad)
- for i in range(grad_len):
- grad_sum[i] = grad_sum[i].asnumpy()
-
- for i in range(1, self._micro_batches):
- loss = self.network(record_datas[i], record_labels[i])
- scaling_sens_filled = C.ones_like(loss)*F.cast(scaling_sens, F.dtype(loss))
- record_grad = self.grad(self.network, weights)(record_datas[i], record_labels[i], scaling_sens_filled)
- record_grad = self._clip_by_global_norm(record_grad, GRADIENT_CLIP_TYPE, self._l2_norm)
- for j in range(grad_len):
- grad_sum[j] = grad_sum[j] + record_grad[j].asnumpy()
-
- for i in range(grad_len):
- grad_sum[i] = Tensor(grad_sum[i], ms.float32)
- grads = tuple(grad_sum)
- loss = self.network(data, label)
-
- grads = self.hyper_map(F.partial(grad_scale, scaling_sens), grads)
- # apply grad reducer on grads
- grads = self.grad_reducer(grads)
- # get the overflow buffer
- if not self.gpu_target:
- self.get_status(init)
- # sum overflow buffer elements, 0:not overflow , >0:overflow
- flag_sum = self.reduce_sum(init, (0,))
- else:
- flag_sum = self.hyper_map(F.partial(_grad_overflow), grads)
- flag_sum = self.addn(flag_sum)
- # convert flag_sum to scalar
- flag_sum = self.reshape(flag_sum, (()))
- if self.is_distributed:
- # sum overflow flag over devices
- flag_reduce = self.allreduce(flag_sum)
- cond = self.less_equal(self.base, flag_reduce)
- else:
- cond = self.less_equal(self.base, flag_sum)
- overflow = cond
- if sens is None:
- overflow = self.loss_scaling_manager(self.loss_scale, cond)
- # if there is no overflow, do optimize
- if overflow:
- opt = False
- else:
- opt = self.optimizer(grads)
- ret = (loss, cond, scaling_sens)
- return F.depend(ret, opt)
-
-
- class _TrainOneStepCell(Cell):
- r"""
- Network training package class.
-
- Wraps the network with an optimizer. The resulting Cell be trained with input data and label.
- Backward graph will be created in the construct function to do parameter updating. Different
- parallel modes are available to run the training.
-
- Args:
- network (Cell): The training network.
- optimizer (Cell): Optimizer for updating the weights.
- sens (Number): The scaling number to be filled as the input of back propagation. Default value is 1.0.
- micro_batches (int): The number of small batches split from an original batch. Default: None.
- l2_norm_clip (float): Use to clip the bound, if set 1, will return the original data. Default: 1.0.
- mech (Mechanisms): The object can generate the different type of noise. Default: None.
-
- Inputs:
- - **data** (Tensor) - Tensor of shape :math:`(N, \ldots)`.
- - **label** (Tensor) - Tensor of shape :math:`(N, \ldots)`.
-
- Outputs:
- Tensor, a scalar Tensor with shape :math:`()`.
- """
-
- def __init__(self, network, optimizer, sens=1.0, micro_batches=None, l2_norm_clip=1.0, mech=None):
- super(_TrainOneStepCell, self).__init__(auto_prefix=False)
- self.network = network
- self.network.add_flags(defer_inline=True)
- self.weights = optimizer.parameters
- self.optimizer = optimizer
- self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True)
- self.sens = sens
- self.reducer_flag = False
- self.grad_reducer = None
- parallel_mode = _get_parallel_mode()
- if parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL):
- self.reducer_flag = True
- if self.reducer_flag:
- mean = _get_mirror_mean()
- degree = _get_device_num()
- self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
-
- # dp params
- self._micro_batches = micro_batches
- float_norm_clip = check_param_type('l2_norm_clip', l2_norm_clip, float)
- self._l2_norm = check_value_positive('l2_norm_clip', float_norm_clip)
- self._split = P.Split(0, self._micro_batches)
- self._clip_by_global_norm = _ClipGradients()
- self._mech = mech
-
- def construct(self, data, label):
- """
- construct a compute flow.
- """
- weights = self.weights
- record_datas = self._split(data)
- record_labels = self._split(label)
- loss = self.network(record_datas[0], record_labels[0])
- sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
- record_grad = self.grad(self.network, weights)(record_datas[0], record_labels[0], sens)
- record_grad = self._clip_by_global_norm(record_grad, GRADIENT_CLIP_TYPE, self._l2_norm)
- grad_sum = list(record_grad)
- grad_len = len(record_grad)
- for i in range(grad_len):
- grad_sum[i] = grad_sum[i].asnumpy()
-
- for i in range(1, self._micro_batches):
- loss = self.network(record_datas[i], record_labels[i])
- sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
- record_grad = self.grad(self.network, weights)(record_datas[i], record_labels[i], sens)
- record_grad = self._clip_by_global_norm(record_grad, GRADIENT_CLIP_TYPE, self._l2_norm)
- for j in range(grad_len):
- grad_sum[j] = grad_sum[j] + record_grad[j].asnumpy()
-
- for i in range(grad_len):
- grad_sum[i] = Tensor(grad_sum[i], ms.float32)
- grads = tuple(grad_sum)
- loss = self.network(data, label)
-
- if self.reducer_flag:
- # apply grad reducer on grads
- grads = self.grad_reducer(grads)
- return F.depend(loss, self.optimizer(grads))
|